bernoulli_distribution.hpp 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. //---------------------------------------------------------------------------//
  2. // Copyright (c) 2014 Roshan <thisisroshansmail@gmail.com>
  3. //
  4. // Distributed under the Boost Software License, Version 1.0
  5. // See accompanying file LICENSE_1_0.txt or copy at
  6. // http://www.boost.org/LICENSE_1_0.txt
  7. //
  8. // See http://boostorg.github.com/compute for more information.
  9. //---------------------------------------------------------------------------//
  10. #ifndef BOOST_COMPUTE_RANDOM_BERNOULLI_DISTRIBUTION_HPP
  11. #define BOOST_COMPUTE_RANDOM_BERNOULLI_DISTRIBUTION_HPP
  12. #include <boost/assert.hpp>
  13. #include <boost/type_traits.hpp>
  14. #include <boost/compute/command_queue.hpp>
  15. #include <boost/compute/function.hpp>
  16. #include <boost/compute/types/fundamental.hpp>
  17. #include <boost/compute/detail/iterator_range_size.hpp>
  18. #include <boost/compute/detail/literal.hpp>
  19. namespace boost {
  20. namespace compute {
  21. ///
  22. /// \class bernoulli_distribution
  23. /// \brief Produces random boolean values according to the following
  24. /// discrete probability function with parameter p :
  25. /// P(true/p) = p and P(false/p) = (1 - p)
  26. ///
  27. /// The following example shows how to setup a bernoulli distribution to
  28. /// produce random boolean values with parameter p = 0.25
  29. ///
  30. /// \snippet test/test_bernoulli_distribution.cpp generate
  31. ///
  32. template<class RealType = float>
  33. class bernoulli_distribution
  34. {
  35. public:
  36. /// Creates a new bernoulli distribution
  37. bernoulli_distribution(RealType p = 0.5f)
  38. : m_p(p)
  39. {
  40. }
  41. /// Destroys the bernoulli_distribution object
  42. ~bernoulli_distribution()
  43. {
  44. }
  45. /// Returns the value of the parameter p
  46. RealType p() const
  47. {
  48. return m_p;
  49. }
  50. /// Generates bernoulli distributed booleans and stores
  51. /// them in the range [\p first, \p last).
  52. template<class OutputIterator, class Generator>
  53. void generate(OutputIterator first,
  54. OutputIterator last,
  55. Generator &generator,
  56. command_queue &queue)
  57. {
  58. size_t count = detail::iterator_range_size(first, last);
  59. vector<uint_> tmp(count, queue.get_context());
  60. generator.generate(tmp.begin(), tmp.end(), queue);
  61. BOOST_COMPUTE_FUNCTION(bool, scale_random, (const uint_ x),
  62. {
  63. return (convert_RealType(x) / MAX_RANDOM) < PARAM;
  64. });
  65. scale_random.define("PARAM", detail::make_literal(m_p));
  66. scale_random.define("MAX_RANDOM", "UINT_MAX");
  67. scale_random.define(
  68. "convert_RealType", std::string("convert_") + type_name<RealType>()
  69. );
  70. transform(
  71. tmp.begin(), tmp.end(), first, scale_random, queue
  72. );
  73. }
  74. private:
  75. RealType m_p;
  76. BOOST_STATIC_ASSERT_MSG(
  77. boost::is_floating_point<RealType>::value,
  78. "Template argument must be a floating point type"
  79. );
  80. };
  81. } // end compute namespace
  82. } // end boost namespace
  83. #endif // BOOST_COMPUTE_RANDOM_BERNOULLI_DISTRIBUTION_HPP