autodiff_cpp11.hpp 17 KB


  1. // Copyright Matthew Pulver 2018 - 2019.
  2. // Distributed under the Boost Software License, Version 1.0.
  3. // (See accompanying file LICENSE_1_0.txt or copy at
  4. // https://www.boost.org/LICENSE_1_0.txt)
  5. // Contributors:
  6. // * Kedar R. Bhat - C++11 compatibility.
  7. // Notes:
  8. // * Any changes to this file should always be downstream from autodiff.cpp.
  9. // C++17 is a higher-level language and is easier to maintain. For example, a number of functions which are
  10. // lucidly read in autodiff.cpp are forced to be split into multiple structs/functions in this file for
  11. // C++11.
  12. // * Use of typename RootType and SizeType is a hack to prevent Visual Studio 2015 from compiling functions
  13. // that are never called, that would otherwise produce compiler errors. Also forces functions to be inline.
  14. #ifndef BOOST_MATH_DIFFERENTIATION_AUTODIFF_HPP
  15. #error \
  16. "Do not #include this file directly. This should only be #included by autodiff.hpp for C++11 compatibility."
  17. #endif
  18. #include <boost/mp11/integer_sequence.hpp>
  19. namespace boost {
  20. namespace math {
  21. namespace differentiation {
  22. inline namespace autodiff_v1 {
  23. namespace detail {
  24. template <typename RealType, size_t Order>
  25. fvar<RealType, Order>::fvar(root_type const& ca, bool const is_variable) {
  26. fvar_cpp11(is_fvar<RealType>{}, ca, is_variable);
  27. }
  28. template <typename RealType, size_t Order>
  29. template <typename RootType>
  30. void fvar<RealType, Order>::fvar_cpp11(std::true_type, RootType const& ca, bool const is_variable) {
  31. v.front() = RealType(ca, is_variable);
  32. if (0 < Order)
  33. std::fill(v.begin() + 1, v.end(), static_cast<RealType>(0));
  34. }
  35. template <typename RealType, size_t Order>
  36. template <typename RootType>
  37. void fvar<RealType, Order>::fvar_cpp11(std::false_type, RootType const& ca, bool const is_variable) {
  38. v.front() = ca;
  39. if (0 < Order) {
  40. v[1] = static_cast<root_type>(static_cast<int>(is_variable));
  41. if (1 < Order)
  42. std::fill(v.begin() + 2, v.end(), static_cast<RealType>(0));
  43. }
  44. }
  45. template <typename RealType, size_t Order>
  46. template <typename... Orders>
  47. get_type_at<RealType, sizeof...(Orders)> fvar<RealType, Order>::at_cpp11(std::true_type,
  48. size_t order,
  49. Orders...) const {
  50. return v.at(order);
  51. }
  52. template <typename RealType, size_t Order>
  53. template <typename... Orders>
  54. get_type_at<RealType, sizeof...(Orders)> fvar<RealType, Order>::at_cpp11(std::false_type,
  55. size_t order,
  56. Orders... orders) const {
  57. return v.at(order).at(orders...);
  58. }
  59. // Can throw "std::out_of_range: array::at: __n (which is 7) >= _Nm (which is 7)"
  60. template <typename RealType, size_t Order>
  61. template <typename... Orders>
  62. get_type_at<RealType, sizeof...(Orders)> fvar<RealType, Order>::at(size_t order, Orders... orders) const {
  63. return at_cpp11(std::integral_constant<bool, sizeof...(orders) == 0>{}, order, orders...);
  64. }
  65. template <typename T, typename... Ts>
  66. constexpr T product(Ts...) {
  67. return static_cast<T>(1);
  68. }
  69. template <typename T, typename... Ts>
  70. constexpr T product(T factor, Ts... factors) {
  71. return factor * product<T>(factors...);
  72. }
  73. // Can throw "std::out_of_range: array::at: __n (which is 7) >= _Nm (which is 7)"
  74. template <typename RealType, size_t Order>
  75. template <typename... Orders>
  76. get_type_at<fvar<RealType, Order>, sizeof...(Orders)> fvar<RealType, Order>::derivative(
  77. Orders... orders) const {
  78. static_assert(sizeof...(Orders) <= depth,
  79. "Number of parameters to derivative(...) cannot exceed fvar::depth.");
  80. return at(static_cast<size_t>(orders)...) *
  81. product(boost::math::factorial<root_type>(static_cast<unsigned>(orders))...);
  82. }
  83. template <typename RootType, typename Func>
  84. class Curry {
  85. Func const& f_;
  86. size_t const i_;
  87. public:
  88. template <typename SizeType> // typename SizeType to force inline constructor.
  89. Curry(Func const& f, SizeType i) : f_(f), i_(static_cast<std::size_t>(i)) {}
  90. template <typename... Indices>
  91. RootType operator()(Indices... indices) const {
  92. using unsigned_t = typename std::make_unsigned<typename std::common_type<Indices>::type...>::type;
  93. return f_(i_, static_cast<unsigned_t>(indices)...);
  94. }
  95. };
  96. template <typename RealType, size_t Order>
  97. template <typename Func, typename Fvar, typename... Fvars>
  98. promote<fvar<RealType, Order>, Fvar, Fvars...> fvar<RealType, Order>::apply_coefficients(
  99. size_t const order,
  100. Func const& f,
  101. Fvar const& cr,
  102. Fvars&&... fvars) const {
  103. fvar<RealType, Order> const epsilon = fvar<RealType, Order>(*this).set_root(0);
  104. size_t i = order < order_sum ? order : order_sum;
  105. using return_type = promote<fvar<RealType, Order>, Fvar, Fvars...>;
  106. return_type accumulator = cr.apply_coefficients(
  107. order - i, Curry<typename return_type::root_type, Func>(f, i), std::forward<Fvars>(fvars)...);
  108. while (i--)
  109. (accumulator *= epsilon) += cr.apply_coefficients(
  110. order - i, Curry<typename return_type::root_type, Func>(f, i), std::forward<Fvars>(fvars)...);
  111. return accumulator;
  112. }
  113. template <typename RealType, size_t Order>
  114. template <typename Func, typename Fvar, typename... Fvars>
  115. promote<fvar<RealType, Order>, Fvar, Fvars...> fvar<RealType, Order>::apply_coefficients_nonhorner(
  116. size_t const order,
  117. Func const& f,
  118. Fvar const& cr,
  119. Fvars&&... fvars) const {
  120. fvar<RealType, Order> const epsilon = fvar<RealType, Order>(*this).set_root(0);
  121. fvar<RealType, Order> epsilon_i = fvar<RealType, Order>(1); // epsilon to the power of i
  122. using return_type = promote<fvar<RealType, Order>, Fvar, Fvars...>;
  123. return_type accumulator = cr.apply_coefficients_nonhorner(
  124. order, Curry<typename return_type::root_type, Func>(f, 0), std::forward<Fvars>(fvars)...);
  125. size_t const i_max = order < order_sum ? order : order_sum;
  126. for (size_t i = 1; i <= i_max; ++i) {
  127. epsilon_i = epsilon_i.epsilon_multiply(i - 1, 0, epsilon, 1, 0);
  128. accumulator += epsilon_i.epsilon_multiply(
  129. i,
  130. 0,
  131. cr.apply_coefficients_nonhorner(
  132. order - i, Curry<typename return_type::root_type, Func>(f, i), std::forward<Fvars>(fvars)...),
  133. 0,
  134. 0);
  135. }
  136. return accumulator;
  137. }
  138. template <typename RealType, size_t Order>
  139. template <typename Func, typename Fvar, typename... Fvars>
  140. promote<fvar<RealType, Order>, Fvar, Fvars...> fvar<RealType, Order>::apply_derivatives(
  141. size_t const order,
  142. Func const& f,
  143. Fvar const& cr,
  144. Fvars&&... fvars) const {
  145. fvar<RealType, Order> const epsilon = fvar<RealType, Order>(*this).set_root(0);
  146. size_t i = order < order_sum ? order : order_sum;
  147. using return_type = promote<fvar<RealType, Order>, Fvar, Fvars...>;
  148. return_type accumulator =
  149. cr.apply_derivatives(
  150. order - i, Curry<typename return_type::root_type, Func>(f, i), std::forward<Fvars>(fvars)...) /
  151. factorial<root_type>(static_cast<unsigned>(i));
  152. while (i--)
  153. (accumulator *= epsilon) +=
  154. cr.apply_derivatives(
  155. order - i, Curry<typename return_type::root_type, Func>(f, i), std::forward<Fvars>(fvars)...) /
  156. factorial<root_type>(static_cast<unsigned>(i));
  157. return accumulator;
  158. }
  159. template <typename RealType, size_t Order>
  160. template <typename Func, typename Fvar, typename... Fvars>
  161. promote<fvar<RealType, Order>, Fvar, Fvars...> fvar<RealType, Order>::apply_derivatives_nonhorner(
  162. size_t const order,
  163. Func const& f,
  164. Fvar const& cr,
  165. Fvars&&... fvars) const {
  166. fvar<RealType, Order> const epsilon = fvar<RealType, Order>(*this).set_root(0);
  167. fvar<RealType, Order> epsilon_i = fvar<RealType, Order>(1); // epsilon to the power of i
  168. using return_type = promote<fvar<RealType, Order>, Fvar, Fvars...>;
  169. return_type accumulator = cr.apply_derivatives_nonhorner(
  170. order, Curry<typename return_type::root_type, Func>(f, 0), std::forward<Fvars>(fvars)...);
  171. size_t const i_max = order < order_sum ? order : order_sum;
  172. for (size_t i = 1; i <= i_max; ++i) {
  173. epsilon_i = epsilon_i.epsilon_multiply(i - 1, 0, epsilon, 1, 0);
  174. accumulator += epsilon_i.epsilon_multiply(
  175. i,
  176. 0,
  177. cr.apply_derivatives_nonhorner(
  178. order - i, Curry<typename return_type::root_type, Func>(f, i), std::forward<Fvars>(fvars)...) /
  179. factorial<root_type>(static_cast<unsigned>(i)),
  180. 0,
  181. 0);
  182. }
  183. return accumulator;
  184. }
  185. template <typename RealType, size_t Order>
  186. template <typename SizeType>
  187. fvar<RealType, Order> fvar<RealType, Order>::epsilon_multiply_cpp11(std::true_type,
  188. SizeType z0,
  189. size_t isum0,
  190. fvar<RealType, Order> const& cr,
  191. size_t z1,
  192. size_t isum1) const {
  193. size_t const m0 = order_sum + isum0 < Order + z0 ? Order + z0 - (order_sum + isum0) : 0;
  194. size_t const m1 = order_sum + isum1 < Order + z1 ? Order + z1 - (order_sum + isum1) : 0;
  195. size_t const i_max = m0 + m1 < Order ? Order - (m0 + m1) : 0;
  196. fvar<RealType, Order> retval = fvar<RealType, Order>();
  197. for (size_t i = 0, j = Order; i <= i_max; ++i, --j)
  198. retval.v[j] = epsilon_inner_product(z0, isum0, m0, cr, z1, isum1, m1, j);
  199. return retval;
  200. }
  201. template <typename RealType, size_t Order>
  202. template <typename SizeType>
  203. fvar<RealType, Order> fvar<RealType, Order>::epsilon_multiply_cpp11(std::false_type,
  204. SizeType z0,
  205. size_t isum0,
  206. fvar<RealType, Order> const& cr,
  207. size_t z1,
  208. size_t isum1) const {
  209. using ssize_t = typename std::make_signed<std::size_t>::type;
  210. RealType const zero(0);
  211. size_t const m0 = order_sum + isum0 < Order + z0 ? Order + z0 - (order_sum + isum0) : 0;
  212. size_t const m1 = order_sum + isum1 < Order + z1 ? Order + z1 - (order_sum + isum1) : 0;
  213. size_t const i_max = m0 + m1 < Order ? Order - (m0 + m1) : 0;
  214. fvar<RealType, Order> retval = fvar<RealType, Order>();
  215. for (size_t i = 0, j = Order; i <= i_max; ++i, --j)
  216. retval.v[j] = std::inner_product(
  217. v.cbegin() + ssize_t(m0), v.cend() - ssize_t(i + m1), cr.v.crbegin() + ssize_t(i + m0), zero);
  218. return retval;
  219. }
  220. template <typename RealType, size_t Order>
  221. fvar<RealType, Order> fvar<RealType, Order>::epsilon_multiply(size_t z0,
  222. size_t isum0,
  223. fvar<RealType, Order> const& cr,
  224. size_t z1,
  225. size_t isum1) const {
  226. return epsilon_multiply_cpp11(is_fvar<RealType>{}, z0, isum0, cr, z1, isum1);
  227. }
  228. template <typename RealType, size_t Order>
  229. template <typename SizeType>
  230. fvar<RealType, Order> fvar<RealType, Order>::epsilon_multiply_cpp11(std::true_type,
  231. SizeType z0,
  232. size_t isum0,
  233. root_type const& ca) const {
  234. fvar<RealType, Order> retval(*this);
  235. size_t const m0 = order_sum + isum0 < Order + z0 ? Order + z0 - (order_sum + isum0) : 0;
  236. for (size_t i = m0; i <= Order; ++i)
  237. retval.v[i] = retval.v[i].epsilon_multiply(z0, isum0 + i, ca);
  238. return retval;
  239. }
  240. template <typename RealType, size_t Order>
  241. template <typename SizeType>
  242. fvar<RealType, Order> fvar<RealType, Order>::epsilon_multiply_cpp11(std::false_type,
  243. SizeType z0,
  244. size_t isum0,
  245. root_type const& ca) const {
  246. fvar<RealType, Order> retval(*this);
  247. size_t const m0 = order_sum + isum0 < Order + z0 ? Order + z0 - (order_sum + isum0) : 0;
  248. for (size_t i = m0; i <= Order; ++i)
  249. if (retval.v[i] != static_cast<RealType>(0))
  250. retval.v[i] *= ca;
  251. return retval;
  252. }
  253. template <typename RealType, size_t Order>
  254. fvar<RealType, Order> fvar<RealType, Order>::epsilon_multiply(size_t z0,
  255. size_t isum0,
  256. root_type const& ca) const {
  257. return epsilon_multiply_cpp11(is_fvar<RealType>{}, z0, isum0, ca);
  258. }
  259. template <typename RealType, size_t Order>
  260. template <typename RootType>
  261. fvar<RealType, Order>& fvar<RealType, Order>::multiply_assign_by_root_type_cpp11(std::true_type,
  262. bool is_root,
  263. RootType const& ca) {
  264. auto itr = v.begin();
  265. itr->multiply_assign_by_root_type(is_root, ca);
  266. for (++itr; itr != v.end(); ++itr)
  267. itr->multiply_assign_by_root_type(false, ca);
  268. return *this;
  269. }
  270. template <typename RealType, size_t Order>
  271. template <typename RootType>
  272. fvar<RealType, Order>& fvar<RealType, Order>::multiply_assign_by_root_type_cpp11(std::false_type,
  273. bool is_root,
  274. RootType const& ca) {
  275. auto itr = v.begin();
  276. if (is_root || *itr != 0)
  277. *itr *= ca; // Skip multiplication of 0 by ca=inf to avoid nan, except when is_root.
  278. for (++itr; itr != v.end(); ++itr)
  279. if (*itr != 0)
  280. *itr *= ca;
  281. return *this;
  282. }
  283. template <typename RealType, size_t Order>
  284. fvar<RealType, Order>& fvar<RealType, Order>::multiply_assign_by_root_type(bool is_root,
  285. root_type const& ca) {
  286. return multiply_assign_by_root_type_cpp11(is_fvar<RealType>{}, is_root, ca);
  287. }
  288. template <typename RealType, size_t Order>
  289. template <typename RootType>
  290. fvar<RealType, Order>& fvar<RealType, Order>::negate_cpp11(std::true_type, RootType const&) {
  291. std::for_each(v.begin(), v.end(), [](RealType& r) { r.negate(); });
  292. return *this;
  293. }
  294. template <typename RealType, size_t Order>
  295. template <typename RootType>
  296. fvar<RealType, Order>& fvar<RealType, Order>::negate_cpp11(std::false_type, RootType const&) {
  297. std::for_each(v.begin(), v.end(), [](RealType& a) { a = -a; });
  298. return *this;
  299. }
  300. template <typename RealType, size_t Order>
  301. fvar<RealType, Order>& fvar<RealType, Order>::negate() {
  302. return negate_cpp11(is_fvar<RealType>{}, static_cast<root_type>(*this));
  303. }
  304. template <typename RealType, size_t Order>
  305. template <typename RootType>
  306. fvar<RealType, Order>& fvar<RealType, Order>::set_root_cpp11(std::true_type, RootType const& root) {
  307. v.front().set_root(root);
  308. return *this;
  309. }
  310. template <typename RealType, size_t Order>
  311. template <typename RootType>
  312. fvar<RealType, Order>& fvar<RealType, Order>::set_root_cpp11(std::false_type, RootType const& root) {
  313. v.front() = root;
  314. return *this;
  315. }
  316. template <typename RealType, size_t Order>
  317. fvar<RealType, Order>& fvar<RealType, Order>::set_root(root_type const& root) {
  318. return set_root_cpp11(is_fvar<RealType>{}, root);
  319. }
  320. template <typename RealType, size_t Order, size_t... Is>
  321. auto make_fvar_for_tuple(mp11::index_sequence<Is...>, RealType const& ca)
  322. -> decltype(make_fvar<RealType, zero<Is>::value..., Order>(ca)) {
  323. return make_fvar<RealType, zero<Is>::value..., Order>(ca);
  324. }
  325. template <typename RealType, size_t... Orders, size_t... Is, typename... RealTypes>
  326. auto make_ftuple_impl(mp11::index_sequence<Is...>, RealTypes const&... ca)
  327. -> decltype(std::make_tuple(make_fvar_for_tuple<RealType, Orders>(mp11::make_index_sequence<Is>{},
  328. ca)...)) {
  329. return std::make_tuple(make_fvar_for_tuple<RealType, Orders>(mp11::make_index_sequence<Is>{}, ca)...);
  330. }
  331. } // namespace detail
  332. template <typename RealType, size_t... Orders, typename... RealTypes>
  333. auto make_ftuple(RealTypes const&... ca)
  334. -> decltype(detail::make_ftuple_impl<RealType, Orders...>(mp11::index_sequence_for<RealTypes...>{},
  335. ca...)) {
  336. static_assert(sizeof...(Orders) == sizeof...(RealTypes),
  337. "Number of Orders must match number of function parameters.");
  338. return detail::make_ftuple_impl<RealType, Orders...>(mp11::index_sequence_for<RealTypes...>{}, ca...);
  339. }
  340. } // namespace autodiff_v1
  341. } // namespace differentiation
  342. } // namespace math
  343. } // namespace boost