Skip to content

Instantly share code, notes, and snippets.

@DSCF-1224
Created November 17, 2024 21:38
Show Gist options
  • Save DSCF-1224/02acde2d131f76d6a4fd95811c1e1e01 to your computer and use it in GitHub Desktop.
Save DSCF-1224/02acde2d131f76d6a4fd95811c1e1e01 to your computer and use it in GitHub Desktop.
PRNG using SIMD (AVX2)
#include <random>
#include <immintrin.h>
#ifndef PRNG_HPP_
#define PRNG_HPP_
namespace prng
{
class Prng32
{
public:
using result_type = uint32_t;
constexpr static result_type min(void) noexcept
{
return std::numeric_limits<result_type>::min();
}
constexpr static result_type max(void) noexcept
{
return std::numeric_limits<result_type>::max();
}
};
class Prng32M128 : public Prng32
{
friend class SplitMix32M128;
friend class Xorshift32M128;
private:
using state_type = __m128i;
state_type state_;
public:
static const int kNumSequences = 4;
private:
constexpr void Load(const result_type seed[kNumSequences]) noexcept
{
this->state_ = _mm_load_si128((state_type*)seed);
}
public:
constexpr explicit Prng32M128(void) noexcept
{
alignas(sizeof(result_type)) result_type seed[kNumSequences];
std::random_device seed_gen;
for (size_t i = 0; i < kNumSequences; i++)
{
seed[i] = seed_gen();
}
this->Load(seed);
}
constexpr explicit Prng32M128(const state_type &seed) noexcept
{
this->state_ = seed;
}
explicit Prng32M128(const result_type seed[kNumSequences]) noexcept
{
this->Load(seed);
}
};
class Prng32M256 : public Prng32
{
friend class SplitMix32M256;
friend class Xorshift32M256;
private:
using state_type = __m256i;
state_type state_;
public:
static const int kNumSequences = 8;
private:
constexpr void Load(const result_type seed[kNumSequences]) noexcept
{
this->state_ = _mm256_load_si256((state_type*)seed);
}
public:
constexpr explicit Prng32M256(void) noexcept
{
alignas(sizeof(result_type)) result_type seed[kNumSequences];
std::random_device seed_gen;
for (size_t i = 0; i < kNumSequences; i++)
{
seed[i] = seed_gen();
}
this->Load(seed);
}
constexpr explicit Prng32M256(const state_type &seed) noexcept
{
this->state_ = seed;
}
explicit Prng32M256(const result_type seed[kNumSequences]) noexcept
{
this->Load(seed);
}
};
class Prng32Single : public Prng32
{
friend class SplitMix32;
friend class Xorshift32;
private:
result_type state_;
public:
constexpr explicit Prng32Single(void) noexcept
{
std::random_device seed_gen;
this->state_ = seed_gen();
}
constexpr explicit Prng32Single(const result_type seed) noexcept
{
this->state_ = seed;
}
};
class Prng64
{
public:
using result_type = uint64_t;
constexpr static result_type min(void) noexcept
{
return std::numeric_limits<result_type>::min();
}
constexpr static result_type max(void) noexcept
{
return std::numeric_limits<result_type>::max();
}
};
class Prng64M128 : public Prng64
{
friend class Xorshift64M128;
private:
using state_type = __m128i;
state_type state_;
public:
static const int kNumSequences = 2;
private:
constexpr void Load(const result_type seed[kNumSequences]) noexcept
{
this->state_ = _mm_load_si128((state_type*)seed);
}
public:
constexpr explicit Prng64M128(void) noexcept
{
alignas(sizeof(result_type)) result_type seed[kNumSequences];
std::random_device seed_gen;
for (size_t i = 0; i < kNumSequences; i++)
{
seed[i] = seed_gen();
}
this->Load(seed);
}
constexpr explicit Prng64M128(const state_type &seed) noexcept
{
this->state_ = seed;
}
explicit Prng64M128(const result_type seed[kNumSequences]) noexcept
{
this->Load(seed);
}
};
class Prng64M256 : public Prng64
{
friend class Xorshift64M256;
private:
using state_type = __m256i;
state_type state_;
public:
static const int kNumSequences = 4;
private:
constexpr void Load(const result_type seed[kNumSequences]) noexcept
{
this->state_ = _mm256_load_si256((state_type*)seed);
}
public:
constexpr explicit Prng64M256(void) noexcept
{
alignas(sizeof(result_type)) result_type seed[kNumSequences];
std::random_device seed_gen;
for (size_t i = 0; i < kNumSequences; i++)
{
seed[i] = seed_gen();
}
this->Load(seed);
}
constexpr explicit Prng64M256(const state_type &seed) noexcept
{
this->state_ = seed;
}
explicit Prng64M256(const result_type seed[kNumSequences]) noexcept
{
this->Load(seed);
}
};
class Prng64Single : public Prng64
{
friend class SplitMix64;
friend class Xorshift64;
private:
result_type state_;
public:
constexpr explicit Prng64Single(void) noexcept
{
std::random_device seed_gen;
this->state_ = seed_gen();
}
constexpr explicit Prng64Single(const result_type seed) noexcept
{
this->state_ = seed;
}
};
class SplitMix32 : public Prng32Single
{
public:
static const result_type kMultiplier1 = static_cast<result_type>(0x85ebca6b);
static const result_type kMultiplier2 = static_cast<result_type>(0xc2b2ae35);
static const result_type kOffset = static_cast<result_type>(0x9e3779b9);
static const int kShift1 = 16;
static const int kShift2 = 13;
static const int kShift3 = 16;
constexpr SplitMix32(void) noexcept : Prng32Single()
{
/* NOTHING TO DO HERE */
}
constexpr SplitMix32(const result_type seed) noexcept : Prng32Single(seed)
{
/* NOTHING TO DO HERE */
}
constexpr result_type operator()() noexcept
{
this->state_ += kOffset;
result_type x = this->state_;
x = (x ^ (x >> kShift1)) * kMultiplier1;
x = (x ^ (x >> kShift2)) * kMultiplier2;
x ^= x >> kShift3;
return x;
}
};
class SplitMix32M128 : public Prng32M128
{
private:
static const state_type kMultiplier1;
static const state_type kMultiplier2;
static const state_type kOffset;
public:
constexpr explicit SplitMix32M128(const state_type &seed) noexcept : Prng32M128(seed)
{
/* NOTHING TO DO HERE */
}
explicit SplitMix32M128(const result_type seed[kNumSequences]) noexcept : Prng32M128(seed)
{
/* NOTHING TO DO HERE */
}
result_type operator()() noexcept
{
alignas(sizeof(result_type)) result_type harvest[kNumSequences];
this->Generate(harvest);
return harvest[0];
}
state_type Generate(void) noexcept
{
this->state_ = _mm_add_epi32(this->state_, kOffset);
state_type x = this->state_;
x = _mm_mullo_epi32(_mm_xor_si128(x, _mm_srli_epi32(x, SplitMix32::kShift1)), kMultiplier1);
x = _mm_mullo_epi32(_mm_xor_si128(x, _mm_srli_epi32(x, SplitMix32::kShift2)), kMultiplier2);
x = _mm_xor_si128(x, _mm_srli_epi32(x, SplitMix32::kShift3));
return x;
}
void Generate(const result_type harvest[kNumSequences]) noexcept
{
_mm_store_si128((state_type*)harvest, this->Generate());
}
};
const Prng32M128::state_type SplitMix32M128::kMultiplier1 = _mm_set1_epi32(SplitMix32::kMultiplier1);
const Prng32M128::state_type SplitMix32M128::kMultiplier2 = _mm_set1_epi32(SplitMix32::kMultiplier2);
const Prng32M128::state_type SplitMix32M128::kOffset = _mm_set1_epi32(SplitMix32::kOffset);
class SplitMix32M256 : public Prng32M256
{
private:
static const state_type kMultiplier1;
static const state_type kMultiplier2;
static const state_type kOffset;
public:
constexpr explicit SplitMix32M256(const state_type &seed) noexcept : Prng32M256(seed)
{
/* NOTHING TO DO HERE */
}
explicit SplitMix32M256(const result_type seed[kNumSequences]) noexcept : Prng32M256(seed)
{
/* NOTHING TO DO HERE */
}
result_type operator()() noexcept
{
alignas(sizeof(result_type)) result_type harvest[kNumSequences];
this->Generate(harvest);
return harvest[0];
}
state_type Generate(void) noexcept
{
this->state_ = _mm256_add_epi32(this->state_, kOffset);
state_type x = this->state_;
x = _mm256_mullo_epi32(_mm256_xor_si256(x, _mm256_srli_epi32(x, SplitMix32::kShift1)), kMultiplier1);
x = _mm256_mullo_epi32(_mm256_xor_si256(x, _mm256_srli_epi32(x, SplitMix32::kShift2)), kMultiplier2);
x = _mm256_xor_si256(x, _mm256_srli_epi32(x, SplitMix32::kShift3));
return x;
}
void Generate(const result_type harvest[kNumSequences]) noexcept
{
_mm256_store_si256((state_type*)harvest, this->Generate());
}
};
const Prng32M256::state_type SplitMix32M256::kMultiplier1 = _mm256_set1_epi32(SplitMix32::kMultiplier1);
const Prng32M256::state_type SplitMix32M256::kMultiplier2 = _mm256_set1_epi32(SplitMix32::kMultiplier2);
const Prng32M256::state_type SplitMix32M256::kOffset = _mm256_set1_epi32(SplitMix32::kOffset);
// reference implementation: https://www.jstatsoft.org/article/view/v008i14
class Xorshift32 : public Prng32Single
{
public:
static const int kShift1 = 13;
static const int kShift2 = 17;
static const int kShift3 = 5;
constexpr explicit Xorshift32(void) noexcept : Prng32Single()
{
/* NOTHING TO DO HERE */
}
constexpr explicit Xorshift32(const result_type seed) noexcept : Prng32Single(seed)
{
/* NOTHING TO DO HERE */
}
constexpr result_type operator()() noexcept
{
result_type x = this->state_;
x ^= (x << kShift1);
x ^= (x >> kShift2);
x ^= (x << kShift3);
this->state_ = x;
return x;
}
};
// reference implementation: https://www.jstatsoft.org/article/view/v008i14
class Xorshift32M128 : public Prng32M128
{
public:
constexpr explicit Xorshift32M128(const state_type &seed) noexcept : Prng32M128(seed)
{
/* NOTHING TO DO HERE */
}
explicit Xorshift32M128(const result_type seed[kNumSequences]) noexcept : Prng32M128(seed)
{
/* NOTHING TO DO HERE */
}
result_type operator()() noexcept
{
alignas(sizeof(result_type)) result_type harvest[kNumSequences];
this->Generate(harvest);
return harvest[0];
}
state_type Generate(void) noexcept
{
state_type x = this->state_;
x = _mm_xor_si128(x, _mm_slli_epi32(x, Xorshift32::kShift1));
x = _mm_xor_si128(x, _mm_srli_epi32(x, Xorshift32::kShift2));
x = _mm_xor_si128(x, _mm_slli_epi32(x, Xorshift32::kShift3));
this->state_ = x;
return x;
}
void Generate(const result_type harvest[kNumSequences]) noexcept
{
_mm_store_si128((state_type*)harvest, this->Generate());
}
};
// reference implementation: https://www.jstatsoft.org/article/view/v008i14
class Xorshift32M256 : public Prng32M256
{
public:
constexpr explicit Xorshift32M256(const state_type &seed) noexcept : Prng32M256(seed)
{
/* NOTHING TO DO HERE */
}
explicit Xorshift32M256(const result_type seed[kNumSequences]) noexcept : Prng32M256(seed)
{
/* NOTHING TO DO HERE */
}
result_type operator()() noexcept
{
alignas(sizeof(result_type)) result_type harvest[kNumSequences];
this->Generate(harvest);
return harvest[0];
}
state_type Generate(void) noexcept
{
state_type x = this->state_;
x = _mm256_xor_si256(x, _mm256_slli_epi32(x, Xorshift32::kShift1));
x = _mm256_xor_si256(x, _mm256_srli_epi32(x, Xorshift32::kShift2));
x = _mm256_xor_si256(x, _mm256_slli_epi32(x, Xorshift32::kShift3));
this->state_ = x;
return x;
}
void Generate(const result_type harvest[kNumSequences]) noexcept
{
_mm256_store_si256((state_type*)harvest, this->Generate());
}
};
// reference implementation: https://www.jstatsoft.org/article/view/v008i14
class Xorshift64 : public Prng64Single
{
public:
static const int kShift1 = 13;
static const int kShift2 = 7;
static const int kShift3 = 17;
explicit Xorshift64(void) noexcept : Prng64Single()
{
/* NOTHING TO DO HERE */
}
constexpr explicit Xorshift64(const result_type seed) : Prng64Single(seed)
{
/* NOTHING TO DO HERE */
}
constexpr result_type operator()() noexcept
{
result_type x = this->state_;
x ^= (x << kShift1);
x ^= (x >> kShift2);
x ^= (x << kShift3);
this->state_ = x;
return x;
}
};
// reference implementation: https://www.jstatsoft.org/article/view/v008i14
class Xorshift64M128 : public Prng64M128
{
public:
constexpr explicit Xorshift64M128(const state_type &seed) noexcept : Prng64M128(seed)
{
/* NOTHING TO DO HERE */
}
explicit Xorshift64M128(const result_type seed[kNumSequences]) noexcept : Prng64M128(seed)
{
/* NOTHING TO DO HERE */
}
result_type operator()() noexcept
{
alignas(sizeof(result_type)) result_type harvest[kNumSequences];
this->Generate(harvest);
return harvest[0];
}
state_type Generate(void) noexcept
{
state_type x = this->state_;
x = _mm_xor_si128(x, _mm_slli_epi64(x, Xorshift64::kShift1));
x = _mm_xor_si128(x, _mm_srli_epi64(x, Xorshift64::kShift2));
x = _mm_xor_si128(x, _mm_slli_epi64(x, Xorshift64::kShift3));
this->state_ = x;
return x;
}
void Generate(const result_type harvest[kNumSequences]) noexcept
{
_mm_store_si128((state_type*)harvest, this->Generate());
}
};
// reference implementation: https://www.jstatsoft.org/article/view/v008i14
class Xorshift64M256 : public Prng64M256
{
public:
constexpr explicit Xorshift64M256(const state_type &seed) noexcept : Prng64M256(seed)
{
/* NOTHING TO DO HERE */
}
explicit Xorshift64M256(const result_type seed[kNumSequences]) noexcept : Prng64M256(seed)
{
/* NOTHING TO DO HERE */
}
result_type operator()() noexcept
{
alignas(sizeof(result_type)) result_type harvest[kNumSequences];
this->Generate(harvest);
return harvest[0];
}
state_type Generate(void) noexcept
{
state_type x = this->state_;
x = _mm256_xor_si256(x, _mm256_slli_epi64(x, Xorshift64::kShift1));
x = _mm256_xor_si256(x, _mm256_srli_epi64(x, Xorshift64::kShift2));
x = _mm256_xor_si256(x, _mm256_slli_epi64(x, Xorshift64::kShift3));
this->state_ = x;
return x;
}
void Generate(const result_type harvest[kNumSequences]) noexcept
{
_mm256_store_si256((state_type*)harvest, this->Generate());
}
};
template<std::uniform_random_bit_generator G>
static constexpr bool IsUniformRandomBitGenerator(void)
{
return true;
}
template<typename G>
static constexpr bool IsUniformRandomBitGenerator(void)
{
return false;
}
int IsUniformRandomBitGenerator(void)
{
static_assert( IsUniformRandomBitGenerator< std::mt19937 >(), "std::mt19937" );
static_assert( IsUniformRandomBitGenerator< std::mt19937_64 >(), "std::mt19937_64" );
static_assert( IsUniformRandomBitGenerator< SplitMix32 >(), "SplitMix32" );
static_assert( IsUniformRandomBitGenerator< SplitMix32M128 >(), "SplitMix32M128" );
static_assert( IsUniformRandomBitGenerator< SplitMix32M256 >(), "SplitMix32M256" );
static_assert( IsUniformRandomBitGenerator< Xorshift32 >(), "Xorshift32" );
static_assert( IsUniformRandomBitGenerator< Xorshift32M128 >(), "Xorshift32M128" );
static_assert( IsUniformRandomBitGenerator< Xorshift32M256 >(), "Xorshift32M256" );
static_assert( IsUniformRandomBitGenerator< Xorshift64 >(), "Xorshift64" );
static_assert( IsUniformRandomBitGenerator< Xorshift64M128 >(), "Xorshift64M128" );
static_assert( IsUniformRandomBitGenerator< Xorshift64M256 >(), "Xorshift64M256" );
return EXIT_SUCCESS;
}
}
#endif // PRNG_HPP_
{
"version": "2.0.0",
"tasks": [
{
"type": "cppbuild",
"label": "C/C++: g++-14 build active file",
"command": "/usr/bin/g++-14",
"args": [
"-fdiagnostics-color=always",
"-std=gnu++2b",
"-mavx2",
"-O0",
"-g",
"-rdynamic",
"${file}",
"-o",
"${fileDirname}/${fileBasenameNoExtension}"
],
"options": {
"cwd": "${fileDirname}"
},
"problemMatcher": [
"$gcc"
],
"group": {
"kind": "build",
"isDefault": true
},
"detail": "Task generated by Debugger."
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment