binary_find.hpp 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. //---------------------------------------------------------------------------//
  2. // Copyright (c) 2014 Roshan <thisisroshansmail@gmail.com>
  3. //
  4. // Distributed under the Boost Software License, Version 1.0
  5. // See accompanying file LICENSE_1_0.txt or copy at
  6. // http://www.boost.org/LICENSE_1_0.txt
  7. //
  8. // See http://boostorg.github.com/compute for more information.
  9. //---------------------------------------------------------------------------//
  10. #ifndef BOOST_COMPUTE_ALGORITHM_DETAIL_BINARY_FIND_HPP
  11. #define BOOST_COMPUTE_ALGORITHM_DETAIL_BINARY_FIND_HPP
  12. #include <boost/compute/functional.hpp>
  13. #include <boost/compute/algorithm/find_if.hpp>
  14. #include <boost/compute/algorithm/transform.hpp>
  15. #include <boost/compute/command_queue.hpp>
  16. #include <boost/compute/detail/parameter_cache.hpp>
  17. namespace boost {
  18. namespace compute {
  19. namespace detail{
  20. ///
  21. /// \brief Binary find kernel class
  22. ///
  23. /// Subclass of meta_kernel to perform single step in binary find.
  24. ///
  25. template<class InputIterator, class UnaryPredicate>
  26. class binary_find_kernel : public meta_kernel
  27. {
  28. public:
  29. binary_find_kernel(InputIterator first,
  30. InputIterator last,
  31. UnaryPredicate predicate)
  32. : meta_kernel("binary_find")
  33. {
  34. typedef typename std::iterator_traits<InputIterator>::value_type value_type;
  35. m_index_arg = add_arg<uint_ *>(memory_object::global_memory, "index");
  36. m_block_arg = add_arg<uint_>("block");
  37. atomic_min<uint_> atomic_min_uint;
  38. *this <<
  39. "uint i = get_global_id(0) * block;\n" <<
  40. decl<value_type>("value") << "=" << first[var<uint_>("i")] << ";\n" <<
  41. "if(" << predicate(var<value_type>("value")) << ") {\n" <<
  42. atomic_min_uint(var<uint_ *>("index"), var<uint_>("i")) << ";\n" <<
  43. "}\n";
  44. }
  45. size_t m_index_arg;
  46. size_t m_block_arg;
  47. };
  48. ///
  49. /// \brief Binary find algorithm
  50. ///
  51. /// Finds the end of true values in the partitioned range [first, last).
  52. /// \return Iterator pointing to end of true values
  53. ///
  54. /// \param first Iterator pointing to start of range
  55. /// \param last Iterator pointing to end of range
  56. /// \param predicate Predicate according to which the range is partitioned
  57. /// \param queue Queue on which to execute
  58. ///
  59. template<class InputIterator, class UnaryPredicate>
  60. inline InputIterator binary_find(InputIterator first,
  61. InputIterator last,
  62. UnaryPredicate predicate,
  63. command_queue &queue = system::default_queue())
  64. {
  65. const device &device = queue.get_device();
  66. boost::shared_ptr<parameter_cache> parameters =
  67. detail::parameter_cache::get_global_cache(device);
  68. const std::string cache_key = "__boost_binary_find";
  69. size_t find_if_limit = 128;
  70. size_t threads = parameters->get(cache_key, "tpb", 128);
  71. size_t count = iterator_range_size(first, last);
  72. InputIterator search_first = first;
  73. InputIterator search_last = last;
  74. scalar<uint_> index(queue.get_context());
  75. // construct and compile binary_find kernel
  76. binary_find_kernel<InputIterator, UnaryPredicate>
  77. binary_find_kernel(search_first, search_last, predicate);
  78. ::boost::compute::kernel kernel = binary_find_kernel.compile(queue.get_context());
  79. // set buffer for index
  80. kernel.set_arg(binary_find_kernel.m_index_arg, index.get_buffer());
  81. while(count > find_if_limit) {
  82. index.write(static_cast<uint_>(count), queue);
  83. // set block and run binary_find kernel
  84. uint_ block = static_cast<uint_>((count - 1)/(threads - 1));
  85. kernel.set_arg(binary_find_kernel.m_block_arg, block);
  86. queue.enqueue_1d_range_kernel(kernel, 0, threads, 0);
  87. size_t i = index.read(queue);
  88. if(i == count) {
  89. search_first = search_last - ((count - 1)%(threads - 1));
  90. break;
  91. } else {
  92. search_last = search_first + i;
  93. search_first = search_last - ((count - 1)/(threads - 1));
  94. }
  95. // Make sure that first and last stay within the input range
  96. search_last = (std::min)(search_last, last);
  97. search_last = (std::max)(search_last, first);
  98. search_first = (std::max)(search_first, first);
  99. search_first = (std::min)(search_first, last);
  100. count = iterator_range_size(search_first, search_last);
  101. }
  102. return find_if(search_first, search_last, predicate, queue);
  103. }
  104. } // end detail namespace
  105. } // end compute namespace
  106. } // end boost namespace
  107. #endif // BOOST_COMPUTE_ALGORITHM_DETAIL_BINARY_FIND_HPP