chebyshev_transform.hpp 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. // (C) Copyright Nick Thompson 2017.
  2. // Use, modification and distribution are subject to the
  3. // Boost Software License, Version 1.0. (See accompanying file
  4. // LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
  5. #ifndef BOOST_MATH_SPECIAL_CHEBYSHEV_TRANSFORM_HPP
  6. #define BOOST_MATH_SPECIAL_CHEBYSHEV_TRANSFORM_HPP
  7. #include <cmath>
  8. #include <type_traits>
  9. #include <fftw3.h>
  10. #include <boost/math/constants/constants.hpp>
  11. #include <boost/math/special_functions/chebyshev.hpp>
  12. #ifdef BOOST_HAS_FLOAT128
  13. #include <quadmath.h>
  14. #endif
  15. namespace boost { namespace math {
  16. namespace detail{
  17. template <class T>
  18. struct fftw_cos_transform;
  19. template<>
  20. struct fftw_cos_transform<double>
  21. {
  22. fftw_cos_transform(int n, double* data1, double* data2)
  23. {
  24. plan = fftw_plan_r2r_1d(n, data1, data2, FFTW_REDFT10, FFTW_ESTIMATE);
  25. }
  26. ~fftw_cos_transform()
  27. {
  28. fftw_destroy_plan(plan);
  29. }
  30. void execute(double* data1, double* data2)
  31. {
  32. fftw_execute_r2r(plan, data1, data2);
  33. }
  34. static double cos(double x) { return std::cos(x); }
  35. static double fabs(double x) { return std::fabs(x); }
  36. private:
  37. fftw_plan plan;
  38. };
  39. template<>
  40. struct fftw_cos_transform<float>
  41. {
  42. fftw_cos_transform(int n, float* data1, float* data2)
  43. {
  44. plan = fftwf_plan_r2r_1d(n, data1, data2, FFTW_REDFT10, FFTW_ESTIMATE);
  45. }
  46. ~fftw_cos_transform()
  47. {
  48. fftwf_destroy_plan(plan);
  49. }
  50. void execute(float* data1, float* data2)
  51. {
  52. fftwf_execute_r2r(plan, data1, data2);
  53. }
  54. static float cos(float x) { return std::cos(x); }
  55. static float fabs(float x) { return std::fabs(x); }
  56. private:
  57. fftwf_plan plan;
  58. };
  59. template<>
  60. struct fftw_cos_transform<long double>
  61. {
  62. fftw_cos_transform(int n, long double* data1, long double* data2)
  63. {
  64. plan = fftwl_plan_r2r_1d(n, data1, data2, FFTW_REDFT10, FFTW_ESTIMATE);
  65. }
  66. ~fftw_cos_transform()
  67. {
  68. fftwl_destroy_plan(plan);
  69. }
  70. void execute(long double* data1, long double* data2)
  71. {
  72. fftwl_execute_r2r(plan, data1, data2);
  73. }
  74. static long double cos(long double x) { return std::cos(x); }
  75. static long double fabs(long double x) { return std::fabs(x); }
  76. private:
  77. fftwl_plan plan;
  78. };
  79. #ifdef BOOST_HAS_FLOAT128
  80. template<>
  81. struct fftw_cos_transform<__float128>
  82. {
  83. fftw_cos_transform(int n, __float128* data1, __float128* data2)
  84. {
  85. plan = fftwq_plan_r2r_1d(n, data1, data2, FFTW_REDFT10, FFTW_ESTIMATE);
  86. }
  87. ~fftw_cos_transform()
  88. {
  89. fftwq_destroy_plan(plan);
  90. }
  91. void execute(__float128* data1, __float128* data2)
  92. {
  93. fftwq_execute_r2r(plan, data1, data2);
  94. }
  95. static __float128 cos(__float128 x) { return cosq(x); }
  96. static __float128 fabs(__float128 x) { return fabsq(x); }
  97. private:
  98. fftwq_plan plan;
  99. };
  100. #endif
  101. }
  102. template<class Real>
  103. class chebyshev_transform
  104. {
  105. public:
  106. template<class F>
  107. chebyshev_transform(const F& f, Real a, Real b,
  108. Real tol = 500 * std::numeric_limits<Real>::epsilon(),
  109. size_t max_refinements = 15) : m_a(a), m_b(b)
  110. {
  111. if (a >= b)
  112. {
  113. throw std::domain_error("a < b is required.\n");
  114. }
  115. using boost::math::constants::half;
  116. using boost::math::constants::pi;
  117. using std::cos;
  118. using std::abs;
  119. Real bma = (b-a)*half<Real>();
  120. Real bpa = (b+a)*half<Real>();
  121. size_t n = 256;
  122. std::vector<Real> vf;
  123. size_t refinements = 0;
  124. while(refinements < max_refinements)
  125. {
  126. vf.resize(n);
  127. m_coeffs.resize(n);
  128. detail::fftw_cos_transform<Real> plan(static_cast<int>(n), vf.data(), m_coeffs.data());
  129. Real inv_n = 1/static_cast<Real>(n);
  130. for(size_t j = 0; j < n/2; ++j)
  131. {
  132. // Use symmetry cos((j+1/2)pi/n) = - cos((n-1-j+1/2)pi/n)
  133. Real y = detail::fftw_cos_transform<Real>::cos(pi<Real>()*(j+half<Real>())*inv_n);
  134. vf[j] = f(y*bma + bpa)*inv_n;
  135. vf[n-1-j]= f(bpa-y*bma)*inv_n;
  136. }
  137. plan.execute(vf.data(), m_coeffs.data());
  138. Real max_coeff = 0;
  139. for (auto const & coeff : m_coeffs)
  140. {
  141. if (detail::fftw_cos_transform<Real>::fabs(coeff) > max_coeff)
  142. {
  143. max_coeff = detail::fftw_cos_transform<Real>::fabs(coeff);
  144. }
  145. }
  146. size_t j = m_coeffs.size() - 1;
  147. while (abs(m_coeffs[j])/max_coeff < tol)
  148. {
  149. --j;
  150. }
  151. // If ten coefficients are eliminated, the we say we've done all
  152. // we need to do:
  153. if (n - j > 10)
  154. {
  155. m_coeffs.resize(j+1);
  156. return;
  157. }
  158. n *= 2;
  159. ++refinements;
  160. }
  161. }
  162. Real operator()(Real x) const
  163. {
  164. using boost::math::constants::half;
  165. if (x > m_b || x < m_a)
  166. {
  167. throw std::domain_error("x not in [a, b]\n");
  168. }
  169. Real z = (2*x - m_a - m_b)/(m_b - m_a);
  170. return chebyshev_clenshaw_recurrence(m_coeffs.data(), m_coeffs.size(), z);
  171. }
  172. // Integral over entire domain [a, b]
  173. Real integrate() const
  174. {
  175. Real Q = m_coeffs[0]/2;
  176. for(size_t j = 2; j < m_coeffs.size(); j += 2)
  177. {
  178. Q += -m_coeffs[j]/((j+1)*(j-1));
  179. }
  180. return (m_b - m_a)*Q;
  181. }
  182. const std::vector<Real>& coefficients() const
  183. {
  184. return m_coeffs;
  185. }
  186. Real prime(Real x) const
  187. {
  188. Real z = (2*x - m_a - m_b)/(m_b - m_a);
  189. Real dzdx = 2/(m_b - m_a);
  190. if (m_coeffs.size() < 2)
  191. {
  192. return 0;
  193. }
  194. Real b2 = 0;
  195. Real d2 = 0;
  196. Real b1 = m_coeffs[m_coeffs.size() -1];
  197. Real d1 = 0;
  198. for(size_t j = m_coeffs.size() - 2; j >= 1; --j)
  199. {
  200. Real tmp1 = 2*z*b1 - b2 + m_coeffs[j];
  201. Real tmp2 = 2*z*d1 - d2 + 2*b1;
  202. b2 = b1;
  203. b1 = tmp1;
  204. d2 = d1;
  205. d1 = tmp2;
  206. }
  207. return dzdx*(z*d1 - d2 + b1);
  208. }
  209. private:
  210. std::vector<Real> m_coeffs;
  211. Real m_a;
  212. Real m_b;
  213. };
  214. }}
  215. #endif