Created
March 1, 2021 11:11
-
-
Save xxzefgh/d14857b1663fefe41315f046d319eb87 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* Reimplemented python's `random.choices` function | |
* https://github.com/python/cpython/blob/3.9/Lib/random.py#L473 | |
*/ | |
export function randomChoice<T>(population: T[], weights: number[], random: () => number): T { | |
let n = population.length; | |
let cum_weights = []; | |
for (let i = 0; i < weights.length; i++) { | |
cum_weights[i] = i === 0 ? weights[i] : weights[i] + cum_weights[i - 1]; | |
} | |
if (cum_weights.length !== n) { | |
throw new RangeError('The number of weights does not match the population'); | |
} | |
let total = cum_weights[cum_weights.length - 1]; | |
if (total <= 0) { | |
throw new RangeError('Total of weights must be greater than zero'); | |
} | |
if (!Number.isFinite(total)) { | |
throw new RangeError('Total of weights must be finite'); | |
} | |
let hi = n - 1; | |
return population[bisect(cum_weights, random() * total, 0, hi)]; | |
} | |
function bisect<T>(a: T[], x: T, lo: number, hi: number): number { | |
if (lo < 0) { | |
throw new RangeError('lo must be non-negative'); | |
} | |
while (lo < hi) { | |
let mid = Math.floor((lo + hi) / 2); | |
if (x < a[mid]) { | |
hi = mid; | |
} else { | |
lo = mid + 1; | |
} | |
} | |
return lo; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment