Skip to content

Instantly share code, notes, and snippets.

@DSCF-1224
Last active November 26, 2024 22:08
Show Gist options
  • Save DSCF-1224/02acde2d131f76d6a4fd95811c1e1e01 to your computer and use it in GitHub Desktop.
Save DSCF-1224/02acde2d131f76d6a4fd95811c1e1e01 to your computer and use it in GitHub Desktop.
PRNG using SIMD (AVX2)
#include <array>
#include <random>
#include <immintrin.h>
#ifndef PRNG_HPP_
#define PRNG_HPP_
namespace prng
{
template <typename T>
class Prng
{
public:
using result_type = T;
[[nodiscard]]
static constexpr result_type min(void) noexcept;
[[nodiscard]]
static constexpr result_type max(void) noexcept;
};
template <typename T>
inline constexpr Prng<T>::result_type Prng<T>::min() noexcept
{
return std::numeric_limits<T>::min();
}
template <typename T>
inline constexpr Prng<T>::result_type Prng<T>::max() noexcept
{
return std::numeric_limits<T>::max();
}
template <typename T>
class PrngSingle : public Prng<T>
{
private:
Prng<T>::result_type state_;
public:
[[nodiscard]]
constexpr explicit PrngSingle(void) noexcept;
[[nodiscard]]
constexpr explicit PrngSingle(const Prng<T>::result_type seed) noexcept;
friend class SplitMix32;
friend class SplitMix64;
friend class Xorshift32;
friend class Xorshift64;
};
template <typename T>
constexpr PrngSingle<T>::PrngSingle(void) noexcept
{
std::random_device seed_gen;
this->state_ = seed_gen();
}
template <typename T>
constexpr PrngSingle<T>::PrngSingle(const Prng<T>::result_type seed) noexcept
{
this->state_ = seed;
}
using Prng32Single = PrngSingle<std::uint32_t>;
using Prng64Single = PrngSingle<std::uint64_t>;
template <typename T, int num_sequences>
class PrngM128 : public Prng<T>
{
public:
using state_type = __m128i;
static const int kNumSequences = num_sequences;
[[nodiscard]]
constexpr explicit PrngM128(void) noexcept;
[[nodiscard]]
constexpr explicit PrngM128(const state_type &seed) noexcept;
virtual constexpr Prng<T>::result_type operator()() noexcept final;
virtual constexpr state_type Generate(void) noexcept;
private:
state_type state_;
friend class SplitMix32M128;
friend class Xorshift32M128;
friend class SplitMix64M128;
friend class Xorshift64M128;
};
template <typename T, int num_sequences>
constexpr PrngM128<T, num_sequences>::PrngM128(void) noexcept
{
std::random_device seed_gen;
for (size_t i = 0; i < kNumSequences; i++)
{
this->state_[i] = seed_gen();
}
}
template <typename T, int num_sequences>
constexpr PrngM128<T, num_sequences>::PrngM128(const state_type &seed) noexcept
{
this->state_ = seed;
}
template <typename T, int num_sequences>
constexpr Prng<T>::result_type PrngM128<T, num_sequences>::operator()() noexcept
{
alignas(sizeof(16)) T harvest[kNumSequences];
_mm_store_si128((state_type*)harvest, this->Generate());
return harvest[0];
}
template <typename T, int num_sequences>
constexpr PrngM128<T, num_sequences>::state_type PrngM128<T, num_sequences>::Generate(void) noexcept
{
return _mm_setzero_si128();
}
using Prng32M128 = PrngM128<std::uint32_t, 4>;
using Prng64M128 = PrngM128<std::uint64_t, 2>;
template <typename T, int num_sequences>
class PrngM256 : public Prng<T>
{
public:
using state_type = __m256i;
static const int kNumSequences = num_sequences;
[[nodiscard]]
constexpr explicit PrngM256(void) noexcept;
[[nodiscard]]
constexpr explicit PrngM256(const state_type &seed) noexcept;
private:
state_type state_;
friend class SplitMix32M256;
friend class Xorshift32M256;
friend class Xorshift64M256;
};
template <typename T, int num_sequences>
constexpr PrngM256<T, num_sequences>::PrngM256(void) noexcept
{
std::random_device seed_gen;
for (size_t i = 0; i < kNumSequences; i++)
{
this->state_[i] = seed_gen();
}
}
template <typename T, int num_sequences>
constexpr PrngM256<T, num_sequences>::PrngM256(const state_type &seed) noexcept
{
this->state_ = seed;
}
using Prng32M256 = PrngM256<std::uint32_t, 8>;
using Prng64M256 = PrngM256<std::uint64_t, 4>;
// reference implementation: https://github.com/bryc/code/blob/master/jshash/PRNGs.md#splitmix32
class SplitMix32 : public Prng32Single
{
public:
static const result_type kMultiplier1 = static_cast<result_type>(0x85ebca6b);
static const result_type kMultiplier2 = static_cast<result_type>(0xc2b2ae35);
static const result_type kOffset = static_cast<result_type>(0x9e3779b9);
static const int kShift1 = 16;
static const int kShift2 = 13;
static const int kShift3 = 16;
[[nodiscard]]
constexpr SplitMix32(void) noexcept;
[[nodiscard]]
constexpr SplitMix32(const result_type seed) noexcept;
constexpr result_type operator()() noexcept;
};
constexpr SplitMix32::SplitMix32(void) noexcept : Prng32Single()
{
/* NOTHING TO DO HERE */
}
constexpr SplitMix32::SplitMix32(const result_type seed) noexcept : Prng32Single(seed)
{
/* NOTHING TO DO HERE */
}
constexpr SplitMix32::result_type SplitMix32::operator()() noexcept
{
this->state_ += kOffset;
result_type x = this->state_;
x = (x ^ (x >> kShift1)) * kMultiplier1;
x = (x ^ (x >> kShift2)) * kMultiplier2;
x ^= x >> kShift3;
return x;
}
// reference implementation: https://github.com/bryc/code/blob/master/jshash/PRNGs.md#splitmix32
class SplitMix32M128 : public Prng32M128
{
private:
static const state_type kMultiplier1;
static const state_type kMultiplier2;
static const state_type kOffset;
public:
[[nodiscard]]
constexpr explicit SplitMix32M128(void) noexcept;
[[nodiscard]]
constexpr explicit SplitMix32M128(const state_type &seed) noexcept;
constexpr state_type Generate(void) noexcept override;
};
const Prng32M128::state_type SplitMix32M128::kMultiplier1 = _mm_set1_epi32(SplitMix32::kMultiplier1);
const Prng32M128::state_type SplitMix32M128::kMultiplier2 = _mm_set1_epi32(SplitMix32::kMultiplier2);
const Prng32M128::state_type SplitMix32M128::kOffset = _mm_set1_epi32(SplitMix32::kOffset);
constexpr SplitMix32M128::SplitMix32M128(void) noexcept : Prng32M128()
{
/* NOTHING TO DO HERE */
}
constexpr SplitMix32M128::SplitMix32M128(const state_type &seed) noexcept : Prng32M128(seed)
{
/* NOTHING TO DO HERE */
}
constexpr SplitMix32M128::state_type SplitMix32M128::Generate(void) noexcept
{
this->state_ = _mm_add_epi32(this->state_, kOffset);
state_type x = this->state_;
x = _mm_mullo_epi32(_mm_xor_si128(x, _mm_srli_epi32(x, SplitMix32::kShift1)), kMultiplier1);
x = _mm_mullo_epi32(_mm_xor_si128(x, _mm_srli_epi32(x, SplitMix32::kShift2)), kMultiplier2);
x = _mm_xor_si128(x, _mm_srli_epi32(x, SplitMix32::kShift3));
return x;
}
// reference implementation: https://github.com/bryc/code/blob/master/jshash/PRNGs.md#splitmix32
class SplitMix32M256 : public Prng32M256
{
private:
static const state_type kMultiplier1;
static const state_type kMultiplier2;
static const state_type kOffset;
public:
[[nodiscard]]
constexpr explicit SplitMix32M256(void) noexcept;
[[nodiscard]]
constexpr explicit SplitMix32M256(const state_type &seed) noexcept;
constexpr result_type operator()() noexcept;
constexpr state_type Generate(void) noexcept;
};
const Prng32M256::state_type SplitMix32M256::kMultiplier1 = _mm256_set1_epi32(SplitMix32::kMultiplier1);
const Prng32M256::state_type SplitMix32M256::kMultiplier2 = _mm256_set1_epi32(SplitMix32::kMultiplier2);
const Prng32M256::state_type SplitMix32M256::kOffset = _mm256_set1_epi32(SplitMix32::kOffset);
constexpr SplitMix32M256::SplitMix32M256(void) noexcept : Prng32M256()
{
/* NOTHING TO DO HERE */
}
constexpr SplitMix32M256::SplitMix32M256(const state_type &seed) noexcept : Prng32M256(seed)
{
/* NOTHING TO DO HERE */
}
constexpr SplitMix32M256::result_type SplitMix32M256::operator()() noexcept
{
alignas(sizeof(result_type)) result_type harvest[kNumSequences];
_mm256_store_si256((state_type*)harvest, this->Generate());
return harvest[0];
}
constexpr SplitMix32M256::state_type SplitMix32M256::Generate(void) noexcept
{
this->state_ = _mm256_add_epi32(this->state_, kOffset);
state_type x = this->state_;
x = _mm256_mullo_epi32(_mm256_xor_si256(x, _mm256_srli_epi32(x, SplitMix32::kShift1)), kMultiplier1);
x = _mm256_mullo_epi32(_mm256_xor_si256(x, _mm256_srli_epi32(x, SplitMix32::kShift2)), kMultiplier2);
x = _mm256_xor_si256(x, _mm256_srli_epi32(x, SplitMix32::kShift3));
return x;
}
// reference implementation: https://xorshift.di.unimi.it/splitmix64.c
class SplitMix64 : public Prng64Single
{
public:
static const result_type kMultiplier1 = static_cast<result_type>(0xbf58476d1ce4e5b9);
static const result_type kMultiplier2 = static_cast<result_type>(0x94d049bb133111eb);
static const result_type kOffset = static_cast<result_type>(0x9e3779b97f4a7c15);
static const int kShift1 = 30;
static const int kShift2 = 27;
static const int kShift3 = 31;
[[nodiscard]]
constexpr SplitMix64(void) noexcept;
[[nodiscard]]
constexpr SplitMix64(const result_type seed) noexcept;
constexpr result_type operator()() noexcept;
};
constexpr SplitMix64::SplitMix64(void) noexcept : Prng64Single()
{
/* NOTHING TO DO HERE */
}
constexpr SplitMix64::SplitMix64(const result_type seed) noexcept : Prng64Single(seed)
{
/* NOTHING TO DO HERE */
}
constexpr SplitMix64::result_type SplitMix64::operator()() noexcept
{
this->state_ += kOffset;
result_type x = this->state_;
x = (x ^ (x >> kShift1)) * kMultiplier1;
x = (x ^ (x >> kShift2)) * kMultiplier2;
x ^= x >> kShift3;
return x;
}
// reference implementation: https://www.jstatsoft.org/article/view/v008i14
class Xorshift32 : public Prng32Single
{
public:
static const int kShift1 = 13;
static const int kShift2 = 17;
static const int kShift3 = 5;
[[nodiscard]]
constexpr explicit Xorshift32(void) noexcept;
[[nodiscard]]
constexpr explicit Xorshift32(const result_type seed) noexcept;
constexpr result_type operator()() noexcept;
};
constexpr Xorshift32::Xorshift32(void) noexcept : Prng32Single()
{
/* NOTHING TO DO HERE */
}
constexpr Xorshift32::Xorshift32(const result_type seed) noexcept : Prng32Single(seed)
{
/* NOTHING TO DO HERE */
}
constexpr Xorshift32::result_type Xorshift32::operator()() noexcept
{
result_type x = this->state_;
x ^= (x << kShift1);
x ^= (x >> kShift2);
x ^= (x << kShift3);
this->state_ = x;
return x;
}
// reference implementation: https://www.jstatsoft.org/article/view/v008i14
class Xorshift32M128 : public Prng32M128
{
public:
[[nodiscard]]
constexpr explicit Xorshift32M128(void) noexcept;
[[nodiscard]]
constexpr explicit Xorshift32M128(const state_type &seed) noexcept;
constexpr state_type Generate(void) noexcept override;
};
constexpr Xorshift32M128::Xorshift32M128(void) noexcept : Prng32M128()
{
/* NOTHING TO DO HERE */
}
constexpr Xorshift32M128::Xorshift32M128(const state_type &seed) noexcept : Prng32M128(seed)
{
/* NOTHING TO DO HERE */
}
constexpr Xorshift32M128::state_type Xorshift32M128::Generate(void) noexcept
{
state_type x = this->state_;
x = _mm_xor_si128(x, _mm_slli_epi32(x, Xorshift32::kShift1));
x = _mm_xor_si128(x, _mm_srli_epi32(x, Xorshift32::kShift2));
x = _mm_xor_si128(x, _mm_slli_epi32(x, Xorshift32::kShift3));
this->state_ = x;
return x;
}
// reference implementation: https://www.jstatsoft.org/article/view/v008i14
class Xorshift32M256 : public Prng32M256
{
public:
[[nodiscard]]
constexpr explicit Xorshift32M256(void) noexcept;
[[nodiscard]]
constexpr explicit Xorshift32M256(const state_type &seed) noexcept;
constexpr result_type operator()() noexcept;
constexpr state_type Generate(void) noexcept;
};
constexpr Xorshift32M256::Xorshift32M256(void) noexcept : Prng32M256()
{
/* NOTHING TO DO HERE */
}
constexpr Xorshift32M256::Xorshift32M256(const state_type &seed) noexcept : Prng32M256(seed)
{
/* NOTHING TO DO HERE */
}
constexpr Xorshift32M256::result_type Xorshift32M256::operator()() noexcept
{
alignas(sizeof(result_type)) result_type harvest[kNumSequences];
_mm256_store_si256((state_type*)harvest, this->Generate());
return harvest[0];
}
constexpr Xorshift32M256::state_type Xorshift32M256::Generate(void) noexcept
{
state_type x = this->state_;
x = _mm256_xor_si256(x, _mm256_slli_epi32(x, Xorshift32::kShift1));
x = _mm256_xor_si256(x, _mm256_srli_epi32(x, Xorshift32::kShift2));
x = _mm256_xor_si256(x, _mm256_slli_epi32(x, Xorshift32::kShift3));
this->state_ = x;
return x;
}
// reference implementation: https://www.jstatsoft.org/article/view/v008i14
class Xorshift64 : public Prng64Single
{
public:
static const int kShift1 = 13;
static const int kShift2 = 7;
static const int kShift3 = 17;
[[nodiscard]]
constexpr explicit Xorshift64(void) noexcept;
[[nodiscard]]
constexpr explicit Xorshift64(const result_type seed) noexcept;
constexpr result_type operator()() noexcept;
};
constexpr Xorshift64::Xorshift64(void) noexcept : Prng64Single()
{
/* NOTHING TO DO HERE */
}
constexpr Xorshift64::Xorshift64(const result_type seed) noexcept : Prng64Single(seed)
{
/* NOTHING TO DO HERE */
}
constexpr Xorshift64::result_type Xorshift64::operator()() noexcept
{
result_type x = this->state_;
x ^= (x << kShift1);
x ^= (x >> kShift2);
x ^= (x << kShift3);
this->state_ = x;
return x;
}
// reference implementation: https://www.jstatsoft.org/article/view/v008i14
class Xorshift64M128 : public Prng64M128
{
public:
[[nodiscard]]
constexpr explicit Xorshift64M128(void) noexcept;
[[nodiscard]]
constexpr explicit Xorshift64M128(const state_type &seed) noexcept;
constexpr state_type Generate(void) noexcept override;
};
constexpr Xorshift64M128::Xorshift64M128(void) noexcept : Prng64M128()
{
/* NOTHING TO DO HERE */
}
constexpr Xorshift64M128::Xorshift64M128(const state_type &seed) noexcept : Prng64M128(seed)
{
/* NOTHING TO DO HERE */
}
constexpr Xorshift64M128::state_type Xorshift64M128::Generate(void) noexcept
{
state_type x = this->state_;
x = _mm_xor_si128(x, _mm_slli_epi64(x, Xorshift64::kShift1));
x = _mm_xor_si128(x, _mm_srli_epi64(x, Xorshift64::kShift2));
x = _mm_xor_si128(x, _mm_slli_epi64(x, Xorshift64::kShift3));
this->state_ = x;
return x;
}
// reference implementation: https://www.jstatsoft.org/article/view/v008i14
class Xorshift64M256 : public Prng64M256
{
public:
[[nodiscard]]
constexpr explicit Xorshift64M256(void) noexcept;
[[nodiscard]]
constexpr explicit Xorshift64M256(const state_type &seed) noexcept;
constexpr result_type operator()() noexcept;
constexpr state_type Generate(void) noexcept;
};
constexpr Xorshift64M256::Xorshift64M256(void) noexcept : Prng64M256()
{
/* NOTHING TO DO HERE */
}
constexpr Xorshift64M256::Xorshift64M256(const state_type &seed) noexcept : Prng64M256(seed)
{
/* NOTHING TO DO HERE */
}
constexpr Xorshift64M256::result_type Xorshift64M256::operator()() noexcept
{
alignas(sizeof(result_type)) result_type harvest[kNumSequences];
_mm256_store_si256((state_type*)harvest, this->Generate());
return harvest[0];
}
constexpr Xorshift64M256::state_type Xorshift64M256::Generate(void) noexcept
{
state_type x = this->state_;
x = _mm256_xor_si256(x, _mm256_slli_epi64(x, Xorshift64::kShift1));
x = _mm256_xor_si256(x, _mm256_srli_epi64(x, Xorshift64::kShift2));
x = _mm256_xor_si256(x, _mm256_slli_epi64(x, Xorshift64::kShift3));
this->state_ = x;
return x;
}
template<std::uniform_random_bit_generator G>
static constexpr bool IsUniformRandomBitGenerator(void)
{
return true;
}
template<typename G>
static constexpr bool IsUniformRandomBitGenerator(void)
{
return false;
}
int IsUniformRandomBitGenerator(void)
{
static_assert( IsUniformRandomBitGenerator< std::mt19937 >(), "std::mt19937" );
static_assert( IsUniformRandomBitGenerator< std::mt19937_64 >(), "std::mt19937_64" );
static_assert( IsUniformRandomBitGenerator< SplitMix32 >(), "SplitMix32" );
static_assert( IsUniformRandomBitGenerator< SplitMix32M128 >(), "SplitMix32M128" );
static_assert( IsUniformRandomBitGenerator< SplitMix32M256 >(), "SplitMix32M256" );
static_assert( IsUniformRandomBitGenerator< Xorshift32 >(), "Xorshift32" );
static_assert( IsUniformRandomBitGenerator< Xorshift32M128 >(), "Xorshift32M128" );
static_assert( IsUniformRandomBitGenerator< Xorshift32M256 >(), "Xorshift32M256" );
static_assert( IsUniformRandomBitGenerator< Xorshift64 >(), "Xorshift64" );
static_assert( IsUniformRandomBitGenerator< Xorshift64M128 >(), "Xorshift64M128" );
static_assert( IsUniformRandomBitGenerator< Xorshift64M256 >(), "Xorshift64M256" );
return EXIT_SUCCESS;
}
}
#endif // PRNG_HPP_
{
"version": "2.0.0",
"tasks": [
{
"type": "cppbuild",
"label": "C/C++: g++-14 build active file",
"command": "/usr/bin/g++-14",
"args": [
"-fdiagnostics-color=always",
"-std=gnu++2b",
"-mavx2",
"-O0",
"-g",
"-rdynamic",
"${file}",
"-o",
"${fileDirname}/${fileBasenameNoExtension}"
],
"options": {
"cwd": "${fileDirname}"
},
"problemMatcher": [
"$gcc"
],
"group": {
"kind": "build",
"isDefault": true
},
"detail": "Task generated by Debugger."
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment