|
#ifdef __aarch64__ |
|
|
|
#include <arm_neon.h> |
|
|
|
#include <chrono> |
|
#include <cstdint> |
|
#include <cstring> |
|
#include <functional> |
|
#include <ratio> |
|
#include <vector> |
|
|
|
class MiniLcg { |
|
public: |
|
MiniLcg() { |
|
std::uint8_t val = 123; |
|
for (std::uint8_t& s : state_) { |
|
s = val; |
|
val = val * 13 + 47; |
|
} |
|
} |
|
|
|
void Generate1(std::uint8_t* dst) { |
|
state_[0] = state_[0] * kMulConstant + kAddConstant; |
|
dst[0] = state_[0]; |
|
} |
|
|
|
void Generate16(std::uint8_t* dst) { |
|
uint8x16_t val = vld1q_u8(state_); |
|
uint8x16_t kmul = vdupq_n_u8(kMulConstant); |
|
uint8x16_t kadd = vdupq_n_u8(kAddConstant); |
|
val = vmlaq_u8(kadd, kmul, val); |
|
vst1q_u8(state_, val); |
|
vst1q_u8(dst, val); |
|
} |
|
|
|
void Generate64(std::uint8_t* dst) { |
|
uint8x16_t val0 = vld1q_u8(state_ + 0); |
|
uint8x16_t val1 = vld1q_u8(state_ + 16); |
|
uint8x16_t val2 = vld1q_u8(state_ + 32); |
|
uint8x16_t val3 = vld1q_u8(state_ + 48); |
|
uint8x16_t kmul = vdupq_n_u8(kMulConstant); |
|
uint8x16_t kadd = vdupq_n_u8(kAddConstant); |
|
val0 = vmlaq_u8(kadd, kmul, val0); |
|
val1 = vmlaq_u8(kadd, kmul, val1); |
|
val2 = vmlaq_u8(kadd, kmul, val2); |
|
val3 = vmlaq_u8(kadd, kmul, val3); |
|
vst1q_u8(state_ + 0, val0); |
|
vst1q_u8(state_ + 16, val1); |
|
vst1q_u8(state_ + 32, val2); |
|
vst1q_u8(state_ + 48, val3); |
|
vst1q_u8(dst + 0, val0); |
|
vst1q_u8(dst + 16, val1); |
|
vst1q_u8(dst + 32, val2); |
|
vst1q_u8(dst + 48, val3); |
|
} |
|
|
|
void Generate(std::uint8_t* dst, int size) { |
|
assert((size % 16) == 0); |
|
int i = 0; |
|
uint8x16_t val0 = vld1q_u8(state_ + 0); |
|
uint8x16_t val1 = vld1q_u8(state_ + 16); |
|
uint8x16_t val2 = vld1q_u8(state_ + 32); |
|
uint8x16_t val3 = vld1q_u8(state_ + 48); |
|
uint8x16_t kmul = vdupq_n_u8(kMulConstant); |
|
uint8x16_t kadd = vdupq_n_u8(kAddConstant); |
|
for (; i <= size - 64; i += 64) { |
|
val0 = vmlaq_u8(kadd, kmul, val0); |
|
val1 = vmlaq_u8(kadd, kmul, val1); |
|
val2 = vmlaq_u8(kadd, kmul, val2); |
|
val3 = vmlaq_u8(kadd, kmul, val3); |
|
vst1q_u8(dst + i + 0, val0); |
|
vst1q_u8(dst + i + 16, val1); |
|
vst1q_u8(dst + i + 32, val2); |
|
vst1q_u8(dst + i + 48, val3); |
|
} |
|
for (; i <= size - 16; i += 16) { |
|
val0 = vmlaq_u8(kadd, kmul, val0); |
|
vst1q_u8(dst + i, val0); |
|
} |
|
vst1q_u8(state_ + 0, val0); |
|
vst1q_u8(state_ + 16, val1); |
|
vst1q_u8(state_ + 32, val2); |
|
vst1q_u8(state_ + 48, val3); |
|
} |
|
|
|
private: |
|
static constexpr std::uint8_t kMulConstant = 37; |
|
static constexpr std::uint8_t kAddConstant = 47; |
|
|
|
std::uint8_t state_[64]; |
|
}; |
|
|
|
void print_array(const std::uint8_t* buf, int size) { |
|
printf("["); |
|
for (int i = 0; i < size; i++) { |
|
printf(" %d", buf[i]); |
|
} |
|
printf(" ]"); |
|
} |
|
|
|
float ToFloatSeconds(const std::chrono::steady_clock::duration& duration) { |
|
return std::chrono::duration_cast<std::chrono::duration<float>>(duration) |
|
.count(); |
|
} |
|
|
|
void Warmup(MiniLcg* lcg, std::uint8_t* buf, int size) { |
|
const auto& start = std::chrono::steady_clock::now(); |
|
while (true) { |
|
for (int i = 0; i <= size - 16; i += 16) { |
|
lcg->Generate16(buf + i); |
|
} |
|
const auto& end = std::chrono::steady_clock::now(); |
|
float elapsed = ToFloatSeconds(end - start); |
|
if (elapsed > 1) { |
|
return; |
|
} |
|
} |
|
} |
|
|
|
int Period(const std::uint8_t* buf, int size) { |
|
for (int p = 1; p < size; p++) { |
|
bool is_a_period = true; |
|
for (int i = 0; i < size - p; i++) { |
|
if (buf[i] != buf[i + p]) { |
|
is_a_period = false; |
|
break; |
|
} |
|
} |
|
if (is_a_period) return p; |
|
} |
|
return size; |
|
} |
|
|
|
void Benchmark(const char* name, bool sanity_check, |
|
const std::vector<std::uint8_t>& buf, |
|
const std::function<void()>& func) { |
|
const auto& start = std::chrono::steady_clock::now(); |
|
func(); |
|
const auto& end = std::chrono::steady_clock::now(); |
|
float elapsed = ToFloatSeconds(end - start); |
|
if (sanity_check) { |
|
printf("%s has period %d. Sample values: ", name, |
|
Period(buf.data(), buf.size())); |
|
print_array(buf.data(), 256); |
|
printf("\n"); |
|
} else { |
|
printf("%s generated %lu bytes in %.3g seconds, throughput %.3g GB/s\n", |
|
name, buf.size(), elapsed, 1e-9 * buf.size() / elapsed); |
|
} |
|
} |
|
|
|
int main(int argc, char* argv[]) { |
|
bool sanity_check = argc >= 2 && !strcmp(argv[1], "--sanity-check"); |
|
int size = 1 << 24; |
|
std::vector<std::uint8_t> buf(size); |
|
MiniLcg lcg; |
|
Warmup(&lcg, buf.data(), size); |
|
Benchmark("Generate1", sanity_check, buf, [&]() { |
|
for (int i = 0; i < size; i++) { |
|
lcg.Generate1(&buf[i]); |
|
} |
|
}); |
|
Benchmark("Generate16", sanity_check, buf, [&]() { |
|
for (int i = 0; i <= size - 16; i += 16) { |
|
lcg.Generate16(&buf[i]); |
|
} |
|
}); |
|
Benchmark("Generate64", sanity_check, buf, [&]() { |
|
for (int i = 0; i <= size - 64; i += 64) { |
|
lcg.Generate64(&buf[i]); |
|
} |
|
}); |
|
Benchmark("Generate", sanity_check, buf, |
|
[&]() { lcg.Generate(buf.data(), buf.size()); }); |
|
} |
|
|
|
#endif |