lexicographical_compare.hpp 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. //---------------------------------------------------------------------------//
  2. // Copyright (c) 2014 Mageswaran.D <mageswaran1989@gmail.com>
  3. //
  4. // Distributed under the Boost Software License, Version 1.0
  5. // See accompanying file LICENSE_1_0.txt or copy at
  6. // http://www.boost.org/LICENSE_1_0.txt
  7. //
  8. // See http://boostorg.github.com/compute for more information.
  9. //---------------------------------------------------------------------------//
  10. #include <boost/static_assert.hpp>
  11. #include <boost/compute/system.hpp>
  12. #include <boost/compute/context.hpp>
  13. #include <boost/compute/command_queue.hpp>
  14. #include <boost/compute/algorithm/any_of.hpp>
  15. #include <boost/compute/container/vector.hpp>
  16. #include <boost/compute/utility/program_cache.hpp>
  17. #include <boost/compute/type_traits/is_device_iterator.hpp>
  18. namespace boost {
  19. namespace compute {
  20. namespace detail {
  21. const char lexicographical_compare_source[] =
  22. "__kernel void lexicographical_compare(const uint size1,\n"
  23. " const uint size2,\n"
  24. " __global const T1 *range1,\n"
  25. " __global const T2 *range2,\n"
  26. " __global bool *result_buf)\n"
  27. "{\n"
  28. " const uint i = get_global_id(0);\n"
  29. " if((i != size1) && (i != size2)){\n"
  30. //Individual elements are compared and results are stored in parallel.
  31. //0 is true
  32. " if(range1[i] < range2[i])\n"
  33. " result_buf[i] = 0;\n"
  34. " else\n"
  35. " result_buf[i] = 1;\n"
  36. " }\n"
  37. " else\n"
  38. " result_buf[i] = !((i == size1) && (i != size2));\n"
  39. "}\n";
  40. template<class InputIterator1, class InputIterator2>
  41. inline bool dispatch_lexicographical_compare(InputIterator1 first1,
  42. InputIterator1 last1,
  43. InputIterator2 first2,
  44. InputIterator2 last2,
  45. command_queue &queue)
  46. {
  47. const boost::compute::context &context = queue.get_context();
  48. boost::shared_ptr<program_cache> cache =
  49. program_cache::get_global_cache(context);
  50. size_t iterator_size1 = iterator_range_size(first1, last1);
  51. size_t iterator_size2 = iterator_range_size(first2, last2);
  52. size_t max_size = (std::max)(iterator_size1, iterator_size2);
  53. if(max_size == 0){
  54. return false;
  55. }
  56. boost::compute::vector<bool> result_vector(max_size, context);
  57. typedef typename std::iterator_traits<InputIterator1>::value_type value_type1;
  58. typedef typename std::iterator_traits<InputIterator2>::value_type value_type2;
  59. // load (or create) lexicographical compare program
  60. std::string cache_key =
  61. std::string("__boost_lexicographical_compare")
  62. + type_name<value_type1>() + type_name<value_type2>();
  63. std::stringstream options;
  64. options << " -DT1=" << type_name<value_type1>();
  65. options << " -DT2=" << type_name<value_type2>();
  66. program lexicographical_compare_program = cache->get_or_build(
  67. cache_key, options.str(), lexicographical_compare_source, context
  68. );
  69. kernel lexicographical_compare_kernel(lexicographical_compare_program,
  70. "lexicographical_compare");
  71. lexicographical_compare_kernel.set_arg<uint_>(0, iterator_size1);
  72. lexicographical_compare_kernel.set_arg<uint_>(1, iterator_size2);
  73. lexicographical_compare_kernel.set_arg(2, first1.get_buffer());
  74. lexicographical_compare_kernel.set_arg(3, first2.get_buffer());
  75. lexicographical_compare_kernel.set_arg(4, result_vector.get_buffer());
  76. queue.enqueue_1d_range_kernel(lexicographical_compare_kernel,
  77. 0,
  78. max_size,
  79. 0);
  80. return boost::compute::any_of(result_vector.begin(),
  81. result_vector.end(),
  82. _1 == 0,
  83. queue);
  84. }
  85. } // end detail namespace
  86. /// Checks if the first range [first1, last1) is lexicographically
  87. /// less than the second range [first2, last2).
  88. ///
  89. /// Space complexity:
  90. /// \Omega(max(distance(\p first1, \p last1), distance(\p first2, \p last2)))
  91. template<class InputIterator1, class InputIterator2>
  92. inline bool lexicographical_compare(InputIterator1 first1,
  93. InputIterator1 last1,
  94. InputIterator2 first2,
  95. InputIterator2 last2,
  96. command_queue &queue = system::default_queue())
  97. {
  98. BOOST_STATIC_ASSERT(is_device_iterator<InputIterator1>::value);
  99. BOOST_STATIC_ASSERT(is_device_iterator<InputIterator2>::value);
  100. return detail::dispatch_lexicographical_compare(first1, last1, first2, last2, queue);
  101. }
  102. } // end compute namespace
  103. } // end boost namespac