all_reduce.hpp 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. // Copyright (C) 2005-2006 Douglas Gregor <doug.gregor -at- gmail.com>
  2. // Copyright (C) 2004 The Trustees of Indiana University
  3. // Use, modification and distribution is subject to the Boost Software
  4. // License, Version 1.0. (See accompanying file LICENSE_1_0.txt or copy at
  5. // http://www.boost.org/LICENSE_1_0.txt)
  6. // Authors: Douglas Gregor
  7. // Andrew Lumsdaine
  8. // Message Passing Interface 1.1 -- Section 4.9.1. Reduce
  9. #ifndef BOOST_MPI_ALL_REDUCE_HPP
  10. #define BOOST_MPI_ALL_REDUCE_HPP
  11. #include <vector>
  12. #include <boost/mpi/inplace.hpp>
  13. // All-reduce falls back to reduce() + broadcast() in some cases.
  14. #include <boost/mpi/collectives/broadcast.hpp>
  15. #include <boost/mpi/collectives/reduce.hpp>
  16. namespace boost { namespace mpi {
  17. namespace detail {
  18. /**********************************************************************
  19. * Simple reduction with MPI_Allreduce *
  20. **********************************************************************/
  21. // We are reducing for a type that has an associated MPI
  22. // datatype and operation, so we'll use MPI_Allreduce directly.
  23. template<typename T, typename Op>
  24. void
  25. all_reduce_impl(const communicator& comm, const T* in_values, int n,
  26. T* out_values, Op /*op*/, mpl::true_ /*is_mpi_op*/,
  27. mpl::true_ /*is_mpi_datatype*/)
  28. {
  29. BOOST_MPI_CHECK_RESULT(MPI_Allreduce,
  30. (const_cast<T*>(in_values), out_values, n,
  31. boost::mpi::get_mpi_datatype<T>(*in_values),
  32. (is_mpi_op<Op, T>::op()), comm));
  33. }
  34. /**********************************************************************
  35. * User-defined reduction with MPI_Allreduce *
  36. **********************************************************************/
  37. // We are reducing at the root for a type that has an associated MPI
  38. // datatype but with a custom operation. We'll use MPI_Reduce
  39. // directly, but we'll need to create an MPI_Op manually.
  40. template<typename T, typename Op>
  41. void
  42. all_reduce_impl(const communicator& comm, const T* in_values, int n,
  43. T* out_values, Op /* op */, mpl::false_ /*is_mpi_op*/,
  44. mpl::true_ /*is_mpi_datatype*/)
  45. {
  46. user_op<Op, T> mpi_op;
  47. BOOST_MPI_CHECK_RESULT(MPI_Allreduce,
  48. (const_cast<T*>(in_values), out_values, n,
  49. boost::mpi::get_mpi_datatype<T>(*in_values),
  50. mpi_op.get_mpi_op(), comm));
  51. }
  52. /**********************************************************************
  53. * User-defined, tree-based reduction for non-MPI data types *
  54. **********************************************************************/
  55. // We are reducing at the root for a type that has no associated MPI
  56. // datatype and operation, so we'll use a simple tree-based
  57. // algorithm.
  58. template<typename T, typename Op>
  59. void
  60. all_reduce_impl(const communicator& comm, const T* in_values, int n,
  61. T* out_values, Op op, mpl::false_ /*is_mpi_op*/,
  62. mpl::false_ /*is_mpi_datatype*/)
  63. {
  64. if (in_values == MPI_IN_PLACE) {
  65. // if in_values matches the in place tag, then the output
  66. // buffer actually contains the input data.
  67. // But we can just go back to the out of place
  68. // implementation in this case.
  69. // it's not clear how/if we can avoid the copy.
  70. std::vector<T> tmp_in( out_values, out_values + n);
  71. reduce(comm, &(tmp_in[0]), n, out_values, op, 0);
  72. } else {
  73. reduce(comm, in_values, n, out_values, op, 0);
  74. }
  75. broadcast(comm, out_values, n, 0);
  76. }
  77. } // end namespace detail
  78. template<typename T, typename Op>
  79. inline void
  80. all_reduce(const communicator& comm, const T* in_values, int n, T* out_values,
  81. Op op)
  82. {
  83. detail::all_reduce_impl(comm, in_values, n, out_values, op,
  84. is_mpi_op<Op, T>(), is_mpi_datatype<T>());
  85. }
  86. template<typename T, typename Op>
  87. inline void
  88. all_reduce(const communicator& comm, inplace_t<T*> inout_values, int n, Op op)
  89. {
  90. all_reduce(comm, static_cast<const T*>(MPI_IN_PLACE), n, inout_values.buffer, op);
  91. }
  92. template<typename T, typename Op>
  93. inline void
  94. all_reduce(const communicator& comm, inplace_t<T> inout_values, Op op)
  95. {
  96. all_reduce(comm, static_cast<const T*>(MPI_IN_PLACE), 1, &(inout_values.buffer), op);
  97. }
  98. template<typename T, typename Op>
  99. inline void
  100. all_reduce(const communicator& comm, const T& in_value, T& out_value, Op op)
  101. {
  102. detail::all_reduce_impl(comm, &in_value, 1, &out_value, op,
  103. is_mpi_op<Op, T>(), is_mpi_datatype<T>());
  104. }
  105. template<typename T, typename Op>
  106. T all_reduce(const communicator& comm, const T& in_value, Op op)
  107. {
  108. T result;
  109. ::boost::mpi::all_reduce(comm, in_value, result, op);
  110. return result;
  111. }
  112. } } // end namespace boost::mpi
  113. #endif // BOOST_MPI_ALL_REDUCE_HPP