Skip to content

Instantly share code, notes, and snippets.

@tylerneylon
Created January 12, 2023 19:10
Show Gist options
  • Save tylerneylon/4ce94614ce97d21c8a56a7d1472a4342 to your computer and use it in GitHub Desktop.
Save tylerneylon/4ce94614ce97d21c8a56a7d1472a4342 to your computer and use it in GitHub Desktop.
This is an example implementation of an efficient reservoir sampling algorithm.
""" reservoir.py
This is an example implementation of an efficient
reservoir sampling algorithm -- this algorithm is useful
when you have a data stream, possibly of unknown length,
and you'd like to maintain an incrementally updated
random subset of fixed size k. This algorithm works by
occasionally adding a new data point from the stream
into the 'reservoir,' which is simply a length-k list
of data points.
A more detailed explanation is here:
https://en.wikipedia.org/wiki/Reservoir_sampling
"""
import math
import random
from collections import Counter
from itertools import islice
def rand01():
''' This returns a uniformly random number in (0, 1).
It's like most such functions but excludes 0.
'''
while True:
x = random.random()
if x > 0:
return x
def reservoir_indexes(k):
''' This is a key part of a reservoir sampling algorithm.
The input k is the size of the reservoir.
This function yields an infinite and increasing sequence
of indexes into the stream to be sampled.
'''
# Populate the initial reservoir.
for i in range(k):
yield i
# Populate the rest of the reservoir.
w = math.exp(math.log(rand01()) / k)
while True:
i += int(math.log(rand01()) / math.log(1 - w)) + 1
yield i
w *= math.exp(math.log(rand01()) / k)
# Put together a simple demo and test.
if __name__ == '__main__':
print('Reservoir sampling indexes for k = 5:')
sampler = reservoir_indexes(5)
for idx in islice(sampler, 20):
print(idx)
# Here's a little check to see if this appears to
# be correctly uniformly sampling 5 things from a size-100
# stream. This also shows an example of how to
# use the generator.
n_trials = 10_000
counter = Counter()
print('Frequency rates, choosing 5 items from 100:')
print(f'(Running {n_trials} trials and averaging.)')
for i in range(n_trials):
# We'll track the last 5 indexes sampled which are < 100.
indexes = list(range(5))
for i in reservoir_indexes(5):
if i >= 100: break
if i < 5: continue
indexes[random.randint(0, 4)] = i
counter.update(indexes)
# Print out the frequency of each index.
for i in range(100):
print(f'{i:2d}-{counter[i] / n_trials:.2f}', end=' ')
if i % 5 == 4:
print()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment