123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328 |
- #include <boost/asio/dispatch.hpp>
- #include <boost/asio/execution_context.hpp>
- #include <boost/asio/thread_pool.hpp>
- #include <condition_variable>
- #include <memory>
- #include <mutex>
- #include <queue>
- #include <thread>
- #include <numeric>
- using boost::asio::dispatch;
- using boost::asio::execution_context;
- using boost::asio::thread_pool;
- // A fixed-size thread pool used to implement fork/join semantics. Functions
- // are scheduled using a simple FIFO queue. Implementing work stealing, or
- // using a queue based on atomic operations, are left as tasks for the reader.
- class fork_join_pool : public execution_context
- {
- public:
- // The constructor starts a thread pool with the specified number of threads.
- // Note that the thread_count is not a fixed limit on the pool's concurrency.
- // Additional threads may temporarily be added to the pool if they join a
- // fork_executor.
- explicit fork_join_pool(
- std::size_t thread_count = std::thread::hardware_concurrency() * 2)
- : use_count_(1),
- threads_(thread_count)
- {
- try
- {
- // Ask each thread in the pool to dequeue and execute functions until
- // it is time to shut down, i.e. the use count is zero.
- for (thread_count_ = 0; thread_count_ < thread_count; ++thread_count_)
- {
- dispatch(threads_, [&]
- {
- std::unique_lock<std::mutex> lock(mutex_);
- while (use_count_ > 0)
- if (!execute_next(lock))
- condition_.wait(lock);
- });
- }
- }
- catch (...)
- {
- stop_threads();
- threads_.join();
- throw;
- }
- }
- // The destructor waits for the pool to finish executing functions.
- ~fork_join_pool()
- {
- stop_threads();
- threads_.join();
- }
- private:
- friend class fork_executor;
- // The base for all functions that are queued in the pool.
- struct function_base
- {
- std::shared_ptr<std::size_t> work_count_;
- void (*execute_)(std::shared_ptr<function_base>& p);
- };
- // Execute the next function from the queue, if any. Returns true if a
- // function was executed, and false if the queue was empty.
- bool execute_next(std::unique_lock<std::mutex>& lock)
- {
- if (queue_.empty())
- return false;
- auto p(queue_.front());
- queue_.pop();
- lock.unlock();
- execute(lock, p);
- return true;
- }
- // Execute a function and decrement the outstanding work.
- void execute(std::unique_lock<std::mutex>& lock,
- std::shared_ptr<function_base>& p)
- {
- std::shared_ptr<std::size_t> work_count(std::move(p->work_count_));
- try
- {
- p->execute_(p);
- lock.lock();
- do_work_finished(work_count);
- }
- catch (...)
- {
- lock.lock();
- do_work_finished(work_count);
- throw;
- }
- }
- // Increment outstanding work.
- void do_work_started(const std::shared_ptr<std::size_t>& work_count) noexcept
- {
- if (++(*work_count) == 1)
- ++use_count_;
- }
- // Decrement outstanding work. Notify waiting threads if we run out.
- void do_work_finished(const std::shared_ptr<std::size_t>& work_count) noexcept
- {
- if (--(*work_count) == 0)
- {
- --use_count_;
- condition_.notify_all();
- }
- }
- // Dispatch a function, executing it immediately if the queue is already
- // loaded. Otherwise adds the function to the queue and wakes a thread.
- void do_dispatch(std::shared_ptr<function_base> p,
- const std::shared_ptr<std::size_t>& work_count)
- {
- std::unique_lock<std::mutex> lock(mutex_);
- if (queue_.size() > thread_count_ * 16)
- {
- do_work_started(work_count);
- lock.unlock();
- execute(lock, p);
- }
- else
- {
- queue_.push(p);
- do_work_started(work_count);
- condition_.notify_one();
- }
- }
- // Add a function to the queue and wake a thread.
- void do_post(std::shared_ptr<function_base> p,
- const std::shared_ptr<std::size_t>& work_count)
- {
- std::lock_guard<std::mutex> lock(mutex_);
- queue_.push(p);
- do_work_started(work_count);
- condition_.notify_one();
- }
- // Ask all threads to shut down.
- void stop_threads()
- {
- std::lock_guard<std::mutex> lock(mutex_);
- --use_count_;
- condition_.notify_all();
- }
- std::mutex mutex_;
- std::condition_variable condition_;
- std::queue<std::shared_ptr<function_base>> queue_;
- std::size_t use_count_;
- std::size_t thread_count_;
- thread_pool threads_;
- };
- // A class that satisfies the Executor requirements. Every function or piece of
- // work associated with a fork_executor is part of a single, joinable group.
- class fork_executor
- {
- public:
- fork_executor(fork_join_pool& ctx)
- : context_(ctx),
- work_count_(std::make_shared<std::size_t>(0))
- {
- }
- fork_join_pool& context() const noexcept
- {
- return context_;
- }
- void on_work_started() const noexcept
- {
- std::lock_guard<std::mutex> lock(context_.mutex_);
- context_.do_work_started(work_count_);
- }
- void on_work_finished() const noexcept
- {
- std::lock_guard<std::mutex> lock(context_.mutex_);
- context_.do_work_finished(work_count_);
- }
- template <class Func, class Alloc>
- void dispatch(Func&& f, const Alloc& a) const
- {
- auto p(std::allocate_shared<function<Func>>(
- typename std::allocator_traits<Alloc>::template rebind_alloc<char>(a),
- std::move(f), work_count_));
- context_.do_dispatch(p, work_count_);
- }
- template <class Func, class Alloc>
- void post(Func f, const Alloc& a) const
- {
- auto p(std::allocate_shared<function<Func>>(
- typename std::allocator_traits<Alloc>::template rebind_alloc<char>(a),
- std::move(f), work_count_));
- context_.do_post(p, work_count_);
- }
- template <class Func, class Alloc>
- void defer(Func&& f, const Alloc& a) const
- {
- post(std::forward<Func>(f), a);
- }
- friend bool operator==(const fork_executor& a,
- const fork_executor& b) noexcept
- {
- return a.work_count_ == b.work_count_;
- }
- friend bool operator!=(const fork_executor& a,
- const fork_executor& b) noexcept
- {
- return a.work_count_ != b.work_count_;
- }
- // Block until all work associated with the executor is complete. While it is
- // waiting, the thread may be borrowed to execute functions from the queue.
- void join() const
- {
- std::unique_lock<std::mutex> lock(context_.mutex_);
- while (*work_count_ > 0)
- if (!context_.execute_next(lock))
- context_.condition_.wait(lock);
- }
- private:
- template <class Func>
- struct function : fork_join_pool::function_base
- {
- explicit function(Func f, const std::shared_ptr<std::size_t>& w)
- : function_(std::move(f))
- {
- work_count_ = w;
- execute_ = [](std::shared_ptr<fork_join_pool::function_base>& p)
- {
- Func tmp(std::move(static_cast<function*>(p.get())->function_));
- p.reset();
- tmp();
- };
- }
- Func function_;
- };
- fork_join_pool& context_;
- std::shared_ptr<std::size_t> work_count_;
- };
- // Helper class to automatically join a fork_executor when exiting a scope.
- class join_guard
- {
- public:
- explicit join_guard(const fork_executor& ex) : ex_(ex) {}
- join_guard(const join_guard&) = delete;
- join_guard(join_guard&&) = delete;
- ~join_guard() { ex_.join(); }
- private:
- fork_executor ex_;
- };
- //------------------------------------------------------------------------------
- #include <algorithm>
- #include <iostream>
- #include <random>
- #include <vector>
- fork_join_pool pool;
- template <class Iterator>
- void fork_join_sort(Iterator begin, Iterator end)
- {
- std::size_t n = end - begin;
- if (n > 32768)
- {
- {
- fork_executor fork(pool);
- join_guard join(fork);
- dispatch(fork, [=]{ fork_join_sort(begin, begin + n / 2); });
- dispatch(fork, [=]{ fork_join_sort(begin + n / 2, end); });
- }
- std::inplace_merge(begin, begin + n / 2, end);
- }
- else
- {
- std::sort(begin, end);
- }
- }
- int main(int argc, char* argv[])
- {
- if (argc != 2)
- {
- std::cerr << "Usage: fork_join <size>\n";
- return 1;
- }
- std::vector<double> vec(std::atoll(argv[1]));
- std::iota(vec.begin(), vec.end(), 0);
- std::random_device rd;
- std::mt19937 g(rd());
- std::shuffle(vec.begin(), vec.end(), g);
- std::chrono::steady_clock::time_point start = std::chrono::steady_clock::now();
- fork_join_sort(vec.begin(), vec.end());
- std::chrono::steady_clock::duration elapsed = std::chrono::steady_clock::now() - start;
- std::cout << "sort took ";
- std::cout << std::chrono::duration_cast<std::chrono::microseconds>(elapsed).count();
- std::cout << " microseconds" << std::endl;
- }
|