find_extrema_with_atomics.hpp 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. //---------------------------------------------------------------------------//
  2. // Copyright (c) 2013 Kyle Lutz <kyle.r.lutz@gmail.com>
  3. //
  4. // Distributed under the Boost Software License, Version 1.0
  5. // See accompanying file LICENSE_1_0.txt or copy at
  6. // http://www.boost.org/LICENSE_1_0.txt
  7. //
  8. // See http://boostorg.github.com/compute for more information.
  9. //---------------------------------------------------------------------------//
  10. #ifndef BOOST_COMPUTE_ALGORITHM_DETAIL_FIND_EXTREMA_WITH_ATOMICS_HPP
  11. #define BOOST_COMPUTE_ALGORITHM_DETAIL_FIND_EXTREMA_WITH_ATOMICS_HPP
  12. #include <boost/compute/types.hpp>
  13. #include <boost/compute/command_queue.hpp>
  14. #include <boost/compute/container/detail/scalar.hpp>
  15. #include <boost/compute/functional/atomic.hpp>
  16. #include <boost/compute/detail/meta_kernel.hpp>
  17. #include <boost/compute/detail/iterator_range_size.hpp>
  18. namespace boost {
  19. namespace compute {
  20. namespace detail {
  21. template<class InputIterator, class Compare>
  22. inline InputIterator find_extrema_with_atomics(InputIterator first,
  23. InputIterator last,
  24. Compare compare,
  25. const bool find_minimum,
  26. command_queue &queue)
  27. {
  28. typedef typename std::iterator_traits<InputIterator>::value_type value_type;
  29. typedef typename std::iterator_traits<InputIterator>::difference_type difference_type;
  30. const context &context = queue.get_context();
  31. meta_kernel k("find_extrema");
  32. atomic_cmpxchg<uint_> atomic_cmpxchg_uint;
  33. k <<
  34. "const uint gid = get_global_id(0);\n" <<
  35. "uint old_index = *index;\n" <<
  36. k.decl<value_type>("old") <<
  37. " = " << first[k.var<uint_>("old_index")] << ";\n" <<
  38. k.decl<value_type>("new") <<
  39. " = " << first[k.var<uint_>("gid")] << ";\n" <<
  40. k.decl<bool>("compare_result") << ";\n" <<
  41. "#ifdef BOOST_COMPUTE_FIND_MAXIMUM\n" <<
  42. "while(" <<
  43. "(compare_result = " << compare(k.var<value_type>("old"),
  44. k.var<value_type>("new")) << ")" <<
  45. " || (!(compare_result" <<
  46. " || " << compare(k.var<value_type>("new"),
  47. k.var<value_type>("old")) << ") "
  48. "&& gid < old_index)){\n" <<
  49. "#else\n" <<
  50. // while condition explained for minimum case with less (<)
  51. // as comparison function:
  52. // while(new_value < old_value
  53. // OR (new_value == old_value AND new_index < old_index))
  54. "while(" <<
  55. "(compare_result = " << compare(k.var<value_type>("new"),
  56. k.var<value_type>("old")) << ")" <<
  57. " || (!(compare_result" <<
  58. " || " << compare(k.var<value_type>("old"),
  59. k.var<value_type>("new")) << ") "
  60. "&& gid < old_index)){\n" <<
  61. "#endif\n" <<
  62. " if(" << atomic_cmpxchg_uint(k.var<uint_ *>("index"),
  63. k.var<uint_>("old_index"),
  64. k.var<uint_>("gid")) << " == old_index)\n" <<
  65. " break;\n" <<
  66. " else\n" <<
  67. " old_index = *index;\n" <<
  68. "old = " << first[k.var<uint_>("old_index")] << ";\n" <<
  69. "}\n";
  70. size_t index_arg_index = k.add_arg<uint_ *>(memory_object::global_memory, "index");
  71. std::string options;
  72. if(!find_minimum){
  73. options = "-DBOOST_COMPUTE_FIND_MAXIMUM";
  74. }
  75. kernel kernel = k.compile(context, options);
  76. // setup index buffer
  77. scalar<uint_> index(context);
  78. kernel.set_arg(index_arg_index, index.get_buffer());
  79. // initialize index
  80. index.write(0, queue);
  81. // run kernel
  82. size_t count = iterator_range_size(first, last);
  83. queue.enqueue_1d_range_kernel(kernel, 0, count, 0);
  84. // read index and return iterator
  85. return first + static_cast<difference_type>(index.read(queue));
  86. }
  87. } // end detail namespace
  88. } // end compute namespace
  89. } // end boost namespace
  90. #endif // BOOST_COMPUTE_ALGORITHM_DETAIL_FIND_EXTREMA_WITH_ATOMICS_HPP