serial_scan.hpp 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. //---------------------------------------------------------------------------//
  2. // Copyright (c) 2013 Kyle Lutz <kyle.r.lutz@gmail.com>
  3. //
  4. // Distributed under the Boost Software License, Version 1.0
  5. // See accompanying file LICENSE_1_0.txt or copy at
  6. // http://www.boost.org/LICENSE_1_0.txt
  7. //
  8. // See http://boostorg.github.com/compute for more information.
  9. //---------------------------------------------------------------------------//
  10. #ifndef BOOST_COMPUTE_ALGORITHM_DETAIL_SERIAL_SCAN_HPP
  11. #define BOOST_COMPUTE_ALGORITHM_DETAIL_SERIAL_SCAN_HPP
  12. #include <iterator>
  13. #include <boost/compute/device.hpp>
  14. #include <boost/compute/kernel.hpp>
  15. #include <boost/compute/command_queue.hpp>
  16. #include <boost/compute/detail/meta_kernel.hpp>
  17. #include <boost/compute/detail/iterator_range_size.hpp>
  18. namespace boost {
  19. namespace compute {
  20. namespace detail {
  21. template<class InputIterator, class OutputIterator, class T, class BinaryOperator>
  22. inline OutputIterator serial_scan(InputIterator first,
  23. InputIterator last,
  24. OutputIterator result,
  25. bool exclusive,
  26. T init,
  27. BinaryOperator op,
  28. command_queue &queue)
  29. {
  30. if(first == last){
  31. return result;
  32. }
  33. typedef typename
  34. std::iterator_traits<InputIterator>::value_type input_type;
  35. typedef typename
  36. std::iterator_traits<OutputIterator>::value_type output_type;
  37. const context &context = queue.get_context();
  38. // create scan kernel
  39. meta_kernel k("serial_scan");
  40. // Arguments
  41. size_t n_arg = k.add_arg<ulong_>("n");
  42. size_t init_arg = k.add_arg<output_type>("initial_value");
  43. if(!exclusive){
  44. k <<
  45. k.decl<const ulong_>("start_idx") << " = 1;\n" <<
  46. k.decl<output_type>("sum") << " = " << first[0] << ";\n" <<
  47. result[0] << " = sum;\n";
  48. }
  49. else {
  50. k <<
  51. k.decl<const ulong_>("start_idx") << " = 0;\n" <<
  52. k.decl<output_type>("sum") << " = initial_value;\n";
  53. }
  54. k <<
  55. "for(ulong i = start_idx; i < n; i++){\n" <<
  56. k.decl<const input_type>("x") << " = "
  57. << first[k.var<ulong_>("i")] << ";\n";
  58. if(exclusive){
  59. k << result[k.var<ulong_>("i")] << " = sum;\n";
  60. }
  61. k << " sum = "
  62. << op(k.var<output_type>("sum"), k.var<output_type>("x"))
  63. << ";\n";
  64. if(!exclusive){
  65. k << result[k.var<ulong_>("i")] << " = sum;\n";
  66. }
  67. k << "}\n";
  68. // compile scan kernel
  69. kernel scan_kernel = k.compile(context);
  70. // setup kernel arguments
  71. size_t n = detail::iterator_range_size(first, last);
  72. scan_kernel.set_arg<ulong_>(n_arg, n);
  73. scan_kernel.set_arg<output_type>(init_arg, static_cast<output_type>(init));
  74. // execute the kernel
  75. queue.enqueue_1d_range_kernel(scan_kernel, 0, 1, 1);
  76. // return iterator pointing to the end of the result range
  77. return result + n;
  78. }
  79. } // end detail namespace
  80. } // end compute namespace
  81. } // end boost namespace
  82. #endif // BOOST_COMPUTE_ALGORITHM_DETAIL_SERIAL_SCAN_HPP