Created
July 10, 2017 02:09
-
-
Save Mr4k/eabaca318499bd54e5e18431efbc6622 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from bintrees import FastAVLTree as AVLTree | |
import random | |
import time | |
def weightedRandom(weights): | |
""" | |
Draw from a general discrete distribution. | |
:param weights: A dictionary of weights that must sum to one. | |
:return: A random sample from it the distribution defined by the weights. | |
""" | |
#generate a uniform random number from 0 - 1 | |
remainder = random.random() | |
for weight in weights.iteritems(): | |
value, color = weight | |
remainder -= value | |
if remainder <= 0: | |
return color | |
def partitionWeights(weights): | |
""" | |
The preprocessing step. | |
:param weights: A dictionary of weights that must sum to one. | |
:return: A partition used to draw quickly from the distribution. | |
""" | |
boxes = [] | |
numWeights = len(weights) | |
# We use a AVLTree to make our pull/push operations O(log n) | |
tree = AVLTree(weights) | |
for i in xrange(numWeights): | |
smallestValue, smallestColor = tree.pop_min() # O(log n) | |
overfill = 1.0 / numWeights - smallestValue | |
if overfill > 0.00001: | |
largestValue, largestColor = tree.pop_max() # O(log n) | |
largestValue -= overfill | |
if largestValue > 0.00001: | |
tree.insert(largestValue, largestColor) # O(log n) | |
boxes.append((smallestValue, smallestColor, largestColor)) | |
else: | |
boxes.append((smallestValue, smallestColor, "none")) | |
return boxes | |
def drawFromPartition(partition): | |
""" | |
The draw step. | |
:param partition: partition A partition of a distribution into boxes. | |
:return: A sample from the distribution represented by the partition. | |
""" | |
numBoxes = len(partition) | |
i = random.randint(0, numBoxes - 1) | |
value, color1, color2 = partition[i] | |
if random.random() / numBoxes <= value: | |
return color1 | |
else: | |
return color2 | |
#compare in a speed test | |
weights = {} | |
numWeights = 1000 | |
nweights = np.random.rand(numWeights, 1) | |
nweights /= sum(nweights) | |
for i in xrange(numWeights): | |
weights[float(nweights[i])] = i | |
start = time.time() | |
for i in xrange(100000): | |
weightedRandom(weights) | |
end = time.time() | |
print end - start | |
start = time.time() | |
partition = partitionWeights(weights) | |
for i in xrange(100000): | |
drawFromPartition(partition) | |
end = time.time() | |
print end - start |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment