Created
November 10, 2017 15:21
-
-
Save emillundh/db7c19c3b17ec24b90c2ab47f622c50b to your computer and use it in GitHub Desktop.
Timing some algorithms for weighted choices
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import timeit | |
import random | |
import numpy as np | |
x = range(10) | |
y = [0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09] | |
ysum = sum(y) | |
y = [yi / ysum for yi in y] | |
# ===== Pure python methods | |
def python_1(x, y): | |
""" | |
https://stackoverflow.com/questions/3679694/a-weighted-version-of-random-choice | |
""" | |
space = {} | |
current = 0 | |
for choice, weight in zip(x,y): | |
if weight > 0: | |
space[current] = choice | |
current += weight | |
rand = random.uniform(0, current) | |
for key in sorted(list(space.keys()) + [current]): | |
if rand < key: | |
return choice | |
choice = space[key] | |
return None | |
def python_2(seq, weights): | |
""" | |
https://scaron.info/blog/python-weighted-choice.html | |
""" | |
assert len(weights) == len(seq) | |
assert abs(1. - sum(weights)) < 1e-6 | |
x = random.random() | |
for i, _ in enumerate(seq): | |
if x <= weights[i]: | |
return seq[i] | |
x -= weights[i] | |
# ======= random.choices in python3 | |
def python3(x, y): | |
return random.choices(x, y) | |
# ========= methods using numpy | |
def numpy1(x, y): | |
# numpy, but with python vectors as input | |
choice = np.random.choice(x, p=y) | |
return | |
def numpy2(x, y): | |
# numpy, explicit cast from python vectors | |
x = np.fromiter(x, dtype='int32') | |
y = np.fromiter(y, dtype='float32') | |
choice = np.random.choice(x, p=y) | |
return | |
def numpy3(objects, weights): | |
""" | |
https://stackoverflow.com/questions/10803135/weighted-choice-short-and-simple | |
""" | |
cs = np.cumsum(weights) #An array of the weights, cumulatively summed. | |
idx = sum(cs < np.random.rand()) #Find the index of the first weight over a random value. | |
return objects[idx] | |
def numpy4(x, weights): | |
""" | |
https://glowingpython.blogspot.se/2012/09/weighted-random-choice.html | |
""" | |
t = np.cumsum(weights) | |
s = sum(weights) | |
return x[np.searchsorted(t,np.random.rand()*s)] | |
for prob in (python_1, python_2, python3, numpy1, numpy2, numpy3, numpy4): | |
print(prob.__name__, timeit.timeit('{}(x, y)'.format(prob.__name__), globals=globals(), number=100000)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment