reduce_by_key.hpp 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. //---------------------------------------------------------------------------//
  2. // Copyright (c) 2015 Jakub Szuppe <j.szuppe@gmail.com>
  3. //
  4. // Distributed under the Boost Software License, Version 1.0
  5. // See accompanying file LICENSE_1_0.txt or copy at
  6. // http://www.boost.org/LICENSE_1_0.txt
  7. //
  8. // See http://boostorg.github.com/compute for more information.
  9. //---------------------------------------------------------------------------//
  10. #ifndef BOOST_COMPUTE_ALGORITHM_REDUCE_BY_KEY_HPP
  11. #define BOOST_COMPUTE_ALGORITHM_REDUCE_BY_KEY_HPP
  12. #include <iterator>
  13. #include <utility>
  14. #include <boost/static_assert.hpp>
  15. #include <boost/compute/command_queue.hpp>
  16. #include <boost/compute/device.hpp>
  17. #include <boost/compute/functional.hpp>
  18. #include <boost/compute/system.hpp>
  19. #include <boost/compute/algorithm/detail/reduce_by_key.hpp>
  20. #include <boost/compute/type_traits/is_device_iterator.hpp>
  21. namespace boost {
  22. namespace compute {
  23. /// The \c reduce_by_key() algorithm performs reduction for each contiguous
  24. /// subsequence of values determinate by equivalent keys.
  25. ///
  26. /// Returns a pair of iterators at the end of the ranges [\p keys_result, keys_result_last)
  27. /// and [\p values_result, \p values_result_last).
  28. ///
  29. /// If no function is specified, \c plus will be used.
  30. /// If no predicate is specified, \c equal_to will be used.
  31. ///
  32. /// \param keys_first the first key
  33. /// \param keys_last the last key
  34. /// \param values_first the first input value
  35. /// \param keys_result iterator pointing to the key output
  36. /// \param values_result iterator pointing to the reduced value output
  37. /// \param function binary reduction function
  38. /// \param predicate binary predicate which returns true only if two keys are equal
  39. /// \param queue command queue to perform the operation
  40. ///
  41. /// The \c reduce_by_key() algorithm assumes that the binary reduction function
  42. /// is associative. When used with non-associative functions the result may
  43. /// be non-deterministic and vary in precision. Notably this affects the
  44. /// \c plus<float>() function as floating-point addition is not associative
  45. /// and may produce slightly different results than a serial algorithm.
  46. ///
  47. /// For example, to calculate the sum of the values for each key:
  48. ///
  49. /// \snippet test/test_reduce_by_key.cpp reduce_by_key_int
  50. ///
  51. /// Space complexity on GPUs: \Omega(2n)<br>
  52. /// Space complexity on CPUs: \Omega(1)
  53. ///
  54. /// \see reduce()
  55. template<class InputKeyIterator, class InputValueIterator,
  56. class OutputKeyIterator, class OutputValueIterator,
  57. class BinaryFunction, class BinaryPredicate>
  58. inline std::pair<OutputKeyIterator, OutputValueIterator>
  59. reduce_by_key(InputKeyIterator keys_first,
  60. InputKeyIterator keys_last,
  61. InputValueIterator values_first,
  62. OutputKeyIterator keys_result,
  63. OutputValueIterator values_result,
  64. BinaryFunction function,
  65. BinaryPredicate predicate,
  66. command_queue &queue = system::default_queue())
  67. {
  68. BOOST_STATIC_ASSERT(is_device_iterator<InputKeyIterator>::value);
  69. BOOST_STATIC_ASSERT(is_device_iterator<InputValueIterator>::value);
  70. BOOST_STATIC_ASSERT(is_device_iterator<OutputKeyIterator>::value);
  71. BOOST_STATIC_ASSERT(is_device_iterator<OutputValueIterator>::value);
  72. return detail::dispatch_reduce_by_key(keys_first, keys_last, values_first,
  73. keys_result, values_result,
  74. function, predicate,
  75. queue);
  76. }
  77. /// \overload
  78. template<class InputKeyIterator, class InputValueIterator,
  79. class OutputKeyIterator, class OutputValueIterator,
  80. class BinaryFunction>
  81. inline std::pair<OutputKeyIterator, OutputValueIterator>
  82. reduce_by_key(InputKeyIterator keys_first,
  83. InputKeyIterator keys_last,
  84. InputValueIterator values_first,
  85. OutputKeyIterator keys_result,
  86. OutputValueIterator values_result,
  87. BinaryFunction function,
  88. command_queue &queue = system::default_queue())
  89. {
  90. BOOST_STATIC_ASSERT(is_device_iterator<InputKeyIterator>::value);
  91. BOOST_STATIC_ASSERT(is_device_iterator<InputValueIterator>::value);
  92. BOOST_STATIC_ASSERT(is_device_iterator<OutputKeyIterator>::value);
  93. BOOST_STATIC_ASSERT(is_device_iterator<OutputValueIterator>::value);
  94. typedef typename std::iterator_traits<InputKeyIterator>::value_type key_type;
  95. return reduce_by_key(keys_first, keys_last, values_first,
  96. keys_result, values_result,
  97. function, equal_to<key_type>(),
  98. queue);
  99. }
  100. /// \overload
  101. template<class InputKeyIterator, class InputValueIterator,
  102. class OutputKeyIterator, class OutputValueIterator>
  103. inline std::pair<OutputKeyIterator, OutputValueIterator>
  104. reduce_by_key(InputKeyIterator keys_first,
  105. InputKeyIterator keys_last,
  106. InputValueIterator values_first,
  107. OutputKeyIterator keys_result,
  108. OutputValueIterator values_result,
  109. command_queue &queue = system::default_queue())
  110. {
  111. BOOST_STATIC_ASSERT(is_device_iterator<InputKeyIterator>::value);
  112. BOOST_STATIC_ASSERT(is_device_iterator<InputValueIterator>::value);
  113. BOOST_STATIC_ASSERT(is_device_iterator<OutputKeyIterator>::value);
  114. BOOST_STATIC_ASSERT(is_device_iterator<OutputValueIterator>::value);
  115. typedef typename std::iterator_traits<InputKeyIterator>::value_type key_type;
  116. typedef typename std::iterator_traits<InputValueIterator>::value_type value_type;
  117. return reduce_by_key(keys_first, keys_last, values_first,
  118. keys_result, values_result,
  119. plus<value_type>(), equal_to<key_type>(),
  120. queue);
  121. }
  122. } // end compute namespace
  123. } // end boost namespace
  124. #endif // BOOST_COMPUTE_ALGORITHM_REDUCE_BY_KEY_HPP