invoke_matching.hpp 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. // Copyright Jim Bosch 2010-2012.
  2. // Copyright Stefan Seefeld 2016.
  3. // Distributed under the Boost Software License, Version 1.0.
  4. // (See accompanying file LICENSE_1_0.txt or copy at
  5. // http://www.boost.org/LICENSE_1_0.txt)
  6. #ifndef boost_python_numpy_invoke_matching_hpp_
  7. #define boost_python_numpy_invoke_matching_hpp_
  8. /**
  9. * @brief Template invocation based on dtype matching.
  10. */
  11. #include <boost/python/numpy/dtype.hpp>
  12. #include <boost/python/numpy/ndarray.hpp>
  13. #include <boost/mpl/integral_c.hpp>
  14. namespace boost { namespace python { namespace numpy {
  15. namespace detail
  16. {
  17. struct BOOST_NUMPY_DECL add_pointer_meta
  18. {
  19. template <typename T>
  20. struct apply
  21. {
  22. typedef typename boost::add_pointer<T>::type type;
  23. };
  24. };
  25. struct BOOST_NUMPY_DECL dtype_template_match_found {};
  26. struct BOOST_NUMPY_DECL nd_template_match_found {};
  27. template <typename Function>
  28. struct dtype_template_invoker
  29. {
  30. template <typename T>
  31. void operator()(T *) const
  32. {
  33. if (dtype::get_builtin<T>() == m_dtype)
  34. {
  35. m_func.Function::template apply<T>();
  36. throw dtype_template_match_found();
  37. }
  38. }
  39. dtype_template_invoker(dtype const & dtype_, Function func)
  40. : m_dtype(dtype_), m_func(func) {}
  41. private:
  42. dtype const & m_dtype;
  43. Function m_func;
  44. };
  45. template <typename Function>
  46. struct dtype_template_invoker< boost::reference_wrapper<Function> >
  47. {
  48. template <typename T>
  49. void operator()(T *) const
  50. {
  51. if (dtype::get_builtin<T>() == m_dtype)
  52. {
  53. m_func.Function::template apply<T>();
  54. throw dtype_template_match_found();
  55. }
  56. }
  57. dtype_template_invoker(dtype const & dtype_, Function & func)
  58. : m_dtype(dtype_), m_func(func) {}
  59. private:
  60. dtype const & m_dtype;
  61. Function & m_func;
  62. };
  63. template <typename Function>
  64. struct nd_template_invoker
  65. {
  66. template <int N>
  67. void operator()(boost::mpl::integral_c<int,N> *) const
  68. {
  69. if (m_nd == N)
  70. {
  71. m_func.Function::template apply<N>();
  72. throw nd_template_match_found();
  73. }
  74. }
  75. nd_template_invoker(int nd, Function func) : m_nd(nd), m_func(func) {}
  76. private:
  77. int m_nd;
  78. Function m_func;
  79. };
  80. template <typename Function>
  81. struct nd_template_invoker< boost::reference_wrapper<Function> >
  82. {
  83. template <int N>
  84. void operator()(boost::mpl::integral_c<int,N> *) const
  85. {
  86. if (m_nd == N)
  87. {
  88. m_func.Function::template apply<N>();
  89. throw nd_template_match_found();
  90. }
  91. }
  92. nd_template_invoker(int nd, Function & func) : m_nd(nd), m_func(func) {}
  93. private:
  94. int m_nd;
  95. Function & m_func;
  96. };
  97. } // namespace boost::python::numpy::detail
  98. template <typename Sequence, typename Function>
  99. void invoke_matching_nd(int nd, Function f)
  100. {
  101. detail::nd_template_invoker<Function> invoker(nd, f);
  102. try { boost::mpl::for_each< Sequence, detail::add_pointer_meta >(invoker);}
  103. catch (detail::nd_template_match_found &) { return;}
  104. PyErr_SetString(PyExc_TypeError, "number of dimensions not found in template list.");
  105. python::throw_error_already_set();
  106. }
  107. template <typename Sequence, typename Function>
  108. void invoke_matching_dtype(dtype const & dtype_, Function f)
  109. {
  110. detail::dtype_template_invoker<Function> invoker(dtype_, f);
  111. try { boost::mpl::for_each< Sequence, detail::add_pointer_meta >(invoker);}
  112. catch (detail::dtype_template_match_found &) { return;}
  113. PyErr_SetString(PyExc_TypeError, "dtype not found in template list.");
  114. python::throw_error_already_set();
  115. }
  116. namespace detail
  117. {
  118. template <typename T, typename Function>
  119. struct array_template_invoker_wrapper_2
  120. {
  121. template <int N>
  122. void apply() const { m_func.Function::template apply<T,N>();}
  123. array_template_invoker_wrapper_2(Function & func) : m_func(func) {}
  124. private:
  125. Function & m_func;
  126. };
  127. template <typename DimSequence, typename Function>
  128. struct array_template_invoker_wrapper_1
  129. {
  130. template <typename T>
  131. void apply() const { invoke_matching_nd<DimSequence>(m_nd, array_template_invoker_wrapper_2<T,Function>(m_func));}
  132. array_template_invoker_wrapper_1(int nd, Function & func) : m_nd(nd), m_func(func) {}
  133. private:
  134. int m_nd;
  135. Function & m_func;
  136. };
  137. template <typename DimSequence, typename Function>
  138. struct array_template_invoker_wrapper_1< DimSequence, boost::reference_wrapper<Function> >
  139. : public array_template_invoker_wrapper_1< DimSequence, Function >
  140. {
  141. array_template_invoker_wrapper_1(int nd, Function & func)
  142. : array_template_invoker_wrapper_1< DimSequence, Function >(nd, func) {}
  143. };
  144. } // namespace boost::python::numpy::detail
  145. template <typename TypeSequence, typename DimSequence, typename Function>
  146. void invoke_matching_array(ndarray const & array_, Function f)
  147. {
  148. detail::array_template_invoker_wrapper_1<DimSequence,Function> wrapper(array_.get_nd(), f);
  149. invoke_matching_dtype<TypeSequence>(array_.get_dtype(), wrapper);
  150. }
  151. }}} // namespace boost::python::numpy
  152. #endif