Created
November 17, 2024 21:38
-
-
Save DSCF-1224/02acde2d131f76d6a4fd95811c1e1e01 to your computer and use it in GitHub Desktop.
PRNG using SIMD (AVX2)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <random> | |
#include <immintrin.h> | |
#ifndef PRNG_HPP_ | |
#define PRNG_HPP_ | |
namespace prng | |
{ | |
class Prng32 | |
{ | |
public: | |
using result_type = uint32_t; | |
constexpr static result_type min(void) noexcept | |
{ | |
return std::numeric_limits<result_type>::min(); | |
} | |
constexpr static result_type max(void) noexcept | |
{ | |
return std::numeric_limits<result_type>::max(); | |
} | |
}; | |
class Prng32M128 : public Prng32 | |
{ | |
friend class SplitMix32M128; | |
friend class Xorshift32M128; | |
private: | |
using state_type = __m128i; | |
state_type state_; | |
public: | |
static const int kNumSequences = 4; | |
private: | |
constexpr void Load(const result_type seed[kNumSequences]) noexcept | |
{ | |
this->state_ = _mm_load_si128((state_type*)seed); | |
} | |
public: | |
constexpr explicit Prng32M128(void) noexcept | |
{ | |
alignas(sizeof(result_type)) result_type seed[kNumSequences]; | |
std::random_device seed_gen; | |
for (size_t i = 0; i < kNumSequences; i++) | |
{ | |
seed[i] = seed_gen(); | |
} | |
this->Load(seed); | |
} | |
constexpr explicit Prng32M128(const state_type &seed) noexcept | |
{ | |
this->state_ = seed; | |
} | |
explicit Prng32M128(const result_type seed[kNumSequences]) noexcept | |
{ | |
this->Load(seed); | |
} | |
}; | |
class Prng32M256 : public Prng32 | |
{ | |
friend class SplitMix32M256; | |
friend class Xorshift32M256; | |
private: | |
using state_type = __m256i; | |
state_type state_; | |
public: | |
static const int kNumSequences = 8; | |
private: | |
constexpr void Load(const result_type seed[kNumSequences]) noexcept | |
{ | |
this->state_ = _mm256_load_si256((state_type*)seed); | |
} | |
public: | |
constexpr explicit Prng32M256(void) noexcept | |
{ | |
alignas(sizeof(result_type)) result_type seed[kNumSequences]; | |
std::random_device seed_gen; | |
for (size_t i = 0; i < kNumSequences; i++) | |
{ | |
seed[i] = seed_gen(); | |
} | |
this->Load(seed); | |
} | |
constexpr explicit Prng32M256(const state_type &seed) noexcept | |
{ | |
this->state_ = seed; | |
} | |
explicit Prng32M256(const result_type seed[kNumSequences]) noexcept | |
{ | |
this->Load(seed); | |
} | |
}; | |
class Prng32Single : public Prng32 | |
{ | |
friend class SplitMix32; | |
friend class Xorshift32; | |
private: | |
result_type state_; | |
public: | |
constexpr explicit Prng32Single(void) noexcept | |
{ | |
std::random_device seed_gen; | |
this->state_ = seed_gen(); | |
} | |
constexpr explicit Prng32Single(const result_type seed) noexcept | |
{ | |
this->state_ = seed; | |
} | |
}; | |
class Prng64 | |
{ | |
public: | |
using result_type = uint64_t; | |
constexpr static result_type min(void) noexcept | |
{ | |
return std::numeric_limits<result_type>::min(); | |
} | |
constexpr static result_type max(void) noexcept | |
{ | |
return std::numeric_limits<result_type>::max(); | |
} | |
}; | |
class Prng64M128 : public Prng64 | |
{ | |
friend class Xorshift64M128; | |
private: | |
using state_type = __m128i; | |
state_type state_; | |
public: | |
static const int kNumSequences = 2; | |
private: | |
constexpr void Load(const result_type seed[kNumSequences]) noexcept | |
{ | |
this->state_ = _mm_load_si128((state_type*)seed); | |
} | |
public: | |
constexpr explicit Prng64M128(void) noexcept | |
{ | |
alignas(sizeof(result_type)) result_type seed[kNumSequences]; | |
std::random_device seed_gen; | |
for (size_t i = 0; i < kNumSequences; i++) | |
{ | |
seed[i] = seed_gen(); | |
} | |
this->Load(seed); | |
} | |
constexpr explicit Prng64M128(const state_type &seed) noexcept | |
{ | |
this->state_ = seed; | |
} | |
explicit Prng64M128(const result_type seed[kNumSequences]) noexcept | |
{ | |
this->Load(seed); | |
} | |
}; | |
class Prng64M256 : public Prng64 | |
{ | |
friend class Xorshift64M256; | |
private: | |
using state_type = __m256i; | |
state_type state_; | |
public: | |
static const int kNumSequences = 4; | |
private: | |
constexpr void Load(const result_type seed[kNumSequences]) noexcept | |
{ | |
this->state_ = _mm256_load_si256((state_type*)seed); | |
} | |
public: | |
constexpr explicit Prng64M256(void) noexcept | |
{ | |
alignas(sizeof(result_type)) result_type seed[kNumSequences]; | |
std::random_device seed_gen; | |
for (size_t i = 0; i < kNumSequences; i++) | |
{ | |
seed[i] = seed_gen(); | |
} | |
this->Load(seed); | |
} | |
constexpr explicit Prng64M256(const state_type &seed) noexcept | |
{ | |
this->state_ = seed; | |
} | |
explicit Prng64M256(const result_type seed[kNumSequences]) noexcept | |
{ | |
this->Load(seed); | |
} | |
}; | |
class Prng64Single : public Prng64 | |
{ | |
friend class SplitMix64; | |
friend class Xorshift64; | |
private: | |
result_type state_; | |
public: | |
constexpr explicit Prng64Single(void) noexcept | |
{ | |
std::random_device seed_gen; | |
this->state_ = seed_gen(); | |
} | |
constexpr explicit Prng64Single(const result_type seed) noexcept | |
{ | |
this->state_ = seed; | |
} | |
}; | |
class SplitMix32 : public Prng32Single | |
{ | |
public: | |
static const result_type kMultiplier1 = static_cast<result_type>(0x85ebca6b); | |
static const result_type kMultiplier2 = static_cast<result_type>(0xc2b2ae35); | |
static const result_type kOffset = static_cast<result_type>(0x9e3779b9); | |
static const int kShift1 = 16; | |
static const int kShift2 = 13; | |
static const int kShift3 = 16; | |
constexpr SplitMix32(void) noexcept : Prng32Single() | |
{ | |
/* NOTHING TO DO HERE */ | |
} | |
constexpr SplitMix32(const result_type seed) noexcept : Prng32Single(seed) | |
{ | |
/* NOTHING TO DO HERE */ | |
} | |
constexpr result_type operator()() noexcept | |
{ | |
this->state_ += kOffset; | |
result_type x = this->state_; | |
x = (x ^ (x >> kShift1)) * kMultiplier1; | |
x = (x ^ (x >> kShift2)) * kMultiplier2; | |
x ^= x >> kShift3; | |
return x; | |
} | |
}; | |
class SplitMix32M128 : public Prng32M128 | |
{ | |
private: | |
static const state_type kMultiplier1; | |
static const state_type kMultiplier2; | |
static const state_type kOffset; | |
public: | |
constexpr explicit SplitMix32M128(const state_type &seed) noexcept : Prng32M128(seed) | |
{ | |
/* NOTHING TO DO HERE */ | |
} | |
explicit SplitMix32M128(const result_type seed[kNumSequences]) noexcept : Prng32M128(seed) | |
{ | |
/* NOTHING TO DO HERE */ | |
} | |
result_type operator()() noexcept | |
{ | |
alignas(sizeof(result_type)) result_type harvest[kNumSequences]; | |
this->Generate(harvest); | |
return harvest[0]; | |
} | |
state_type Generate(void) noexcept | |
{ | |
this->state_ = _mm_add_epi32(this->state_, kOffset); | |
state_type x = this->state_; | |
x = _mm_mullo_epi32(_mm_xor_si128(x, _mm_srli_epi32(x, SplitMix32::kShift1)), kMultiplier1); | |
x = _mm_mullo_epi32(_mm_xor_si128(x, _mm_srli_epi32(x, SplitMix32::kShift2)), kMultiplier2); | |
x = _mm_xor_si128(x, _mm_srli_epi32(x, SplitMix32::kShift3)); | |
return x; | |
} | |
void Generate(const result_type harvest[kNumSequences]) noexcept | |
{ | |
_mm_store_si128((state_type*)harvest, this->Generate()); | |
} | |
}; | |
const Prng32M128::state_type SplitMix32M128::kMultiplier1 = _mm_set1_epi32(SplitMix32::kMultiplier1); | |
const Prng32M128::state_type SplitMix32M128::kMultiplier2 = _mm_set1_epi32(SplitMix32::kMultiplier2); | |
const Prng32M128::state_type SplitMix32M128::kOffset = _mm_set1_epi32(SplitMix32::kOffset); | |
class SplitMix32M256 : public Prng32M256 | |
{ | |
private: | |
static const state_type kMultiplier1; | |
static const state_type kMultiplier2; | |
static const state_type kOffset; | |
public: | |
constexpr explicit SplitMix32M256(const state_type &seed) noexcept : Prng32M256(seed) | |
{ | |
/* NOTHING TO DO HERE */ | |
} | |
explicit SplitMix32M256(const result_type seed[kNumSequences]) noexcept : Prng32M256(seed) | |
{ | |
/* NOTHING TO DO HERE */ | |
} | |
result_type operator()() noexcept | |
{ | |
alignas(sizeof(result_type)) result_type harvest[kNumSequences]; | |
this->Generate(harvest); | |
return harvest[0]; | |
} | |
state_type Generate(void) noexcept | |
{ | |
this->state_ = _mm256_add_epi32(this->state_, kOffset); | |
state_type x = this->state_; | |
x = _mm256_mullo_epi32(_mm256_xor_si256(x, _mm256_srli_epi32(x, SplitMix32::kShift1)), kMultiplier1); | |
x = _mm256_mullo_epi32(_mm256_xor_si256(x, _mm256_srli_epi32(x, SplitMix32::kShift2)), kMultiplier2); | |
x = _mm256_xor_si256(x, _mm256_srli_epi32(x, SplitMix32::kShift3)); | |
return x; | |
} | |
void Generate(const result_type harvest[kNumSequences]) noexcept | |
{ | |
_mm256_store_si256((state_type*)harvest, this->Generate()); | |
} | |
}; | |
const Prng32M256::state_type SplitMix32M256::kMultiplier1 = _mm256_set1_epi32(SplitMix32::kMultiplier1); | |
const Prng32M256::state_type SplitMix32M256::kMultiplier2 = _mm256_set1_epi32(SplitMix32::kMultiplier2); | |
const Prng32M256::state_type SplitMix32M256::kOffset = _mm256_set1_epi32(SplitMix32::kOffset); | |
// reference implementation: https://www.jstatsoft.org/article/view/v008i14 | |
class Xorshift32 : public Prng32Single | |
{ | |
public: | |
static const int kShift1 = 13; | |
static const int kShift2 = 17; | |
static const int kShift3 = 5; | |
constexpr explicit Xorshift32(void) noexcept : Prng32Single() | |
{ | |
/* NOTHING TO DO HERE */ | |
} | |
constexpr explicit Xorshift32(const result_type seed) noexcept : Prng32Single(seed) | |
{ | |
/* NOTHING TO DO HERE */ | |
} | |
constexpr result_type operator()() noexcept | |
{ | |
result_type x = this->state_; | |
x ^= (x << kShift1); | |
x ^= (x >> kShift2); | |
x ^= (x << kShift3); | |
this->state_ = x; | |
return x; | |
} | |
}; | |
// reference implementation: https://www.jstatsoft.org/article/view/v008i14 | |
class Xorshift32M128 : public Prng32M128 | |
{ | |
public: | |
constexpr explicit Xorshift32M128(const state_type &seed) noexcept : Prng32M128(seed) | |
{ | |
/* NOTHING TO DO HERE */ | |
} | |
explicit Xorshift32M128(const result_type seed[kNumSequences]) noexcept : Prng32M128(seed) | |
{ | |
/* NOTHING TO DO HERE */ | |
} | |
result_type operator()() noexcept | |
{ | |
alignas(sizeof(result_type)) result_type harvest[kNumSequences]; | |
this->Generate(harvest); | |
return harvest[0]; | |
} | |
state_type Generate(void) noexcept | |
{ | |
state_type x = this->state_; | |
x = _mm_xor_si128(x, _mm_slli_epi32(x, Xorshift32::kShift1)); | |
x = _mm_xor_si128(x, _mm_srli_epi32(x, Xorshift32::kShift2)); | |
x = _mm_xor_si128(x, _mm_slli_epi32(x, Xorshift32::kShift3)); | |
this->state_ = x; | |
return x; | |
} | |
void Generate(const result_type harvest[kNumSequences]) noexcept | |
{ | |
_mm_store_si128((state_type*)harvest, this->Generate()); | |
} | |
}; | |
// reference implementation: https://www.jstatsoft.org/article/view/v008i14 | |
class Xorshift32M256 : public Prng32M256 | |
{ | |
public: | |
constexpr explicit Xorshift32M256(const state_type &seed) noexcept : Prng32M256(seed) | |
{ | |
/* NOTHING TO DO HERE */ | |
} | |
explicit Xorshift32M256(const result_type seed[kNumSequences]) noexcept : Prng32M256(seed) | |
{ | |
/* NOTHING TO DO HERE */ | |
} | |
result_type operator()() noexcept | |
{ | |
alignas(sizeof(result_type)) result_type harvest[kNumSequences]; | |
this->Generate(harvest); | |
return harvest[0]; | |
} | |
state_type Generate(void) noexcept | |
{ | |
state_type x = this->state_; | |
x = _mm256_xor_si256(x, _mm256_slli_epi32(x, Xorshift32::kShift1)); | |
x = _mm256_xor_si256(x, _mm256_srli_epi32(x, Xorshift32::kShift2)); | |
x = _mm256_xor_si256(x, _mm256_slli_epi32(x, Xorshift32::kShift3)); | |
this->state_ = x; | |
return x; | |
} | |
void Generate(const result_type harvest[kNumSequences]) noexcept | |
{ | |
_mm256_store_si256((state_type*)harvest, this->Generate()); | |
} | |
}; | |
// reference implementation: https://www.jstatsoft.org/article/view/v008i14 | |
class Xorshift64 : public Prng64Single | |
{ | |
public: | |
static const int kShift1 = 13; | |
static const int kShift2 = 7; | |
static const int kShift3 = 17; | |
explicit Xorshift64(void) noexcept : Prng64Single() | |
{ | |
/* NOTHING TO DO HERE */ | |
} | |
constexpr explicit Xorshift64(const result_type seed) : Prng64Single(seed) | |
{ | |
/* NOTHING TO DO HERE */ | |
} | |
constexpr result_type operator()() noexcept | |
{ | |
result_type x = this->state_; | |
x ^= (x << kShift1); | |
x ^= (x >> kShift2); | |
x ^= (x << kShift3); | |
this->state_ = x; | |
return x; | |
} | |
}; | |
// reference implementation: https://www.jstatsoft.org/article/view/v008i14 | |
class Xorshift64M128 : public Prng64M128 | |
{ | |
public: | |
constexpr explicit Xorshift64M128(const state_type &seed) noexcept : Prng64M128(seed) | |
{ | |
/* NOTHING TO DO HERE */ | |
} | |
explicit Xorshift64M128(const result_type seed[kNumSequences]) noexcept : Prng64M128(seed) | |
{ | |
/* NOTHING TO DO HERE */ | |
} | |
result_type operator()() noexcept | |
{ | |
alignas(sizeof(result_type)) result_type harvest[kNumSequences]; | |
this->Generate(harvest); | |
return harvest[0]; | |
} | |
state_type Generate(void) noexcept | |
{ | |
state_type x = this->state_; | |
x = _mm_xor_si128(x, _mm_slli_epi64(x, Xorshift64::kShift1)); | |
x = _mm_xor_si128(x, _mm_srli_epi64(x, Xorshift64::kShift2)); | |
x = _mm_xor_si128(x, _mm_slli_epi64(x, Xorshift64::kShift3)); | |
this->state_ = x; | |
return x; | |
} | |
void Generate(const result_type harvest[kNumSequences]) noexcept | |
{ | |
_mm_store_si128((state_type*)harvest, this->Generate()); | |
} | |
}; | |
// reference implementation: https://www.jstatsoft.org/article/view/v008i14 | |
class Xorshift64M256 : public Prng64M256 | |
{ | |
public: | |
constexpr explicit Xorshift64M256(const state_type &seed) noexcept : Prng64M256(seed) | |
{ | |
/* NOTHING TO DO HERE */ | |
} | |
explicit Xorshift64M256(const result_type seed[kNumSequences]) noexcept : Prng64M256(seed) | |
{ | |
/* NOTHING TO DO HERE */ | |
} | |
result_type operator()() noexcept | |
{ | |
alignas(sizeof(result_type)) result_type harvest[kNumSequences]; | |
this->Generate(harvest); | |
return harvest[0]; | |
} | |
state_type Generate(void) noexcept | |
{ | |
state_type x = this->state_; | |
x = _mm256_xor_si256(x, _mm256_slli_epi64(x, Xorshift64::kShift1)); | |
x = _mm256_xor_si256(x, _mm256_srli_epi64(x, Xorshift64::kShift2)); | |
x = _mm256_xor_si256(x, _mm256_slli_epi64(x, Xorshift64::kShift3)); | |
this->state_ = x; | |
return x; | |
} | |
void Generate(const result_type harvest[kNumSequences]) noexcept | |
{ | |
_mm256_store_si256((state_type*)harvest, this->Generate()); | |
} | |
}; | |
template<std::uniform_random_bit_generator G> | |
static constexpr bool IsUniformRandomBitGenerator(void) | |
{ | |
return true; | |
} | |
template<typename G> | |
static constexpr bool IsUniformRandomBitGenerator(void) | |
{ | |
return false; | |
} | |
int IsUniformRandomBitGenerator(void) | |
{ | |
static_assert( IsUniformRandomBitGenerator< std::mt19937 >(), "std::mt19937" ); | |
static_assert( IsUniformRandomBitGenerator< std::mt19937_64 >(), "std::mt19937_64" ); | |
static_assert( IsUniformRandomBitGenerator< SplitMix32 >(), "SplitMix32" ); | |
static_assert( IsUniformRandomBitGenerator< SplitMix32M128 >(), "SplitMix32M128" ); | |
static_assert( IsUniformRandomBitGenerator< SplitMix32M256 >(), "SplitMix32M256" ); | |
static_assert( IsUniformRandomBitGenerator< Xorshift32 >(), "Xorshift32" ); | |
static_assert( IsUniformRandomBitGenerator< Xorshift32M128 >(), "Xorshift32M128" ); | |
static_assert( IsUniformRandomBitGenerator< Xorshift32M256 >(), "Xorshift32M256" ); | |
static_assert( IsUniformRandomBitGenerator< Xorshift64 >(), "Xorshift64" ); | |
static_assert( IsUniformRandomBitGenerator< Xorshift64M128 >(), "Xorshift64M128" ); | |
static_assert( IsUniformRandomBitGenerator< Xorshift64M256 >(), "Xorshift64M256" ); | |
return EXIT_SUCCESS; | |
} | |
} | |
#endif // PRNG_HPP_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"version": "2.0.0", | |
"tasks": [ | |
{ | |
"type": "cppbuild", | |
"label": "C/C++: g++-14 build active file", | |
"command": "/usr/bin/g++-14", | |
"args": [ | |
"-fdiagnostics-color=always", | |
"-std=gnu++2b", | |
"-mavx2", | |
"-O0", | |
"-g", | |
"-rdynamic", | |
"${file}", | |
"-o", | |
"${fileDirname}/${fileBasenameNoExtension}" | |
], | |
"options": { | |
"cwd": "${fileDirname}" | |
}, | |
"problemMatcher": [ | |
"$gcc" | |
], | |
"group": { | |
"kind": "build", | |
"isDefault": true | |
}, | |
"detail": "Task generated by Debugger." | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment