Created
March 30, 2024 22:36
-
-
Save hariedo/5923c5d2bc4d811a28b376abec3ba96d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- python -*- | |
''' | |
SYNOPSIS | |
>>> import weighted | |
>>> choices = { 'heads': 2, 'tails': 1 } | |
>>> for i in xrange(100): | |
... print weighted.weighted(choices) | |
This example would print a hundred coin flips, but the heads | |
will appear twice as often as tails, statistically speaking. | |
AUTHOR | |
Ed Halley ([email protected]) 13 December 2007 | |
''' | |
import random | |
def pare(choices): | |
'''Given a dict of key:weight pairs, returns the total of all weights.''' | |
return sum(choices.values()) | |
def weighted(choices, total=0): | |
'''Given a dict of key:weight pairs, chooses a key at random. | |
The dict values are non-negative numerical weights. Keys with higher | |
values are chosen more often than keys with lower values. | |
If the caller knows the total of all weights, it can be given to | |
avoid recalculating it internally on each call. If the given total | |
is not accurate, a key may be chosen with a poorly-shaped | |
distribution. | |
''' | |
if not total: | |
total = pare(choices) | |
mark = random.random()*total | |
keys = choices.keys() | |
for i in xrange(len(keys)): | |
span = choices[keys[i]] | |
if span > mark: | |
return keys[i] | |
mark -= span | |
# should not reach here if total is accurate | |
return random.choice(keys) | |
if __name__ == '__main__': | |
print 'Testing weighted random distribution...' | |
choices = { 'ten': 10, | |
'eight': 8, | |
'seven': 7, | |
'six': 6, | |
'four': 4, | |
'one': 1 } | |
tally = { } | |
reps = 10000000 | |
total = sum(choices.values()) | |
for i in xrange(reps): | |
if 0 == i % 500000: print reps-i, '\r', | |
x = weighted(choices, total) | |
try: tally[x] += 1 | |
except: tally[x] = 1 | |
print reps, 'reps', '==', sum(tally.values()), 'picks' | |
print "%s\t%9s %9s %s" % ('key:','picks:', | |
'fair:','bias (ideal 100%):') | |
for key in choices: | |
expected = choices[key]*reps/total | |
print "%s\t%9d %9d (%g%%)" % (key, tally[key], | |
expected, tally[key]*100./expected) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment