Last active
November 4, 2017 20:06
-
-
Save matthewfeickert/8c6bcf26462b6e06983aee08597da2db to your computer and use it in GitHub Desktop.
Quick Python demonstration of Uniform sampling problem (https://twitter.com/fermatslibrary/status/924263998589145090)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
""" | |
Problem Statement: Sample from the Uniform distribution over the range [0,1] | |
until the sum of the numbers sampled is greater than 1. On average, how | |
many samples are taken? | |
Answer: e | |
Problem Source: https://twitter.com/fermatslibrary/status/924263998589145090 | |
Author: Matthew Feickert | |
Date: 2016-10-28 | |
""" | |
import sys | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import matplotlib.patches as mpatches | |
import itertools # for fast looping | |
def simulate_sampling(n_trials): | |
""" | |
Simulate experiments | |
Args: | |
n_trials: `int` The number of trials | |
Returns: | |
sum(counts): `int` | |
counts: `array of ints` | |
""" | |
counts = [] | |
for _ in itertools.repeat(None, n_trials): | |
n_samples = 0 | |
sum_ = 0 | |
while sum_ <= 1: | |
sum_ += np.random.uniform(0, 1) | |
n_samples += 1 | |
counts.append(n_samples) | |
return sum(counts), counts | |
def main(n_trials=100000): | |
n_samples, counts = simulate_sampling(n_trials) | |
result = n_samples / n_trials | |
print('Average number of samples taken over {} trials: {}'.format( | |
n_trials, result)) | |
print('Difference between result and e (~{:.6f}): {:.6f}'.format( | |
np.exp(1), | |
np.absolute(result - np.exp(1)))) | |
print('Relative difference between result and e: {:.6f}'.format( | |
np.absolute(1 - (result / np.exp(1))))) | |
# Plot results | |
x = list(range(2, 10)) | |
relative_counts = list(counts.count(number) for number in x) | |
plt.plot(x, relative_counts, 'ro', linewidth=1, | |
color='black', markerfacecolor='blue') | |
line_result = plt.axvline( | |
x=result, color='black', label='mean number of samples: {}'.format(result)) | |
plt.xlabel('Number of samples') | |
plt.ylabel('Relative count') | |
# Legend | |
handles = [line_result] | |
handles.append(mpatches.Patch( | |
color='none', label='e ~ {:.6f}'.format(np.exp(1)))) | |
plt.legend(handles=handles) | |
plt.savefig('sample_uniform.png') | |
if __name__ == '__main__': | |
if len(sys.argv) > 1: | |
main(int(sys.argv[1])) | |
else: | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment