Created
August 20, 2024 20:48
-
-
Save es3n1n/5f9ca17c030305f56679afba0543f7cb to your computer and use it in GitHub Desktop.
unfinished uniform_init_distribution drop-in replacement that utilizes https://arxiv.org/pdf/1805.10941
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/// \brief A drop-in replacement for `std::uniform_int_distribution`. | |
/// | |
/// We are using a replacement because the uniform_int_distribution implementation is different in libc and libstdc++. | |
/// This is important because otherwise we can't reproduce random values across multiple platforms. | |
/// | |
/// \tparam IntType The integer type to be used for the distribution. Defaults to `int`. | |
template <typename IntType = int> | |
class UniformIntDistribution { | |
public: | |
using ResultTy = IntType; | |
explicit UniformIntDistribution(ResultTy min, ResultTy max = (std::numeric_limits<ResultTy>::max)()): min_(min), max_(max) { } | |
/// \brief Generates a random integer within the distribution range. | |
/// \engine The random engine instance to use. | |
template <typename Engine> | |
ResultTy operator()(Engine& engine) { | |
return eval(engine, min_, max_); | |
} | |
private: | |
using UResultTy = std::make_unsigned_t<ResultTy>; | |
template <class DiffTy, class URng> | |
class RngFromURng { // wrap a URNG as an RNG | |
public: | |
using ConvUDiffTy = std::make_unsigned_t<DiffTy>; | |
using RngResultTy = std::invoke_result_t<URng&>; | |
using _Udiff = std::conditional_t<sizeof(RngResultTy) < sizeof(ConvUDiffTy), ConvUDiffTy, RngResultTy>; | |
static constexpr unsigned int _Udiff_bits = sizeof(_Udiff) * CHAR_BIT; | |
using _Uprod = std::conditional_t<_Udiff_bits <= 16, uint32_t, | |
std::conditional_t<_Udiff_bits <= 32, uint64_t, std::_Unsigned128>>; // fixme: _Unsigned128 | |
explicit RngFromURng(URng& _Func): _Ref(_Func) { } | |
DiffTy operator()(DiffTy _Index) { // adapt _Urng closed range to [0, _Index) | |
// From Daniel Lemire, "Fast Random Integer Generation in an Interval", ACM Trans. Model. Comput. Simul. 29 (1), | |
// 2019. | |
// | |
// Algorithm 5 <-> This Code: | |
// m <-> _Product | |
// l <-> _Rem | |
// s <-> _Index | |
// t <-> _Threshold | |
// L <-> _Generated_bits | |
// 2^L - 1 <-> _Mask | |
_Udiff _Mask = _Bmask; | |
unsigned int _Niter = 1; | |
if constexpr (_Bits < _Udiff_bits) { | |
while (_Mask < static_cast<_Udiff>(_Index - 1)) { | |
_Mask <<= _Bits; | |
_Mask |= _Bmask; | |
++_Niter; | |
} | |
} | |
// x <- random integer in [0, 2^L) | |
// m <- x * s | |
auto _Product = _Get_random_product(_Index, _Niter); | |
// l <- m mod 2^L | |
auto _Rem = static_cast<_Udiff>(_Product) & _Mask; | |
if (_Rem < _Index) { | |
// t <- (2^L - s) mod s | |
const auto _Threshold = (_Mask - _Index + 1) % _Index; | |
while (_Rem < _Threshold) { | |
_Product = _Get_random_product(_Index, _Niter); | |
_Rem = static_cast<_Udiff>(_Product) & _Mask; | |
} | |
} | |
unsigned int _Generated_bits; | |
if constexpr (_Bits < _Udiff_bits) { | |
_Generated_bits = static_cast<unsigned int>(_Popcount(_Mask)); | |
} else { | |
_Generated_bits = _Udiff_bits; | |
} | |
// m / 2^L | |
return static_cast<DiffTy>(_Product >> _Generated_bits); | |
} | |
_Udiff _Get_all_bits() { | |
_Udiff _Ret = _Get_bits(); | |
if constexpr (_Bits < _Udiff_bits) { | |
for (unsigned int _Num = _Bits; _Num < _Udiff_bits; _Num += _Bits) { // don't mask away any bits | |
_Ret <<= _Bits; | |
_Ret |= _Get_bits(); | |
} | |
} | |
return _Ret; | |
} | |
RngFromURng(const RngFromURng&) = delete; | |
RngFromURng& operator=(const RngFromURng&) = delete; | |
private: | |
_Udiff _Get_bits() { // return a random value within [0, _Bmask] | |
static constexpr auto _Urng_min = (URng::min)(); | |
for (;;) { // repeat until random value is in range | |
const _Udiff _Val = _Ref() - _Urng_min; | |
if (_Val <= _Bmask) { | |
return _Val; | |
} | |
} | |
} | |
static constexpr size_t _Calc_bits() { | |
auto _Bits_local = _Udiff_bits; | |
auto _Bmask_local = static_cast<_Udiff>(-1); | |
for (; (URng::max)() - (URng::min)() < _Bmask_local; _Bmask_local >>= 1) { | |
--_Bits_local; | |
} | |
return _Bits_local; | |
} | |
_Uprod _Get_random_product(const DiffTy _Index, unsigned int _Niter) { | |
_Udiff _Ret = _Get_bits(); | |
if constexpr (_Bits < _Udiff_bits) { | |
while (--_Niter > 0) { | |
_Ret <<= _Bits; | |
_Ret |= _Get_bits(); | |
} | |
} | |
if constexpr (std::is_same_v<_Udiff, uint64_t>) { | |
uint64_t _High; | |
const auto _Low = std::_Base128::_UMul128(_Ret, static_cast<_Udiff>(_Index), _High); | |
return _Uprod{_Low, _High}; | |
} else { | |
return _Uprod{_Ret} * _Uprod{_Index}; | |
} | |
} | |
URng& _Ref; // reference to URNG | |
static constexpr size_t _Bits = _Calc_bits(); // number of random bits generated by _Get_bits() | |
static constexpr _Udiff _Bmask = static_cast<_Udiff>(-1) >> (_Udiff_bits - _Bits); // 2^_Bits - 1 | |
}; | |
/// \brief Evaluates the distribution and generates a random integer. | |
/// \param engine The random engine instance to use. | |
/// \param min The minimum value of the distribution (inclusive). | |
/// \param max The maximum value of the distribution (inclusive). | |
/// \return A random integer within the specified range. | |
template <typename Engine> | |
ResultTy eval(Engine& engine, ResultTy min, ResultTy max) const { | |
const auto u_min = adjust(min); | |
const auto u_max = adjust(max); | |
UResultTy result; | |
if ((u_max - u_min) == static_cast<UResultTy>(-1)) { | |
result = static_cast<UResultTy>(engine()); | |
} else { | |
result = static_cast<UResultTy>(RngFromURng<UResultTy, Engine>(engine)(static_cast<UResultTy>(u_max - u_min + 1))); | |
} | |
return static_cast<ResultTy>(adjust(static_cast<UResultTy>(result + u_min))); | |
} | |
static UResultTy adjust(UResultTy val) { // convert signed ranges to unsigned ranges and vice versa | |
if constexpr (std::is_signed_v<ResultTy>) { | |
constexpr UResultTy _Adjuster = (static_cast<UResultTy>(-1) >> 1) + 1; // 2^(N-1) | |
if (val < _Adjuster) { | |
return static_cast<UResultTy>(val + _Adjuster); | |
} else { | |
return static_cast<UResultTy>(val - _Adjuster); | |
} | |
} else { // ResultTy is already unsigned, do nothing | |
return val; | |
} | |
} | |
ResultTy min_; | |
ResultTy max_; | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment