fork_join.cpp 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. #include <boost/asio/dispatch.hpp>
  2. #include <boost/asio/execution_context.hpp>
  3. #include <boost/asio/thread_pool.hpp>
  4. #include <condition_variable>
  5. #include <memory>
  6. #include <mutex>
  7. #include <queue>
  8. #include <thread>
  9. #include <numeric>
  10. using boost::asio::dispatch;
  11. using boost::asio::execution_context;
  12. using boost::asio::thread_pool;
  13. // A fixed-size thread pool used to implement fork/join semantics. Functions
  14. // are scheduled using a simple FIFO queue. Implementing work stealing, or
  15. // using a queue based on atomic operations, are left as tasks for the reader.
  16. class fork_join_pool : public execution_context
  17. {
  18. public:
  19. // The constructor starts a thread pool with the specified number of threads.
  20. // Note that the thread_count is not a fixed limit on the pool's concurrency.
  21. // Additional threads may temporarily be added to the pool if they join a
  22. // fork_executor.
  23. explicit fork_join_pool(
  24. std::size_t thread_count = std::thread::hardware_concurrency() * 2)
  25. : use_count_(1),
  26. threads_(thread_count)
  27. {
  28. try
  29. {
  30. // Ask each thread in the pool to dequeue and execute functions until
  31. // it is time to shut down, i.e. the use count is zero.
  32. for (thread_count_ = 0; thread_count_ < thread_count; ++thread_count_)
  33. {
  34. dispatch(threads_, [&]
  35. {
  36. std::unique_lock<std::mutex> lock(mutex_);
  37. while (use_count_ > 0)
  38. if (!execute_next(lock))
  39. condition_.wait(lock);
  40. });
  41. }
  42. }
  43. catch (...)
  44. {
  45. stop_threads();
  46. threads_.join();
  47. throw;
  48. }
  49. }
  50. // The destructor waits for the pool to finish executing functions.
  51. ~fork_join_pool()
  52. {
  53. stop_threads();
  54. threads_.join();
  55. }
  56. private:
  57. friend class fork_executor;
  58. // The base for all functions that are queued in the pool.
  59. struct function_base
  60. {
  61. std::shared_ptr<std::size_t> work_count_;
  62. void (*execute_)(std::shared_ptr<function_base>& p);
  63. };
  64. // Execute the next function from the queue, if any. Returns true if a
  65. // function was executed, and false if the queue was empty.
  66. bool execute_next(std::unique_lock<std::mutex>& lock)
  67. {
  68. if (queue_.empty())
  69. return false;
  70. auto p(queue_.front());
  71. queue_.pop();
  72. lock.unlock();
  73. execute(lock, p);
  74. return true;
  75. }
  76. // Execute a function and decrement the outstanding work.
  77. void execute(std::unique_lock<std::mutex>& lock,
  78. std::shared_ptr<function_base>& p)
  79. {
  80. std::shared_ptr<std::size_t> work_count(std::move(p->work_count_));
  81. try
  82. {
  83. p->execute_(p);
  84. lock.lock();
  85. do_work_finished(work_count);
  86. }
  87. catch (...)
  88. {
  89. lock.lock();
  90. do_work_finished(work_count);
  91. throw;
  92. }
  93. }
  94. // Increment outstanding work.
  95. void do_work_started(const std::shared_ptr<std::size_t>& work_count) noexcept
  96. {
  97. if (++(*work_count) == 1)
  98. ++use_count_;
  99. }
  100. // Decrement outstanding work. Notify waiting threads if we run out.
  101. void do_work_finished(const std::shared_ptr<std::size_t>& work_count) noexcept
  102. {
  103. if (--(*work_count) == 0)
  104. {
  105. --use_count_;
  106. condition_.notify_all();
  107. }
  108. }
  109. // Dispatch a function, executing it immediately if the queue is already
  110. // loaded. Otherwise adds the function to the queue and wakes a thread.
  111. void do_dispatch(std::shared_ptr<function_base> p,
  112. const std::shared_ptr<std::size_t>& work_count)
  113. {
  114. std::unique_lock<std::mutex> lock(mutex_);
  115. if (queue_.size() > thread_count_ * 16)
  116. {
  117. do_work_started(work_count);
  118. lock.unlock();
  119. execute(lock, p);
  120. }
  121. else
  122. {
  123. queue_.push(p);
  124. do_work_started(work_count);
  125. condition_.notify_one();
  126. }
  127. }
  128. // Add a function to the queue and wake a thread.
  129. void do_post(std::shared_ptr<function_base> p,
  130. const std::shared_ptr<std::size_t>& work_count)
  131. {
  132. std::lock_guard<std::mutex> lock(mutex_);
  133. queue_.push(p);
  134. do_work_started(work_count);
  135. condition_.notify_one();
  136. }
  137. // Ask all threads to shut down.
  138. void stop_threads()
  139. {
  140. std::lock_guard<std::mutex> lock(mutex_);
  141. --use_count_;
  142. condition_.notify_all();
  143. }
  144. std::mutex mutex_;
  145. std::condition_variable condition_;
  146. std::queue<std::shared_ptr<function_base>> queue_;
  147. std::size_t use_count_;
  148. std::size_t thread_count_;
  149. thread_pool threads_;
  150. };
  151. // A class that satisfies the Executor requirements. Every function or piece of
  152. // work associated with a fork_executor is part of a single, joinable group.
  153. class fork_executor
  154. {
  155. public:
  156. fork_executor(fork_join_pool& ctx)
  157. : context_(ctx),
  158. work_count_(std::make_shared<std::size_t>(0))
  159. {
  160. }
  161. fork_join_pool& context() const noexcept
  162. {
  163. return context_;
  164. }
  165. void on_work_started() const noexcept
  166. {
  167. std::lock_guard<std::mutex> lock(context_.mutex_);
  168. context_.do_work_started(work_count_);
  169. }
  170. void on_work_finished() const noexcept
  171. {
  172. std::lock_guard<std::mutex> lock(context_.mutex_);
  173. context_.do_work_finished(work_count_);
  174. }
  175. template <class Func, class Alloc>
  176. void dispatch(Func&& f, const Alloc& a) const
  177. {
  178. auto p(std::allocate_shared<function<Func>>(
  179. typename std::allocator_traits<Alloc>::template rebind_alloc<char>(a),
  180. std::move(f), work_count_));
  181. context_.do_dispatch(p, work_count_);
  182. }
  183. template <class Func, class Alloc>
  184. void post(Func f, const Alloc& a) const
  185. {
  186. auto p(std::allocate_shared<function<Func>>(
  187. typename std::allocator_traits<Alloc>::template rebind_alloc<char>(a),
  188. std::move(f), work_count_));
  189. context_.do_post(p, work_count_);
  190. }
  191. template <class Func, class Alloc>
  192. void defer(Func&& f, const Alloc& a) const
  193. {
  194. post(std::forward<Func>(f), a);
  195. }
  196. friend bool operator==(const fork_executor& a,
  197. const fork_executor& b) noexcept
  198. {
  199. return a.work_count_ == b.work_count_;
  200. }
  201. friend bool operator!=(const fork_executor& a,
  202. const fork_executor& b) noexcept
  203. {
  204. return a.work_count_ != b.work_count_;
  205. }
  206. // Block until all work associated with the executor is complete. While it is
  207. // waiting, the thread may be borrowed to execute functions from the queue.
  208. void join() const
  209. {
  210. std::unique_lock<std::mutex> lock(context_.mutex_);
  211. while (*work_count_ > 0)
  212. if (!context_.execute_next(lock))
  213. context_.condition_.wait(lock);
  214. }
  215. private:
  216. template <class Func>
  217. struct function : fork_join_pool::function_base
  218. {
  219. explicit function(Func f, const std::shared_ptr<std::size_t>& w)
  220. : function_(std::move(f))
  221. {
  222. work_count_ = w;
  223. execute_ = [](std::shared_ptr<fork_join_pool::function_base>& p)
  224. {
  225. Func tmp(std::move(static_cast<function*>(p.get())->function_));
  226. p.reset();
  227. tmp();
  228. };
  229. }
  230. Func function_;
  231. };
  232. fork_join_pool& context_;
  233. std::shared_ptr<std::size_t> work_count_;
  234. };
  235. // Helper class to automatically join a fork_executor when exiting a scope.
  236. class join_guard
  237. {
  238. public:
  239. explicit join_guard(const fork_executor& ex) : ex_(ex) {}
  240. join_guard(const join_guard&) = delete;
  241. join_guard(join_guard&&) = delete;
  242. ~join_guard() { ex_.join(); }
  243. private:
  244. fork_executor ex_;
  245. };
  246. //------------------------------------------------------------------------------
  247. #include <algorithm>
  248. #include <iostream>
  249. #include <random>
  250. #include <vector>
  251. fork_join_pool pool;
  252. template <class Iterator>
  253. void fork_join_sort(Iterator begin, Iterator end)
  254. {
  255. std::size_t n = end - begin;
  256. if (n > 32768)
  257. {
  258. {
  259. fork_executor fork(pool);
  260. join_guard join(fork);
  261. dispatch(fork, [=]{ fork_join_sort(begin, begin + n / 2); });
  262. dispatch(fork, [=]{ fork_join_sort(begin + n / 2, end); });
  263. }
  264. std::inplace_merge(begin, begin + n / 2, end);
  265. }
  266. else
  267. {
  268. std::sort(begin, end);
  269. }
  270. }
  271. int main(int argc, char* argv[])
  272. {
  273. if (argc != 2)
  274. {
  275. std::cerr << "Usage: fork_join <size>\n";
  276. return 1;
  277. }
  278. std::vector<double> vec(std::atoll(argv[1]));
  279. std::iota(vec.begin(), vec.end(), 0);
  280. std::random_device rd;
  281. std::mt19937 g(rd());
  282. std::shuffle(vec.begin(), vec.end(), g);
  283. std::chrono::steady_clock::time_point start = std::chrono::steady_clock::now();
  284. fork_join_sort(vec.begin(), vec.end());
  285. std::chrono::steady_clock::duration elapsed = std::chrono::steady_clock::now() - start;
  286. std::cout << "sort took ";
  287. std::cout << std::chrono::duration_cast<std::chrono::microseconds>(elapsed).count();
  288. std::cout << " microseconds" << std::endl;
  289. }