Skip to content

Instantly share code, notes, and snippets.

@madmann91
Last active June 14, 2021 21:53
Show Gist options
  • Save madmann91/680cc1cd4a7bfb9687859f139c8aa058 to your computer and use it in GitHub Desktop.
Save madmann91/680cc1cd4a7bfb9687859f139c8aa058 to your computer and use it in GitHub Desktop.
Sampling from a discrete probability distribution
#include <tgmath.h>
#include <stddef.h>
#include <stdbool.h>
#include <stdint.h>
#include <inttypes.h>
#include <string.h>
#include <stdio.h>
// Discrete probability distribution sampling according to the method named
// "Squared Histogram" in "Fast Generation of Discrete Random Variables",
// by G. Marsaglia et. al
struct dist {
float* v;
size_t* k;
size_t n;
};
static inline uint32_t minstd(uint32_t* state) {
uint32_t prev = *state;
*state = (prev * 48721) % 2147483647;
return prev;
}
static inline float randf(uint32_t* state) {
return minstd(state) / (float)2147483646;
}
static void print_dist(const struct dist* dist) {
for (size_t i = 0; i < dist->n; ++i)
printf("%zu ", dist->k[i]);
printf("\n");
for (size_t i = 0; i < dist->n; ++i)
printf("%f ", dist->v[i]);
printf("\n");
for (size_t i = 0; i < dist->n; ++i)
printf("%zu ", i);
printf("\n");
}
static inline uint32_t sample_dist(const struct dist* dist, uint32_t* state) {
float u = randf(state);
size_t j = u * dist->n;
if (j >= dist->n) j = dist->n - 1;
return u < dist->v[j] ? j : dist->k[j];
}
static void gen_dist(struct dist* dist, float* p) {
float a = 1.0f / dist->n;
for (size_t i = 0; i < dist->n; ++i) {
dist->k[i] = i;
dist->v[i] = (i + 1) * a;
}
for (size_t i = 0; i < dist->n - 1; ++i) {
size_t min = 0, max = 0;
for (size_t j = 1; j < dist->n; ++j) {
if (p[min] > p[j]) min = j;
if (p[max] < p[j]) max = j;
}
if (min == max)
break;
dist->k[min] = max;
dist->v[min] = min * a + p[min];
p[max] -= a - p[min];
p[min] = a;
}
}
static void print_hist(const size_t* bins, size_t n) {
size_t total = 0;
for (size_t i = 0; i < n; ++i)
total += bins[i];
for (size_t i = total * 9 / 10; i > 0; i -= total / 10) {
for (size_t j = 0; j < n; ++j) {
if (bins[j] > i)
printf(" ## ");
else
printf(" ");
}
printf("\n");
}
for (size_t j = 0; j < n; ++j)
printf("----");
printf("\n");
for (size_t j = 0; j < n; ++j)
printf(" %2zu ", j);
printf("\n");
for (size_t j = 0; j < n; ++j)
printf(" %2zu%%", (bins[j] * 100) / total);
printf("\n");
}
int main() {
float p[] = { 0.1f, 0.4f, 0.3f, 0.2f };
const size_t n = sizeof(p) / sizeof(p[0]);
float v[n];
size_t k[n];
struct dist dist = { .v = v, .k = k, .n = n };
gen_dist(&dist, p);
printf("Squared histogram data:\n");
print_dist(&dist);
size_t bins[n];
memset(bins, 0, sizeof(size_t) * n);
uint32_t state = 1;
for (size_t i = 0; i < 10000; ++i)
bins[sample_dist(&dist, &state)]++;
printf("\nHistogram from sampling discrete distribution:\n");
print_hist(bins, n);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment