reduce_test.cpp 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. // Copyright (C) 2005, 2006 Douglas Gregor <doug.gregor -at- gmail.com>.
  2. // Use, modification and distribution is subject to the Boost Software
  3. // License, Version 1.0. (See accompanying file LICENSE_1_0.txt or copy at
  4. // http://www.boost.org/LICENSE_1_0.txt)
  5. // A test of the reduce() collective.
  6. #include <boost/mpi/collectives/reduce.hpp>
  7. #include <boost/mpi/communicator.hpp>
  8. #include <boost/mpi/environment.hpp>
  9. #include <algorithm>
  10. #include <boost/serialization/string.hpp>
  11. #include <boost/iterator/counting_iterator.hpp>
  12. #include <boost/lexical_cast.hpp>
  13. #include <numeric>
  14. #define BOOST_TEST_MODULE mpi_reduce_test
  15. #include <boost/test/included/unit_test.hpp>
  16. using boost::mpi::communicator;
  17. // A simple point class that we can build, add, compare, and
  18. // serialize.
  19. struct point
  20. {
  21. point() : x(0), y(0), z(0) { }
  22. point(int x, int y, int z) : x(x), y(y), z(z) { }
  23. int x;
  24. int y;
  25. int z;
  26. private:
  27. template<typename Archiver>
  28. void serialize(Archiver& ar, unsigned int /*version*/)
  29. {
  30. ar & x & y & z;
  31. }
  32. friend class boost::serialization::access;
  33. };
  34. std::ostream& operator<<(std::ostream& out, const point& p)
  35. {
  36. return out << p.x << ' ' << p.y << ' ' << p.z;
  37. }
  38. bool operator==(const point& p1, const point& p2)
  39. {
  40. return p1.x == p2.x && p1.y == p2.y && p1.z == p2.z;
  41. }
  42. bool operator!=(const point& p1, const point& p2)
  43. {
  44. return !(p1 == p2);
  45. }
  46. point operator+(const point& p1, const point& p2)
  47. {
  48. return point(p1.x + p2.x, p1.y + p2.y, p1.z + p2.z);
  49. }
  50. namespace boost { namespace mpi {
  51. template <>
  52. struct is_mpi_datatype<point> : public mpl::true_ { };
  53. } } // end namespace boost::mpi
  54. template<typename Generator, typename Op>
  55. void
  56. reduce_test(const communicator& comm, Generator generator,
  57. const char* type_kind, Op op, const char* op_kind,
  58. typename Generator::result_type init,
  59. int root = -1)
  60. {
  61. typedef typename Generator::result_type value_type;
  62. value_type value = generator(comm.rank());
  63. if (root == -1) {
  64. for (root = 0; root < comm.size(); ++root)
  65. reduce_test(comm, generator, type_kind, op, op_kind, init, root);
  66. } else {
  67. using boost::mpi::reduce;
  68. if (comm.rank() == root) {
  69. std::cout << "Reducing to " << op_kind << " of " << type_kind
  70. << " at root " << root << "...";
  71. std::cout.flush();
  72. value_type result_value;
  73. reduce(comm, value, result_value, op, root);
  74. // Compute expected result
  75. std::vector<value_type> generated_values;
  76. for (int p = 0; p < comm.size(); ++p)
  77. generated_values.push_back(generator(p));
  78. value_type expected_result = std::accumulate(generated_values.begin(),
  79. generated_values.end(),
  80. init, op);
  81. BOOST_CHECK(result_value == expected_result);
  82. if (result_value == expected_result)
  83. std::cout << "OK." << std::endl;
  84. } else {
  85. reduce(comm, value, op, root);
  86. }
  87. }
  88. (comm.barrier)();
  89. }
  90. // Generates integers to test with reduce()
  91. struct int_generator
  92. {
  93. typedef int result_type;
  94. int_generator(int base = 1) : base(base) { }
  95. int operator()(int p) const { return base + p; }
  96. private:
  97. int base;
  98. };
  99. // Generate points to test with reduce()
  100. struct point_generator
  101. {
  102. typedef point result_type;
  103. point_generator(point origin) : origin(origin) { }
  104. point operator()(int p) const
  105. {
  106. return point(origin.x + 1, origin.y + 1, origin.z + 1);
  107. }
  108. private:
  109. point origin;
  110. };
  111. struct string_generator
  112. {
  113. typedef std::string result_type;
  114. std::string operator()(int p) const
  115. {
  116. std::string result = boost::lexical_cast<std::string>(p);
  117. result += " rosebud";
  118. if (p != 1) result += 's';
  119. return result;
  120. }
  121. };
  122. struct secret_int_bit_and
  123. {
  124. int operator()(int x, int y) const { return x & y; }
  125. };
  126. struct wrapped_int
  127. {
  128. wrapped_int() : value(0) { }
  129. explicit wrapped_int(int value) : value(value) { }
  130. template<typename Archive>
  131. void serialize(Archive& ar, unsigned int /* version */)
  132. {
  133. ar & value;
  134. }
  135. int value;
  136. };
  137. wrapped_int operator+(const wrapped_int& x, const wrapped_int& y)
  138. {
  139. return wrapped_int(x.value + y.value);
  140. }
  141. bool operator==(const wrapped_int& x, const wrapped_int& y)
  142. {
  143. return x.value == y.value;
  144. }
  145. // Generates wrapped_its to test with reduce()
  146. struct wrapped_int_generator
  147. {
  148. typedef wrapped_int result_type;
  149. wrapped_int_generator(int base = 1) : base(base) { }
  150. wrapped_int operator()(int p) const { return wrapped_int(base + p); }
  151. private:
  152. int base;
  153. };
  154. namespace boost { namespace mpi {
  155. // Make std::plus<wrapped_int> commutative.
  156. template<>
  157. struct is_commutative<std::plus<wrapped_int>, wrapped_int>
  158. : mpl::true_ { };
  159. } } // end namespace boost::mpi
  160. BOOST_AUTO_TEST_CASE(reduce_check)
  161. {
  162. using namespace boost::mpi;
  163. environment env;
  164. communicator comm;
  165. // Built-in MPI datatypes with built-in MPI operations
  166. reduce_test(comm, int_generator(), "integers", std::plus<int>(), "sum", 0);
  167. reduce_test(comm, int_generator(), "integers", std::multiplies<int>(),
  168. "product", 1);
  169. reduce_test(comm, int_generator(), "integers", maximum<int>(),
  170. "maximum", 0);
  171. reduce_test(comm, int_generator(), "integers", minimum<int>(),
  172. "minimum", 2);
  173. // User-defined MPI datatypes with operations that have the
  174. // same name as built-in operations.
  175. reduce_test(comm, point_generator(point(0,0,0)), "points",
  176. std::plus<point>(), "sum", point());
  177. // Built-in MPI datatypes with user-defined operations
  178. reduce_test(comm, int_generator(17), "integers", secret_int_bit_and(),
  179. "bitwise and", -1);
  180. // Arbitrary types with user-defined, commutative operations.
  181. reduce_test(comm, wrapped_int_generator(17), "wrapped integers",
  182. std::plus<wrapped_int>(), "sum", wrapped_int(0));
  183. // Arbitrary types with (non-commutative) user-defined operations
  184. reduce_test(comm, string_generator(), "strings",
  185. std::plus<std::string>(), "concatenation", std::string());
  186. }