svm_ptr.hpp 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. //---------------------------------------------------------------------------//
  2. // Copyright (c) 2013-2014 Kyle Lutz <kyle.r.lutz@gmail.com>
  3. //
  4. // Distributed under the Boost Software License, Version 1.0
  5. // See accompanying file LICENSE_1_0.txt or copy at
  6. // http://www.boost.org/LICENSE_1_0.txt
  7. //
  8. // See http://boostorg.github.com/compute for more information.
  9. //---------------------------------------------------------------------------//
  10. #ifndef BOOST_COMPUTE_MEMORY_SVM_PTR_HPP
  11. #define BOOST_COMPUTE_MEMORY_SVM_PTR_HPP
  12. #include <boost/type_traits.hpp>
  13. #include <boost/static_assert.hpp>
  14. #include <boost/assert.hpp>
  15. #include <boost/compute/cl.hpp>
  16. #include <boost/compute/kernel.hpp>
  17. #include <boost/compute/context.hpp>
  18. #include <boost/compute/command_queue.hpp>
  19. #include <boost/compute/type_traits/is_device_iterator.hpp>
  20. namespace boost {
  21. namespace compute {
  22. // forward declaration for svm_ptr<T>
  23. template<class T>
  24. class svm_ptr;
  25. // svm functions require OpenCL 2.0
  26. #if defined(BOOST_COMPUTE_CL_VERSION_2_0) || defined(BOOST_COMPUTE_DOXYGEN_INVOKED)
  27. namespace detail {
  28. template<class T, class IndexExpr>
  29. struct svm_ptr_index_expr
  30. {
  31. typedef T result_type;
  32. svm_ptr_index_expr(const svm_ptr<T> &svm_ptr,
  33. const IndexExpr &expr)
  34. : m_svm_ptr(svm_ptr),
  35. m_expr(expr)
  36. {
  37. }
  38. operator T() const
  39. {
  40. BOOST_STATIC_ASSERT_MSG(boost::is_integral<IndexExpr>::value,
  41. "Index expression must be integral");
  42. BOOST_ASSERT(m_svm_ptr.get());
  43. const context &context = m_svm_ptr.get_context();
  44. const device &device = context.get_device();
  45. command_queue queue(context, device);
  46. T value;
  47. T* ptr =
  48. static_cast<T*>(m_svm_ptr.get()) + static_cast<std::ptrdiff_t>(m_expr);
  49. queue.enqueue_svm_map(static_cast<void*>(ptr), sizeof(T), CL_MAP_READ);
  50. value = *(ptr);
  51. queue.enqueue_svm_unmap(static_cast<void*>(ptr)).wait();
  52. return value;
  53. }
  54. const svm_ptr<T> &m_svm_ptr;
  55. IndexExpr m_expr;
  56. };
  57. } // end detail namespace
  58. #endif
  59. template<class T>
  60. class svm_ptr
  61. {
  62. public:
  63. typedef T value_type;
  64. typedef std::ptrdiff_t difference_type;
  65. typedef T* pointer;
  66. typedef T& reference;
  67. typedef std::random_access_iterator_tag iterator_category;
  68. svm_ptr()
  69. : m_ptr(0)
  70. {
  71. }
  72. svm_ptr(void *ptr, const context &context)
  73. : m_ptr(static_cast<T*>(ptr)),
  74. m_context(context)
  75. {
  76. }
  77. svm_ptr(const svm_ptr<T> &other)
  78. : m_ptr(other.m_ptr),
  79. m_context(other.m_context)
  80. {
  81. }
  82. svm_ptr<T>& operator=(const svm_ptr<T> &other)
  83. {
  84. m_ptr = other.m_ptr;
  85. m_context = other.m_context;
  86. return *this;
  87. }
  88. ~svm_ptr()
  89. {
  90. }
  91. void* get() const
  92. {
  93. return m_ptr;
  94. }
  95. svm_ptr<T> operator+(difference_type n)
  96. {
  97. return svm_ptr<T>(m_ptr + n, m_context);
  98. }
  99. difference_type operator-(svm_ptr<T> other)
  100. {
  101. BOOST_ASSERT(other.m_context == m_context);
  102. return m_ptr - other.m_ptr;
  103. }
  104. const context& get_context() const
  105. {
  106. return m_context;
  107. }
  108. bool operator==(const svm_ptr<T>& other) const
  109. {
  110. return (other.m_context == m_context) && (m_ptr == other.m_ptr);
  111. }
  112. bool operator!=(const svm_ptr<T>& other) const
  113. {
  114. return (other.m_context != m_context) || (m_ptr != other.m_ptr);
  115. }
  116. // svm functions require OpenCL 2.0
  117. #if defined(BOOST_COMPUTE_CL_VERSION_2_0) || defined(BOOST_COMPUTE_DOXYGEN_INVOKED)
  118. /// \internal_
  119. template<class Expr>
  120. detail::svm_ptr_index_expr<T, Expr>
  121. operator[](const Expr &expr) const
  122. {
  123. BOOST_ASSERT(m_ptr);
  124. return detail::svm_ptr_index_expr<T, Expr>(*this,
  125. expr);
  126. }
  127. #endif
  128. private:
  129. T *m_ptr;
  130. context m_context;
  131. };
  132. namespace detail {
  133. /// \internal_
  134. template<class T>
  135. struct set_kernel_arg<svm_ptr<T> >
  136. {
  137. void operator()(kernel &kernel_, size_t index, const svm_ptr<T> &ptr)
  138. {
  139. kernel_.set_arg_svm_ptr(index, ptr.get());
  140. }
  141. };
  142. } // end detail namespace
  143. /// \internal_ (is_device_iterator specialization for svm_ptr)
  144. template<class T>
  145. struct is_device_iterator<svm_ptr<T> > : boost::true_type {};
  146. } // end compute namespace
  147. } // end boost namespace
  148. #endif // BOOST_COMPUTE_MEMORY_SVM_PTR_HPP