insertion_sort.hpp 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. //---------------------------------------------------------------------------//
  2. // Copyright (c) 2013 Kyle Lutz <kyle.r.lutz@gmail.com>
  3. //
  4. // Distributed under the Boost Software License, Version 1.0
  5. // See accompanying file LICENSE_1_0.txt or copy at
  6. // http://www.boost.org/LICENSE_1_0.txt
  7. //
  8. // See http://boostorg.github.com/compute for more information.
  9. //---------------------------------------------------------------------------//
  10. #ifndef BOOST_COMPUTE_ALGORITHM_DETAIL_INSERTION_SORT_HPP
  11. #define BOOST_COMPUTE_ALGORITHM_DETAIL_INSERTION_SORT_HPP
  12. #include <boost/compute/kernel.hpp>
  13. #include <boost/compute/program.hpp>
  14. #include <boost/compute/command_queue.hpp>
  15. #include <boost/compute/detail/meta_kernel.hpp>
  16. #include <boost/compute/detail/iterator_range_size.hpp>
  17. #include <boost/compute/memory/local_buffer.hpp>
  18. namespace boost {
  19. namespace compute {
  20. namespace detail {
  21. template<class Iterator, class Compare>
  22. inline void serial_insertion_sort(Iterator first,
  23. Iterator last,
  24. Compare compare,
  25. command_queue &queue)
  26. {
  27. typedef typename std::iterator_traits<Iterator>::value_type T;
  28. size_t count = iterator_range_size(first, last);
  29. if(count < 2){
  30. return;
  31. }
  32. meta_kernel k("serial_insertion_sort");
  33. size_t local_data_arg = k.add_arg<T *>(memory_object::local_memory, "data");
  34. size_t count_arg = k.add_arg<uint_>("n");
  35. k <<
  36. // copy data to local memory
  37. "for(uint i = 0; i < n; i++){\n" <<
  38. " data[i] = " << first[k.var<uint_>("i")] << ";\n"
  39. "}\n"
  40. // sort data in local memory
  41. "for(uint i = 1; i < n; i++){\n" <<
  42. " " << k.decl<const T>("value") << " = data[i];\n" <<
  43. " uint pos = i;\n" <<
  44. " while(pos > 0 && " <<
  45. compare(k.var<const T>("value"),
  46. k.var<const T>("data[pos-1]")) << "){\n" <<
  47. " data[pos] = data[pos-1];\n" <<
  48. " pos--;\n" <<
  49. " }\n" <<
  50. " data[pos] = value;\n" <<
  51. "}\n" <<
  52. // copy sorted data to output
  53. "for(uint i = 0; i < n; i++){\n" <<
  54. " " << first[k.var<uint_>("i")] << " = data[i];\n"
  55. "}\n";
  56. const context &context = queue.get_context();
  57. ::boost::compute::kernel kernel = k.compile(context);
  58. kernel.set_arg(local_data_arg, local_buffer<T>(count));
  59. kernel.set_arg(count_arg, static_cast<uint_>(count));
  60. queue.enqueue_task(kernel);
  61. }
  62. template<class Iterator>
  63. inline void serial_insertion_sort(Iterator first,
  64. Iterator last,
  65. command_queue &queue)
  66. {
  67. typedef typename std::iterator_traits<Iterator>::value_type T;
  68. ::boost::compute::less<T> less;
  69. return serial_insertion_sort(first, last, less, queue);
  70. }
  71. template<class KeyIterator, class ValueIterator, class Compare>
  72. inline void serial_insertion_sort_by_key(KeyIterator keys_first,
  73. KeyIterator keys_last,
  74. ValueIterator values_first,
  75. Compare compare,
  76. command_queue &queue)
  77. {
  78. typedef typename std::iterator_traits<KeyIterator>::value_type key_type;
  79. typedef typename std::iterator_traits<ValueIterator>::value_type value_type;
  80. size_t count = iterator_range_size(keys_first, keys_last);
  81. if(count < 2){
  82. return;
  83. }
  84. meta_kernel k("serial_insertion_sort_by_key");
  85. size_t local_keys_arg = k.add_arg<key_type *>(memory_object::local_memory, "keys");
  86. size_t local_data_arg = k.add_arg<value_type *>(memory_object::local_memory, "data");
  87. size_t count_arg = k.add_arg<uint_>("n");
  88. k <<
  89. // copy data to local memory
  90. "for(uint i = 0; i < n; i++){\n" <<
  91. " keys[i] = " << keys_first[k.var<uint_>("i")] << ";\n"
  92. " data[i] = " << values_first[k.var<uint_>("i")] << ";\n"
  93. "}\n"
  94. // sort data in local memory
  95. "for(uint i = 1; i < n; i++){\n" <<
  96. " " << k.decl<const key_type>("key") << " = keys[i];\n" <<
  97. " " << k.decl<const value_type>("value") << " = data[i];\n" <<
  98. " uint pos = i;\n" <<
  99. " while(pos > 0 && " <<
  100. compare(k.var<const key_type>("key"),
  101. k.var<const key_type>("keys[pos-1]")) << "){\n" <<
  102. " keys[pos] = keys[pos-1];\n" <<
  103. " data[pos] = data[pos-1];\n" <<
  104. " pos--;\n" <<
  105. " }\n" <<
  106. " keys[pos] = key;\n" <<
  107. " data[pos] = value;\n" <<
  108. "}\n" <<
  109. // copy sorted data to output
  110. "for(uint i = 0; i < n; i++){\n" <<
  111. " " << keys_first[k.var<uint_>("i")] << " = keys[i];\n"
  112. " " << values_first[k.var<uint_>("i")] << " = data[i];\n"
  113. "}\n";
  114. const context &context = queue.get_context();
  115. ::boost::compute::kernel kernel = k.compile(context);
  116. kernel.set_arg(local_keys_arg, static_cast<uint_>(count * sizeof(key_type)), 0);
  117. kernel.set_arg(local_data_arg, static_cast<uint_>(count * sizeof(value_type)), 0);
  118. kernel.set_arg(count_arg, static_cast<uint_>(count));
  119. queue.enqueue_task(kernel);
  120. }
  121. template<class KeyIterator, class ValueIterator>
  122. inline void serial_insertion_sort_by_key(KeyIterator keys_first,
  123. KeyIterator keys_last,
  124. ValueIterator values_first,
  125. command_queue &queue)
  126. {
  127. typedef typename std::iterator_traits<KeyIterator>::value_type key_type;
  128. serial_insertion_sort_by_key(
  129. keys_first,
  130. keys_last,
  131. values_first,
  132. boost::compute::less<key_type>(),
  133. queue
  134. );
  135. }
  136. } // end detail namespace
  137. } // end compute namespace
  138. } // end boost namespace
  139. #endif // BOOST_COMPUTE_ALGORITHM_DETAIL_INSERTION_SORT_HPP