set_intersection.hpp 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. //---------------------------------------------------------------------------//
  2. // Copyright (c) 2014 Roshan <thisisroshansmail@gmail.com>
  3. //
  4. // Distributed under the Boost Software License, Version 1.0
  5. // See accompanying file LICENSE_1_0.txt or copy at
  6. // http://www.boost.org/LICENSE_1_0.txt
  7. //
  8. // See http://boostorg.github.com/compute for more information.
  9. //---------------------------------------------------------------------------//
  10. #ifndef BOOST_COMPUTE_ALGORITHM_SET_INTERSECTION_HPP
  11. #define BOOST_COMPUTE_ALGORITHM_SET_INTERSECTION_HPP
  12. #include <iterator>
  13. #include <boost/static_assert.hpp>
  14. #include <boost/compute/algorithm/detail/compact.hpp>
  15. #include <boost/compute/algorithm/detail/balanced_path.hpp>
  16. #include <boost/compute/algorithm/exclusive_scan.hpp>
  17. #include <boost/compute/algorithm/fill_n.hpp>
  18. #include <boost/compute/container/vector.hpp>
  19. #include <boost/compute/detail/iterator_range_size.hpp>
  20. #include <boost/compute/detail/meta_kernel.hpp>
  21. #include <boost/compute/system.hpp>
  22. #include <boost/compute/type_traits/is_device_iterator.hpp>
  23. namespace boost {
  24. namespace compute {
  25. namespace detail {
  26. ///
  27. /// \brief Serial set intersection kernel class
  28. ///
  29. /// Subclass of meta_kernel to perform serial set intersection after tiling
  30. ///
  31. class serial_set_intersection_kernel : meta_kernel
  32. {
  33. public:
  34. unsigned int tile_size;
  35. serial_set_intersection_kernel() : meta_kernel("set_intersection")
  36. {
  37. tile_size = 4;
  38. }
  39. template<class InputIterator1, class InputIterator2,
  40. class InputIterator3, class InputIterator4,
  41. class OutputIterator1, class OutputIterator2>
  42. void set_range(InputIterator1 first1,
  43. InputIterator2 first2,
  44. InputIterator3 tile_first1,
  45. InputIterator3 tile_last1,
  46. InputIterator4 tile_first2,
  47. OutputIterator1 result,
  48. OutputIterator2 counts)
  49. {
  50. m_count = iterator_range_size(tile_first1, tile_last1) - 1;
  51. *this <<
  52. "uint i = get_global_id(0);\n" <<
  53. "uint start1 = " << tile_first1[expr<uint_>("i")] << ";\n" <<
  54. "uint end1 = " << tile_first1[expr<uint_>("i+1")] << ";\n" <<
  55. "uint start2 = " << tile_first2[expr<uint_>("i")] << ";\n" <<
  56. "uint end2 = " << tile_first2[expr<uint_>("i+1")] << ";\n" <<
  57. "uint index = i*" << tile_size << ";\n" <<
  58. "uint count = 0;\n" <<
  59. "while(start1<end1 && start2<end2)\n" <<
  60. "{\n" <<
  61. " if(" << first1[expr<uint_>("start1")] << " == " <<
  62. first2[expr<uint_>("start2")] << ")\n" <<
  63. " {\n" <<
  64. result[expr<uint_>("index")] <<
  65. " = " << first1[expr<uint_>("start1")] << ";\n" <<
  66. " index++; count++;\n" <<
  67. " start1++; start2++;\n" <<
  68. " }\n" <<
  69. " else if(" << first1[expr<uint_>("start1")] << " < " <<
  70. first2[expr<uint_>("start2")] << ")\n" <<
  71. " start1++;\n" <<
  72. " else start2++;\n" <<
  73. "}\n" <<
  74. counts[expr<uint_>("i")] << " = count;\n";
  75. }
  76. event exec(command_queue &queue)
  77. {
  78. if(m_count == 0) {
  79. return event();
  80. }
  81. return exec_1d(queue, 0, m_count);
  82. }
  83. private:
  84. size_t m_count;
  85. };
  86. } //end detail namespace
  87. ///
  88. /// \brief Set intersection algorithm
  89. ///
  90. /// Finds the intersection of the sorted range [first1, last1) with the sorted
  91. /// range [first2, last2) and stores it in range starting at result
  92. /// \return Iterator pointing to end of intersection
  93. ///
  94. /// \param first1 Iterator pointing to start of first set
  95. /// \param last1 Iterator pointing to end of first set
  96. /// \param first2 Iterator pointing to start of second set
  97. /// \param last2 Iterator pointing to end of second set
  98. /// \param result Iterator pointing to start of range in which the intersection
  99. /// will be stored
  100. /// \param queue Queue on which to execute
  101. ///
  102. /// Space complexity:
  103. /// \Omega(2(distance(\p first1, \p last1) + distance(\p first2, \p last2)))
  104. template<class InputIterator1, class InputIterator2, class OutputIterator>
  105. inline OutputIterator set_intersection(InputIterator1 first1,
  106. InputIterator1 last1,
  107. InputIterator2 first2,
  108. InputIterator2 last2,
  109. OutputIterator result,
  110. command_queue &queue = system::default_queue())
  111. {
  112. BOOST_STATIC_ASSERT(is_device_iterator<InputIterator1>::value);
  113. BOOST_STATIC_ASSERT(is_device_iterator<InputIterator2>::value);
  114. BOOST_STATIC_ASSERT(is_device_iterator<OutputIterator>::value);
  115. typedef typename std::iterator_traits<InputIterator1>::value_type value_type;
  116. int tile_size = 1024;
  117. int count1 = detail::iterator_range_size(first1, last1);
  118. int count2 = detail::iterator_range_size(first2, last2);
  119. vector<uint_> tile_a((count1+count2+tile_size-1)/tile_size+1, queue.get_context());
  120. vector<uint_> tile_b((count1+count2+tile_size-1)/tile_size+1, queue.get_context());
  121. // Tile the sets
  122. detail::balanced_path_kernel tiling_kernel;
  123. tiling_kernel.tile_size = tile_size;
  124. tiling_kernel.set_range(first1, last1, first2, last2,
  125. tile_a.begin()+1, tile_b.begin()+1);
  126. fill_n(tile_a.begin(), 1, 0, queue);
  127. fill_n(tile_b.begin(), 1, 0, queue);
  128. tiling_kernel.exec(queue);
  129. fill_n(tile_a.end()-1, 1, count1, queue);
  130. fill_n(tile_b.end()-1, 1, count2, queue);
  131. vector<value_type> temp_result(count1+count2, queue.get_context());
  132. vector<uint_> counts((count1+count2+tile_size-1)/tile_size + 1, queue.get_context());
  133. fill_n(counts.end()-1, 1, 0, queue);
  134. // Find individual intersections
  135. detail::serial_set_intersection_kernel intersection_kernel;
  136. intersection_kernel.tile_size = tile_size;
  137. intersection_kernel.set_range(first1, first2, tile_a.begin(), tile_a.end(),
  138. tile_b.begin(), temp_result.begin(), counts.begin());
  139. intersection_kernel.exec(queue);
  140. exclusive_scan(counts.begin(), counts.end(), counts.begin(), queue);
  141. // Compact the results
  142. detail::compact_kernel compact_kernel;
  143. compact_kernel.tile_size = tile_size;
  144. compact_kernel.set_range(temp_result.begin(), counts.begin(), counts.end(), result);
  145. compact_kernel.exec(queue);
  146. return result + (counts.end() - 1).read(queue);
  147. }
  148. } //end compute namespace
  149. } //end boost namespace
  150. #endif // BOOST_COMPUTE_ALGORITHM_SET_INTERSECTION_HPP