count_if_with_threads.hpp 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. //---------------------------------------------------------------------------//
  2. // Copyright (c) 2013 Kyle Lutz <kyle.r.lutz@gmail.com>
  3. //
  4. // Distributed under the Boost Software License, Version 1.0
  5. // See accompanying file LICENSE_1_0.txt or copy at
  6. // http://www.boost.org/LICENSE_1_0.txt
  7. //
  8. // See http://boostorg.github.com/compute for more information.
  9. //---------------------------------------------------------------------------//
  10. #ifndef BOOST_COMPUTE_ALGORITHM_DETAIL_COUNT_IF_WITH_THREADS_HPP
  11. #define BOOST_COMPUTE_ALGORITHM_DETAIL_COUNT_IF_WITH_THREADS_HPP
  12. #include <numeric>
  13. #include <boost/compute/detail/meta_kernel.hpp>
  14. #include <boost/compute/container/vector.hpp>
  15. namespace boost {
  16. namespace compute {
  17. namespace detail {
  18. template<class InputIterator, class Predicate>
  19. class count_if_with_threads_kernel : meta_kernel
  20. {
  21. public:
  22. typedef typename
  23. std::iterator_traits<InputIterator>::value_type
  24. value_type;
  25. count_if_with_threads_kernel()
  26. : meta_kernel("count_if_with_threads")
  27. {
  28. }
  29. void set_args(InputIterator first,
  30. InputIterator last,
  31. Predicate predicate)
  32. {
  33. typedef typename std::iterator_traits<InputIterator>::value_type T;
  34. m_size = detail::iterator_range_size(first, last);
  35. m_size_arg = add_arg<const ulong_>("size");
  36. m_counts_arg = add_arg<ulong_ *>(memory_object::global_memory, "counts");
  37. *this <<
  38. // thread parameters
  39. "const uint gid = get_global_id(0);\n" <<
  40. "const uint block_size = size / get_global_size(0);\n" <<
  41. "const uint start = block_size * gid;\n" <<
  42. "uint end = 0;\n" <<
  43. "if(gid == get_global_size(0) - 1)\n" <<
  44. " end = size;\n" <<
  45. "else\n" <<
  46. " end = block_size * gid + block_size;\n" <<
  47. // count values
  48. "uint count = 0;\n" <<
  49. "for(uint i = start; i < end; i++){\n" <<
  50. decl<const T>("value") << "="
  51. << first[expr<uint_>("i")] << ";\n" <<
  52. if_(predicate(var<const T>("value"))) << "{\n" <<
  53. "count++;\n" <<
  54. "}\n" <<
  55. "}\n" <<
  56. // write count
  57. "counts[gid] = count;\n";
  58. }
  59. size_t exec(command_queue &queue)
  60. {
  61. const device &device = queue.get_device();
  62. const context &context = queue.get_context();
  63. size_t threads = device.compute_units();
  64. const size_t minimum_block_size = 2048;
  65. if(m_size / threads < minimum_block_size){
  66. threads = static_cast<size_t>(
  67. (std::max)(
  68. std::ceil(float(m_size) / minimum_block_size),
  69. 1.0f
  70. )
  71. );
  72. }
  73. // storage for counts
  74. ::boost::compute::vector<ulong_> counts(threads, context);
  75. // exec kernel
  76. set_arg(m_size_arg, static_cast<ulong_>(m_size));
  77. set_arg(m_counts_arg, counts.get_buffer());
  78. exec_1d(queue, 0, threads, 1);
  79. // copy counts to the host
  80. std::vector<ulong_> host_counts(threads);
  81. ::boost::compute::copy(counts.begin(), counts.end(), host_counts.begin(), queue);
  82. // return sum of counts
  83. return std::accumulate(host_counts.begin(), host_counts.end(), size_t(0));
  84. }
  85. private:
  86. size_t m_size;
  87. size_t m_size_arg;
  88. size_t m_counts_arg;
  89. };
  90. // counts values that match the predicate using one thread per block. this is
  91. // optimized for cpu-type devices with a small number of compute units.
  92. template<class InputIterator, class Predicate>
  93. inline size_t count_if_with_threads(InputIterator first,
  94. InputIterator last,
  95. Predicate predicate,
  96. command_queue &queue)
  97. {
  98. count_if_with_threads_kernel<InputIterator, Predicate> kernel;
  99. kernel.set_args(first, last, predicate);
  100. return kernel.exec(queue);
  101. }
  102. } // end detail namespace
  103. } // end compute namespace
  104. } // end boost namespace
  105. #endif // BOOST_COMPUTE_ALGORITHM_DETAIL_COUNT_IF_WITH_THREADS_HPP