normal_distribution.hpp 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. //---------------------------------------------------------------------------//
  2. // Copyright (c) 2013-2014 Kyle Lutz <kyle.r.lutz@gmail.com>
  3. //
  4. // Distributed under the Boost Software License, Version 1.0
  5. // See accompanying file LICENSE_1_0.txt or copy at
  6. // http://www.boost.org/LICENSE_1_0.txt
  7. //
  8. // See http://boostorg.github.com/compute for more information.
  9. //---------------------------------------------------------------------------//
  10. #ifndef BOOST_COMPUTE_RANDOM_NORMAL_DISTRIBUTION_HPP
  11. #define BOOST_COMPUTE_RANDOM_NORMAL_DISTRIBUTION_HPP
  12. #include <limits>
  13. #include <boost/assert.hpp>
  14. #include <boost/type_traits.hpp>
  15. #include <boost/compute/command_queue.hpp>
  16. #include <boost/compute/function.hpp>
  17. #include <boost/compute/types/fundamental.hpp>
  18. #include <boost/compute/type_traits/make_vector_type.hpp>
  19. namespace boost {
  20. namespace compute {
  21. /// \class normal_distribution
  22. /// \brief Produces random, normally-distributed floating-point numbers.
  23. ///
  24. /// The following example shows how to setup a normal distribution to
  25. /// produce random \c float values centered at \c 5:
  26. ///
  27. /// \snippet test/test_normal_distribution.cpp generate
  28. ///
  29. /// \see default_random_engine, uniform_real_distribution
  30. template<class RealType = float>
  31. class normal_distribution
  32. {
  33. public:
  34. typedef RealType result_type;
  35. /// Creates a new normal distribution producing numbers with the given
  36. /// \p mean and \p stddev.
  37. normal_distribution(RealType mean = 0.f, RealType stddev = 1.f)
  38. : m_mean(mean),
  39. m_stddev(stddev)
  40. {
  41. }
  42. /// Destroys the normal distribution object.
  43. ~normal_distribution()
  44. {
  45. }
  46. /// Returns the mean value of the distribution.
  47. result_type mean() const
  48. {
  49. return m_mean;
  50. }
  51. /// Returns the standard-deviation of the distribution.
  52. result_type stddev() const
  53. {
  54. return m_stddev;
  55. }
  56. /// Returns the minimum value of the distribution.
  57. result_type min BOOST_PREVENT_MACRO_SUBSTITUTION () const
  58. {
  59. return -std::numeric_limits<RealType>::infinity();
  60. }
  61. /// Returns the maximum value of the distribution.
  62. result_type max BOOST_PREVENT_MACRO_SUBSTITUTION () const
  63. {
  64. return std::numeric_limits<RealType>::infinity();
  65. }
  66. /// Generates normally-distributed floating-point numbers and stores
  67. /// them to the range [\p first, \p last).
  68. template<class OutputIterator, class Generator>
  69. void generate(OutputIterator first,
  70. OutputIterator last,
  71. Generator &generator,
  72. command_queue &queue)
  73. {
  74. typedef typename make_vector_type<RealType, 2>::type RealType2;
  75. size_t count = detail::iterator_range_size(first, last);
  76. vector<uint_> tmp(count, queue.get_context());
  77. generator.generate(tmp.begin(), tmp.end(), queue);
  78. BOOST_COMPUTE_FUNCTION(RealType2, box_muller, (const uint2_ x),
  79. {
  80. const RealType one = 1;
  81. const RealType two = 2;
  82. // Use nextafter to push values down into [0,1) range; without this, floating point rounding can
  83. // lead to have x1 = 1, but that would lead to taking the log of 0, which would result in negative
  84. // infinities; by pushing the values off 1 towards 0, we ensure this won't happen.
  85. const RealType x1 = nextafter(x.x / (RealType) UINT_MAX, (RealType) 0);
  86. const RealType x2 = x.y / (RealType) UINT_MAX;
  87. const RealType rho = sqrt(-two * log(one-x1));
  88. const RealType z1 = rho * cos(two * M_PI_F * x2);
  89. const RealType z2 = rho * sin(two * M_PI_F * x2);
  90. return (RealType2)(MEAN, MEAN) + (RealType2)(z1, z2) * (RealType2)(STDDEV, STDDEV);
  91. });
  92. box_muller.define("MEAN", boost::lexical_cast<std::string>(m_mean));
  93. box_muller.define("STDDEV", boost::lexical_cast<std::string>(m_stddev));
  94. box_muller.define("RealType", type_name<RealType>());
  95. box_muller.define("RealType2", type_name<RealType2>());
  96. transform(
  97. make_buffer_iterator<uint2_>(tmp.get_buffer(), 0),
  98. make_buffer_iterator<uint2_>(tmp.get_buffer(), count / 2),
  99. make_buffer_iterator<RealType2>(first.get_buffer(), 0),
  100. box_muller,
  101. queue
  102. );
  103. }
  104. private:
  105. RealType m_mean;
  106. RealType m_stddev;
  107. BOOST_STATIC_ASSERT_MSG(
  108. boost::is_floating_point<RealType>::value,
  109. "Template argument must be a floating point type"
  110. );
  111. };
  112. } // end compute namespace
  113. } // end boost namespace
  114. #endif // BOOST_COMPUTE_RANDOM_NORMAL_DISTRIBUTION_HPP