reduce_by_key_with_scan.hpp 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541
  1. //---------------------------------------------------------------------------//
  2. // Copyright (c) 2015 Jakub Szuppe <j.szuppe@gmail.com>
  3. //
  4. // Distributed under the Boost Software License, Version 1.0
  5. // See accompanying file LICENSE_1_0.txt or copy at
  6. // http://www.boost.org/LICENSE_1_0.txt
  7. //
  8. // See http://boostorg.github.com/compute for more information.
  9. //---------------------------------------------------------------------------//
  10. #ifndef BOOST_COMPUTE_ALGORITHM_DETAIL_REDUCE_BY_KEY_WITH_SCAN_HPP
  11. #define BOOST_COMPUTE_ALGORITHM_DETAIL_REDUCE_BY_KEY_WITH_SCAN_HPP
  12. #include <algorithm>
  13. #include <iterator>
  14. #include <boost/compute/command_queue.hpp>
  15. #include <boost/compute/functional.hpp>
  16. #include <boost/compute/algorithm/inclusive_scan.hpp>
  17. #include <boost/compute/container/vector.hpp>
  18. #include <boost/compute/container/detail/scalar.hpp>
  19. #include <boost/compute/detail/meta_kernel.hpp>
  20. #include <boost/compute/detail/iterator_range_size.hpp>
  21. #include <boost/compute/detail/read_write_single_value.hpp>
  22. #include <boost/compute/type_traits.hpp>
  23. #include <boost/compute/utility/program_cache.hpp>
  24. namespace boost {
  25. namespace compute {
  26. namespace detail {
  27. /// \internal_
  28. ///
  29. /// Fills \p new_keys_first with unsigned integer keys generated from vector
  30. /// of original keys \p keys_first. New keys can be distinguish by simple equality
  31. /// predicate.
  32. ///
  33. /// \param keys_first iterator pointing to the first key
  34. /// \param number_of_keys number of keys
  35. /// \param predicate binary predicate for key comparison
  36. /// \param new_keys_first iterator pointing to the new keys vector
  37. /// \param preferred_work_group_size preferred work group size
  38. /// \param queue command queue to perform the operation
  39. ///
  40. /// Binary function \p predicate must take two keys as arguments and
  41. /// return true only if they are considered the same.
  42. ///
  43. /// The first new key equals zero and the last equals number of unique keys
  44. /// minus one.
  45. ///
  46. /// No local memory usage.
  47. template<class InputKeyIterator, class BinaryPredicate>
  48. inline void generate_uint_keys(InputKeyIterator keys_first,
  49. size_t number_of_keys,
  50. BinaryPredicate predicate,
  51. vector<uint_>::iterator new_keys_first,
  52. size_t preferred_work_group_size,
  53. command_queue &queue)
  54. {
  55. typedef typename
  56. std::iterator_traits<InputKeyIterator>::value_type key_type;
  57. detail::meta_kernel k("reduce_by_key_new_key_flags");
  58. k.add_set_arg<const uint_>("count", uint_(number_of_keys));
  59. k <<
  60. k.decl<const uint_>("gid") << " = get_global_id(0);\n" <<
  61. k.decl<uint_>("value") << " = 0;\n" <<
  62. "if(gid >= count){\n return;\n}\n" <<
  63. "if(gid > 0){ \n" <<
  64. k.decl<key_type>("key") << " = " <<
  65. keys_first[k.var<const uint_>("gid")] << ";\n" <<
  66. k.decl<key_type>("previous_key") << " = " <<
  67. keys_first[k.var<const uint_>("gid - 1")] << ";\n" <<
  68. " value = " << predicate(k.var<key_type>("previous_key"),
  69. k.var<key_type>("key")) <<
  70. " ? 0 : 1;\n" <<
  71. "}\n else {\n" <<
  72. " value = 0;\n" <<
  73. "}\n" <<
  74. new_keys_first[k.var<const uint_>("gid")] << " = value;\n";
  75. const context &context = queue.get_context();
  76. kernel kernel = k.compile(context);
  77. size_t work_group_size = preferred_work_group_size;
  78. size_t work_groups_no = static_cast<size_t>(
  79. std::ceil(float(number_of_keys) / work_group_size)
  80. );
  81. queue.enqueue_1d_range_kernel(kernel,
  82. 0,
  83. work_groups_no * work_group_size,
  84. work_group_size);
  85. inclusive_scan(new_keys_first, new_keys_first + number_of_keys,
  86. new_keys_first, queue);
  87. }
  88. /// \internal_
  89. /// Calculate carry-out for each work group.
  90. /// Carry-out is a pair of the last key processed by a work group and sum of all
  91. /// values under this key in this work group.
  92. template<class InputValueIterator, class OutputValueIterator, class BinaryFunction>
  93. inline void carry_outs(vector<uint_>::iterator keys_first,
  94. InputValueIterator values_first,
  95. size_t count,
  96. vector<uint_>::iterator carry_out_keys_first,
  97. OutputValueIterator carry_out_values_first,
  98. BinaryFunction function,
  99. size_t work_group_size,
  100. command_queue &queue)
  101. {
  102. typedef typename
  103. std::iterator_traits<OutputValueIterator>::value_type value_out_type;
  104. detail::meta_kernel k("reduce_by_key_with_scan_carry_outs");
  105. k.add_set_arg<const uint_>("count", uint_(count));
  106. size_t local_keys_arg = k.add_arg<uint_ *>(memory_object::local_memory, "lkeys");
  107. size_t local_vals_arg = k.add_arg<value_out_type *>(memory_object::local_memory, "lvals");
  108. k <<
  109. k.decl<const uint_>("gid") << " = get_global_id(0);\n" <<
  110. k.decl<const uint_>("wg_size") << " = get_local_size(0);\n" <<
  111. k.decl<const uint_>("lid") << " = get_local_id(0);\n" <<
  112. k.decl<const uint_>("group_id") << " = get_group_id(0);\n" <<
  113. k.decl<uint_>("key") << ";\n" <<
  114. k.decl<value_out_type>("value") << ";\n" <<
  115. "if(gid < count){\n" <<
  116. k.var<uint_>("key") << " = " <<
  117. keys_first[k.var<const uint_>("gid")] << ";\n" <<
  118. k.var<value_out_type>("value") << " = " <<
  119. values_first[k.var<const uint_>("gid")] << ";\n" <<
  120. "lkeys[lid] = key;\n" <<
  121. "lvals[lid] = value;\n" <<
  122. "}\n" <<
  123. // Calculate carry out for each work group by performing Hillis/Steele scan
  124. // where only last element (key-value pair) is saved
  125. k.decl<value_out_type>("result") << " = value;\n" <<
  126. k.decl<uint_>("other_key") << ";\n" <<
  127. k.decl<value_out_type>("other_value") << ";\n" <<
  128. "for(" << k.decl<uint_>("offset") << " = 1; " <<
  129. "offset < wg_size; offset *= 2){\n"
  130. " barrier(CLK_LOCAL_MEM_FENCE);\n" <<
  131. " if(lid >= offset){\n"
  132. " other_key = lkeys[lid - offset];\n" <<
  133. " if(other_key == key){\n" <<
  134. " other_value = lvals[lid - offset];\n" <<
  135. " result = " << function(k.var<value_out_type>("result"),
  136. k.var<value_out_type>("other_value")) << ";\n" <<
  137. " }\n" <<
  138. " }\n" <<
  139. " barrier(CLK_LOCAL_MEM_FENCE);\n" <<
  140. " lvals[lid] = result;\n" <<
  141. "}\n" <<
  142. // save carry out
  143. "if(lid == (wg_size - 1)){\n" <<
  144. carry_out_keys_first[k.var<const uint_>("group_id")] << " = key;\n" <<
  145. carry_out_values_first[k.var<const uint_>("group_id")] << " = result;\n" <<
  146. "}\n";
  147. size_t work_groups_no = static_cast<size_t>(
  148. std::ceil(float(count) / work_group_size)
  149. );
  150. const context &context = queue.get_context();
  151. kernel kernel = k.compile(context);
  152. kernel.set_arg(local_keys_arg, local_buffer<uint_>(work_group_size));
  153. kernel.set_arg(local_vals_arg, local_buffer<value_out_type>(work_group_size));
  154. queue.enqueue_1d_range_kernel(kernel,
  155. 0,
  156. work_groups_no * work_group_size,
  157. work_group_size);
  158. }
  159. /// \internal_
  160. /// Calculate carry-in by performing inclusive scan by key on carry-outs vector.
  161. template<class OutputValueIterator, class BinaryFunction>
  162. inline void carry_ins(vector<uint_>::iterator carry_out_keys_first,
  163. OutputValueIterator carry_out_values_first,
  164. OutputValueIterator carry_in_values_first,
  165. size_t carry_out_size,
  166. BinaryFunction function,
  167. size_t work_group_size,
  168. command_queue &queue)
  169. {
  170. typedef typename
  171. std::iterator_traits<OutputValueIterator>::value_type value_out_type;
  172. uint_ values_pre_work_item = static_cast<uint_>(
  173. std::ceil(float(carry_out_size) / work_group_size)
  174. );
  175. detail::meta_kernel k("reduce_by_key_with_scan_carry_ins");
  176. k.add_set_arg<const uint_>("carry_out_size", uint_(carry_out_size));
  177. k.add_set_arg<const uint_>("values_per_work_item", values_pre_work_item);
  178. size_t local_keys_arg = k.add_arg<uint_ *>(memory_object::local_memory, "lkeys");
  179. size_t local_vals_arg = k.add_arg<value_out_type *>(memory_object::local_memory, "lvals");
  180. k <<
  181. k.decl<uint_>("id") << " = get_global_id(0) * values_per_work_item;\n" <<
  182. k.decl<uint_>("idx") << " = id;\n" <<
  183. k.decl<const uint_>("wg_size") << " = get_local_size(0);\n" <<
  184. k.decl<const uint_>("lid") << " = get_local_id(0);\n" <<
  185. k.decl<const uint_>("group_id") << " = get_group_id(0);\n" <<
  186. k.decl<uint_>("key") << ";\n" <<
  187. k.decl<value_out_type>("value") << ";\n" <<
  188. k.decl<uint_>("previous_key") << ";\n" <<
  189. k.decl<value_out_type>("result") << ";\n" <<
  190. "if(id < carry_out_size){\n" <<
  191. k.var<uint_>("previous_key") << " = " <<
  192. carry_out_keys_first[k.var<const uint_>("id")] << ";\n" <<
  193. k.var<value_out_type>("result") << " = " <<
  194. carry_out_values_first[k.var<const uint_>("id")] << ";\n" <<
  195. carry_in_values_first[k.var<const uint_>("id")] << " = result;\n" <<
  196. "}\n" <<
  197. k.decl<const uint_>("end") << " = (id + values_per_work_item) <= carry_out_size" <<
  198. " ? (values_per_work_item + id) : carry_out_size;\n" <<
  199. "for(idx = idx + 1; idx < end; idx += 1){\n" <<
  200. " key = " << carry_out_keys_first[k.var<const uint_>("idx")] << ";\n" <<
  201. " value = " << carry_out_values_first[k.var<const uint_>("idx")] << ";\n" <<
  202. " if(previous_key == key){\n" <<
  203. " result = " << function(k.var<value_out_type>("result"),
  204. k.var<value_out_type>("value")) << ";\n" <<
  205. " }\n else { \n" <<
  206. " result = value;\n"
  207. " }\n" <<
  208. " " << carry_in_values_first[k.var<const uint_>("idx")] << " = result;\n" <<
  209. " previous_key = key;\n"
  210. "}\n" <<
  211. // save the last key and result to local memory
  212. "lkeys[lid] = previous_key;\n" <<
  213. "lvals[lid] = result;\n" <<
  214. // Hillis/Steele scan
  215. "for(" << k.decl<uint_>("offset") << " = 1; " <<
  216. "offset < wg_size; offset *= 2){\n"
  217. " barrier(CLK_LOCAL_MEM_FENCE);\n" <<
  218. " if(lid >= offset){\n"
  219. " key = lkeys[lid - offset];\n" <<
  220. " if(previous_key == key){\n" <<
  221. " value = lvals[lid - offset];\n" <<
  222. " result = " << function(k.var<value_out_type>("result"),
  223. k.var<value_out_type>("value")) << ";\n" <<
  224. " }\n" <<
  225. " }\n" <<
  226. " barrier(CLK_LOCAL_MEM_FENCE);\n" <<
  227. " lvals[lid] = result;\n" <<
  228. "}\n" <<
  229. "barrier(CLK_LOCAL_MEM_FENCE);\n" <<
  230. "if(lid > 0){\n" <<
  231. // load key-value reduced by previous work item
  232. " previous_key = lkeys[lid - 1];\n" <<
  233. " result = lvals[lid - 1];\n" <<
  234. "}\n" <<
  235. // add key-value reduced by previous work item
  236. "for(idx = id; idx < id + values_per_work_item; idx += 1){\n" <<
  237. // make sure all carry-ins are saved in global memory
  238. " barrier( CLK_GLOBAL_MEM_FENCE );\n" <<
  239. " if(lid > 0 && idx < carry_out_size) {\n"
  240. " key = " << carry_out_keys_first[k.var<const uint_>("idx")] << ";\n" <<
  241. " value = " << carry_in_values_first[k.var<const uint_>("idx")] << ";\n" <<
  242. " if(previous_key == key){\n" <<
  243. " value = " << function(k.var<value_out_type>("result"),
  244. k.var<value_out_type>("value")) << ";\n" <<
  245. " }\n" <<
  246. " " << carry_in_values_first[k.var<const uint_>("idx")] << " = value;\n" <<
  247. " }\n" <<
  248. "}\n";
  249. const context &context = queue.get_context();
  250. kernel kernel = k.compile(context);
  251. kernel.set_arg(local_keys_arg, local_buffer<uint_>(work_group_size));
  252. kernel.set_arg(local_vals_arg, local_buffer<value_out_type>(work_group_size));
  253. queue.enqueue_1d_range_kernel(kernel,
  254. 0,
  255. work_group_size,
  256. work_group_size);
  257. }
  258. /// \internal_
  259. ///
  260. /// Perform final reduction by key. Each work item:
  261. /// 1. Perform local work-group reduction (Hillis/Steele scan)
  262. /// 2. Add carry-in (if keys are right)
  263. /// 3. Save reduced value if next key is different than processed one
  264. template<class InputKeyIterator, class InputValueIterator,
  265. class OutputKeyIterator, class OutputValueIterator,
  266. class BinaryFunction>
  267. inline void final_reduction(InputKeyIterator keys_first,
  268. InputValueIterator values_first,
  269. OutputKeyIterator keys_result,
  270. OutputValueIterator values_result,
  271. size_t count,
  272. BinaryFunction function,
  273. vector<uint_>::iterator new_keys_first,
  274. vector<uint_>::iterator carry_in_keys_first,
  275. OutputValueIterator carry_in_values_first,
  276. size_t carry_in_size,
  277. size_t work_group_size,
  278. command_queue &queue)
  279. {
  280. typedef typename
  281. std::iterator_traits<OutputValueIterator>::value_type value_out_type;
  282. detail::meta_kernel k("reduce_by_key_with_scan_final_reduction");
  283. k.add_set_arg<const uint_>("count", uint_(count));
  284. size_t local_keys_arg = k.add_arg<uint_ *>(memory_object::local_memory, "lkeys");
  285. size_t local_vals_arg = k.add_arg<value_out_type *>(memory_object::local_memory, "lvals");
  286. k <<
  287. k.decl<const uint_>("gid") << " = get_global_id(0);\n" <<
  288. k.decl<const uint_>("wg_size") << " = get_local_size(0);\n" <<
  289. k.decl<const uint_>("lid") << " = get_local_id(0);\n" <<
  290. k.decl<const uint_>("group_id") << " = get_group_id(0);\n" <<
  291. k.decl<uint_>("key") << ";\n" <<
  292. k.decl<value_out_type>("value") << ";\n"
  293. "if(gid < count){\n" <<
  294. k.var<uint_>("key") << " = " <<
  295. new_keys_first[k.var<const uint_>("gid")] << ";\n" <<
  296. k.var<value_out_type>("value") << " = " <<
  297. values_first[k.var<const uint_>("gid")] << ";\n" <<
  298. "lkeys[lid] = key;\n" <<
  299. "lvals[lid] = value;\n" <<
  300. "}\n" <<
  301. // Hillis/Steele scan
  302. k.decl<value_out_type>("result") << " = value;\n" <<
  303. k.decl<uint_>("other_key") << ";\n" <<
  304. k.decl<value_out_type>("other_value") << ";\n" <<
  305. "for(" << k.decl<uint_>("offset") << " = 1; " <<
  306. "offset < wg_size ; offset *= 2){\n"
  307. " barrier(CLK_LOCAL_MEM_FENCE);\n" <<
  308. " if(lid >= offset) {\n" <<
  309. " other_key = lkeys[lid - offset];\n" <<
  310. " if(other_key == key){\n" <<
  311. " other_value = lvals[lid - offset];\n" <<
  312. " result = " << function(k.var<value_out_type>("result"),
  313. k.var<value_out_type>("other_value")) << ";\n" <<
  314. " }\n" <<
  315. " }\n" <<
  316. " barrier(CLK_LOCAL_MEM_FENCE);\n" <<
  317. " lvals[lid] = result;\n" <<
  318. "}\n" <<
  319. "if(gid >= count) {\n return;\n};\n" <<
  320. k.decl<const bool>("save") << " = (gid < (count - 1)) ?"
  321. << new_keys_first[k.var<const uint_>("gid + 1")] << " != key" <<
  322. ": true;\n" <<
  323. // Add carry in
  324. k.decl<uint_>("carry_in_key") << ";\n" <<
  325. "if(group_id > 0 && save) {\n" <<
  326. " carry_in_key = " << carry_in_keys_first[k.var<const uint_>("group_id - 1")] << ";\n" <<
  327. " if(key == carry_in_key){\n" <<
  328. " other_value = " << carry_in_values_first[k.var<const uint_>("group_id - 1")] << ";\n" <<
  329. " result = " << function(k.var<value_out_type>("result"),
  330. k.var<value_out_type>("other_value")) << ";\n" <<
  331. " }\n" <<
  332. "}\n" <<
  333. // Save result only if the next key is different or it's the last element.
  334. "if(save){\n" <<
  335. keys_result[k.var<uint_>("key")] << " = " << keys_first[k.var<const uint_>("gid")] << ";\n" <<
  336. values_result[k.var<uint_>("key")] << " = result;\n" <<
  337. "}\n"
  338. ;
  339. size_t work_groups_no = static_cast<size_t>(
  340. std::ceil(float(count) / work_group_size)
  341. );
  342. const context &context = queue.get_context();
  343. kernel kernel = k.compile(context);
  344. kernel.set_arg(local_keys_arg, local_buffer<uint_>(work_group_size));
  345. kernel.set_arg(local_vals_arg, local_buffer<value_out_type>(work_group_size));
  346. queue.enqueue_1d_range_kernel(kernel,
  347. 0,
  348. work_groups_no * work_group_size,
  349. work_group_size);
  350. }
  351. /// \internal_
  352. /// Returns preferred work group size for reduce by key with scan algorithm.
  353. template<class KeyType, class ValueType>
  354. inline size_t get_work_group_size(const device& device)
  355. {
  356. std::string cache_key = std::string("__boost_reduce_by_key_with_scan")
  357. + "k_" + type_name<KeyType>() + "_v_" + type_name<ValueType>();
  358. // load parameters
  359. boost::shared_ptr<parameter_cache> parameters =
  360. detail::parameter_cache::get_global_cache(device);
  361. return (std::max)(
  362. static_cast<size_t>(parameters->get(cache_key, "wgsize", 256)),
  363. static_cast<size_t>(device.get_info<CL_DEVICE_MAX_WORK_GROUP_SIZE>())
  364. );
  365. }
  366. /// \internal_
  367. ///
  368. /// 1. For each work group carry-out value is calculated (it's done by key-oriented
  369. /// Hillis/Steele scan). Carry-out is a pair of the last key processed by work
  370. /// group and sum of all values under this key in work group.
  371. /// 2. From every carry-out carry-in is calculated by performing inclusive scan
  372. /// by key.
  373. /// 3. Final reduction by key is performed (key-oriented Hillis/Steele scan),
  374. /// carry-in values are added where needed.
  375. template<class InputKeyIterator, class InputValueIterator,
  376. class OutputKeyIterator, class OutputValueIterator,
  377. class BinaryFunction, class BinaryPredicate>
  378. inline size_t reduce_by_key_with_scan(InputKeyIterator keys_first,
  379. InputKeyIterator keys_last,
  380. InputValueIterator values_first,
  381. OutputKeyIterator keys_result,
  382. OutputValueIterator values_result,
  383. BinaryFunction function,
  384. BinaryPredicate predicate,
  385. command_queue &queue)
  386. {
  387. typedef typename
  388. std::iterator_traits<InputValueIterator>::value_type value_type;
  389. typedef typename
  390. std::iterator_traits<InputKeyIterator>::value_type key_type;
  391. typedef typename
  392. std::iterator_traits<OutputValueIterator>::value_type value_out_type;
  393. const context &context = queue.get_context();
  394. size_t count = detail::iterator_range_size(keys_first, keys_last);
  395. if(count == 0){
  396. return size_t(0);
  397. }
  398. const device &device = queue.get_device();
  399. size_t work_group_size = get_work_group_size<value_type, key_type>(device);
  400. // Replace original key with unsigned integer keys generated based on given
  401. // predicate. New key is also an index for keys_result and values_result vectors,
  402. // which points to place where reduced value should be saved.
  403. vector<uint_> new_keys(count, context);
  404. vector<uint_>::iterator new_keys_first = new_keys.begin();
  405. generate_uint_keys(keys_first, count, predicate, new_keys_first,
  406. work_group_size, queue);
  407. // Calculate carry-out and carry-in vectors size
  408. const size_t carry_out_size = static_cast<size_t>(
  409. std::ceil(float(count) / work_group_size)
  410. );
  411. vector<uint_> carry_out_keys(carry_out_size, context);
  412. vector<value_out_type> carry_out_values(carry_out_size, context);
  413. carry_outs(new_keys_first, values_first, count, carry_out_keys.begin(),
  414. carry_out_values.begin(), function, work_group_size, queue);
  415. vector<value_out_type> carry_in_values(carry_out_size, context);
  416. carry_ins(carry_out_keys.begin(), carry_out_values.begin(),
  417. carry_in_values.begin(), carry_out_size, function, work_group_size,
  418. queue);
  419. final_reduction(keys_first, values_first, keys_result, values_result,
  420. count, function, new_keys_first, carry_out_keys.begin(),
  421. carry_in_values.begin(), carry_out_size, work_group_size,
  422. queue);
  423. const size_t result = read_single_value<uint_>(new_keys.get_buffer(),
  424. count - 1, queue);
  425. return result + 1;
  426. }
  427. /// \internal_
  428. /// Return true if requirements for running reduce by key with scan on given
  429. /// device are met (at least one work group of preferred size can be run).
  430. template<class InputKeyIterator, class InputValueIterator,
  431. class OutputKeyIterator, class OutputValueIterator>
  432. bool reduce_by_key_with_scan_requirements_met(InputKeyIterator keys_first,
  433. InputValueIterator values_first,
  434. OutputKeyIterator keys_result,
  435. OutputValueIterator values_result,
  436. const size_t count,
  437. command_queue &queue)
  438. {
  439. typedef typename
  440. std::iterator_traits<InputValueIterator>::value_type value_type;
  441. typedef typename
  442. std::iterator_traits<InputKeyIterator>::value_type key_type;
  443. typedef typename
  444. std::iterator_traits<OutputValueIterator>::value_type value_out_type;
  445. (void) keys_first;
  446. (void) values_first;
  447. (void) keys_result;
  448. (void) values_result;
  449. const device &device = queue.get_device();
  450. // device must have dedicated local memory storage
  451. if(device.get_info<CL_DEVICE_LOCAL_MEM_TYPE>() != CL_LOCAL)
  452. {
  453. return false;
  454. }
  455. // local memory size in bytes (per compute unit)
  456. const size_t local_mem_size = device.get_info<CL_DEVICE_LOCAL_MEM_SIZE>();
  457. // preferred work group size
  458. size_t work_group_size = get_work_group_size<key_type, value_type>(device);
  459. // local memory size needed to perform parallel reduction
  460. size_t required_local_mem_size = 0;
  461. // keys size
  462. required_local_mem_size += sizeof(uint_) * work_group_size;
  463. // reduced values size
  464. required_local_mem_size += sizeof(value_out_type) * work_group_size;
  465. return (required_local_mem_size <= local_mem_size);
  466. }
  467. } // end detail namespace
  468. } // end compute namespace
  469. } // end boost namespace
  470. #endif // BOOST_COMPUTE_ALGORITHM_DETAIL_REDUCE_BY_KEY_WITH_SCAN_HPP