context.hpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. //---------------------------------------------------------------------------//
  2. // Copyright (c) 2013 Kyle Lutz <kyle.r.lutz@gmail.com>
  3. //
  4. // Distributed under the Boost Software License, Version 1.0
  5. // See accompanying file LICENSE_1_0.txt or copy at
  6. // http://www.boost.org/LICENSE_1_0.txt
  7. //
  8. // See http://boostorg.github.com/compute for more information.
  9. //---------------------------------------------------------------------------//
  10. #ifndef BOOST_COMPUTE_LAMBDA_CONTEXT_HPP
  11. #define BOOST_COMPUTE_LAMBDA_CONTEXT_HPP
  12. #include <boost/proto/core.hpp>
  13. #include <boost/proto/context.hpp>
  14. #include <boost/type_traits.hpp>
  15. #include <boost/preprocessor/repetition.hpp>
  16. #include <boost/compute/config.hpp>
  17. #include <boost/compute/function.hpp>
  18. #include <boost/compute/lambda/result_of.hpp>
  19. #include <boost/compute/lambda/functional.hpp>
  20. #include <boost/compute/type_traits/result_of.hpp>
  21. #include <boost/compute/type_traits/type_name.hpp>
  22. #include <boost/compute/detail/meta_kernel.hpp>
  23. namespace boost {
  24. namespace compute {
  25. namespace lambda {
  26. namespace mpl = boost::mpl;
  27. namespace proto = boost::proto;
  28. #define BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(tag, op) \
  29. template<class LHS, class RHS> \
  30. void operator()(tag, const LHS &lhs, const RHS &rhs) \
  31. { \
  32. if(proto::arity_of<LHS>::value > 0){ \
  33. stream << '('; \
  34. proto::eval(lhs, *this); \
  35. stream << ')'; \
  36. } \
  37. else { \
  38. proto::eval(lhs, *this); \
  39. } \
  40. \
  41. stream << op; \
  42. \
  43. if(proto::arity_of<RHS>::value > 0){ \
  44. stream << '('; \
  45. proto::eval(rhs, *this); \
  46. stream << ')'; \
  47. } \
  48. else { \
  49. proto::eval(rhs, *this); \
  50. } \
  51. }
  52. // lambda expression context
  53. template<class Args>
  54. struct context : proto::callable_context<context<Args> >
  55. {
  56. typedef void result_type;
  57. typedef Args args_tuple;
  58. // create a lambda context for kernel with args
  59. context(boost::compute::detail::meta_kernel &kernel, const Args &args_)
  60. : stream(kernel),
  61. args(args_)
  62. {
  63. }
  64. // handle terminals
  65. template<class T>
  66. void operator()(proto::tag::terminal, const T &x)
  67. {
  68. // terminal values in lambda expressions are always literals
  69. stream << stream.lit(x);
  70. }
  71. void operator()(proto::tag::terminal, const uchar_ &x)
  72. {
  73. stream << "(uchar)(" << stream.lit(uint_(x)) << "u)";
  74. }
  75. void operator()(proto::tag::terminal, const char_ &x)
  76. {
  77. stream << "(char)(" << stream.lit(int_(x)) << ")";
  78. }
  79. void operator()(proto::tag::terminal, const ushort_ &x)
  80. {
  81. stream << "(ushort)(" << stream.lit(x) << "u)";
  82. }
  83. void operator()(proto::tag::terminal, const short_ &x)
  84. {
  85. stream << "(short)(" << stream.lit(x) << ")";
  86. }
  87. void operator()(proto::tag::terminal, const uint_ &x)
  88. {
  89. stream << "(" << stream.lit(x) << "u)";
  90. }
  91. void operator()(proto::tag::terminal, const ulong_ &x)
  92. {
  93. stream << "(" << stream.lit(x) << "ul)";
  94. }
  95. void operator()(proto::tag::terminal, const long_ &x)
  96. {
  97. stream << "(" << stream.lit(x) << "l)";
  98. }
  99. // handle placeholders
  100. template<int I>
  101. void operator()(proto::tag::terminal, placeholder<I>)
  102. {
  103. stream << boost::get<I>(args);
  104. }
  105. // handle functions
  106. #define BOOST_COMPUTE_LAMBDA_CONTEXT_FUNCTION_ARG(z, n, unused) \
  107. BOOST_PP_COMMA_IF(n) BOOST_PP_CAT(const Arg, n) BOOST_PP_CAT(&arg, n)
  108. #define BOOST_COMPUTE_LAMBDA_CONTEXT_FUNCTION(z, n, unused) \
  109. template<class F, BOOST_PP_ENUM_PARAMS(n, class Arg)> \
  110. void operator()( \
  111. proto::tag::function, \
  112. const F &function, \
  113. BOOST_PP_REPEAT(n, BOOST_COMPUTE_LAMBDA_CONTEXT_FUNCTION_ARG, ~) \
  114. ) \
  115. { \
  116. proto::value(function).apply(*this, BOOST_PP_ENUM_PARAMS(n, arg)); \
  117. }
  118. BOOST_PP_REPEAT_FROM_TO(1, BOOST_COMPUTE_MAX_ARITY, BOOST_COMPUTE_LAMBDA_CONTEXT_FUNCTION, ~)
  119. #undef BOOST_COMPUTE_LAMBDA_CONTEXT_FUNCTION
  120. // operators
  121. BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::plus, '+')
  122. BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::minus, '-')
  123. BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::multiplies, '*')
  124. BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::divides, '/')
  125. BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::modulus, '%')
  126. BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::less, '<')
  127. BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::greater, '>')
  128. BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::less_equal, "<=")
  129. BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::greater_equal, ">=")
  130. BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::equal_to, "==")
  131. BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::not_equal_to, "!=")
  132. BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::logical_and, "&&")
  133. BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::logical_or, "||")
  134. BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::bitwise_and, '&')
  135. BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::bitwise_or, '|')
  136. BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::bitwise_xor, '^')
  137. BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::assign, '=')
  138. // subscript operator
  139. template<class LHS, class RHS>
  140. void operator()(proto::tag::subscript, const LHS &lhs, const RHS &rhs)
  141. {
  142. proto::eval(lhs, *this);
  143. stream << '[';
  144. proto::eval(rhs, *this);
  145. stream << ']';
  146. }
  147. // ternary conditional operator
  148. template<class Pred, class Arg1, class Arg2>
  149. void operator()(proto::tag::if_else_, const Pred &p, const Arg1 &x, const Arg2 &y)
  150. {
  151. proto::eval(p, *this);
  152. stream << '?';
  153. proto::eval(x, *this);
  154. stream << ':';
  155. proto::eval(y, *this);
  156. }
  157. boost::compute::detail::meta_kernel &stream;
  158. Args args;
  159. };
  160. namespace detail {
  161. template<class Expr, class Arg>
  162. struct invoked_unary_expression
  163. {
  164. typedef typename ::boost::compute::result_of<Expr(Arg)>::type result_type;
  165. invoked_unary_expression(const Expr &expr, const Arg &arg)
  166. : m_expr(expr),
  167. m_arg(arg)
  168. {
  169. }
  170. Expr m_expr;
  171. Arg m_arg;
  172. };
  173. template<class Expr, class Arg>
  174. boost::compute::detail::meta_kernel&
  175. operator<<(boost::compute::detail::meta_kernel &kernel,
  176. const invoked_unary_expression<Expr, Arg> &expr)
  177. {
  178. context<boost::tuple<Arg> > ctx(kernel, boost::make_tuple(expr.m_arg));
  179. proto::eval(expr.m_expr, ctx);
  180. return kernel;
  181. }
  182. template<class Expr, class Arg1, class Arg2>
  183. struct invoked_binary_expression
  184. {
  185. typedef typename ::boost::compute::result_of<Expr(Arg1, Arg2)>::type result_type;
  186. invoked_binary_expression(const Expr &expr,
  187. const Arg1 &arg1,
  188. const Arg2 &arg2)
  189. : m_expr(expr),
  190. m_arg1(arg1),
  191. m_arg2(arg2)
  192. {
  193. }
  194. Expr m_expr;
  195. Arg1 m_arg1;
  196. Arg2 m_arg2;
  197. };
  198. template<class Expr, class Arg1, class Arg2>
  199. boost::compute::detail::meta_kernel&
  200. operator<<(boost::compute::detail::meta_kernel &kernel,
  201. const invoked_binary_expression<Expr, Arg1, Arg2> &expr)
  202. {
  203. context<boost::tuple<Arg1, Arg2> > ctx(
  204. kernel,
  205. boost::make_tuple(expr.m_arg1, expr.m_arg2)
  206. );
  207. proto::eval(expr.m_expr, ctx);
  208. return kernel;
  209. }
  210. } // end detail namespace
  211. // forward declare domain
  212. struct domain;
  213. // lambda expression wrapper
  214. template<class Expr>
  215. struct expression : proto::extends<Expr, expression<Expr>, domain>
  216. {
  217. typedef proto::extends<Expr, expression<Expr>, domain> base_type;
  218. BOOST_PROTO_EXTENDS_USING_ASSIGN(expression)
  219. expression(const Expr &expr = Expr())
  220. : base_type(expr)
  221. {
  222. }
  223. // result_of protocol
  224. template<class Signature>
  225. struct result
  226. {
  227. };
  228. template<class This>
  229. struct result<This()>
  230. {
  231. typedef
  232. typename ::boost::compute::lambda::result_of<Expr>::type type;
  233. };
  234. template<class This, class Arg>
  235. struct result<This(Arg)>
  236. {
  237. typedef
  238. typename ::boost::compute::lambda::result_of<
  239. Expr,
  240. typename boost::tuple<Arg>
  241. >::type type;
  242. };
  243. template<class This, class Arg1, class Arg2>
  244. struct result<This(Arg1, Arg2)>
  245. {
  246. typedef typename
  247. ::boost::compute::lambda::result_of<
  248. Expr,
  249. typename boost::tuple<Arg1, Arg2>
  250. >::type type;
  251. };
  252. template<class Arg>
  253. detail::invoked_unary_expression<expression<Expr>, Arg>
  254. operator()(const Arg &x) const
  255. {
  256. return detail::invoked_unary_expression<expression<Expr>, Arg>(*this, x);
  257. }
  258. template<class Arg1, class Arg2>
  259. detail::invoked_binary_expression<expression<Expr>, Arg1, Arg2>
  260. operator()(const Arg1 &x, const Arg2 &y) const
  261. {
  262. return detail::invoked_binary_expression<
  263. expression<Expr>,
  264. Arg1,
  265. Arg2
  266. >(*this, x, y);
  267. }
  268. // function<> conversion operator
  269. template<class R, class A1>
  270. operator function<R(A1)>() const
  271. {
  272. using ::boost::compute::detail::meta_kernel;
  273. std::stringstream source;
  274. ::boost::compute::detail::meta_kernel_variable<A1> arg1("x");
  275. source << "inline " << type_name<R>() << " lambda"
  276. << ::boost::compute::detail::generate_argument_list<R(A1)>('x')
  277. << "{\n"
  278. << " return " << meta_kernel::expr_to_string((*this)(arg1)) << ";\n"
  279. << "}\n";
  280. return make_function_from_source<R(A1)>("lambda", source.str());
  281. }
  282. template<class R, class A1, class A2>
  283. operator function<R(A1, A2)>() const
  284. {
  285. using ::boost::compute::detail::meta_kernel;
  286. std::stringstream source;
  287. ::boost::compute::detail::meta_kernel_variable<A1> arg1("x");
  288. ::boost::compute::detail::meta_kernel_variable<A1> arg2("y");
  289. source << "inline " << type_name<R>() << " lambda"
  290. << ::boost::compute::detail::generate_argument_list<R(A1, A2)>('x')
  291. << "{\n"
  292. << " return " << meta_kernel::expr_to_string((*this)(arg1, arg2)) << ";\n"
  293. << "}\n";
  294. return make_function_from_source<R(A1, A2)>("lambda", source.str());
  295. }
  296. };
  297. // lambda expression domain
  298. struct domain : proto::domain<proto::generator<expression> >
  299. {
  300. };
  301. } // end lambda namespace
  302. } // end compute namespace
  303. } // end boost namespace
  304. #endif // BOOST_COMPUTE_LAMBDA_CONTEXT_HPP