Skip to content

Instantly share code, notes, and snippets.

@ttsugriy
Created September 9, 2023 23:25
Show Gist options
  • Save ttsugriy/9dece74b8ddee0802d834dd779d38bd1 to your computer and use it in GitHub Desktop.
Save ttsugriy/9dece74b8ddee0802d834dd779d38bd1 to your computer and use it in GitHub Desktop.
Reservoir sampling benchmark
#include <algorithm>
#include <cmath>
#include <iostream>
#include <iterator>
#include <random>
#include <vector>
class AlgorithmL {
private:
std::vector<double> reservoir;
long long counter;
long long next;
double w;
std::mt19937_64 rng{};
std::uniform_real_distribution<double> unif{0, 1};
public:
AlgorithmL(int capacity) : reservoir(capacity), counter{0}, next{capacity} {
w = exp(log(std::rand() / (double)RAND_MAX) / reservoir.size());
skip();
}
void add(double value) {
if (counter < reservoir.size()) {
reservoir[counter] = value;
} else if (counter == next) {
int index = std::rand() % reservoir.size();
reservoir[index] = value;
skip();
}
++counter;
}
auto sample() { return reservoir; }
private:
void skip() {
next += (long long)(log(unif(rng)) / log(1 - w)) + 1;
w *= exp(log(unif(rng)) / reservoir.size());
}
};
class AlgorithmR {
private:
std::vector<double> reservoir;
long long counter;
public:
AlgorithmR(int capacity) : reservoir(capacity), counter{0} {}
void add(double value) {
if (counter < reservoir.size()) {
reservoir[counter] = value;
} else {
long long replacementIndex = std::rand() % counter;
if (replacementIndex < reservoir.size()) {
reservoir[replacementIndex] = value;
}
}
++counter;
}
auto sample() { return reservoir; }
};
static void BH_algoR(benchmark::State& state) {
auto k = static_cast<int>(state.range(0));
AlgorithmR algoR{k};
double d = 0.0;
for (auto _ : state) {
algoR.add(d++);
}
}
BENCHMARK(BH_algoR)->Arg(5)->Arg(20)->Arg(200)->Arg(1000)->Arg(10000);
static void BH_algoL(benchmark::State& state) {
auto k = static_cast<int>(state.range(0));
AlgorithmL algoL{k};
double d = 0.0;
for (auto _ : state) {
algoL.add(d++);
}
}
BENCHMARK(BH_algoL)->Arg(5)->Arg(20)->Arg(200)->Arg(1000)->Arg(10000);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment