gather.hpp 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. //---------------------------------------------------------------------------//
  2. // Copyright (c) 2013 Kyle Lutz <kyle.r.lutz@gmail.com>
  3. //
  4. // Distributed under the Boost Software License, Version 1.0
  5. // See accompanying file LICENSE_1_0.txt or copy at
  6. // http://www.boost.org/LICENSE_1_0.txt
  7. //
  8. // See http://boostorg.github.com/compute for more information.
  9. //---------------------------------------------------------------------------//
  10. #ifndef BOOST_COMPUTE_ALGORITHM_GATHER_HPP
  11. #define BOOST_COMPUTE_ALGORITHM_GATHER_HPP
  12. #include <boost/static_assert.hpp>
  13. #include <boost/compute/command_queue.hpp>
  14. #include <boost/compute/detail/iterator_range_size.hpp>
  15. #include <boost/compute/detail/meta_kernel.hpp>
  16. #include <boost/compute/exception.hpp>
  17. #include <boost/compute/iterator/buffer_iterator.hpp>
  18. #include <boost/compute/system.hpp>
  19. #include <boost/compute/type_traits/type_name.hpp>
  20. #include <boost/compute/type_traits/is_device_iterator.hpp>
  21. namespace boost {
  22. namespace compute {
  23. namespace detail {
  24. template<class InputIterator, class MapIterator, class OutputIterator>
  25. class gather_kernel : public meta_kernel
  26. {
  27. public:
  28. gather_kernel() : meta_kernel("gather")
  29. {}
  30. void set_range(MapIterator first,
  31. MapIterator last,
  32. InputIterator input,
  33. OutputIterator result)
  34. {
  35. m_count = iterator_range_size(first, last);
  36. *this <<
  37. "const uint i = get_global_id(0);\n" <<
  38. result[expr<uint_>("i")] << "=" <<
  39. input[first[expr<uint_>("i")]] << ";\n";
  40. }
  41. event exec(command_queue &queue)
  42. {
  43. if(m_count == 0) {
  44. return event();
  45. }
  46. return exec_1d(queue, 0, m_count);
  47. }
  48. private:
  49. size_t m_count;
  50. };
  51. } // end detail namespace
  52. /// Copies the elements using the indices from the range [\p first, \p last)
  53. /// to the range beginning at \p result using the input values from the range
  54. /// beginning at \p input.
  55. ///
  56. /// Space complexity: \Omega(1)
  57. ///
  58. /// \see scatter()
  59. template<class InputIterator, class MapIterator, class OutputIterator>
  60. inline void gather(MapIterator first,
  61. MapIterator last,
  62. InputIterator input,
  63. OutputIterator result,
  64. command_queue &queue = system::default_queue())
  65. {
  66. BOOST_STATIC_ASSERT(is_device_iterator<InputIterator>::value);
  67. BOOST_STATIC_ASSERT(is_device_iterator<MapIterator>::value);
  68. BOOST_STATIC_ASSERT(is_device_iterator<OutputIterator>::value);
  69. detail::gather_kernel<InputIterator, MapIterator, OutputIterator> kernel;
  70. kernel.set_range(first, last, input, result);
  71. kernel.exec(queue);
  72. }
  73. } // end compute namespace
  74. } // end boost namespace
  75. #endif // BOOST_COMPUTE_ALGORITHM_GATHER_HPP