Skip to content

Instantly share code, notes, and snippets.

@orlp
Last active November 19, 2020 02:55
Show Gist options
  • Save orlp/e9b31d3397a7dd3e34d6bc862ce3b88d to your computer and use it in GitHub Desktop.
Save orlp/e9b31d3397a7dd3e34d6bc862ce3b88d to your computer and use it in GitHub Desktop.
A succint numpy implementation of Vose's Alias Method, an O(n) construction O(1) sampling time algorithm for a fixed, weighted distribution.
import numpy as np
from collections import deque
class VoseAliasMethod:
# Vose's Alias Method as described at https://www.keithschwarz.com/darts-dice-coins/.
def __init__(self, weights):
pmf = weights / np.sum(weights)
self.n = pmf.shape[0]
self.prob = np.zeros(self.n, dtype=np.float64)
self.alias = np.zeros(self.n, dtype=np.int64)
p = pmf * self.n
small = deque(np.nonzero(p < 1.0)[0])
large = deque(np.nonzero(p >= 1.0)[0])
while small and large:
l = small.popleft()
g = large.popleft()
self.prob[l] = p[l]
self.alias[l] = g
p[g] = (p[g] + p[l]) - 1.0
(small if p[g] < 1.0 else large).append(g)
self.prob[small] = 1.0
self.prob[large] = 1.0
def sample(self, size):
ri = np.random.randint(0, self.n, size=size)
rx = np.random.uniform(size=size)
return np.where(rx < self.prob[ri], ri, self.alias[ri])
if __name__ == "__main__":
# Example.
weights = [1, 3, 6]
sampler = VoseAliasMethod(weights)
print(sampler.sample(8))
# [2 1 0 2 1 2 1 2]
print(sampler.sample((5, 5)))
# [[1 0 0 2 2]
# [0 2 1 2 2]
# [2 1 2 1 1]
# [0 2 1 2 2]
# [2 2 2 1 1]]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment