scan_on_gpu.hpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. //---------------------------------------------------------------------------//
  2. // Copyright (c) 2013 Kyle Lutz <kyle.r.lutz@gmail.com>
  3. //
  4. // Distributed under the Boost Software License, Version 1.0
  5. // See accompanying file LICENSE_1_0.txt or copy at
  6. // http://www.boost.org/LICENSE_1_0.txt
  7. //
  8. // See http://boostorg.github.com/compute for more information.
  9. //---------------------------------------------------------------------------//
  10. #ifndef BOOST_COMPUTE_ALGORITHM_DETAIL_SCAN_ON_GPU_HPP
  11. #define BOOST_COMPUTE_ALGORITHM_DETAIL_SCAN_ON_GPU_HPP
  12. #include <boost/compute/kernel.hpp>
  13. #include <boost/compute/detail/meta_kernel.hpp>
  14. #include <boost/compute/command_queue.hpp>
  15. #include <boost/compute/container/vector.hpp>
  16. #include <boost/compute/detail/iterator_range_size.hpp>
  17. #include <boost/compute/memory/local_buffer.hpp>
  18. #include <boost/compute/iterator/buffer_iterator.hpp>
  19. namespace boost {
  20. namespace compute {
  21. namespace detail {
  22. template<class InputIterator, class OutputIterator, class BinaryOperator>
  23. class local_scan_kernel : public meta_kernel
  24. {
  25. public:
  26. local_scan_kernel(InputIterator first,
  27. InputIterator last,
  28. OutputIterator result,
  29. bool exclusive,
  30. BinaryOperator op)
  31. : meta_kernel("local_scan")
  32. {
  33. typedef typename std::iterator_traits<InputIterator>::value_type T;
  34. (void) last;
  35. bool checked = true;
  36. m_block_sums_arg = add_arg<T *>(memory_object::global_memory, "block_sums");
  37. m_scratch_arg = add_arg<T *>(memory_object::local_memory, "scratch");
  38. m_block_size_arg = add_arg<const cl_uint>("block_size");
  39. m_count_arg = add_arg<const cl_uint>("count");
  40. m_init_value_arg = add_arg<const T>("init");
  41. // work-item parameters
  42. *this <<
  43. "const uint gid = get_global_id(0);\n" <<
  44. "const uint lid = get_local_id(0);\n";
  45. // check against data size
  46. if(checked){
  47. *this <<
  48. "if(gid < count){\n";
  49. }
  50. // copy values from input to local memory
  51. if(exclusive){
  52. *this <<
  53. decl<const T>("local_init") << "= (gid == 0) ? init : 0;\n" <<
  54. "if(lid == 0){ scratch[lid] = local_init; }\n" <<
  55. "else { scratch[lid] = " << first[expr<cl_uint>("gid-1")] << "; }\n";
  56. }
  57. else{
  58. *this <<
  59. "scratch[lid] = " << first[expr<cl_uint>("gid")] << ";\n";
  60. }
  61. if(checked){
  62. *this <<
  63. "}\n"
  64. "else {\n" <<
  65. " scratch[lid] = 0;\n" <<
  66. "}\n";
  67. }
  68. // wait for all threads to read from input
  69. *this <<
  70. "barrier(CLK_LOCAL_MEM_FENCE);\n";
  71. // perform scan
  72. *this <<
  73. "for(uint i = 1; i < block_size; i <<= 1){\n" <<
  74. " " << decl<const T>("x") << " = lid >= i ? scratch[lid-i] : 0;\n" <<
  75. " barrier(CLK_LOCAL_MEM_FENCE);\n" <<
  76. " if(lid >= i){\n" <<
  77. " scratch[lid] = " << op(var<T>("scratch[lid]"), var<T>("x")) << ";\n" <<
  78. " }\n" <<
  79. " barrier(CLK_LOCAL_MEM_FENCE);\n" <<
  80. "}\n";
  81. // copy results to output
  82. if(checked){
  83. *this <<
  84. "if(gid < count){\n";
  85. }
  86. *this <<
  87. result[expr<cl_uint>("gid")] << " = scratch[lid];\n";
  88. if(checked){
  89. *this << "}\n";
  90. }
  91. // store sum for the block
  92. if(exclusive){
  93. *this <<
  94. "if(lid == block_size - 1 && gid < count) {\n" <<
  95. " block_sums[get_group_id(0)] = " <<
  96. op(first[expr<cl_uint>("gid")], var<T>("scratch[lid]")) <<
  97. ";\n" <<
  98. "}\n";
  99. }
  100. else {
  101. *this <<
  102. "if(lid == block_size - 1){\n" <<
  103. " block_sums[get_group_id(0)] = scratch[lid];\n" <<
  104. "}\n";
  105. }
  106. }
  107. size_t m_block_sums_arg;
  108. size_t m_scratch_arg;
  109. size_t m_block_size_arg;
  110. size_t m_count_arg;
  111. size_t m_init_value_arg;
  112. };
  113. template<class T, class BinaryOperator>
  114. class write_scanned_output_kernel : public meta_kernel
  115. {
  116. public:
  117. write_scanned_output_kernel(BinaryOperator op)
  118. : meta_kernel("write_scanned_output")
  119. {
  120. bool checked = true;
  121. m_output_arg = add_arg<T *>(memory_object::global_memory, "output");
  122. m_block_sums_arg = add_arg<const T *>(memory_object::global_memory, "block_sums");
  123. m_count_arg = add_arg<const cl_uint>("count");
  124. // work-item parameters
  125. *this <<
  126. "const uint gid = get_global_id(0);\n" <<
  127. "const uint block_id = get_group_id(0);\n";
  128. // check against data size
  129. if(checked){
  130. *this << "if(gid < count){\n";
  131. }
  132. // write output
  133. *this <<
  134. "output[gid] = " <<
  135. op(var<T>("block_sums[block_id]"), var<T>("output[gid] ")) << ";\n";
  136. if(checked){
  137. *this << "}\n";
  138. }
  139. }
  140. size_t m_output_arg;
  141. size_t m_block_sums_arg;
  142. size_t m_count_arg;
  143. };
  144. template<class InputIterator>
  145. inline size_t pick_scan_block_size(InputIterator first, InputIterator last)
  146. {
  147. size_t count = iterator_range_size(first, last);
  148. if(count == 0) { return 0; }
  149. else if(count <= 1) { return 1; }
  150. else if(count <= 2) { return 2; }
  151. else if(count <= 4) { return 4; }
  152. else if(count <= 8) { return 8; }
  153. else if(count <= 16) { return 16; }
  154. else if(count <= 32) { return 32; }
  155. else if(count <= 64) { return 64; }
  156. else if(count <= 128) { return 128; }
  157. else { return 256; }
  158. }
  159. template<class InputIterator, class OutputIterator, class T, class BinaryOperator>
  160. inline OutputIterator scan_impl(InputIterator first,
  161. InputIterator last,
  162. OutputIterator result,
  163. bool exclusive,
  164. T init,
  165. BinaryOperator op,
  166. command_queue &queue)
  167. {
  168. typedef typename
  169. std::iterator_traits<InputIterator>::value_type
  170. input_type;
  171. typedef typename
  172. std::iterator_traits<InputIterator>::difference_type
  173. difference_type;
  174. typedef typename
  175. std::iterator_traits<OutputIterator>::value_type
  176. output_type;
  177. const context &context = queue.get_context();
  178. const size_t count = detail::iterator_range_size(first, last);
  179. size_t block_size = pick_scan_block_size(first, last);
  180. size_t block_count = count / block_size;
  181. if(block_count * block_size < count){
  182. block_count++;
  183. }
  184. ::boost::compute::vector<input_type> block_sums(block_count, context);
  185. // zero block sums
  186. input_type zero;
  187. std::memset(&zero, 0, sizeof(input_type));
  188. ::boost::compute::fill(block_sums.begin(), block_sums.end(), zero, queue);
  189. // local scan
  190. local_scan_kernel<InputIterator, OutputIterator, BinaryOperator>
  191. local_scan_kernel(first, last, result, exclusive, op);
  192. ::boost::compute::kernel kernel = local_scan_kernel.compile(context);
  193. kernel.set_arg(local_scan_kernel.m_scratch_arg, local_buffer<input_type>(block_size));
  194. kernel.set_arg(local_scan_kernel.m_block_sums_arg, block_sums);
  195. kernel.set_arg(local_scan_kernel.m_block_size_arg, static_cast<cl_uint>(block_size));
  196. kernel.set_arg(local_scan_kernel.m_count_arg, static_cast<cl_uint>(count));
  197. kernel.set_arg(local_scan_kernel.m_init_value_arg, static_cast<output_type>(init));
  198. queue.enqueue_1d_range_kernel(kernel,
  199. 0,
  200. block_count * block_size,
  201. block_size);
  202. // inclusive scan block sums
  203. if(block_count > 1){
  204. scan_impl(block_sums.begin(),
  205. block_sums.end(),
  206. block_sums.begin(),
  207. false,
  208. init,
  209. op,
  210. queue
  211. );
  212. }
  213. // add block sums to each block
  214. if(block_count > 1){
  215. write_scanned_output_kernel<input_type, BinaryOperator>
  216. write_output_kernel(op);
  217. kernel = write_output_kernel.compile(context);
  218. kernel.set_arg(write_output_kernel.m_output_arg, result.get_buffer());
  219. kernel.set_arg(write_output_kernel.m_block_sums_arg, block_sums);
  220. kernel.set_arg(write_output_kernel.m_count_arg, static_cast<cl_uint>(count));
  221. queue.enqueue_1d_range_kernel(kernel,
  222. block_size,
  223. block_count * block_size,
  224. block_size);
  225. }
  226. return result + static_cast<difference_type>(count);
  227. }
  228. template<class InputIterator, class OutputIterator, class T, class BinaryOperator>
  229. inline OutputIterator dispatch_scan(InputIterator first,
  230. InputIterator last,
  231. OutputIterator result,
  232. bool exclusive,
  233. T init,
  234. BinaryOperator op,
  235. command_queue &queue)
  236. {
  237. return scan_impl(first, last, result, exclusive, init, op, queue);
  238. }
  239. template<class InputIterator, class T, class BinaryOperator>
  240. inline InputIterator dispatch_scan(InputIterator first,
  241. InputIterator last,
  242. InputIterator result,
  243. bool exclusive,
  244. T init,
  245. BinaryOperator op,
  246. command_queue &queue)
  247. {
  248. typedef typename std::iterator_traits<InputIterator>::value_type value_type;
  249. if(first == result){
  250. // scan input in-place
  251. const context &context = queue.get_context();
  252. // make a temporary copy the input
  253. size_t count = iterator_range_size(first, last);
  254. vector<value_type> tmp(count, context);
  255. copy(first, last, tmp.begin(), queue);
  256. // scan from temporary values
  257. return scan_impl(tmp.begin(), tmp.end(), first, exclusive, init, op, queue);
  258. }
  259. else {
  260. // scan input to output
  261. return scan_impl(first, last, result, exclusive, init, op, queue);
  262. }
  263. }
  264. template<class InputIterator, class OutputIterator, class T, class BinaryOperator>
  265. inline OutputIterator scan_on_gpu(InputIterator first,
  266. InputIterator last,
  267. OutputIterator result,
  268. bool exclusive,
  269. T init,
  270. BinaryOperator op,
  271. command_queue &queue)
  272. {
  273. if(first == last){
  274. return result;
  275. }
  276. return dispatch_scan(first, last, result, exclusive, init, op, queue);
  277. }
  278. } // end detail namespace
  279. } // end compute namespace
  280. } // end boost namespace
  281. #endif // BOOST_COMPUTE_ALGORITHM_DETAIL_SCAN_ON_GPU_HPP