request_handlers.hpp 19 KB


  1. // Copyright (C) 2018 Alain Miniussi <alain.miniussi@oca.eu>.
  2. // Use, modification and distribution is subject to the Boost Software
  3. // License, Version 1.0. (See accompanying file LICENSE_1_0.txt or copy at
  4. // http://www.boost.org/LICENSE_1_0.txt)
  5. // Request implementation dtails
  6. // This header should be included only after the communicator and request
  7. // classes has been defined.
  8. #ifndef BOOST_MPI_REQUEST_HANDLERS_HPP
  9. #define BOOST_MPI_REQUEST_HANDLERS_HPP
  10. #include <boost/mpi/skeleton_and_content_types.hpp>
  11. namespace boost { namespace mpi {
  12. namespace detail {
  13. /**
  14. * Internal data structure that stores everything required to manage
  15. * the receipt of serialized data via a request object.
  16. */
  17. template<typename T>
  18. struct serialized_irecv_data {
  19. serialized_irecv_data(const communicator& comm, T& value)
  20. : m_ia(comm), m_value(value) {}
  21. void deserialize(status& stat)
  22. {
  23. m_ia >> m_value;
  24. stat.m_count = 1;
  25. }
  26. std::size_t m_count;
  27. packed_iarchive m_ia;
  28. T& m_value;
  29. };
  30. template<>
  31. struct serialized_irecv_data<packed_iarchive>
  32. {
  33. serialized_irecv_data(communicator const&, packed_iarchive& ia) : m_ia(ia) { }
  34. void deserialize(status&) { /* Do nothing. */ }
  35. std::size_t m_count;
  36. packed_iarchive& m_ia;
  37. };
  38. /**
  39. * Internal data structure that stores everything required to manage
  40. * the receipt of an array of serialized data via a request object.
  41. */
  42. template<typename T>
  43. struct serialized_array_irecv_data
  44. {
  45. serialized_array_irecv_data(const communicator& comm, T* values, int n)
  46. : m_count(0), m_ia(comm), m_values(values), m_nb(n) {}
  47. void deserialize(status& stat);
  48. std::size_t m_count;
  49. packed_iarchive m_ia;
  50. T* m_values;
  51. int m_nb;
  52. };
  53. template<typename T>
  54. void serialized_array_irecv_data<T>::deserialize(status& stat)
  55. {
  56. T* v = m_values;
  57. T* end = m_values+m_nb;
  58. while (v < end) {
  59. m_ia >> *v++;
  60. }
  61. stat.m_count = m_nb;
  62. }
  63. /**
  64. * Internal data structure that stores everything required to manage
  65. * the receipt of an array of primitive data but unknown size.
  66. * Such an array can have been send with blocking operation and so must
  67. * be compatible with the (size_t,raw_data[]) format.
  68. */
  69. template<typename T, class A>
  70. struct dynamic_array_irecv_data
  71. {
  72. BOOST_STATIC_ASSERT_MSG(is_mpi_datatype<T>::value, "Can only be specialized for MPI datatypes.");
  73. dynamic_array_irecv_data(std::vector<T,A>& values)
  74. : m_count(-1), m_values(values) {}
  75. std::size_t m_count;
  76. std::vector<T,A>& m_values;
  77. };
  78. template<typename T>
  79. struct serialized_irecv_data<const skeleton_proxy<T> >
  80. {
  81. serialized_irecv_data(const communicator& comm, skeleton_proxy<T> proxy)
  82. : m_isa(comm), m_ia(m_isa.get_skeleton()), m_proxy(proxy) { }
  83. void deserialize(status& stat)
  84. {
  85. m_isa >> m_proxy.object;
  86. stat.m_count = 1;
  87. }
  88. std::size_t m_count;
  89. packed_skeleton_iarchive m_isa;
  90. packed_iarchive& m_ia;
  91. skeleton_proxy<T> m_proxy;
  92. };
  93. template<typename T>
  94. struct serialized_irecv_data<skeleton_proxy<T> >
  95. : public serialized_irecv_data<const skeleton_proxy<T> >
  96. {
  97. typedef serialized_irecv_data<const skeleton_proxy<T> > inherited;
  98. serialized_irecv_data(const communicator& comm, const skeleton_proxy<T>& proxy)
  99. : inherited(comm, proxy) { }
  100. };
  101. }
  102. #if BOOST_MPI_VERSION >= 3
  103. template<class Data>
  104. class request::probe_handler
  105. : public request::handler,
  106. protected Data {
  107. protected:
  108. template<typename I1>
  109. probe_handler(communicator const& comm, int source, int tag, I1& i1)
  110. : Data(comm, i1),
  111. m_comm(comm),
  112. m_source(source),
  113. m_tag(tag) {}
  114. // no variadic template for now
  115. template<typename I1, typename I2>
  116. probe_handler(communicator const& comm, int source, int tag, I1& i1, I2& i2)
  117. : Data(comm, i1, i2),
  118. m_comm(comm),
  119. m_source(source),
  120. m_tag(tag) {}
  121. public:
  122. bool active() const { return m_source != MPI_PROC_NULL; }
  123. optional<MPI_Request&> trivial() { return boost::none; }
  124. void cancel() { m_source = MPI_PROC_NULL; }
  125. status wait() {
  126. MPI_Message msg;
  127. status stat;
  128. BOOST_MPI_CHECK_RESULT(MPI_Mprobe, (m_source,m_tag,m_comm,&msg,&stat.m_status));
  129. return unpack(msg, stat);
  130. }
  131. optional<status> test() {
  132. status stat;
  133. int flag = 0;
  134. MPI_Message msg;
  135. BOOST_MPI_CHECK_RESULT(MPI_Improbe, (m_source,m_tag,m_comm,&flag,&msg,&stat.m_status));
  136. if (flag) {
  137. return unpack(msg, stat);
  138. } else {
  139. return optional<status>();
  140. }
  141. }
  142. protected:
  143. friend class request;
  144. status unpack(MPI_Message& msg, status& stat) {
  145. int count;
  146. MPI_Datatype datatype = this->Data::datatype();
  147. BOOST_MPI_CHECK_RESULT(MPI_Get_count, (&stat.m_status, datatype, &count));
  148. this->Data::resize(count);
  149. BOOST_MPI_CHECK_RESULT(MPI_Mrecv, (this->Data::buffer(), count, datatype, &msg, &stat.m_status));
  150. this->Data::deserialize();
  151. m_source = MPI_PROC_NULL;
  152. stat.m_count = 1;
  153. return stat;
  154. }
  155. communicator const& m_comm;
  156. int m_source;
  157. int m_tag;
  158. };
  159. #endif // BOOST_MPI_VERSION >= 3
  160. namespace detail {
  161. template<class A>
  162. struct dynamic_primitive_array_data {
  163. dynamic_primitive_array_data(communicator const&, A& arr) : m_buffer(arr) {}
  164. void* buffer() { return m_buffer.data(); }
  165. void resize(std::size_t sz) { m_buffer.resize(sz); }
  166. void deserialize() {}
  167. MPI_Datatype datatype() { return get_mpi_datatype<typename A::value_type>(); }
  168. A& m_buffer;
  169. };
  170. template<typename T>
  171. struct serialized_data {
  172. serialized_data(communicator const& comm, T& value) : m_archive(comm), m_value(value) {}
  173. void* buffer() { return m_archive.address(); }
  174. void resize(std::size_t sz) { m_archive.resize(sz); }
  175. void deserialize() { m_archive >> m_value; }
  176. MPI_Datatype datatype() { return MPI_PACKED; }
  177. packed_iarchive m_archive;
  178. T& m_value;
  179. };
  180. template<>
  181. struct serialized_data<packed_iarchive> {
  182. serialized_data(communicator const& comm, packed_iarchive& ar) : m_archive(ar) {}
  183. void* buffer() { return m_archive.address(); }
  184. void resize(std::size_t sz) { m_archive.resize(sz); }
  185. void deserialize() {}
  186. MPI_Datatype datatype() { return MPI_PACKED; }
  187. packed_iarchive& m_archive;
  188. };
  189. template<typename T>
  190. struct serialized_data<const skeleton_proxy<T> > {
  191. serialized_data(communicator const& comm, skeleton_proxy<T> skel)
  192. : m_proxy(skel),
  193. m_archive(comm) {}
  194. void* buffer() { return m_archive.get_skeleton().address(); }
  195. void resize(std::size_t sz) { m_archive.get_skeleton().resize(sz); }
  196. void deserialize() { m_archive >> m_proxy.object; }
  197. MPI_Datatype datatype() { return MPI_PACKED; }
  198. skeleton_proxy<T> m_proxy;
  199. packed_skeleton_iarchive m_archive;
  200. };
  201. template<typename T>
  202. struct serialized_data<skeleton_proxy<T> >
  203. : public serialized_data<const skeleton_proxy<T> > {
  204. typedef serialized_data<const skeleton_proxy<T> > super;
  205. serialized_data(communicator const& comm, skeleton_proxy<T> skel)
  206. : super(comm, skel) {}
  207. };
  208. template<typename T>
  209. struct serialized_array_data {
  210. serialized_array_data(communicator const& comm, T* values, int nb)
  211. : m_archive(comm), m_values(values), m_nb(nb) {}
  212. void* buffer() { return m_archive.address(); }
  213. void resize(std::size_t sz) { m_archive.resize(sz); }
  214. void deserialize() {
  215. T* end = m_values + m_nb;
  216. T* v = m_values;
  217. while (v != end) {
  218. m_archive >> *v++;
  219. }
  220. }
  221. MPI_Datatype datatype() { return MPI_PACKED; }
  222. packed_iarchive m_archive;
  223. T* m_values;
  224. int m_nb;
  225. };
  226. }
  227. class BOOST_MPI_DECL request::legacy_handler : public request::handler {
  228. public:
  229. legacy_handler(communicator const& comm, int source, int tag);
  230. void cancel() {
  231. for (int i = 0; i < 2; ++i) {
  232. if (m_requests[i] != MPI_REQUEST_NULL) {
  233. BOOST_MPI_CHECK_RESULT(MPI_Cancel, (m_requests+i));
  234. }
  235. }
  236. }
  237. bool active() const;
  238. optional<MPI_Request&> trivial();
  239. MPI_Request m_requests[2];
  240. communicator m_comm;
  241. int m_source;
  242. int m_tag;
  243. };
  244. template<typename T>
  245. class request::legacy_serialized_handler
  246. : public request::legacy_handler,
  247. protected detail::serialized_irecv_data<T> {
  248. public:
  249. typedef detail::serialized_irecv_data<T> extra;
  250. legacy_serialized_handler(communicator const& comm, int source, int tag, T& value)
  251. : legacy_handler(comm, source, tag),
  252. extra(comm, value) {
  253. BOOST_MPI_CHECK_RESULT(MPI_Irecv,
  254. (&this->extra::m_count, 1,
  255. get_mpi_datatype(this->extra::m_count),
  256. source, tag, comm, m_requests+0));
  257. }
  258. status wait() {
  259. status stat;
  260. if (m_requests[1] == MPI_REQUEST_NULL) {
  261. // Wait for the count message to complete
  262. BOOST_MPI_CHECK_RESULT(MPI_Wait,
  263. (m_requests, &stat.m_status));
  264. // Resize our buffer and get ready to receive its data
  265. this->extra::m_ia.resize(this->extra::m_count);
  266. BOOST_MPI_CHECK_RESULT(MPI_Irecv,
  267. (this->extra::m_ia.address(), this->extra::m_ia.size(), MPI_PACKED,
  268. stat.source(), stat.tag(),
  269. MPI_Comm(m_comm), m_requests + 1));
  270. }
  271. // Wait until we have received the entire message
  272. BOOST_MPI_CHECK_RESULT(MPI_Wait,
  273. (m_requests + 1, &stat.m_status));
  274. this->deserialize(stat);
  275. return stat;
  276. }
  277. optional<status> test() {
  278. status stat;
  279. int flag = 0;
  280. if (m_requests[1] == MPI_REQUEST_NULL) {
  281. // Check if the count message has completed
  282. BOOST_MPI_CHECK_RESULT(MPI_Test,
  283. (m_requests, &flag, &stat.m_status));
  284. if (flag) {
  285. // Resize our buffer and get ready to receive its data
  286. this->extra::m_ia.resize(this->extra::m_count);
  287. BOOST_MPI_CHECK_RESULT(MPI_Irecv,
  288. (this->extra::m_ia.address(), this->extra::m_ia.size(),MPI_PACKED,
  289. stat.source(), stat.tag(),
  290. MPI_Comm(m_comm), m_requests + 1));
  291. } else
  292. return optional<status>(); // We have not finished yet
  293. }
  294. // Check if we have received the message data
  295. BOOST_MPI_CHECK_RESULT(MPI_Test,
  296. (m_requests + 1, &flag, &stat.m_status));
  297. if (flag) {
  298. this->deserialize(stat);
  299. return stat;
  300. } else
  301. return optional<status>();
  302. }
  303. };
  304. template<typename T>
  305. class request::legacy_serialized_array_handler
  306. : public request::legacy_handler,
  307. protected detail::serialized_array_irecv_data<T> {
  308. typedef detail::serialized_array_irecv_data<T> extra;
  309. public:
  310. legacy_serialized_array_handler(communicator const& comm, int source, int tag, T* values, int n)
  311. : legacy_handler(comm, source, tag),
  312. extra(comm, values, n) {
  313. BOOST_MPI_CHECK_RESULT(MPI_Irecv,
  314. (&this->extra::m_count, 1,
  315. get_mpi_datatype(this->extra::m_count),
  316. source, tag, comm, m_requests+0));
  317. }
  318. status wait() {
  319. status stat;
  320. if (m_requests[1] == MPI_REQUEST_NULL) {
  321. // Wait for the count message to complete
  322. BOOST_MPI_CHECK_RESULT(MPI_Wait,
  323. (m_requests, &stat.m_status));
  324. // Resize our buffer and get ready to receive its data
  325. this->extra::m_ia.resize(this->extra::m_count);
  326. BOOST_MPI_CHECK_RESULT(MPI_Irecv,
  327. (this->extra::m_ia.address(), this->extra::m_ia.size(), MPI_PACKED,
  328. stat.source(), stat.tag(),
  329. MPI_Comm(m_comm), m_requests + 1));
  330. }
  331. // Wait until we have received the entire message
  332. BOOST_MPI_CHECK_RESULT(MPI_Wait,
  333. (m_requests + 1, &stat.m_status));
  334. this->deserialize(stat);
  335. return stat;
  336. }
  337. optional<status> test() {
  338. status stat;
  339. int flag = 0;
  340. if (m_requests[1] == MPI_REQUEST_NULL) {
  341. // Check if the count message has completed
  342. BOOST_MPI_CHECK_RESULT(MPI_Test,
  343. (m_requests, &flag, &stat.m_status));
  344. if (flag) {
  345. // Resize our buffer and get ready to receive its data
  346. this->extra::m_ia.resize(this->extra::m_count);
  347. BOOST_MPI_CHECK_RESULT(MPI_Irecv,
  348. (this->extra::m_ia.address(), this->extra::m_ia.size(),MPI_PACKED,
  349. stat.source(), stat.tag(),
  350. MPI_Comm(m_comm), m_requests + 1));
  351. } else
  352. return optional<status>(); // We have not finished yet
  353. }
  354. // Check if we have received the message data
  355. BOOST_MPI_CHECK_RESULT(MPI_Test,
  356. (m_requests + 1, &flag, &stat.m_status));
  357. if (flag) {
  358. this->deserialize(stat);
  359. return stat;
  360. } else
  361. return optional<status>();
  362. }
  363. };
  364. template<typename T, class A>
  365. class request::legacy_dynamic_primitive_array_handler
  366. : public request::legacy_handler,
  367. protected detail::dynamic_array_irecv_data<T,A>
  368. {
  369. typedef detail::dynamic_array_irecv_data<T,A> extra;
  370. public:
  371. legacy_dynamic_primitive_array_handler(communicator const& comm, int source, int tag, std::vector<T,A>& values)
  372. : legacy_handler(comm, source, tag),
  373. extra(values) {
  374. BOOST_MPI_CHECK_RESULT(MPI_Irecv,
  375. (&this->extra::m_count, 1,
  376. get_mpi_datatype(this->extra::m_count),
  377. source, tag, comm, m_requests+0));
  378. }
  379. status wait() {
  380. status stat;
  381. if (m_requests[1] == MPI_REQUEST_NULL) {
  382. // Wait for the count message to complete
  383. BOOST_MPI_CHECK_RESULT(MPI_Wait,
  384. (m_requests, &stat.m_status));
  385. // Resize our buffer and get ready to receive its data
  386. this->extra::m_values.resize(this->extra::m_count);
  387. BOOST_MPI_CHECK_RESULT(MPI_Irecv,
  388. (&(this->extra::m_values[0]), this->extra::m_values.size(), get_mpi_datatype<T>(),
  389. stat.source(), stat.tag(),
  390. MPI_Comm(m_comm), m_requests + 1));
  391. }
  392. // Wait until we have received the entire message
  393. BOOST_MPI_CHECK_RESULT(MPI_Wait,
  394. (m_requests + 1, &stat.m_status));
  395. return stat;
  396. }
  397. optional<status> test() {
  398. status stat;
  399. int flag = 0;
  400. if (m_requests[1] == MPI_REQUEST_NULL) {
  401. // Check if the count message has completed
  402. BOOST_MPI_CHECK_RESULT(MPI_Test,
  403. (m_requests, &flag, &stat.m_status));
  404. if (flag) {
  405. // Resize our buffer and get ready to receive its data
  406. this->extra::m_values.resize(this->extra::m_count);
  407. BOOST_MPI_CHECK_RESULT(MPI_Irecv,
  408. (&(this->extra::m_values[0]), this->extra::m_values.size(), get_mpi_datatype<T>(),
  409. stat.source(), stat.tag(),
  410. MPI_Comm(m_comm), m_requests + 1));
  411. } else
  412. return optional<status>(); // We have not finished yet
  413. }
  414. // Check if we have received the message data
  415. BOOST_MPI_CHECK_RESULT(MPI_Test,
  416. (m_requests + 1, &flag, &stat.m_status));
  417. if (flag) {
  418. return stat;
  419. } else
  420. return optional<status>();
  421. }
  422. };
  423. class BOOST_MPI_DECL request::trivial_handler : public request::handler {
  424. public:
  425. trivial_handler();
  426. status wait();
  427. optional<status> test();
  428. void cancel();
  429. bool active() const;
  430. optional<MPI_Request&> trivial();
  431. private:
  432. friend class request;
  433. MPI_Request m_request;
  434. };
  435. class request::dynamic_handler : public request::handler {
  436. dynamic_handler();
  437. status wait();
  438. optional<status> test();
  439. void cancel();
  440. bool active() const;
  441. optional<MPI_Request&> trivial();
  442. private:
  443. friend class request;
  444. MPI_Request m_requests[2];
  445. };
  446. template<typename T>
  447. request request::make_serialized(communicator const& comm, int source, int tag, T& value) {
  448. #if defined(BOOST_MPI_USE_IMPROBE)
  449. return request(new probe_handler<detail::serialized_data<T> >(comm, source, tag, value));
  450. #else
  451. return request(new legacy_serialized_handler<T>(comm, source, tag, value));
  452. #endif
  453. }
  454. template<typename T>
  455. request request::make_serialized_array(communicator const& comm, int source, int tag, T* values, int n) {
  456. #if defined(BOOST_MPI_USE_IMPROBE)
  457. return request(new probe_handler<detail::serialized_array_data<T> >(comm, source, tag, values, n));
  458. #else
  459. return request(new legacy_serialized_array_handler<T>(comm, source, tag, values, n));
  460. #endif
  461. }
  462. template<typename T, class A>
  463. request request::make_dynamic_primitive_array_recv(communicator const& comm, int source, int tag,
  464. std::vector<T,A>& values) {
  465. #if defined(BOOST_MPI_USE_IMPROBE)
  466. return request(new probe_handler<detail::dynamic_primitive_array_data<std::vector<T,A> > >(comm,source,tag,values));
  467. #else
  468. return request(new legacy_dynamic_primitive_array_handler<T,A>(comm, source, tag, values));
  469. #endif
  470. }
  471. template<typename T>
  472. request
  473. request::make_trivial_send(communicator const& comm, int dest, int tag, T const* values, int n) {
  474. trivial_handler* handler = new trivial_handler;
  475. BOOST_MPI_CHECK_RESULT(MPI_Isend,
  476. (const_cast<T*>(values), n,
  477. get_mpi_datatype<T>(),
  478. dest, tag, comm, &handler->m_request));
  479. return request(handler);
  480. }
  481. template<typename T>
  482. request
  483. request::make_trivial_send(communicator const& comm, int dest, int tag, T const& value) {
  484. return make_trivial_send(comm, dest, tag, &value, 1);
  485. }
  486. template<typename T>
  487. request
  488. request::make_trivial_recv(communicator const& comm, int dest, int tag, T* values, int n) {
  489. trivial_handler* handler = new trivial_handler;
  490. BOOST_MPI_CHECK_RESULT(MPI_Irecv,
  491. (values, n,
  492. get_mpi_datatype<T>(),
  493. dest, tag, comm, &handler->m_request));
  494. return request(handler);
  495. }
  496. template<typename T>
  497. request
  498. request::make_trivial_recv(communicator const& comm, int dest, int tag, T& value) {
  499. return make_trivial_recv(comm, dest, tag, &value, 1);
  500. }
  501. template<typename T, class A>
  502. request request::make_dynamic_primitive_array_send(communicator const& comm, int dest, int tag,
  503. std::vector<T,A> const& values) {
  504. #if defined(BOOST_MPI_USE_IMPROBE)
  505. return make_trivial_send(comm, dest, tag, values.data(), values.size());
  506. #else
  507. {
  508. // non blocking recv by legacy_dynamic_primitive_array_handler
  509. // blocking recv by status recv_vector(source,tag,value,primitive)
  510. boost::shared_ptr<std::size_t> size(new std::size_t(values.size()));
  511. dynamic_handler* handler = new dynamic_handler;
  512. request req(handler);
  513. req.preserve(size);
  514. BOOST_MPI_CHECK_RESULT(MPI_Isend,
  515. (size.get(), 1,
  516. get_mpi_datatype(*size),
  517. dest, tag, comm, handler->m_requests+0));
  518. BOOST_MPI_CHECK_RESULT(MPI_Isend,
  519. (const_cast<T*>(values.data()), *size,
  520. get_mpi_datatype<T>(),
  521. dest, tag, comm, handler->m_requests+1));
  522. return req;
  523. }
  524. #endif
  525. }
  526. inline
  527. request::legacy_handler::legacy_handler(communicator const& comm, int source, int tag)
  528. : m_comm(comm),
  529. m_source(source),
  530. m_tag(tag)
  531. {
  532. m_requests[0] = MPI_REQUEST_NULL;
  533. m_requests[1] = MPI_REQUEST_NULL;
  534. }
  535. }}
  536. #endif // BOOST_MPI_REQUEST_HANDLERS_HPP