all_gather.hpp 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. // Copyright (C) 2005, 2006 Douglas Gregor.
  2. // Use, modification and distribution is subject to the Boost Software
  3. // License, Version 1.0. (See accompanying file LICENSE_1_0.txt or copy at
  4. // http://www.boost.org/LICENSE_1_0.txt)
  5. // Message Passing Interface 1.1 -- Section 4.5. Gather
  6. #ifndef BOOST_MPI_ALLGATHER_HPP
  7. #define BOOST_MPI_ALLGATHER_HPP
  8. #include <cassert>
  9. #include <cstddef>
  10. #include <numeric>
  11. #include <boost/mpi/exception.hpp>
  12. #include <boost/mpi/datatype.hpp>
  13. #include <vector>
  14. #include <boost/mpi/packed_oarchive.hpp>
  15. #include <boost/mpi/packed_iarchive.hpp>
  16. #include <boost/mpi/detail/point_to_point.hpp>
  17. #include <boost/mpi/communicator.hpp>
  18. #include <boost/mpi/environment.hpp>
  19. #include <boost/mpi/detail/offsets.hpp>
  20. #include <boost/mpi/detail/antiques.hpp>
  21. #include <boost/assert.hpp>
  22. namespace boost { namespace mpi {
  23. namespace detail {
  24. // We're all-gathering for a type that has an associated MPI
  25. // datatype, so we'll use MPI_Gather to do all of the work.
  26. template<typename T>
  27. void
  28. all_gather_impl(const communicator& comm, const T* in_values, int n,
  29. T* out_values, mpl::true_)
  30. {
  31. MPI_Datatype type = get_mpi_datatype<T>(*in_values);
  32. BOOST_MPI_CHECK_RESULT(MPI_Allgather,
  33. (const_cast<T*>(in_values), n, type,
  34. out_values, n, type, comm));
  35. }
  36. // We're all-gathering for a type that does not have an
  37. // associated MPI datatype, so we'll need to serialize
  38. // it.
  39. template<typename T>
  40. void
  41. all_gather_impl(const communicator& comm, const T* in_values, int n,
  42. T* out_values, int const* sizes, int const* skips, mpl::false_)
  43. {
  44. int nproc = comm.size();
  45. // first, gather all size, these size can be different for
  46. // each process
  47. packed_oarchive oa(comm);
  48. for (int i = 0; i < n; ++i) {
  49. oa << in_values[i];
  50. }
  51. std::vector<int> oasizes(nproc);
  52. int oasize = oa.size();
  53. BOOST_MPI_CHECK_RESULT(MPI_Allgather,
  54. (&oasize, 1, MPI_INT,
  55. c_data(oasizes), 1, MPI_INT,
  56. MPI_Comm(comm)));
  57. // Gather the archives, which can be of different sizes, so
  58. // we need to use allgatherv.
  59. // Every thing is contiguous, so the offsets can be
  60. // deduced from the collected sizes.
  61. std::vector<int> offsets(nproc);
  62. sizes2offsets(oasizes, offsets);
  63. packed_iarchive::buffer_type recv_buffer(std::accumulate(oasizes.begin(), oasizes.end(), 0));
  64. BOOST_MPI_CHECK_RESULT(MPI_Allgatherv,
  65. (const_cast<void*>(oa.address()), int(oa.size()), MPI_BYTE,
  66. c_data(recv_buffer), c_data(oasizes), c_data(offsets), MPI_BYTE,
  67. MPI_Comm(comm)));
  68. for (int src = 0; src < nproc; ++src) {
  69. int nb = sizes ? sizes[src] : n;
  70. int skip = skips ? skips[src] : 0;
  71. std::advance(out_values, skip);
  72. if (src == comm.rank()) { // this is our local data
  73. for (int i = 0; i < nb; ++i) {
  74. *out_values++ = *in_values++;
  75. }
  76. } else {
  77. packed_iarchive ia(comm, recv_buffer, boost::archive::no_header, offsets[src]);
  78. for (int i = 0; i < nb; ++i) {
  79. ia >> *out_values++;
  80. }
  81. }
  82. }
  83. }
  84. // We're all-gathering for a type that does not have an
  85. // associated MPI datatype, so we'll need to serialize
  86. // it.
  87. template<typename T>
  88. void
  89. all_gather_impl(const communicator& comm, const T* in_values, int n,
  90. T* out_values, mpl::false_ isnt_mpi_type)
  91. {
  92. all_gather_impl(comm, in_values, n, out_values, (int const*)0, (int const*)0, isnt_mpi_type);
  93. }
  94. } // end namespace detail
  95. template<typename T>
  96. void
  97. all_gather(const communicator& comm, const T& in_value, T* out_values)
  98. {
  99. detail::all_gather_impl(comm, &in_value, 1, out_values, is_mpi_datatype<T>());
  100. }
  101. template<typename T>
  102. void
  103. all_gather(const communicator& comm, const T& in_value, std::vector<T>& out_values)
  104. {
  105. using detail::c_data;
  106. out_values.resize(comm.size());
  107. ::boost::mpi::all_gather(comm, in_value, c_data(out_values));
  108. }
  109. template<typename T>
  110. void
  111. all_gather(const communicator& comm, const T* in_values, int n, T* out_values)
  112. {
  113. detail::all_gather_impl(comm, in_values, n, out_values, is_mpi_datatype<T>());
  114. }
  115. template<typename T>
  116. void
  117. all_gather(const communicator& comm, const T* in_values, int n, std::vector<T>& out_values)
  118. {
  119. using detail::c_data;
  120. out_values.resize(comm.size() * n);
  121. ::boost::mpi::all_gather(comm, in_values, n, c_data(out_values));
  122. }
  123. } } // end namespace boost::mpi
  124. #endif // BOOST_MPI_ALL_GATHER_HPP