Last active
August 22, 2021 09:24
-
-
Save teryror/3d52a64a7081257503dd0787a47c3f21 to your computer and use it in GitHub Desktop.
Const evaluatable Rust implementation of Vose's Alias Method
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/// Const evaluatable Rust implementation of Vose's Alias Method, as described | |
/// by Keith Schwarz at https://www.keithschwarz.com/darts-dice-coins/ | |
/// | |
/// In brief, this is an O(n) precomputation, which allows sampling an arbitrary | |
/// finite probability distribution in O(1) time, by first simulating a fair | |
/// n-sided die, followed by a biased coin. | |
/// | |
/// Because floating point arithmetic cannot be used in const functions, this is | |
/// built to operate on integer weights, rather than precomputed probabilities. | |
/// | |
/// Where the standard Alias Method scales the probabilities by a factor of n | |
/// and uses 1 as a cutoff to partition them into large and small probabilites, | |
/// this finds the least common multiple of n and the total weight, scales up | |
/// the weights to match it, and uses the LCM divided by n as the threshold. | |
/// | |
/// Unlike the original method, this approach is perfectly exact and numerically | |
/// stable; I only switch to fixed point arithmetic for the final probability | |
/// calculation, which introduces negligible rounding errors. | |
/// | |
/// The implementation could be made much more elegant as more language features | |
/// become available in const fns, most notably for loops, panics, and arguments | |
/// of mutable reference types. | |
use rand::{Rng, thread_rng}; | |
use rand::distributions::Distribution; | |
const fn gcd(mut a: u32, mut b: u32) -> u32 { | |
while b != 0 { | |
let t = b; | |
b = a % b; | |
a = t; | |
} | |
a | |
} | |
const fn lcm(a: u32, b: u32) -> u32 { | |
(a * b) / gcd(a, b) | |
} | |
pub struct AliasTable<const N: usize> { | |
prob: [u32; N], | |
alias: [usize; N], | |
} | |
impl<const N: usize> AliasTable<N> { | |
pub const fn new(mut weights: [u32; N]) -> Self { | |
let mut prob = [0; N]; | |
let mut alias = [0; N]; | |
// Vec and similar data structures cannot be used in const fns, because | |
// only other const fns may be called, which may not take &mut arguments. | |
// So we have to use an ad-hoc, inline implementation for the work lists. | |
// | |
// These could have capacity N - 1, except the current state of const | |
// generics doesn't allow that. | |
let mut small = [0; N]; | |
let mut small_count = 0; | |
let mut large = [0; N]; | |
let mut large_count = 0; | |
let mut total_weight = 0; | |
let mut i = 0; | |
while i < N { | |
// TODO(const_panic): assert_ne!(weights[i], 0, "Weight at position {} is zero!", i); | |
let _ = 1 / weights[i]; | |
total_weight += weights[i]; | |
i += 1; | |
} | |
let rescaled_total = lcm(total_weight, N as u32); | |
let weight_factor = rescaled_total / total_weight; | |
let mut i = 0; | |
while i < N { | |
weights[i] *= weight_factor; | |
i += 1; | |
} | |
let weight_threshold = rescaled_total / (N as u32); | |
let mut i = 0; | |
while i < N { | |
if weights[i] < weight_threshold { | |
small[small_count] = i; | |
small_count += 1; | |
} else { | |
large[large_count] = i; | |
large_count += 1; | |
} | |
i += 1; | |
} | |
while small_count > 0 && large_count > 0 { | |
small_count -= 1; | |
let l = small[small_count]; | |
large_count -= 1; | |
let g = large[large_count]; | |
prob[l] = (((weights[l] as u64) << 32) / (weight_threshold as u64)) as u32; | |
alias[l] = g; | |
weights[g] -= weight_threshold - weights[l]; | |
if weights[g] < weight_threshold { | |
small[small_count] = g; | |
small_count += 1; | |
} else { | |
large[large_count] = g; | |
large_count += 1; | |
} | |
} | |
while large_count > 0 { | |
large_count -= 1; | |
let g = large[large_count]; | |
prob[g] = u32::MAX; | |
alias[g] = g; | |
} | |
// TODO(const_panic): assert_eq!(small_count, 0); | |
// This should only be possible with floating point arithmetic | |
// due to numerical instability; we should be fine without this loop: | |
while small_count > 0 { | |
small_count -= 1; | |
let l = small[small_count]; | |
prob[l] = u32::MAX; | |
alias[l] = l; | |
} | |
AliasTable { prob, alias } | |
} | |
} | |
impl<const N: usize> Distribution<usize> for AliasTable<N> { | |
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize { | |
let i = rng.gen_range(0..N); | |
let x = rng.gen::<u32>(); | |
if x < self.prob[i] { | |
i | |
} else { | |
self.alias[i] | |
} | |
} | |
} | |
pub struct PopulationTable<T, const N: usize> { | |
items: [T; N], | |
distr: AliasTable<N>, | |
} | |
impl<T, const N: usize> PopulationTable<T, N> { | |
pub const fn new(items: [T; N], weights: [u32; N]) -> Self { | |
PopulationTable { items, distr: AliasTable::new(weights) } | |
} | |
} | |
impl<T, const N: usize> Distribution<T> for PopulationTable<T, N> where T: Clone { | |
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> T { | |
let idx = rng.sample(&self.distr); | |
self.items[idx].clone() | |
} | |
} | |
// TODO(macro_metavar_expr): The expansion of this macro will contain the weight array | |
// twice to automatically determine its length, which is literally redundant work. | |
macro_rules! population_table { | |
($v:vis $name:ident : $t:ty = [ $( $weight:expr => $item:expr ),+ $(,)? ] ) => { | |
$v const $name: PopulationTable<$t, {[$($weight),*].len()}> = PopulationTable::new( | |
[$($item),*], [$($weight),*] | |
); | |
} | |
} | |
population_table! { | |
NAME_TABLE: &'static str = [ | |
2 => "Alice", | |
1 => "Bob", | |
3 => "Charlie", | |
] | |
} | |
pub fn main() { | |
assert_eq!(NAME_TABLE.distr.alias, [0, 2, 2]); | |
assert_eq!(NAME_TABLE.distr.prob, [u32::MAX, 1 << 31, u32::MAX]); | |
let mut rng = thread_rng(); | |
let name = rng.sample(&NAME_TABLE); | |
println!("Hello, {}!", name); | |
} | |
#[cfg(test)] | |
mod test { | |
use super::*; | |
#[test] | |
fn greatest_common_divisor() { | |
assert_eq!(gcd(2, 4), 2); | |
assert_eq!(gcd(2, 5), 1); | |
assert_eq!(gcd(252, 105), 21); | |
} | |
#[test] | |
fn least_common_multiple() { | |
assert_eq!(lcm(2, 4), 4); | |
assert_eq!(lcm(2, 5), 10); | |
assert_eq!(lcm(18, 12), 36); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment