serial_reduce_by_key.hpp 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. //---------------------------------------------------------------------------//
  2. // Copyright (c) 2015 Jakub Szuppe <j.szuppe@gmail.com>
  3. //
  4. // Distributed under the Boost Software License, Version 1.0
  5. // See accompanying file LICENSE_1_0.txt or copy at
  6. // http://www.boost.org/LICENSE_1_0.txt
  7. //
  8. // See http://boostorg.github.com/compute for more information.
  9. //---------------------------------------------------------------------------//
  10. #ifndef BOOST_COMPUTE_ALGORITHM_DETAIL_SERIAL_REDUCE_BY_KEY_HPP
  11. #define BOOST_COMPUTE_ALGORITHM_DETAIL_SERIAL_REDUCE_BY_KEY_HPP
  12. #include <iterator>
  13. #include <boost/compute/command_queue.hpp>
  14. #include <boost/compute/functional.hpp>
  15. #include <boost/compute/container/vector.hpp>
  16. #include <boost/compute/container/detail/scalar.hpp>
  17. #include <boost/compute/detail/meta_kernel.hpp>
  18. #include <boost/compute/detail/iterator_range_size.hpp>
  19. #include <boost/compute/type_traits/result_of.hpp>
  20. namespace boost {
  21. namespace compute {
  22. namespace detail {
  23. template<class InputKeyIterator, class InputValueIterator,
  24. class OutputKeyIterator, class OutputValueIterator,
  25. class BinaryFunction, class BinaryPredicate>
  26. inline size_t serial_reduce_by_key(InputKeyIterator keys_first,
  27. InputKeyIterator keys_last,
  28. InputValueIterator values_first,
  29. OutputKeyIterator keys_result,
  30. OutputValueIterator values_result,
  31. BinaryFunction function,
  32. BinaryPredicate predicate,
  33. command_queue &queue)
  34. {
  35. typedef typename
  36. std::iterator_traits<InputValueIterator>::value_type value_type;
  37. typedef typename
  38. std::iterator_traits<InputKeyIterator>::value_type key_type;
  39. typedef typename
  40. ::boost::compute::result_of<BinaryFunction(value_type, value_type)>::type result_type;
  41. const context &context = queue.get_context();
  42. size_t count = detail::iterator_range_size(keys_first, keys_last);
  43. if(count < 1){
  44. return count;
  45. }
  46. meta_kernel k("serial_reduce_by_key");
  47. size_t count_arg = k.add_arg<uint_>("count");
  48. size_t result_size_arg = k.add_arg<uint_ *>(memory_object::global_memory,
  49. "result_size");
  50. k <<
  51. k.decl<result_type>("result") <<
  52. " = " << values_first[0] << ";\n" <<
  53. k.decl<key_type>("previous_key") << " = " << keys_first[0] << ";\n" <<
  54. k.decl<result_type>("value") << ";\n" <<
  55. k.decl<key_type>("key") << ";\n" <<
  56. k.decl<uint_>("size") << " = 1;\n" <<
  57. keys_result[0] << " = previous_key;\n" <<
  58. values_result[0] << " = result;\n" <<
  59. "for(ulong i = 1; i < count; i++) {\n" <<
  60. " value = " << values_first[k.var<uint_>("i")] << ";\n" <<
  61. " key = " << keys_first[k.var<uint_>("i")] << ";\n" <<
  62. " if (" << predicate(k.var<key_type>("previous_key"),
  63. k.var<key_type>("key")) << ") {\n" <<
  64. " result = " << function(k.var<result_type>("result"),
  65. k.var<result_type>("value")) << ";\n" <<
  66. " }\n " <<
  67. " else { \n" <<
  68. keys_result[k.var<uint_>("size - 1")] << " = previous_key;\n" <<
  69. values_result[k.var<uint_>("size - 1")] << " = result;\n" <<
  70. " result = value;\n" <<
  71. " size++;\n" <<
  72. " } \n" <<
  73. " previous_key = key;\n" <<
  74. "}\n" <<
  75. keys_result[k.var<uint_>("size - 1")] << " = previous_key;\n" <<
  76. values_result[k.var<uint_>("size - 1")] << " = result;\n" <<
  77. "*result_size = size;";
  78. kernel kernel = k.compile(context);
  79. scalar<uint_> result_size(context);
  80. kernel.set_arg(result_size_arg, result_size.get_buffer());
  81. kernel.set_arg(count_arg, static_cast<uint_>(count));
  82. queue.enqueue_task(kernel);
  83. return static_cast<size_t>(result_size.read(queue));
  84. }
  85. } // end detail namespace
  86. } // end compute namespace
  87. } // end boost namespace
  88. #endif // BOOST_COMPUTE_ALGORITHM_DETAIL_SERIAL_REDUCE_BY_KEY_HPP