Last active
June 14, 2021 21:53
-
-
Save madmann91/680cc1cd4a7bfb9687859f139c8aa058 to your computer and use it in GitHub Desktop.
Sampling from a discrete probability distribution
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <tgmath.h> | |
#include <stddef.h> | |
#include <stdbool.h> | |
#include <stdint.h> | |
#include <inttypes.h> | |
#include <string.h> | |
#include <stdio.h> | |
// Discrete probability distribution sampling according to the method named | |
// "Squared Histogram" in "Fast Generation of Discrete Random Variables", | |
// by G. Marsaglia et. al | |
struct dist { | |
float* v; | |
size_t* k; | |
size_t n; | |
}; | |
static inline uint32_t minstd(uint32_t* state) { | |
uint32_t prev = *state; | |
*state = (prev * 48721) % 2147483647; | |
return prev; | |
} | |
static inline float randf(uint32_t* state) { | |
return minstd(state) / (float)2147483646; | |
} | |
static void print_dist(const struct dist* dist) { | |
for (size_t i = 0; i < dist->n; ++i) | |
printf("%zu ", dist->k[i]); | |
printf("\n"); | |
for (size_t i = 0; i < dist->n; ++i) | |
printf("%f ", dist->v[i]); | |
printf("\n"); | |
for (size_t i = 0; i < dist->n; ++i) | |
printf("%zu ", i); | |
printf("\n"); | |
} | |
static inline uint32_t sample_dist(const struct dist* dist, uint32_t* state) { | |
float u = randf(state); | |
size_t j = u * dist->n; | |
if (j >= dist->n) j = dist->n - 1; | |
return u < dist->v[j] ? j : dist->k[j]; | |
} | |
static void gen_dist(struct dist* dist, float* p) { | |
float a = 1.0f / dist->n; | |
for (size_t i = 0; i < dist->n; ++i) { | |
dist->k[i] = i; | |
dist->v[i] = (i + 1) * a; | |
} | |
for (size_t i = 0; i < dist->n - 1; ++i) { | |
size_t min = 0, max = 0; | |
for (size_t j = 1; j < dist->n; ++j) { | |
if (p[min] > p[j]) min = j; | |
if (p[max] < p[j]) max = j; | |
} | |
if (min == max) | |
break; | |
dist->k[min] = max; | |
dist->v[min] = min * a + p[min]; | |
p[max] -= a - p[min]; | |
p[min] = a; | |
} | |
} | |
static void print_hist(const size_t* bins, size_t n) { | |
size_t total = 0; | |
for (size_t i = 0; i < n; ++i) | |
total += bins[i]; | |
for (size_t i = total * 9 / 10; i > 0; i -= total / 10) { | |
for (size_t j = 0; j < n; ++j) { | |
if (bins[j] > i) | |
printf(" ## "); | |
else | |
printf(" "); | |
} | |
printf("\n"); | |
} | |
for (size_t j = 0; j < n; ++j) | |
printf("----"); | |
printf("\n"); | |
for (size_t j = 0; j < n; ++j) | |
printf(" %2zu ", j); | |
printf("\n"); | |
for (size_t j = 0; j < n; ++j) | |
printf(" %2zu%%", (bins[j] * 100) / total); | |
printf("\n"); | |
} | |
int main() { | |
float p[] = { 0.1f, 0.4f, 0.3f, 0.2f }; | |
const size_t n = sizeof(p) / sizeof(p[0]); | |
float v[n]; | |
size_t k[n]; | |
struct dist dist = { .v = v, .k = k, .n = n }; | |
gen_dist(&dist, p); | |
printf("Squared histogram data:\n"); | |
print_dist(&dist); | |
size_t bins[n]; | |
memset(bins, 0, sizeof(size_t) * n); | |
uint32_t state = 1; | |
for (size_t i = 0; i < 10000; ++i) | |
bins[sample_dist(&dist, &state)]++; | |
printf("\nHistogram from sampling discrete distribution:\n"); | |
print_hist(bins, n); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment