scan_on_cpu.hpp 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. //---------------------------------------------------------------------------//
  2. // Copyright (c) 2016 Jakub Szuppe <j.szuppe@gmail.com>
  3. //
  4. // Distributed under the Boost Software License, Version 1.0
  5. // See accompanying file LICENSE_1_0.txt or copy at
  6. // http://www.boost.org/LICENSE_1_0.txt
  7. //
  8. // See http://boostorg.github.com/compute for more information.
  9. //---------------------------------------------------------------------------//
  10. #ifndef BOOST_COMPUTE_ALGORITHM_DETAIL_SCAN_ON_CPU_HPP
  11. #define BOOST_COMPUTE_ALGORITHM_DETAIL_SCAN_ON_CPU_HPP
  12. #include <iterator>
  13. #include <boost/compute/device.hpp>
  14. #include <boost/compute/kernel.hpp>
  15. #include <boost/compute/command_queue.hpp>
  16. #include <boost/compute/algorithm/detail/serial_scan.hpp>
  17. #include <boost/compute/detail/meta_kernel.hpp>
  18. #include <boost/compute/detail/iterator_range_size.hpp>
  19. #include <boost/compute/detail/parameter_cache.hpp>
  20. namespace boost {
  21. namespace compute {
  22. namespace detail {
  23. template<class InputIterator, class OutputIterator, class T, class BinaryOperator>
  24. inline OutputIterator scan_on_cpu(InputIterator first,
  25. InputIterator last,
  26. OutputIterator result,
  27. bool exclusive,
  28. T init,
  29. BinaryOperator op,
  30. command_queue &queue)
  31. {
  32. typedef typename
  33. std::iterator_traits<InputIterator>::value_type input_type;
  34. typedef typename
  35. std::iterator_traits<OutputIterator>::value_type output_type;
  36. const context &context = queue.get_context();
  37. const device &device = queue.get_device();
  38. const size_t compute_units = queue.get_device().compute_units();
  39. boost::shared_ptr<parameter_cache> parameters =
  40. detail::parameter_cache::get_global_cache(device);
  41. std::string cache_key =
  42. "__boost_scan_cpu_" + boost::lexical_cast<std::string>(sizeof(T));
  43. // for inputs smaller than serial_scan_threshold
  44. // serial_scan algorithm is used
  45. uint_ serial_scan_threshold =
  46. parameters->get(cache_key, "serial_scan_threshold", 16384 * sizeof(T));
  47. serial_scan_threshold =
  48. (std::max)(serial_scan_threshold, uint_(compute_units));
  49. size_t count = detail::iterator_range_size(first, last);
  50. if(count == 0){
  51. return result;
  52. }
  53. else if(count < serial_scan_threshold) {
  54. return serial_scan(first, last, result, exclusive, init, op, queue);
  55. }
  56. buffer block_partial_sums(context, sizeof(output_type) * compute_units );
  57. // create scan kernel
  58. meta_kernel k("scan_on_cpu_block_scan");
  59. // Arguments
  60. size_t count_arg = k.add_arg<uint_>("count");
  61. size_t init_arg = k.add_arg<output_type>("initial_value");
  62. size_t block_partial_sums_arg =
  63. k.add_arg<output_type *>(memory_object::global_memory, "block_partial_sums");
  64. k <<
  65. "uint block = (count + get_global_size(0))/(get_global_size(0) + 1);\n" <<
  66. "uint index = get_global_id(0) * block;\n" <<
  67. "uint end = min(count, index + block);\n" <<
  68. "if(index >= end) return;\n";
  69. if(!exclusive){
  70. k <<
  71. k.decl<output_type>("sum") << " = " <<
  72. first[k.var<uint_>("index")] << ";\n" <<
  73. result[k.var<uint_>("index")] << " = sum;\n" <<
  74. "index++;\n";
  75. }
  76. else {
  77. k <<
  78. k.decl<output_type>("sum") << ";\n" <<
  79. "if(index == 0){\n" <<
  80. "sum = initial_value;\n" <<
  81. "}\n" <<
  82. "else {\n" <<
  83. "sum = " << first[k.var<uint_>("index")] << ";\n" <<
  84. "index++;\n" <<
  85. "}\n";
  86. }
  87. k <<
  88. "while(index < end){\n" <<
  89. // load next value
  90. k.decl<const input_type>("value") << " = "
  91. << first[k.var<uint_>("index")] << ";\n";
  92. if(exclusive){
  93. k <<
  94. "if(get_global_id(0) == 0){\n" <<
  95. result[k.var<uint_>("index")] << " = sum;\n" <<
  96. "}\n";
  97. }
  98. k <<
  99. "sum = " << op(k.var<output_type>("sum"),
  100. k.var<output_type>("value")) << ";\n";
  101. if(!exclusive){
  102. k <<
  103. "if(get_global_id(0) == 0){\n" <<
  104. result[k.var<uint_>("index")] << " = sum;\n" <<
  105. "}\n";
  106. }
  107. k <<
  108. "index++;\n" <<
  109. "}\n" << // end while
  110. "block_partial_sums[get_global_id(0)] = sum;\n";
  111. // compile scan kernel
  112. kernel block_scan_kernel = k.compile(context);
  113. // setup kernel arguments
  114. block_scan_kernel.set_arg(count_arg, static_cast<uint_>(count));
  115. block_scan_kernel.set_arg(init_arg, static_cast<output_type>(init));
  116. block_scan_kernel.set_arg(block_partial_sums_arg, block_partial_sums);
  117. // execute the kernel
  118. size_t global_work_size = compute_units;
  119. queue.enqueue_1d_range_kernel(block_scan_kernel, 0, global_work_size, 0);
  120. // scan is done
  121. if(compute_units < 2) {
  122. return result + count;
  123. }
  124. // final scan kernel
  125. meta_kernel l("scan_on_cpu_final_scan");
  126. // Arguments
  127. count_arg = l.add_arg<uint_>("count");
  128. block_partial_sums_arg =
  129. l.add_arg<output_type *>(memory_object::global_memory, "block_partial_sums");
  130. l <<
  131. "uint block = (count + get_global_size(0))/(get_global_size(0) + 1);\n" <<
  132. "uint index = block + get_global_id(0) * block;\n" <<
  133. "uint end = min(count, index + block);\n" <<
  134. k.decl<output_type>("sum") << " = block_partial_sums[0];\n" <<
  135. "for(uint i = 0; i < get_global_id(0); i++) {\n" <<
  136. "sum = " << op(k.var<output_type>("sum"),
  137. k.var<output_type>("block_partial_sums[i + 1]")) << ";\n" <<
  138. "}\n" <<
  139. "while(index < end){\n";
  140. if(exclusive){
  141. l <<
  142. l.decl<output_type>("value") << " = "
  143. << first[k.var<uint_>("index")] << ";\n" <<
  144. result[k.var<uint_>("index")] << " = sum;\n" <<
  145. "sum = " << op(k.var<output_type>("sum"),
  146. k.var<output_type>("value")) << ";\n";
  147. }
  148. else {
  149. l <<
  150. "sum = " << op(k.var<output_type>("sum"),
  151. first[k.var<uint_>("index")]) << ";\n" <<
  152. result[k.var<uint_>("index")] << " = sum;\n";
  153. }
  154. l <<
  155. "index++;\n" <<
  156. "}\n";
  157. // compile scan kernel
  158. kernel final_scan_kernel = l.compile(context);
  159. // setup kernel arguments
  160. final_scan_kernel.set_arg(count_arg, static_cast<uint_>(count));
  161. final_scan_kernel.set_arg(block_partial_sums_arg, block_partial_sums);
  162. // execute the kernel
  163. global_work_size = compute_units;
  164. queue.enqueue_1d_range_kernel(final_scan_kernel, 0, global_work_size, 0);
  165. // return iterator pointing to the end of the result range
  166. return result + count;
  167. }
  168. } // end detail namespace
  169. } // end compute namespace
  170. } // end boost namespace
  171. #endif // BOOST_COMPUTE_ALGORITHM_DETAIL_SCAN_ON_CPU_HPP