-
-
Save jeetsukumaran/5117563 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <sstream> | |
#include <string> | |
#include <random> | |
namespace sftrabbit { | |
template <typename RealType = double> | |
class beta_distribution | |
{ | |
public: | |
typedef RealType result_type; | |
class param_type | |
{ | |
public: | |
typedef beta_distribution distribution_type; | |
explicit param_type(RealType a = 2.0, RealType b = 2.0) | |
: a_param(a), b_param(b) { } | |
RealType a() const { return a_param; } | |
RealType b() const { return b_param; } | |
bool operator==(const param_type& other) const | |
{ | |
return (a_param == other.a_param && | |
b_param == other.b_param); | |
} | |
bool operator!=(const param_type& other) const | |
{ | |
return !(*this == other); | |
} | |
private: | |
RealType a_param, b_param; | |
}; | |
explicit beta_distribution(RealType a = 2.0, RealType b = 2.0) | |
: a_gamma(a), b_gamma(b) { } | |
explicit beta_distribution(const param_type& param) | |
: a_gamma(param.a()), b_gamma(param.b()) { } | |
void reset() { } | |
param_type param() const | |
{ | |
return param_type(a(), b()); | |
} | |
void param(const param_type& param) | |
{ | |
a_gamma = gamma_dist_type(param.a()); | |
b_gamma = gamma_dist_type(param.b()); | |
} | |
template <typename URNG> | |
result_type operator()(URNG& engine) | |
{ | |
return generate(engine, a_gamma, b_gamma); | |
} | |
template <typename URNG> | |
result_type operator()(URNG& engine, const param_type& param) | |
{ | |
gamma_dist_type a_param_gamma(param.a()), | |
b_param_gamma(param.b()); | |
return generate(engine, a_param_gamma, b_param_gamma); | |
} | |
result_type min() const { return 0.0; } | |
result_type max() const { return 1.0; } | |
RealType a() const { return a_gamma.alpha(); } | |
RealType b() const { return b_gamma.alpha(); } | |
bool operator==(const beta_distribution<result_type>& other) const | |
{ | |
return (param() == other.param() && | |
a_gamma == other.a_gamma && | |
b_gamma == other.b_gamma); | |
} | |
bool operator!=(const beta_distribution<result_type>& other) const | |
{ | |
return !(*this == other); | |
} | |
private: | |
typedef std::gamma_distribution<result_type> gamma_dist_type; | |
gamma_dist_type a_gamma, b_gamma; | |
template <typename URNG> | |
result_type generate(URNG& engine, | |
gamma_dist_type& x_gamma, | |
gamma_dist_type& y_gamma) | |
{ | |
result_type x = x_gamma(engine); | |
return x / (x + y_gamma(engine)); | |
} | |
}; | |
template <typename CharT, typename RealType> | |
std::basic_ostream<CharT>& operator<<(std::basic_ostream<CharT>& os, | |
const beta_distribution<RealType>& beta) | |
{ | |
os << "~Beta(" << beta.a() << "," << beta.b() << ")"; | |
return os; | |
} | |
template <typename CharT, typename RealType> | |
std::basic_istream<CharT>& operator>>(std::basic_istream<CharT>& is, | |
beta_distribution<RealType>& beta) | |
{ | |
std::string str; | |
RealType a, b; | |
if (std::getline(is, str, '(') && str == "~Beta" && | |
is >> a && is.get() == ',' && is >> b && is.get() == ')') { | |
beta = beta_distribution<RealType>(a, b); | |
} else { | |
is.setstate(std::ios::failbit); | |
} | |
return is; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment