Skip to content

Instantly share code, notes, and snippets.

@TimSC
Created November 22, 2016 15:06
Show Gist options
  • Save TimSC/cc903b23f4a8777e80e0ee0f48690741 to your computer and use it in GitHub Desktop.
Save TimSC/cc903b23f4a8777e80e0ee0f48690741 to your computer and use it in GitHub Desktop.
Weighted samping in python 2 and 3. Released under CC0.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import print_function
import random
def get_weighted_sample(values, weights, return_index = False):
cum = []
tot = 0.0
for w in weights:
if w < 0.0:
raise RuntimeError("Weights cannot be negative")
tot += w
cum.append(tot)
if tot <= 0.0:
raise RuntimeError("Total weight must be positive")
r = random.random() * tot
lastNonZeroInd = None
for ind, (cw, w) in enumerate(zip(cum, weights)):
if r < cw:
if return_index:
return ind
return values[ind]
if w > 0.0:
lastNonZeroInd = ind
if return_index:
return lastNonZeroInd
return values[lastNonZeroInd]
if __name__ == '__main__':
import numpy as np
import matplotlib.pyplot as plt
values = [v + 5.0 for v in range(10)]
weights = [v / 10.0 for v in range(10)]
print (values)
print (weights)
samples = []
for i in range(10000):
s = get_weighted_sample(values, weights, return_index=False)
samples.append(s)
freq = np.bincount(samples)
plt.plot(freq)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment