prev_permutation.hpp 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. //---------------------------------------------------------------------------//
  2. // Copyright (c) 2014 Roshan <thisisroshansmail@gmail.com>
  3. //
  4. // Distributed under the Boost Software License, Version 1.0
  5. // See accompanying file LICENSE_1_0.txt or copy at
  6. // http://www.boost.org/LICENSE_1_0.txt
  7. //
  8. // See http://boostorg.github.com/compute for more information.
  9. //---------------------------------------------------------------------------//
  10. #ifndef BOOST_COMPUTE_ALGORITHM_PREV_PERMUTATION_HPP
  11. #define BOOST_COMPUTE_ALGORITHM_PREV_PERMUTATION_HPP
  12. #include <iterator>
  13. #include <boost/static_assert.hpp>
  14. #include <boost/compute/system.hpp>
  15. #include <boost/compute/command_queue.hpp>
  16. #include <boost/compute/container/detail/scalar.hpp>
  17. #include <boost/compute/algorithm/reverse.hpp>
  18. #include <boost/compute/type_traits/is_device_iterator.hpp>
  19. namespace boost {
  20. namespace compute {
  21. namespace detail {
  22. ///
  23. /// \brief Helper function for prev_permutation
  24. ///
  25. /// To find rightmost element which is greater
  26. /// than its next element
  27. ///
  28. template<class InputIterator>
  29. inline InputIterator prev_permutation_helper(InputIterator first,
  30. InputIterator last,
  31. command_queue &queue)
  32. {
  33. typedef typename std::iterator_traits<InputIterator>::value_type value_type;
  34. size_t count = detail::iterator_range_size(first, last);
  35. if(count == 0 || count == 1){
  36. return last;
  37. }
  38. count = count - 1;
  39. const context &context = queue.get_context();
  40. detail::meta_kernel k("prev_permutation");
  41. size_t index_arg = k.add_arg<int *>(memory_object::global_memory, "index");
  42. atomic_max<int_> atomic_max_int;
  43. k << k.decl<const int_>("i") << " = get_global_id(0);\n"
  44. << k.decl<const value_type>("cur_value") << "="
  45. << first[k.var<const int_>("i")] << ";\n"
  46. << k.decl<const value_type>("next_value") << "="
  47. << first[k.expr<const int_>("i+1")] << ";\n"
  48. << "if(cur_value > next_value){\n"
  49. << " " << atomic_max_int(k.var<int_ *>("index"), k.var<int_>("i")) << ";\n"
  50. << "}\n";
  51. kernel kernel = k.compile(context);
  52. scalar<int_> index(context);
  53. kernel.set_arg(index_arg, index.get_buffer());
  54. index.write(static_cast<int_>(-1), queue);
  55. queue.enqueue_1d_range_kernel(kernel, 0, count, 0);
  56. int result = static_cast<int>(index.read(queue));
  57. if(result == -1) return last;
  58. else return first + result;
  59. }
  60. ///
  61. /// \brief Helper function for prev_permutation
  62. ///
  63. /// To find the largest element to the right of the element found above
  64. /// that is smaller than it
  65. ///
  66. template<class InputIterator, class ValueType>
  67. inline InputIterator pp_floor(InputIterator first,
  68. InputIterator last,
  69. ValueType value,
  70. command_queue &queue)
  71. {
  72. typedef typename std::iterator_traits<InputIterator>::value_type value_type;
  73. size_t count = detail::iterator_range_size(first, last);
  74. if(count == 0){
  75. return last;
  76. }
  77. const context &context = queue.get_context();
  78. detail::meta_kernel k("pp_floor");
  79. size_t index_arg = k.add_arg<int *>(memory_object::global_memory, "index");
  80. size_t value_arg = k.add_arg<value_type>(memory_object::private_memory, "value");
  81. atomic_max<int_> atomic_max_int;
  82. k << k.decl<const int_>("i") << " = get_global_id(0);\n"
  83. << k.decl<const value_type>("cur_value") << "="
  84. << first[k.var<const int_>("i")] << ";\n"
  85. << "if(cur_value >= " << first[k.expr<int_>("*index")]
  86. << " && cur_value < value){\n"
  87. << " " << atomic_max_int(k.var<int_ *>("index"), k.var<int_>("i")) << ";\n"
  88. << "}\n";
  89. kernel kernel = k.compile(context);
  90. scalar<int_> index(context);
  91. kernel.set_arg(index_arg, index.get_buffer());
  92. index.write(static_cast<int_>(0), queue);
  93. kernel.set_arg(value_arg, value);
  94. queue.enqueue_1d_range_kernel(kernel, 0, count, 0);
  95. int result = static_cast<int>(index.read(queue));
  96. return first + result;
  97. }
  98. } // end detail namespace
  99. ///
  100. /// \brief Permutation generating algorithm
  101. ///
  102. /// Transforms the range [first, last) into the previous permutation from
  103. /// the set of all permutations arranged in lexicographic order
  104. /// \return Boolean value signifying if the first permutation was crossed
  105. /// and the range was reset
  106. ///
  107. /// \param first Iterator pointing to start of range
  108. /// \param last Iterator pointing to end of range
  109. /// \param queue Queue on which to execute
  110. ///
  111. /// Space complexity: \Omega(1)
  112. template<class InputIterator>
  113. inline bool prev_permutation(InputIterator first,
  114. InputIterator last,
  115. command_queue &queue = system::default_queue())
  116. {
  117. BOOST_STATIC_ASSERT(is_device_iterator<InputIterator>::value);
  118. typedef typename std::iterator_traits<InputIterator>::value_type value_type;
  119. if(first == last) return false;
  120. InputIterator first_element =
  121. detail::prev_permutation_helper(first, last, queue);
  122. if(first_element == last)
  123. {
  124. reverse(first, last, queue);
  125. return false;
  126. }
  127. value_type first_value = first_element.read(queue);
  128. InputIterator ceiling_element =
  129. detail::pp_floor(first_element + 1, last, first_value, queue);
  130. value_type ceiling_value = ceiling_element.read(queue);
  131. first_element.write(ceiling_value, queue);
  132. ceiling_element.write(first_value, queue);
  133. reverse(first_element + 1, last, queue);
  134. return true;
  135. }
  136. } // end compute namespace
  137. } // end boost namespace
  138. #endif // BOOST_COMPUTE_ALGORITHM_PREV_PERMUTATION_HPP