search_n.hpp 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. //---------------------------------------------------------------------------//
  2. // Copyright (c) 2014 Roshan <thisisroshansmail@gmail.com>
  3. //
  4. // Distributed under the Boost Software License, Version 1.0
  5. // See accompanying file LICENSE_1_0.txt or copy at
  6. // http://www.boost.org/LICENSE_1_0.txt
  7. //
  8. // See http://boostorg.github.com/compute for more information.
  9. //---------------------------------------------------------------------------//
  10. #ifndef BOOST_COMPUTE_ALGORITHM_DETAIL_SEARCH_N_HPP
  11. #define BOOST_COMPUTE_ALGORITHM_DETAIL_SEARCH_N_HPP
  12. #include <iterator>
  13. #include <boost/static_assert.hpp>
  14. #include <boost/compute/algorithm/find.hpp>
  15. #include <boost/compute/container/vector.hpp>
  16. #include <boost/compute/detail/iterator_range_size.hpp>
  17. #include <boost/compute/detail/meta_kernel.hpp>
  18. #include <boost/compute/system.hpp>
  19. #include <boost/compute/type_traits/is_device_iterator.hpp>
  20. namespace boost {
  21. namespace compute {
  22. namespace detail {
  23. ///
  24. /// \brief Search kernel class
  25. ///
  26. /// Subclass of meta_kernel which is capable of performing search_n
  27. ///
  28. template<class TextIterator, class OutputIterator>
  29. class search_n_kernel : public meta_kernel
  30. {
  31. public:
  32. typedef typename std::iterator_traits<TextIterator>::value_type value_type;
  33. search_n_kernel() : meta_kernel("search_n")
  34. {}
  35. void set_range(TextIterator t_first,
  36. TextIterator t_last,
  37. value_type value,
  38. size_t n,
  39. OutputIterator result)
  40. {
  41. m_n = n;
  42. m_n_arg = add_arg<uint_>("n");
  43. m_value = value;
  44. m_value_arg = add_arg<value_type>("value");
  45. m_count = iterator_range_size(t_first, t_last);
  46. m_count = m_count + 1 - m_n;
  47. *this <<
  48. "uint i = get_global_id(0);\n" <<
  49. "uint i1 = i;\n" <<
  50. "uint j;\n" <<
  51. "for(j = 0; j<n; j++,i++)\n" <<
  52. "{\n" <<
  53. " if(value != " << t_first[expr<uint_>("i")] << ")\n" <<
  54. " j = n + 1;\n" <<
  55. "}\n" <<
  56. "if(j == n)\n" <<
  57. result[expr<uint_>("i1")] << " = 1;\n" <<
  58. "else\n" <<
  59. result[expr<uint_>("i1")] << " = 0;\n";
  60. }
  61. event exec(command_queue &queue)
  62. {
  63. if(m_count == 0) {
  64. return event();
  65. }
  66. set_arg(m_n_arg, uint_(m_n));
  67. set_arg(m_value_arg, m_value);
  68. return exec_1d(queue, 0, m_count);
  69. }
  70. private:
  71. size_t m_n;
  72. size_t m_n_arg;
  73. size_t m_count;
  74. value_type m_value;
  75. size_t m_value_arg;
  76. };
  77. } //end detail namespace
  78. ///
  79. /// \brief Substring matching algorithm
  80. ///
  81. /// Searches for the first occurrence of n consecutive occurrences of
  82. /// value in text [t_first, t_last).
  83. /// \return Iterator pointing to beginning of first occurrence
  84. ///
  85. /// \param t_first Iterator pointing to start of text
  86. /// \param t_last Iterator pointing to end of text
  87. /// \param n Number of times value repeats
  88. /// \param value Value which repeats
  89. /// \param queue Queue on which to execute
  90. ///
  91. /// Space complexity: \Omega(distance(\p t_first, \p t_last))
  92. template<class TextIterator, class ValueType>
  93. inline TextIterator search_n(TextIterator t_first,
  94. TextIterator t_last,
  95. size_t n,
  96. ValueType value,
  97. command_queue &queue = system::default_queue())
  98. {
  99. BOOST_STATIC_ASSERT(is_device_iterator<TextIterator>::value);
  100. // there is no need to check if pattern starts at last n - 1 indices
  101. vector<uint_> matching_indices(
  102. detail::iterator_range_size(t_first, t_last) + 1 - n,
  103. queue.get_context()
  104. );
  105. // search_n_kernel puts value 1 at every index in vector where pattern
  106. // of n values starts at
  107. detail::search_n_kernel<TextIterator,
  108. vector<uint_>::iterator> kernel;
  109. kernel.set_range(t_first, t_last, value, n, matching_indices.begin());
  110. kernel.exec(queue);
  111. vector<uint_>::iterator index = ::boost::compute::find(
  112. matching_indices.begin(), matching_indices.end(), uint_(1), queue
  113. );
  114. // pattern was not found
  115. if(index == matching_indices.end())
  116. return t_last;
  117. return t_first + detail::iterator_range_size(matching_indices.begin(), index);
  118. }
  119. } //end compute namespace
  120. } //end boost namespace
  121. #endif // BOOST_COMPUTE_ALGORITHM_DETAIL_SEARCH_N_HPP