hypergeometric_1F1_error_plot.cpp 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. // Copyright John Maddock 2006.
  2. // Use, modification and distribution are subject to the
  3. // Boost Software License, Version 1.0. (See accompanying file
  4. // LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
  5. #define BOOST_ENABLE_ASSERT_HANDLER
  6. #define BOOST_MATH_MAX_SERIES_ITERATION_POLICY INT_MAX
  7. // for consistent behaviour across compilers/platforms:
  8. #define BOOST_MATH_PROMOTE_DOUBLE_POLICY false
  9. // overflow to infinity is OK, we treat these as zero error as long as the sign is correct!
  10. #define BOOST_MATH_OVERFLOW_ERROR_POLICY ignore_error
  11. #include <iostream>
  12. #include <ctime>
  13. #include <boost/multiprecision/mpfr.hpp>
  14. #include <boost/math/special_functions/hypergeometric_1F1.hpp>
  15. #include <boost/math/special_functions/hypergeometric_pFq.hpp>
  16. #include <boost/math/special_functions/relative_difference.hpp>
  17. #include <boost/random.hpp>
  18. #include <set>
  19. #include <fstream>
  20. #include <boost/iostreams/tee.hpp>
  21. #include <boost/iostreams/stream.hpp>
  22. typedef double test_type;
  23. using boost::multiprecision::mpfr_float;
  24. namespace boost {
  25. //
  26. // We convert assertions into exceptions, so we can log them and continue:
  27. //
  28. void assertion_failed(char const * expr, char const *, char const * file, long line)
  29. {
  30. std::ostringstream oss;
  31. oss << file << ":" << line << " Assertion failed: " << expr;
  32. throw std::runtime_error(oss.str());
  33. }
  34. }
  35. void print_value(double x, std::ostream& os = std::cout)
  36. {
  37. int e;
  38. double m = std::frexp(x, &e);
  39. m = std::ldexp(m, 54);
  40. e -= 54;
  41. boost::int64_t val = (boost::int64_t)m;
  42. BOOST_ASSERT(std::ldexp((double)val, e) == x);
  43. os << "std::ldexp((double)" << val << ", " << e << ")";
  44. }
  45. void print_row(double a, double b, double z, mpfr_float result, std::ostream& os = std::cout)
  46. {
  47. os << " {{ ";
  48. print_value(a, os);
  49. os << ", ";
  50. print_value(b, os);
  51. os << ", ";
  52. print_value(z, os);
  53. os << ", SC_(" << std::setprecision(45) << result << ") }}" << std::endl;
  54. }
  55. struct error_data
  56. {
  57. error_data(double a, double b, double z, boost::intmax_t e)
  58. : a(a), b(b), z(z), error(e) {}
  59. double a, b, z;
  60. boost::intmax_t error;
  61. bool operator<(const error_data& other)const
  62. {
  63. return error < other.error;
  64. }
  65. };
  66. int main()
  67. {
  68. try {
  69. test_type max_a, max_b, max_z, min_a, min_b, min_z;
  70. unsigned number_of_samples;
  71. std::ofstream log_stream, incalculable_stream, unevaluated_stream, bins_stream;
  72. std::string basename;
  73. std::cout << "Enter range for a: ";
  74. std::cin >> min_a >> max_a;
  75. std::cout << "Enter range for b: ";
  76. std::cin >> min_b >> max_b;
  77. std::cout << "Enter range for z: ";
  78. std::cin >> min_z >> max_z;
  79. std::cout << "Enter number of samples: ";
  80. std::cin >> number_of_samples;
  81. std::cout << "Enter basename for log files: ";
  82. std::cin >> basename;
  83. typedef boost::iostreams::tee_device<std::ostream, std::ostream> tee_sink;
  84. typedef boost::iostreams::stream<tee_sink> tee_stream;
  85. log_stream.open((basename + ".log").c_str());
  86. tee_stream tee_log(tee_sink(std::cout, log_stream));
  87. incalculable_stream.open((basename + "_incalculable.log").c_str());
  88. unevaluated_stream.open((basename + "_unevaluated.log").c_str());
  89. bins_stream.open((basename + "_bins.csv").c_str());
  90. tee_stream tee_bins(tee_sink(std::cout, bins_stream));
  91. boost::random::mt19937 gen(std::time(0));
  92. boost::random::uniform_real_distribution<test_type> a_dist(min_a, max_a);
  93. boost::random::uniform_real_distribution<test_type> b_dist(min_b, max_b);
  94. boost::random::uniform_real_distribution<test_type> z_dist(min_z, max_z);
  95. std::multiset<error_data> errors;
  96. std::map<std::pair<int, int>, int> bins;
  97. unsigned incalculable = 0;
  98. unsigned evaluation_errors = 0;
  99. test_type max_error = 0;
  100. do
  101. {
  102. test_type a = a_dist(gen);
  103. test_type b = b_dist(gen);
  104. test_type z = z_dist(gen);
  105. test_type found, expected;
  106. mpfr_float mp_expected;
  107. try {
  108. mp_expected = boost::math::hypergeometric_pFq_precision({ mpfr_float(a) }, { mpfr_float(b) }, mpfr_float(z), 25, 200.0);
  109. expected = (test_type)mp_expected;
  110. }
  111. catch (const std::exception&)
  112. {
  113. // Unable to compute reference value:
  114. ++incalculable;
  115. tee_log << "Unable to compute reference value in reasonable time: " << std::endl;
  116. print_row(a, b, z, mpfr_float(0), tee_log);
  117. incalculable_stream << std::setprecision(6) << std::scientific << a << "," << b << "," << z << "\n";
  118. continue;
  119. }
  120. try
  121. {
  122. found = boost::math::hypergeometric_1F1(a, b, z);
  123. }
  124. catch (const std::exception&)
  125. {
  126. ++evaluation_errors;
  127. --number_of_samples;
  128. log_stream << "Unexpected exception calculating value: " << std::endl;
  129. print_row(a, b, z, mp_expected, log_stream);
  130. unevaluated_stream << std::setprecision(6) << std::scientific << a << "," << b << "," << z << "\n";
  131. continue;
  132. }
  133. test_type err = boost::math::epsilon_difference(found, expected);
  134. if (err > max_error)
  135. {
  136. tee_log << "New maximum error is: " << err << std::endl;
  137. print_row(a, b, z, mp_expected, tee_log);
  138. max_error = err;
  139. }
  140. try {
  141. errors.insert(error_data(a, b, z, boost::math::lltrunc(err)));
  142. }
  143. catch (...)
  144. {
  145. errors.insert(error_data(a, b, z, INT_MAX));
  146. }
  147. --number_of_samples;
  148. if (number_of_samples % 500 == 0)
  149. std::cout << number_of_samples << " samples to go" << std::endl;
  150. } while (number_of_samples);
  151. tee_log << "Max error found was: " << max_error << std::endl;
  152. unsigned current_bin = 0;
  153. unsigned lim = 1;
  154. unsigned old_lim = 0;
  155. while (errors.size())
  156. {
  157. old_lim = lim;
  158. lim *= 2;
  159. //std::cout << "Enter upper limit for bin " << current_bin << ": ";
  160. //std::cin >> lim;
  161. auto p = errors.upper_bound(error_data(0, 0, 0, lim));
  162. int bin_count = std::distance(errors.begin(), p);
  163. if (bin_count)
  164. {
  165. std::ofstream os((basename + "_errors_" + std::to_string(current_bin + 1) + ".csv").c_str());
  166. os << "a,b,z,error\n";
  167. bins[std::make_pair(old_lim, lim)] = bin_count;
  168. for (auto pos = errors.begin(); pos != p; ++pos)
  169. {
  170. os << pos->a << "," << pos->b << "," << pos->z << "," << pos->error << "\n";
  171. }
  172. errors.erase(errors.begin(), p);
  173. }
  174. ++current_bin;
  175. }
  176. tee_bins << "Results:\n\n";
  177. tee_bins << "#bin,Range,2^N,Count\n";
  178. int hash = 0;
  179. for (auto p = bins.begin(); p != bins.end(); ++p, ++hash)
  180. {
  181. tee_bins << hash << "," << p->first.first << "-" << p->first.second << "," << hash+1 << "," << p->second << std::endl;
  182. }
  183. if (evaluation_errors)
  184. {
  185. tee_bins << ",Failed,," << evaluation_errors << std::endl;
  186. }
  187. if (incalculable)
  188. {
  189. tee_bins << ",Incalculable,," << incalculable << std::endl;
  190. }
  191. }
  192. catch (const std::exception& e)
  193. {
  194. std::cout << "Terminating with unhandled exception: " << e.what() << std::endl;
  195. }
  196. return 0;
  197. }