Created
October 6, 2017 00:15
-
-
Save Deamon5550/5a14a8fce6028fe34f1311c3978a892e to your computer and use it in GitHub Desktop.
SIMD Mersenne Twister
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <chrono> | |
#include <ctime> | |
#include <random> | |
#include <ratio> | |
#include <cmath> | |
#include <cinttypes> | |
#include "boost/random.hpp" | |
#include <x86intrin.h> | |
using namespace std; | |
using namespace std::chrono; | |
union m128_to_v4{ | |
__m128i m; | |
uint32_t v[4]; | |
}; | |
class CustomTwister { | |
public: | |
uint32_t x[625]; | |
int32_t i = 624 + 1; | |
CustomTwister() {seed(5489u);} | |
CustomTwister(uint32_t s) {seed(s);} | |
void seed(uint32_t s) { | |
x[0] = s; | |
for(int i = 1; i < 624; i++) { | |
x[i] = 1812433253 * (x[i - 1] ^ (x[i - 1] >> 30)) + i; | |
} | |
x[624] = x[0]; | |
} | |
static inline __m128i muly(const __m128i &a, const __m128i &b) | |
{ | |
#ifdef __SSE4_1__ // modern CPU - use SSE 4.1 | |
return _mm_mullo_epi32(a, b); | |
#else // old CPU - use SSE 2 | |
__m128i tmp1 = _mm_mul_epu32(a,b); /* mul 2,0*/ | |
__m128i tmp2 = _mm_mul_epu32( _mm_srli_si128(a,4), _mm_srli_si128(b,4)); /* mul 3,1 */ | |
return _mm_unpacklo_epi32(_mm_shuffle_epi32(tmp1, _MM_SHUFFLE (0,0,2,0)), _mm_shuffle_epi32(tmp2, _MM_SHUFFLE (0,0,2,0))); /* shuffle results to [63..0] and pack */ | |
#endif | |
} | |
void twist() { | |
__m128i x00; | |
__m128i x01; | |
__m128i x02; | |
__m128i x03; | |
m128_to_v4 xA; | |
for(int i = 0; i < 224; i+=4) { | |
x00 = _mm_lddqu_si128((const __m128i*) (x + i)); | |
x01 = _mm_set1_epi32(0x80000000u); | |
x00 = _mm_and_si128(x00, x01); | |
x02 = _mm_lddqu_si128((const __m128i*) (x + i + 1)); | |
x03 = _mm_set1_epi32(0x7FFFFFFFu); | |
x02 = _mm_and_si128(x02, x03); | |
x00 = _mm_add_epi32(x00, x02); | |
x01 = _mm_srai_epi32(x00, 1); | |
x02 = _mm_set1_epi32(1); | |
x03 = _mm_set1_epi32(0x9908B0DFu); | |
x00 = _mm_and_si128(x01, x02); | |
x02 = muly(x00, x03); | |
x03 = _mm_xor_si128(x01, x02); | |
x01 = _mm_lddqu_si128((const __m128i*) (x + i + 397)); | |
x02 = _mm_xor_si128(x03, x01); | |
_mm_storeu_si128((__m128i*) (x + i), x02); | |
} | |
{ | |
x00 = _mm_lddqu_si128((const __m128i*) (x + 224)); | |
x01 = _mm_set1_epi32(0x80000000u); | |
x00 = _mm_and_si128(x00, x01); | |
x02 = _mm_lddqu_si128((const __m128i*) (x + 225)); | |
x03 = _mm_set1_epi32(0x7FFFFFFFu); | |
x02 = _mm_and_si128(x02, x03); | |
x00 = _mm_add_epi32(x00, x02); | |
x01 = _mm_srai_epi32(x00, 1); | |
x02 = _mm_set1_epi32(1); | |
x03 = _mm_set1_epi32(0x9908B0DFu); | |
x00 = _mm_and_si128(x01, x02); | |
x02 = muly(x00, x03); | |
xA.m = _mm_xor_si128(x01, x02); | |
x[224] = x[621] ^ xA.v[0]; | |
x[225] = x[622] ^ xA.v[1]; | |
x[226] = x[623] ^ xA.v[2]; | |
x[227] = x[0] ^ xA.v[3]; | |
} | |
for(int i = 228; i < 624; i+=4) { | |
x00 = _mm_lddqu_si128((const __m128i*) (x + i)); | |
x01 = _mm_set1_epi32(0x80000000u); | |
x00 = _mm_and_si128(x00, x01); | |
x02 = _mm_lddqu_si128((const __m128i*) (x + i + 1)); | |
x03 = _mm_set1_epi32(0x7FFFFFFFu); | |
x02 = _mm_and_si128(x02, x03); | |
x00 = _mm_add_epi32(x00, x02); | |
x01 = _mm_srai_epi32(x00, 1); | |
x02 = _mm_set1_epi32(1); | |
x03 = _mm_set1_epi32(0x9908B0DFu); | |
x00 = _mm_and_si128(x01, x02); | |
x02 = muly(x00, x03); | |
x03 = _mm_xor_si128(x01, x02); | |
x01 = _mm_lddqu_si128((const __m128i*) (x + i - 227)); | |
x02 = _mm_xor_si128(x03, x01); | |
_mm_storeu_si128((__m128i*) (x + i), x02); | |
} | |
x[624] = x[0]; | |
/* | |
for(int i = 0; i < 624; i++) { | |
uint32_t x0 = (x[i] & 0x80000000) + (x[(i+1) % 624] & 0x7FFFFFFF); | |
uint32_t xA = x0 >> 1; | |
if (x0 % 2 != 0) { | |
xA = xA ^ 0x9908B0DF; | |
} | |
x[i] = x[(i + 397) % 624] ^ xA; | |
} | |
*/ | |
i = 0; | |
} | |
uint32_t operator()() { | |
if (i >= 624) { | |
twist(); | |
} | |
uint32_t z = x[i++]; | |
z ^= (z >> 11); | |
z ^= ((z << 7) & 0x9d2c5680); | |
z ^= ((z << 15) & 0xefc6000); | |
z ^= (z >> 18); | |
return z; | |
} | |
}; | |
double random_01(CustomTwister& twister) { | |
return twister() / (double) 0xFFFFFFFFu; | |
} | |
int main() { | |
std::random_device rd; | |
unsigned seed = rd(); | |
std::mt19937 std_generator(seed); | |
std::uniform_real_distribution<> std_01dist(0.0, 1.0); | |
boost::mt19937 boost_generator; | |
boost_generator.seed(seed); | |
boost::uniform_01<> boost_01dist; | |
const unsigned num_trials = 1000000; | |
for(unsigned counter = 0; counter < num_trials / 10; ++counter) { | |
double num = std_01dist(std_generator); | |
} | |
double std_total = 0.0; | |
high_resolution_clock::time_point std_start = high_resolution_clock::now(); | |
for(unsigned counter = 0; counter < num_trials; ++counter) { | |
double num = std_01dist(std_generator); | |
std_total += num; | |
} | |
// std_total should be roughly close to num_trials / 2 | |
high_resolution_clock::time_point std_end = high_resolution_clock::now(); | |
double std_duration = duration_cast<std::chrono::microseconds>(std_end - std_start).count() / 1000.0; | |
for(unsigned counter = 0; counter < num_trials / 10; ++counter) { | |
double num = boost_01dist(boost_generator); | |
} | |
double boost_total = 0.0; | |
high_resolution_clock::time_point boost_start = high_resolution_clock::now(); | |
for(unsigned counter = 0; counter < num_trials; ++counter) { | |
double num = boost_01dist(boost_generator); | |
boost_total += num; | |
} | |
// boost_total should be roughly close to num_trials / 2 | |
high_resolution_clock::time_point boost_end = high_resolution_clock::now(); | |
double boost_duration = duration_cast<std::chrono::microseconds>(boost_end - boost_start).count() / 1000.0; | |
CustomTwister twister(seed); | |
for(unsigned counter = 0; counter < num_trials / 10; ++counter) { | |
double num = random_01(twister); | |
} | |
double custom_total = 0.0; | |
high_resolution_clock::time_point custom_start = high_resolution_clock::now(); | |
for(unsigned counter = 0; counter < num_trials; ++counter) { | |
double num = random_01(twister); | |
custom_total += num; | |
} | |
// boost_total should be roughly close to num_trials / 2 | |
high_resolution_clock::time_point custom_end = high_resolution_clock::now(); | |
double custom_duration = duration_cast<std::chrono::microseconds>(custom_end - custom_start).count() / 1000.0; | |
printf(" std took %.3fms with a delta of %.2f\n", std_duration, abs(std_total - 500000)); | |
printf(" boost took %.3fms with a delta of %.2f\n", boost_duration, abs(boost_total - 500000)); | |
printf("custom took %.3fms with a delta of %.2f\n", custom_duration, abs(custom_total - 500000)); | |
printf("\n"); | |
if (std_duration > boost_duration && custom_duration > boost_duration) { | |
printf("boost is %.1fx faster than std\n", std_duration / boost_duration); | |
printf("boost is %.1fx faster than custom\n", custom_duration / boost_duration); | |
} else if (std_duration > custom_duration) { | |
printf("custom is %.1fx faster than std\n", std_duration / custom_duration); | |
printf("custom is %.1fx faster than boost\n", boost_duration / custom_duration); | |
} else { | |
printf("std is %.1fx faster than boost\n", boost_duration / std_duration); | |
printf("std is %.1fx faster than custom\n", custom_duration / std_duration); | |
} | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment