Skip to content

Instantly share code, notes, and snippets.

@bjacob
Last active November 26, 2020 16:54
Show Gist options
  • Save bjacob/7d635b91acd02559d73a6d159fe9cfbe to your computer and use it in GitHub Desktop.
Save bjacob/7d635b91acd02559d73a6d159fe9cfbe to your computer and use it in GitHub Desktop.

A shot at the cheapest possible PRNG on ARM NEON

This is just a 8bit PRNG. Randomness is as bad as it could possibly get. The period is 256 in the scalar implementation. The implementation is widened 64x for SIMD purposes, and that has the unintentional side benefit of multiplying the period by that amount too.

Sample output on Pixel4 (pinned on the biggest CPU core by taskset 80):

$ adb shell taskset 80 /data/local/tmp/minilcg
Generate1 generated 16777216 bytes in 0.0361 seconds, throughput 0.465 GB/s
Generate16 generated 16777216 bytes in 0.00416 seconds, throughput 4.04 GB/s
Generate64 generated 16777216 bytes in 0.00293 seconds, throughput 5.72 GB/s
Generate generated 16777216 bytes in 0.000851 seconds, throughput 19.7 GB/s

At nearly 20 GB/s this is close to the DRAM bandwidth on this device.

The code is just naive NEON intrinsics, haven't even looked at the generated code.

If performance of the Generate16 function matters (and to a lesser extent Generate64) then, since this function spends more instructions outside of the core computation than inside it (just one instruction!), the largest performance gains would be made by tailoring it to exactly the interface that its users really need (e.g. maybe the users are happy to address our internal state bytes in-place rather than have us copy them to the user-provided destination buffer?)

#ifdef __aarch64__
#include <arm_neon.h>
#include <chrono>
#include <cstdint>
#include <cstring>
#include <functional>
#include <ratio>
#include <vector>
class MiniLcg {
public:
MiniLcg() {
std::uint8_t val = 123;
for (std::uint8_t& s : state_) {
s = val;
val = val * 13 + 47;
}
}
void Generate1(std::uint8_t* dst) {
state_[0] = state_[0] * kMulConstant + kAddConstant;
dst[0] = state_[0];
}
void Generate16(std::uint8_t* dst) {
uint8x16_t val = vld1q_u8(state_);
uint8x16_t kmul = vdupq_n_u8(kMulConstant);
uint8x16_t kadd = vdupq_n_u8(kAddConstant);
val = vmlaq_u8(kadd, kmul, val);
vst1q_u8(state_, val);
vst1q_u8(dst, val);
}
void Generate64(std::uint8_t* dst) {
uint8x16_t val0 = vld1q_u8(state_ + 0);
uint8x16_t val1 = vld1q_u8(state_ + 16);
uint8x16_t val2 = vld1q_u8(state_ + 32);
uint8x16_t val3 = vld1q_u8(state_ + 48);
uint8x16_t kmul = vdupq_n_u8(kMulConstant);
uint8x16_t kadd = vdupq_n_u8(kAddConstant);
val0 = vmlaq_u8(kadd, kmul, val0);
val1 = vmlaq_u8(kadd, kmul, val1);
val2 = vmlaq_u8(kadd, kmul, val2);
val3 = vmlaq_u8(kadd, kmul, val3);
vst1q_u8(state_ + 0, val0);
vst1q_u8(state_ + 16, val1);
vst1q_u8(state_ + 32, val2);
vst1q_u8(state_ + 48, val3);
vst1q_u8(dst + 0, val0);
vst1q_u8(dst + 16, val1);
vst1q_u8(dst + 32, val2);
vst1q_u8(dst + 48, val3);
}
void Generate(std::uint8_t* dst, int size) {
assert((size % 16) == 0);
int i = 0;
uint8x16_t val0 = vld1q_u8(state_ + 0);
uint8x16_t val1 = vld1q_u8(state_ + 16);
uint8x16_t val2 = vld1q_u8(state_ + 32);
uint8x16_t val3 = vld1q_u8(state_ + 48);
uint8x16_t kmul = vdupq_n_u8(kMulConstant);
uint8x16_t kadd = vdupq_n_u8(kAddConstant);
for (; i <= size - 64; i += 64) {
val0 = vmlaq_u8(kadd, kmul, val0);
val1 = vmlaq_u8(kadd, kmul, val1);
val2 = vmlaq_u8(kadd, kmul, val2);
val3 = vmlaq_u8(kadd, kmul, val3);
vst1q_u8(dst + i + 0, val0);
vst1q_u8(dst + i + 16, val1);
vst1q_u8(dst + i + 32, val2);
vst1q_u8(dst + i + 48, val3);
}
for (; i <= size - 16; i += 16) {
val0 = vmlaq_u8(kadd, kmul, val0);
vst1q_u8(dst + i, val0);
}
vst1q_u8(state_ + 0, val0);
vst1q_u8(state_ + 16, val1);
vst1q_u8(state_ + 32, val2);
vst1q_u8(state_ + 48, val3);
}
private:
static constexpr std::uint8_t kMulConstant = 37;
static constexpr std::uint8_t kAddConstant = 47;
std::uint8_t state_[64];
};
void print_array(const std::uint8_t* buf, int size) {
printf("[");
for (int i = 0; i < size; i++) {
printf(" %d", buf[i]);
}
printf(" ]");
}
float ToFloatSeconds(const std::chrono::steady_clock::duration& duration) {
return std::chrono::duration_cast<std::chrono::duration<float>>(duration)
.count();
}
void Warmup(MiniLcg* lcg, std::uint8_t* buf, int size) {
const auto& start = std::chrono::steady_clock::now();
while (true) {
for (int i = 0; i <= size - 16; i += 16) {
lcg->Generate16(buf + i);
}
const auto& end = std::chrono::steady_clock::now();
float elapsed = ToFloatSeconds(end - start);
if (elapsed > 1) {
return;
}
}
}
int Period(const std::uint8_t* buf, int size) {
for (int p = 1; p < size; p++) {
bool is_a_period = true;
for (int i = 0; i < size - p; i++) {
if (buf[i] != buf[i + p]) {
is_a_period = false;
break;
}
}
if (is_a_period) return p;
}
return size;
}
void Benchmark(const char* name, bool sanity_check,
const std::vector<std::uint8_t>& buf,
const std::function<void()>& func) {
const auto& start = std::chrono::steady_clock::now();
func();
const auto& end = std::chrono::steady_clock::now();
float elapsed = ToFloatSeconds(end - start);
if (sanity_check) {
printf("%s has period %d. Sample values: ", name,
Period(buf.data(), buf.size()));
print_array(buf.data(), 256);
printf("\n");
} else {
printf("%s generated %lu bytes in %.3g seconds, throughput %.3g GB/s\n",
name, buf.size(), elapsed, 1e-9 * buf.size() / elapsed);
}
}
int main(int argc, char* argv[]) {
bool sanity_check = argc >= 2 && !strcmp(argv[1], "--sanity-check");
int size = 1 << 24;
std::vector<std::uint8_t> buf(size);
MiniLcg lcg;
Warmup(&lcg, buf.data(), size);
Benchmark("Generate1", sanity_check, buf, [&]() {
for (int i = 0; i < size; i++) {
lcg.Generate1(&buf[i]);
}
});
Benchmark("Generate16", sanity_check, buf, [&]() {
for (int i = 0; i <= size - 16; i += 16) {
lcg.Generate16(&buf[i]);
}
});
Benchmark("Generate64", sanity_check, buf, [&]() {
for (int i = 0; i <= size - 64; i += 64) {
lcg.Generate64(&buf[i]);
}
});
Benchmark("Generate", sanity_check, buf,
[&]() { lcg.Generate(buf.data(), buf.size()); });
}
#endif
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment