Created
November 7, 2012 00:23
-
-
Save bagrow/4028656 to your computer and use it in GitHub Desktop.
Sample non-uniformly from a set of choices
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# weighted_choice.py | |
# Jim Bagrow | |
# Last Modified: 2012-11-06 | |
import random, bisect | |
import pylab | |
def weighted_choice(choices, num_draws=1): | |
"""Make biased draws (w/o replacement) from choices. Choices should be a | |
list of 2-ples: choices = [(a, w_a), (b, w_b), ...], where w_x is the | |
"weight" of choosing value x. These weights are automatically normalized, | |
so they need not sum to one, but they must be non-negative. Returns list of | |
draws. | |
Example: | |
>>> choices = [ ("H", 0.8), ("T", 0.2) ] # an unfair coin | |
>>> for flip in range(100): | |
>>> print weighted_choice(choices)[0] | |
""" | |
values, weights = zip(*choices) | |
total = 0 | |
cumulative_weights = [] | |
for w in weights: | |
total += w | |
cumulative_weights.append( total ) | |
if total == 0: | |
return random.sample(values,num_draws) | |
if num_draws >= len(cumulative_weights): | |
return list(values) | |
draws = set() # won't have duplicates, so not REALLY with replacement... | |
attempt = 0 | |
while len(draws) < num_draws: | |
x = random.random() * total | |
i = bisect.bisect(cumulative_weights, x) | |
draws.add( values[i] ) | |
if attempt > 5000: | |
return list(draws) | |
attempt += 1 | |
return list(draws) | |
if __name__ == '__main__': | |
# choose i with probability ~ 1/i: | |
choices = [] | |
for i in range(1, 15): | |
choices.append( (i, 1.0/i) ) | |
# do the sampling, count draws: | |
draw2count = {} | |
for sample in xrange(10000): | |
draw = weighted_choice(choices)[0] | |
try: | |
draw2count[draw] += 1 | |
except KeyError: | |
draw2count[draw] = 1 | |
# plot the distribution: | |
draws = sorted(draw2count.keys()) | |
counts = [ draw2count[d] for d in draws ] | |
pylab.hold(True) | |
pylab.loglog(draws,counts, 'o-') | |
pylab.plot( draws, [5000.0/d for d in draws], 'r' ) # should have slope -1 | |
pylab.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment