radix_sort.hpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. //---------------------------------------------------------------------------//
  2. // Copyright (c) 2013 Kyle Lutz <kyle.r.lutz@gmail.com>
  3. //
  4. // Distributed under the Boost Software License, Version 1.0
  5. // See accompanying file LICENSE_1_0.txt or copy at
  6. // http://www.boost.org/LICENSE_1_0.txt
  7. //
  8. // See http://boostorg.github.com/compute for more information.
  9. //---------------------------------------------------------------------------//
  10. #ifndef BOOST_COMPUTE_ALGORITHM_DETAIL_RADIX_SORT_HPP
  11. #define BOOST_COMPUTE_ALGORITHM_DETAIL_RADIX_SORT_HPP
  12. #include <iterator>
  13. #include <boost/assert.hpp>
  14. #include <boost/type_traits/is_signed.hpp>
  15. #include <boost/type_traits/is_floating_point.hpp>
  16. #include <boost/mpl/and.hpp>
  17. #include <boost/mpl/not.hpp>
  18. #include <boost/compute/kernel.hpp>
  19. #include <boost/compute/program.hpp>
  20. #include <boost/compute/command_queue.hpp>
  21. #include <boost/compute/algorithm/exclusive_scan.hpp>
  22. #include <boost/compute/container/vector.hpp>
  23. #include <boost/compute/detail/iterator_range_size.hpp>
  24. #include <boost/compute/detail/parameter_cache.hpp>
  25. #include <boost/compute/type_traits/type_name.hpp>
  26. #include <boost/compute/type_traits/is_fundamental.hpp>
  27. #include <boost/compute/type_traits/is_vector_type.hpp>
  28. #include <boost/compute/utility/program_cache.hpp>
  29. namespace boost {
  30. namespace compute {
  31. namespace detail {
  32. // meta-function returning true if type T is radix-sortable
  33. template<class T>
  34. struct is_radix_sortable :
  35. boost::mpl::and_<
  36. typename ::boost::compute::is_fundamental<T>::type,
  37. typename boost::mpl::not_<typename is_vector_type<T>::type>::type
  38. >
  39. {
  40. };
  41. template<size_t N>
  42. struct radix_sort_value_type
  43. {
  44. };
  45. template<>
  46. struct radix_sort_value_type<1>
  47. {
  48. typedef uchar_ type;
  49. };
  50. template<>
  51. struct radix_sort_value_type<2>
  52. {
  53. typedef ushort_ type;
  54. };
  55. template<>
  56. struct radix_sort_value_type<4>
  57. {
  58. typedef uint_ type;
  59. };
  60. template<>
  61. struct radix_sort_value_type<8>
  62. {
  63. typedef ulong_ type;
  64. };
  65. template<typename T>
  66. inline const char* enable_double()
  67. {
  68. return " -DT2_double=0";
  69. }
  70. template<>
  71. inline const char* enable_double<double>()
  72. {
  73. return " -DT2_double=1";
  74. }
  75. const char radix_sort_source[] =
  76. "#if T2_double\n"
  77. "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n"
  78. "#endif\n"
  79. "#define K2_BITS (1 << K_BITS)\n"
  80. "#define RADIX_MASK ((((T)(1)) << K_BITS) - 1)\n"
  81. "#define SIGN_BIT ((sizeof(T) * CHAR_BIT) - 1)\n"
  82. "#if defined(ASC)\n" // asc order
  83. "inline uint radix(const T x, const uint low_bit)\n"
  84. "{\n"
  85. "#if defined(IS_FLOATING_POINT)\n"
  86. " const T mask = -(x >> SIGN_BIT) | (((T)(1)) << SIGN_BIT);\n"
  87. " return ((x ^ mask) >> low_bit) & RADIX_MASK;\n"
  88. "#elif defined(IS_SIGNED)\n"
  89. " return ((x ^ (((T)(1)) << SIGN_BIT)) >> low_bit) & RADIX_MASK;\n"
  90. "#else\n"
  91. " return (x >> low_bit) & RADIX_MASK;\n"
  92. "#endif\n"
  93. "}\n"
  94. "#else\n" // desc order
  95. // For signed types we just negate the x and for unsigned types we
  96. // subtract the x from max value of its type ((T)(-1) is a max value
  97. // of type T when T is an unsigned type).
  98. "inline uint radix(const T x, const uint low_bit)\n"
  99. "{\n"
  100. "#if defined(IS_FLOATING_POINT)\n"
  101. " const T mask = -(x >> SIGN_BIT) | (((T)(1)) << SIGN_BIT);\n"
  102. " return (((-x) ^ mask) >> low_bit) & RADIX_MASK;\n"
  103. "#elif defined(IS_SIGNED)\n"
  104. " return (((-x) ^ (((T)(1)) << SIGN_BIT)) >> low_bit) & RADIX_MASK;\n"
  105. "#else\n"
  106. " return (((T)(-1) - x) >> low_bit) & RADIX_MASK;\n"
  107. "#endif\n"
  108. "}\n"
  109. "#endif\n" // #if defined(ASC)
  110. "__kernel void count(__global const T *input,\n"
  111. " const uint input_offset,\n"
  112. " const uint input_size,\n"
  113. " __global uint *global_counts,\n"
  114. " __global uint *global_offsets,\n"
  115. " __local uint *local_counts,\n"
  116. " const uint low_bit)\n"
  117. "{\n"
  118. // work-item parameters
  119. " const uint gid = get_global_id(0);\n"
  120. " const uint lid = get_local_id(0);\n"
  121. // zero local counts
  122. " if(lid < K2_BITS){\n"
  123. " local_counts[lid] = 0;\n"
  124. " }\n"
  125. " barrier(CLK_LOCAL_MEM_FENCE);\n"
  126. // reduce local counts
  127. " if(gid < input_size){\n"
  128. " T value = input[input_offset+gid];\n"
  129. " uint bucket = radix(value, low_bit);\n"
  130. " atomic_inc(local_counts + bucket);\n"
  131. " }\n"
  132. " barrier(CLK_LOCAL_MEM_FENCE);\n"
  133. // write block-relative offsets
  134. " if(lid < K2_BITS){\n"
  135. " global_counts[K2_BITS*get_group_id(0) + lid] = local_counts[lid];\n"
  136. // write global offsets
  137. " if(get_group_id(0) == (get_num_groups(0) - 1)){\n"
  138. " global_offsets[lid] = local_counts[lid];\n"
  139. " }\n"
  140. " }\n"
  141. "}\n"
  142. "__kernel void scan(__global const uint *block_offsets,\n"
  143. " __global uint *global_offsets,\n"
  144. " const uint block_count)\n"
  145. "{\n"
  146. " __global const uint *last_block_offsets =\n"
  147. " block_offsets + K2_BITS * (block_count - 1);\n"
  148. // calculate and scan global_offsets
  149. " uint sum = 0;\n"
  150. " for(uint i = 0; i < K2_BITS; i++){\n"
  151. " uint x = global_offsets[i] + last_block_offsets[i];\n"
  152. " mem_fence(CLK_GLOBAL_MEM_FENCE);\n" // work around the RX 500/Vega bug, see #811
  153. " global_offsets[i] = sum;\n"
  154. " sum += x;\n"
  155. " mem_fence(CLK_GLOBAL_MEM_FENCE);\n" // work around the RX Vega bug, see #811
  156. " }\n"
  157. "}\n"
  158. "__kernel void scatter(__global const T *input,\n"
  159. " const uint input_offset,\n"
  160. " const uint input_size,\n"
  161. " const uint low_bit,\n"
  162. " __global const uint *counts,\n"
  163. " __global const uint *global_offsets,\n"
  164. "#ifndef SORT_BY_KEY\n"
  165. " __global T *output,\n"
  166. " const uint output_offset)\n"
  167. "#else\n"
  168. " __global T *keys_output,\n"
  169. " const uint keys_output_offset,\n"
  170. " __global T2 *values_input,\n"
  171. " const uint values_input_offset,\n"
  172. " __global T2 *values_output,\n"
  173. " const uint values_output_offset)\n"
  174. "#endif\n"
  175. "{\n"
  176. // work-item parameters
  177. " const uint gid = get_global_id(0);\n"
  178. " const uint lid = get_local_id(0);\n"
  179. // copy input to local memory
  180. " T value;\n"
  181. " uint bucket;\n"
  182. " __local uint local_input[BLOCK_SIZE];\n"
  183. " if(gid < input_size){\n"
  184. " value = input[input_offset+gid];\n"
  185. " bucket = radix(value, low_bit);\n"
  186. " local_input[lid] = bucket;\n"
  187. " }\n"
  188. // copy block counts to local memory
  189. " __local uint local_counts[(1 << K_BITS)];\n"
  190. " if(lid < K2_BITS){\n"
  191. " local_counts[lid] = counts[get_group_id(0) * K2_BITS + lid];\n"
  192. " }\n"
  193. // wait until local memory is ready
  194. " barrier(CLK_LOCAL_MEM_FENCE);\n"
  195. " if(gid >= input_size){\n"
  196. " return;\n"
  197. " }\n"
  198. // get global offset
  199. " uint offset = global_offsets[bucket] + local_counts[bucket];\n"
  200. // calculate local offset
  201. " uint local_offset = 0;\n"
  202. " for(uint i = 0; i < lid; i++){\n"
  203. " if(local_input[i] == bucket)\n"
  204. " local_offset++;\n"
  205. " }\n"
  206. "#ifndef SORT_BY_KEY\n"
  207. // write value to output
  208. " output[output_offset + offset + local_offset] = value;\n"
  209. "#else\n"
  210. // write key and value if doing sort_by_key
  211. " keys_output[keys_output_offset+offset + local_offset] = value;\n"
  212. " values_output[values_output_offset+offset + local_offset] =\n"
  213. " values_input[values_input_offset+gid];\n"
  214. "#endif\n"
  215. "}\n";
  216. template<class T, class T2>
  217. inline void radix_sort_impl(const buffer_iterator<T> first,
  218. const buffer_iterator<T> last,
  219. const buffer_iterator<T2> values_first,
  220. const bool ascending,
  221. command_queue &queue)
  222. {
  223. typedef T value_type;
  224. typedef typename radix_sort_value_type<sizeof(T)>::type sort_type;
  225. const device &device = queue.get_device();
  226. const context &context = queue.get_context();
  227. // if we have a valid values iterator then we are doing a
  228. // sort by key and have to set up the values buffer
  229. bool sort_by_key = (values_first.get_buffer().get() != 0);
  230. // load (or create) radix sort program
  231. std::string cache_key =
  232. std::string("__boost_radix_sort_") + type_name<value_type>();
  233. if(sort_by_key){
  234. cache_key += std::string("_with_") + type_name<T2>();
  235. }
  236. boost::shared_ptr<program_cache> cache =
  237. program_cache::get_global_cache(context);
  238. boost::shared_ptr<parameter_cache> parameters =
  239. detail::parameter_cache::get_global_cache(device);
  240. // sort parameters
  241. const uint_ k = parameters->get(cache_key, "k", 4);
  242. const uint_ k2 = 1 << k;
  243. const uint_ block_size = parameters->get(cache_key, "tpb", 128);
  244. // sort program compiler options
  245. std::stringstream options;
  246. options << "-DK_BITS=" << k;
  247. options << " -DT=" << type_name<sort_type>();
  248. options << " -DBLOCK_SIZE=" << block_size;
  249. if(boost::is_floating_point<value_type>::value){
  250. options << " -DIS_FLOATING_POINT";
  251. }
  252. if(boost::is_signed<value_type>::value){
  253. options << " -DIS_SIGNED";
  254. }
  255. if(sort_by_key){
  256. options << " -DSORT_BY_KEY";
  257. options << " -DT2=" << type_name<T2>();
  258. options << enable_double<T2>();
  259. }
  260. if(ascending){
  261. options << " -DASC";
  262. }
  263. // get type definition if it is a custom struct
  264. std::string custom_type_def = boost::compute::type_definition<T2>() + "\n";
  265. // load radix sort program
  266. program radix_sort_program = cache->get_or_build(
  267. cache_key, options.str(), custom_type_def + radix_sort_source, context
  268. );
  269. kernel count_kernel(radix_sort_program, "count");
  270. kernel scan_kernel(radix_sort_program, "scan");
  271. kernel scatter_kernel(radix_sort_program, "scatter");
  272. size_t count = detail::iterator_range_size(first, last);
  273. uint_ block_count = static_cast<uint_>(count / block_size);
  274. if(block_count * block_size != count){
  275. block_count++;
  276. }
  277. // setup temporary buffers
  278. vector<value_type> output(count, context);
  279. vector<T2> values_output(sort_by_key ? count : 0, context);
  280. vector<uint_> offsets(k2, context);
  281. vector<uint_> counts(block_count * k2, context);
  282. const buffer *input_buffer = &first.get_buffer();
  283. uint_ input_offset = static_cast<uint_>(first.get_index());
  284. const buffer *output_buffer = &output.get_buffer();
  285. uint_ output_offset = 0;
  286. const buffer *values_input_buffer = &values_first.get_buffer();
  287. uint_ values_input_offset = static_cast<uint_>(values_first.get_index());
  288. const buffer *values_output_buffer = &values_output.get_buffer();
  289. uint_ values_output_offset = 0;
  290. for(uint_ i = 0; i < sizeof(sort_type) * CHAR_BIT / k; i++){
  291. // write counts
  292. count_kernel.set_arg(0, *input_buffer);
  293. count_kernel.set_arg(1, input_offset);
  294. count_kernel.set_arg(2, static_cast<uint_>(count));
  295. count_kernel.set_arg(3, counts);
  296. count_kernel.set_arg(4, offsets);
  297. count_kernel.set_arg(5, block_size * sizeof(uint_), 0);
  298. count_kernel.set_arg(6, i * k);
  299. queue.enqueue_1d_range_kernel(count_kernel,
  300. 0,
  301. block_count * block_size,
  302. block_size);
  303. // scan counts
  304. if(k == 1){
  305. typedef uint2_ counter_type;
  306. ::boost::compute::exclusive_scan(
  307. make_buffer_iterator<counter_type>(counts.get_buffer(), 0),
  308. make_buffer_iterator<counter_type>(counts.get_buffer(), counts.size() / 2),
  309. make_buffer_iterator<counter_type>(counts.get_buffer()),
  310. queue
  311. );
  312. }
  313. else if(k == 2){
  314. typedef uint4_ counter_type;
  315. ::boost::compute::exclusive_scan(
  316. make_buffer_iterator<counter_type>(counts.get_buffer(), 0),
  317. make_buffer_iterator<counter_type>(counts.get_buffer(), counts.size() / 4),
  318. make_buffer_iterator<counter_type>(counts.get_buffer()),
  319. queue
  320. );
  321. }
  322. else if(k == 4){
  323. typedef uint16_ counter_type;
  324. ::boost::compute::exclusive_scan(
  325. make_buffer_iterator<counter_type>(counts.get_buffer(), 0),
  326. make_buffer_iterator<counter_type>(counts.get_buffer(), counts.size() / 16),
  327. make_buffer_iterator<counter_type>(counts.get_buffer()),
  328. queue
  329. );
  330. }
  331. else {
  332. BOOST_ASSERT(false && "unknown k");
  333. break;
  334. }
  335. // scan global offsets
  336. scan_kernel.set_arg(0, counts);
  337. scan_kernel.set_arg(1, offsets);
  338. scan_kernel.set_arg(2, block_count);
  339. queue.enqueue_task(scan_kernel);
  340. // scatter values
  341. scatter_kernel.set_arg(0, *input_buffer);
  342. scatter_kernel.set_arg(1, input_offset);
  343. scatter_kernel.set_arg(2, static_cast<uint_>(count));
  344. scatter_kernel.set_arg(3, i * k);
  345. scatter_kernel.set_arg(4, counts);
  346. scatter_kernel.set_arg(5, offsets);
  347. scatter_kernel.set_arg(6, *output_buffer);
  348. scatter_kernel.set_arg(7, output_offset);
  349. if(sort_by_key){
  350. scatter_kernel.set_arg(8, *values_input_buffer);
  351. scatter_kernel.set_arg(9, values_input_offset);
  352. scatter_kernel.set_arg(10, *values_output_buffer);
  353. scatter_kernel.set_arg(11, values_output_offset);
  354. }
  355. queue.enqueue_1d_range_kernel(scatter_kernel,
  356. 0,
  357. block_count * block_size,
  358. block_size);
  359. // swap buffers
  360. std::swap(input_buffer, output_buffer);
  361. std::swap(values_input_buffer, values_output_buffer);
  362. std::swap(input_offset, output_offset);
  363. std::swap(values_input_offset, values_output_offset);
  364. }
  365. }
  366. template<class Iterator>
  367. inline void radix_sort(Iterator first,
  368. Iterator last,
  369. command_queue &queue)
  370. {
  371. radix_sort_impl(first, last, buffer_iterator<int>(), true, queue);
  372. }
  373. template<class KeyIterator, class ValueIterator>
  374. inline void radix_sort_by_key(KeyIterator keys_first,
  375. KeyIterator keys_last,
  376. ValueIterator values_first,
  377. command_queue &queue)
  378. {
  379. radix_sort_impl(keys_first, keys_last, values_first, true, queue);
  380. }
  381. template<class Iterator>
  382. inline void radix_sort(Iterator first,
  383. Iterator last,
  384. const bool ascending,
  385. command_queue &queue)
  386. {
  387. radix_sort_impl(first, last, buffer_iterator<int>(), ascending, queue);
  388. }
  389. template<class KeyIterator, class ValueIterator>
  390. inline void radix_sort_by_key(KeyIterator keys_first,
  391. KeyIterator keys_last,
  392. ValueIterator values_first,
  393. const bool ascending,
  394. command_queue &queue)
  395. {
  396. radix_sort_impl(keys_first, keys_last, values_first, ascending, queue);
  397. }
  398. } // end detail namespace
  399. } // end compute namespace
  400. } // end boost namespace
  401. #endif // BOOST_COMPUTE_ALGORITHM_DETAIL_RADIX_SORT_HPP