Skip to content

Instantly share code, notes, and snippets.

@clausecker
Last active November 27, 2024 16:12
Show Gist options
  • Save clausecker/1a4e0d0358f26ef5a6fd01d74b187799 to your computer and use it in GitHub Desktop.
Save clausecker/1a4e0d0358f26ef5a6fd01d74b187799 to your computer and use it in GitHub Desktop.
Bias test for arc4random_uniform()
/*
* biastest.c -- bias test for arc4ranom_uniform()
* written by Robert Clausecker <[email protected]>
*/
#include <assert.h>
#include <limits.h>
#include <math.h>
#include <signal.h>
#include <stdatomic.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
static void collect_sample(unsigned char *, long long, uint32_t);
static void analyze_sample(const unsigned char *, long long, uint32_t);
static atomic_bool complete = false;
static long long tries = 5ULL << 32;
static atomic_llong tries_done = 0;
static void
usage(const char *argv0)
{
fprintf(stderr, "usage: %s [-n tries] [-t threshold]\n", argv0);
exit(EXIT_FAILURE);
}
int
main(int argc, char *argv[])
{
uint32_t threshold = 3UL << 30;
int ch;
unsigned char *sample;
while (ch = getopt(argc, argv, "n:t:"), ch != EOF)
switch (ch) {
case 'n':
tries = atoll(optarg);
break;
case 't':
threshold = (uint32_t)atoll(optarg);
break;
default:
usage(argv[0]);
}
if (optind != argc)
usage(argv[0]);
if (threshold == 0) {
fprintf(stderr, "threshold must be between 1 and %lu\n", (unsigned long)UINT32_MAX);
exit(EXIT_FAILURE);
}
sample = calloc(threshold, 1);
if (sample == NULL) {
perror("calloc(threshold, 1)");
return (EXIT_FAILURE);
}
collect_sample(sample, tries, threshold);
analyze_sample(sample, tries, threshold);
}
static void
progress(int signo)
{
(void)signo;
if (!complete) {
fprintf(stderr, "\r%10lld of %10lld samples taken (%5.2f%% done)",
tries_done, tries, (tries_done * 100.0) / tries);
signal(SIGALRM, progress);
alarm(1);
}
}
static void
collect_sample(unsigned char *sample, long long tries, uint32_t threshold)
{
long long i;
uint32_t x;
bool overflowed = false;
progress(SIGALRM);
for (i = 0; i < tries; i++) {
x = arc4random_uniform(threshold);
tries_done++;
assert(x < threshold);
if (sample[x] == UCHAR_MAX) {
if (!overflowed) {
printf("sample table overflow, results will be incorrect\n");
overflowed = true;
}
} else
sample[x]++;
}
progress(SIGALRM);
complete = true;
fputc('\n', stderr);
}
static void
analyze_sample(const unsigned char *sample, long long tries, uint32_t threshold)
{
double discrepancy, average, variance, total;
long long histogram[UCHAR_MAX + 1] = { 0 }, sum, n, median;
uint32_t i, i_min, i_max;
int min, max;
printf("distribution properties:\n");
/* find median, average, deviation, smallest, and largest bucket */
total = 0.0;
for (i = 0; i < threshold; i++) {
histogram[sample[i]]++;
total += (double)i * sample[i];
}
average = total / tries;
variance = 0.0;
median = threshold;
n = 0;
i_min = 0;
i_max = 0;
min = sample[i_min];
max = sample[i_max];
for (i = 0; i < threshold; i++) {
discrepancy = i - average;
variance += sample[i] * discrepancy * discrepancy;
n += sample[i];
if (median == threshold && n > tries / 2)
median = i;
if (sample[i] < min) {
i_min = i;
min = sample[i_min];
} else if (sample[i] > max) {
i_max = i;
max = sample[i_max];
}
}
variance /= tries;
assert(median < threshold);
printf("\tthreshold: %lu\n", (unsigned long)threshold);
printf("\tobservations: %lld\n", tries);
printf("\tleast common: %lu (%d observations)\n", (unsigned long)i_min, min);
printf("\tmost common: %lu (%d observations)\n", (unsigned long)i_max, max);
printf("\tmedian: %lld (expected %lu)\n", median, (unsigned long)threshold / 2);
printf("\taverage: %f (expected %f)\n", average, 0.5 * (threshold - 1));
printf("\tdeviation: %f (expected %f)\n\n", sqrt(variance),
sqrt(((double)threshold * threshold - 1.0) / 12));
/* build histogram and analyze it */
printf("sample properties:\n");
/* find median, average, and deviation */
average = (double)tries / threshold;
variance = 0.0;
for (i = 0; i < UCHAR_MAX; i++) {
discrepancy = i - average;
variance += histogram[i] * discrepancy * discrepancy;
}
variance /= threshold;
n = 0;
median = UCHAR_MAX + 1;
for (i = 0; i <= UCHAR_MAX; i++) {
n += histogram[i];
if (n >= threshold / 2) {
median = i;
break;
}
}
assert(median <= UCHAR_MAX); /* unreachable */
printf("\tmedian: %lld\n", median);
printf("\taverage: %f\n", average);
printf("\tdeviation: %f (expected %f)\n\n", sqrt(variance), sqrt(average * (1.0 - 1.0 / threshold)));
printf("histogram:\n");
for (i = 0; i < 256; i++)
if (histogram[i] != 0)
printf("\t%3d:\t%lld\n", (int)i, histogram[i]);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment