inner_product.hpp 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. //---------------------------------------------------------------------------//
  2. // Copyright (c) 2013 Kyle Lutz <kyle.r.lutz@gmail.com>
  3. //
  4. // Distributed under the Boost Software License, Version 1.0
  5. // See accompanying file LICENSE_1_0.txt or copy at
  6. // http://www.boost.org/LICENSE_1_0.txt
  7. //
  8. // See http://boostorg.github.com/compute for more information.
  9. //---------------------------------------------------------------------------//
  10. #ifndef BOOST_COMPUTE_ALGORITHM_INNER_PRODUCT_HPP
  11. #define BOOST_COMPUTE_ALGORITHM_INNER_PRODUCT_HPP
  12. #include <boost/static_assert.hpp>
  13. #include <boost/compute/system.hpp>
  14. #include <boost/compute/functional.hpp>
  15. #include <boost/compute/command_queue.hpp>
  16. #include <boost/compute/algorithm/accumulate.hpp>
  17. #include <boost/compute/container/vector.hpp>
  18. #include <boost/compute/iterator/transform_iterator.hpp>
  19. #include <boost/compute/iterator/zip_iterator.hpp>
  20. #include <boost/compute/functional/detail/unpack.hpp>
  21. #include <boost/compute/type_traits/is_device_iterator.hpp>
  22. namespace boost {
  23. namespace compute {
  24. /// Returns the inner product of the elements in the range
  25. /// [\p first1, \p last1) with the elements in the range beginning
  26. /// at \p first2.
  27. ///
  28. /// Space complexity: \Omega(1)<br>
  29. /// Space complexity when binary operator is recognized as associative: \Omega(n)
  30. template<class InputIterator1, class InputIterator2, class T>
  31. inline T inner_product(InputIterator1 first1,
  32. InputIterator1 last1,
  33. InputIterator2 first2,
  34. T init,
  35. command_queue &queue = system::default_queue())
  36. {
  37. BOOST_STATIC_ASSERT(is_device_iterator<InputIterator1>::value);
  38. BOOST_STATIC_ASSERT(is_device_iterator<InputIterator2>::value);
  39. typedef typename std::iterator_traits<InputIterator1>::value_type input_type;
  40. ptrdiff_t n = std::distance(first1, last1);
  41. return ::boost::compute::accumulate(
  42. ::boost::compute::make_transform_iterator(
  43. ::boost::compute::make_zip_iterator(
  44. boost::make_tuple(first1, first2)
  45. ),
  46. detail::unpack(multiplies<input_type>())
  47. ),
  48. ::boost::compute::make_transform_iterator(
  49. ::boost::compute::make_zip_iterator(
  50. boost::make_tuple(last1, first2 + n)
  51. ),
  52. detail::unpack(multiplies<input_type>())
  53. ),
  54. init,
  55. queue
  56. );
  57. }
  58. /// \overload
  59. template<class InputIterator1,
  60. class InputIterator2,
  61. class T,
  62. class BinaryAccumulateFunction,
  63. class BinaryTransformFunction>
  64. inline T inner_product(InputIterator1 first1,
  65. InputIterator1 last1,
  66. InputIterator2 first2,
  67. T init,
  68. BinaryAccumulateFunction accumulate_function,
  69. BinaryTransformFunction transform_function,
  70. command_queue &queue = system::default_queue())
  71. {
  72. BOOST_STATIC_ASSERT(is_device_iterator<InputIterator1>::value);
  73. BOOST_STATIC_ASSERT(is_device_iterator<InputIterator2>::value);
  74. typedef typename std::iterator_traits<InputIterator1>::value_type value_type;
  75. size_t count = detail::iterator_range_size(first1, last1);
  76. vector<value_type> result(count, queue.get_context());
  77. transform(first1,
  78. last1,
  79. first2,
  80. result.begin(),
  81. transform_function,
  82. queue);
  83. return ::boost::compute::accumulate(result.begin(),
  84. result.end(),
  85. init,
  86. accumulate_function,
  87. queue);
  88. }
  89. } // end compute namespace
  90. } // end boost namespace
  91. #endif // BOOST_COMPUTE_ALGORITHM_INNER_PRODUCT_HPP