all_reduce_test.cpp 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. // Copyright (C) 2005, 2006 Douglas Gregor.
  2. // Use, modification and distribution is subject to the Boost Software
  3. // License, Version 1.0. (See accompanying file LICENSE_1_0.txt or copy at
  4. // http://www.boost.org/LICENSE_1_0.txt)
  5. // A test of the all_reduce() collective.
  6. #include <boost/mpi/collectives/all_reduce.hpp>
  7. #include <boost/mpi/communicator.hpp>
  8. #include <boost/mpi/environment.hpp>
  9. #include <vector>
  10. #include <algorithm>
  11. #include <boost/serialization/string.hpp>
  12. #include <boost/iterator/counting_iterator.hpp>
  13. #include <boost/lexical_cast.hpp>
  14. #include <numeric>
  15. #define BOOST_TEST_MODULE mpi_all_reduce
  16. #include <boost/test/included/unit_test.hpp>
  17. using boost::mpi::communicator;
  18. // A simple point class that we can build, add, compare, and
  19. // serialize.
  20. struct point
  21. {
  22. point() : x(0), y(0), z(0) { }
  23. point(int x, int y, int z) : x(x), y(y), z(z) { }
  24. int x;
  25. int y;
  26. int z;
  27. private:
  28. template<typename Archiver>
  29. void serialize(Archiver& ar, unsigned int /*version*/)
  30. {
  31. ar & x & y & z;
  32. }
  33. friend class boost::serialization::access;
  34. };
  35. std::ostream& operator<<(std::ostream& out, const point& p)
  36. {
  37. return out << p.x << ' ' << p.y << ' ' << p.z;
  38. }
  39. bool operator==(const point& p1, const point& p2)
  40. {
  41. return p1.x == p2.x && p1.y == p2.y && p1.z == p2.z;
  42. }
  43. bool operator!=(const point& p1, const point& p2)
  44. {
  45. return !(p1 == p2);
  46. }
  47. point operator+(const point& p1, const point& p2)
  48. {
  49. return point(p1.x + p2.x, p1.y + p2.y, p1.z + p2.z);
  50. }
  51. // test lexical order
  52. bool operator<(const point& p1, const point& p2)
  53. {
  54. return (p1.x < p2.x
  55. ? true
  56. : (p1.x > p2.x
  57. ? false
  58. : p1.y < p2.y ));
  59. }
  60. namespace boost { namespace mpi {
  61. template <>
  62. struct is_mpi_datatype<point> : public mpl::true_ { };
  63. } } // end namespace boost::mpi
  64. template<typename Generator, typename Op>
  65. void
  66. all_reduce_one_test(const communicator& comm, Generator generator,
  67. const char* type_kind, Op op, const char* op_kind,
  68. typename Generator::result_type init, bool in_place)
  69. {
  70. typedef typename Generator::result_type value_type;
  71. value_type value = generator(comm.rank());
  72. using boost::mpi::all_reduce;
  73. using boost::mpi::inplace;
  74. if (comm.rank() == 0) {
  75. std::cout << "Reducing to " << op_kind << " of " << type_kind << "...";
  76. std::cout.flush();
  77. }
  78. value_type result_value;
  79. if (in_place) {
  80. all_reduce(comm, inplace(value), op);
  81. result_value = value;
  82. } else {
  83. result_value = all_reduce(comm, value, op);
  84. }
  85. // Compute expected result
  86. std::vector<value_type> generated_values;
  87. for (int p = 0; p < comm.size(); ++p)
  88. generated_values.push_back(generator(p));
  89. value_type expected_result = std::accumulate(generated_values.begin(),
  90. generated_values.end(),
  91. init, op);
  92. BOOST_CHECK(result_value == expected_result);
  93. if (result_value == expected_result && comm.rank() == 0)
  94. std::cout << "OK." << std::endl;
  95. (comm.barrier)();
  96. }
  97. template<typename Generator, typename Op>
  98. void
  99. all_reduce_array_test(const communicator& comm, Generator generator,
  100. const char* type_kind, Op op, const char* op_kind,
  101. typename Generator::result_type init, bool in_place)
  102. {
  103. typedef typename Generator::result_type value_type;
  104. value_type value = generator(comm.rank());
  105. std::vector<value_type> send(10, value);
  106. using boost::mpi::all_reduce;
  107. using boost::mpi::inplace;
  108. if (comm.rank() == 0) {
  109. char const* place = in_place ? "in place" : "out of place";
  110. std::cout << "Reducing (" << place << ") array to " << op_kind << " of " << type_kind << "...";
  111. std::cout.flush();
  112. }
  113. std::vector<value_type> result;
  114. if (in_place) {
  115. all_reduce(comm, inplace(&(send[0])), send.size(), op);
  116. result.swap(send);
  117. } else {
  118. std::vector<value_type> recv(10, value_type());
  119. all_reduce(comm, &(send[0]), send.size(), &(recv[0]), op);
  120. result.swap(recv);
  121. }
  122. // Compute expected result
  123. std::vector<value_type> generated_values;
  124. for (int p = 0; p < comm.size(); ++p)
  125. generated_values.push_back(generator(p));
  126. value_type expected_result = std::accumulate(generated_values.begin(),
  127. generated_values.end(),
  128. init, op);
  129. bool got_expected_result = (std::equal_range(result.begin(), result.end(),
  130. expected_result)
  131. == std::make_pair(result.begin(), result.end()));
  132. BOOST_CHECK(got_expected_result);
  133. if (got_expected_result && comm.rank() == 0)
  134. std::cout << "OK." << std::endl;
  135. (comm.barrier)();
  136. }
  137. // Test the 4 families of all reduce: (value, array) X (in place, out of place)
  138. template<typename Generator, typename Op>
  139. void
  140. all_reduce_test(const communicator& comm, Generator generator,
  141. const char* type_kind, Op op, const char* op_kind,
  142. typename Generator::result_type init)
  143. {
  144. const bool in_place = true;
  145. const bool out_of_place = false;
  146. all_reduce_one_test(comm, generator, type_kind, op, op_kind, init, in_place);
  147. all_reduce_one_test(comm, generator, type_kind, op, op_kind, init, out_of_place);
  148. all_reduce_array_test(comm, generator, type_kind, op, op_kind,
  149. init, in_place);
  150. all_reduce_array_test(comm, generator, type_kind, op, op_kind,
  151. init, out_of_place);
  152. }
  153. // Generates integers to test with all_reduce()
  154. struct int_generator
  155. {
  156. typedef int result_type;
  157. int_generator(int base = 1) : base(base) { }
  158. int operator()(int p) const { return base + p; }
  159. private:
  160. int base;
  161. };
  162. // Generate points to test with all_reduce()
  163. struct point_generator
  164. {
  165. typedef point result_type;
  166. point_generator(point origin) : origin(origin) { }
  167. point operator()(int p) const
  168. {
  169. return point(origin.x + 1, origin.y + 1, origin.z + 1);
  170. }
  171. private:
  172. point origin;
  173. };
  174. struct string_generator
  175. {
  176. typedef std::string result_type;
  177. std::string operator()(int p) const
  178. {
  179. std::string result = boost::lexical_cast<std::string>(p);
  180. result += " rosebud";
  181. if (p != 1) result += 's';
  182. return result;
  183. }
  184. };
  185. struct secret_int_bit_and
  186. {
  187. int operator()(int x, int y) const { return x & y; }
  188. };
  189. struct wrapped_int
  190. {
  191. wrapped_int() : value(0) { }
  192. explicit wrapped_int(int value) : value(value) { }
  193. template<typename Archive>
  194. void serialize(Archive& ar, unsigned int /* version */)
  195. {
  196. ar & value;
  197. }
  198. int value;
  199. };
  200. wrapped_int operator+(const wrapped_int& x, const wrapped_int& y)
  201. {
  202. return wrapped_int(x.value + y.value);
  203. }
  204. bool operator==(const wrapped_int& x, const wrapped_int& y)
  205. {
  206. return x.value == y.value;
  207. }
  208. bool operator<(const wrapped_int& x, const wrapped_int& y)
  209. {
  210. return x.value < y.value;
  211. }
  212. // Generates wrapped_its to test with all_reduce()
  213. struct wrapped_int_generator
  214. {
  215. typedef wrapped_int result_type;
  216. wrapped_int_generator(int base = 1) : base(base) { }
  217. wrapped_int operator()(int p) const { return wrapped_int(base + p); }
  218. private:
  219. int base;
  220. };
  221. namespace boost { namespace mpi {
  222. // Make std::plus<wrapped_int> commutative.
  223. template<>
  224. struct is_commutative<std::plus<wrapped_int>, wrapped_int>
  225. : mpl::true_ { };
  226. } } // end namespace boost::mpi
  227. BOOST_AUTO_TEST_CASE(test_all_reduce)
  228. {
  229. using namespace boost::mpi;
  230. environment env;
  231. communicator comm;
  232. // Built-in MPI datatypes with built-in MPI operations
  233. all_reduce_test(comm, int_generator(), "integers", std::plus<int>(), "sum", 0);
  234. all_reduce_test(comm, int_generator(), "integers", std::multiplies<int>(), "product", 1);
  235. all_reduce_test(comm, int_generator(), "integers", maximum<int>(), "maximum", 0);
  236. all_reduce_test(comm, int_generator(), "integers", minimum<int>(), "minimum", 2);
  237. // User-defined MPI datatypes with operations that have the
  238. // same name as built-in operations.
  239. all_reduce_test(comm, point_generator(point(0,0,0)), "points", std::plus<point>(),
  240. "sum", point());
  241. // Built-in MPI datatypes with user-defined operations
  242. all_reduce_test(comm, int_generator(17), "integers", secret_int_bit_and(),
  243. "bitwise and", -1);
  244. // Arbitrary types with user-defined, commutative operations.
  245. all_reduce_test(comm, wrapped_int_generator(17), "wrapped integers",
  246. std::plus<wrapped_int>(), "sum", wrapped_int(0));
  247. // Arbitrary types with (non-commutative) user-defined operations
  248. all_reduce_test(comm, string_generator(), "strings",
  249. std::plus<std::string>(), "concatenation", std::string());
  250. }