gather.hpp 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. // Copyright (C) 2005, 2006 Douglas Gregor.
  2. // Use, modification and distribution is subject to the Boost Software
  3. // License, Version 1.0. (See accompanying file LICENSE_1_0.txt or copy at
  4. // http://www.boost.org/LICENSE_1_0.txt)
  5. // Message Passing Interface 1.1 -- Section 4.5. Gather
  6. #ifndef BOOST_MPI_GATHER_HPP
  7. #define BOOST_MPI_GATHER_HPP
  8. #include <cassert>
  9. #include <cstddef>
  10. #include <numeric>
  11. #include <boost/mpi/exception.hpp>
  12. #include <boost/mpi/datatype.hpp>
  13. #include <vector>
  14. #include <boost/mpi/packed_oarchive.hpp>
  15. #include <boost/mpi/packed_iarchive.hpp>
  16. #include <boost/mpi/detail/point_to_point.hpp>
  17. #include <boost/mpi/communicator.hpp>
  18. #include <boost/mpi/environment.hpp>
  19. #include <boost/mpi/detail/offsets.hpp>
  20. #include <boost/mpi/detail/antiques.hpp>
  21. #include <boost/assert.hpp>
  22. namespace boost { namespace mpi {
  23. namespace detail {
  24. // We're gathering at the root for a type that has an associated MPI
  25. // datatype, so we'll use MPI_Gather to do all of the work.
  26. template<typename T>
  27. void
  28. gather_impl(const communicator& comm, const T* in_values, int n,
  29. T* out_values, int root, mpl::true_)
  30. {
  31. MPI_Datatype type = get_mpi_datatype<T>(*in_values);
  32. BOOST_MPI_CHECK_RESULT(MPI_Gather,
  33. (const_cast<T*>(in_values), n, type,
  34. out_values, n, type, root, comm));
  35. }
  36. // We're gathering from a non-root for a type that has an associated MPI
  37. // datatype, so we'll use MPI_Gather to do all of the work.
  38. template<typename T>
  39. void
  40. gather_impl(const communicator& comm, const T* in_values, int n, int root,
  41. mpl::true_ is_mpi_type)
  42. {
  43. assert(comm.rank() != root);
  44. gather_impl(comm, in_values, n, (T*)0, root, is_mpi_type);
  45. }
  46. // We're gathering at the root for a type that does not have an
  47. // associated MPI datatype, so we'll need to serialize
  48. // it.
  49. template<typename T>
  50. void
  51. gather_impl(const communicator& comm, const T* in_values, int n, T* out_values,
  52. int const* nslot, int const* nskip, int root, mpl::false_)
  53. {
  54. int nproc = comm.size();
  55. // first, gather all size, these size can be different for
  56. // each process
  57. packed_oarchive oa(comm);
  58. for (int i = 0; i < n; ++i) {
  59. oa << in_values[i];
  60. }
  61. bool is_root = comm.rank() == root;
  62. std::vector<int> oasizes(is_root ? nproc : 0);
  63. int oasize = oa.size();
  64. BOOST_MPI_CHECK_RESULT(MPI_Gather,
  65. (&oasize, 1, MPI_INT,
  66. c_data(oasizes), 1, MPI_INT,
  67. root, MPI_Comm(comm)));
  68. // Gather the archives, which can be of different sizes, so
  69. // we need to use gatherv.
  70. // Everything is contiguous (in the transmitted archive), so
  71. // the offsets can be deduced from the collected sizes.
  72. std::vector<int> offsets;
  73. if (is_root) sizes2offsets(oasizes, offsets);
  74. packed_iarchive::buffer_type recv_buffer(is_root ? std::accumulate(oasizes.begin(), oasizes.end(), 0) : 0);
  75. BOOST_MPI_CHECK_RESULT(MPI_Gatherv,
  76. (const_cast<void*>(oa.address()), int(oa.size()), MPI_BYTE,
  77. c_data(recv_buffer), c_data(oasizes), c_data(offsets), MPI_BYTE,
  78. root, MPI_Comm(comm)));
  79. if (is_root) {
  80. for (int src = 0; src < nproc; ++src) {
  81. // handle variadic case
  82. int nb = nslot ? nslot[src] : n;
  83. int skip = nskip ? nskip[src] : 0;
  84. std::advance(out_values, skip);
  85. if (src == root) {
  86. BOOST_ASSERT(nb == n);
  87. for (int i = 0; i < nb; ++i) {
  88. *out_values++ = *in_values++;
  89. }
  90. } else {
  91. packed_iarchive ia(comm, recv_buffer, boost::archive::no_header, offsets[src]);
  92. for (int i = 0; i < nb; ++i) {
  93. ia >> *out_values++;
  94. }
  95. }
  96. }
  97. }
  98. }
  99. // We're gathering at a non-root for a type that does not have an
  100. // associated MPI datatype, so we'll need to serialize
  101. // it.
  102. template<typename T>
  103. void
  104. gather_impl(const communicator& comm, const T* in_values, int n, T* out_values,int root,
  105. mpl::false_ is_mpi_type)
  106. {
  107. gather_impl(comm, in_values, n, out_values, (int const*)0, (int const*)0, root, is_mpi_type);
  108. }
  109. } // end namespace detail
  110. template<typename T>
  111. void
  112. gather(const communicator& comm, const T& in_value, T* out_values, int root)
  113. {
  114. BOOST_ASSERT(out_values || (comm.rank() != root));
  115. detail::gather_impl(comm, &in_value, 1, out_values, root, is_mpi_datatype<T>());
  116. }
  117. template<typename T>
  118. void gather(const communicator& comm, const T& in_value, int root)
  119. {
  120. BOOST_ASSERT(comm.rank() != root);
  121. detail::gather_impl(comm, &in_value, 1, (T*)0, root, is_mpi_datatype<T>());
  122. }
  123. template<typename T>
  124. void
  125. gather(const communicator& comm, const T& in_value, std::vector<T>& out_values,
  126. int root)
  127. {
  128. using detail::c_data;
  129. if (comm.rank() == root) {
  130. out_values.resize(comm.size());
  131. }
  132. ::boost::mpi::gather(comm, in_value, c_data(out_values), root);
  133. }
  134. template<typename T>
  135. void
  136. gather(const communicator& comm, const T* in_values, int n, T* out_values,
  137. int root)
  138. {
  139. detail::gather_impl(comm, in_values, n, out_values, root,
  140. is_mpi_datatype<T>());
  141. }
  142. template<typename T>
  143. void
  144. gather(const communicator& comm, const T* in_values, int n,
  145. std::vector<T>& out_values, int root)
  146. {
  147. if (comm.rank() == root) {
  148. out_values.resize(comm.size() * n);
  149. }
  150. ::boost::mpi::gather(comm, in_values, n, out_values.data(), root);
  151. }
  152. template<typename T>
  153. void gather(const communicator& comm, const T* in_values, int n, int root)
  154. {
  155. BOOST_ASSERT(comm.rank() != root);
  156. detail::gather_impl(comm, in_values, n, root, is_mpi_datatype<T>());
  157. }
  158. } } // end namespace boost::mpi
  159. #endif // BOOST_MPI_GATHER_HPP