Skip to content

Instantly share code, notes, and snippets.

@breeko
Last active April 18, 2019 13:53
Show Gist options
  • Save breeko/69e54a4e77229618efeb468fc97c90bf to your computer and use it in GitHub Desktop.
Save breeko/69e54a4e77229618efeb468fc97c90bf to your computer and use it in GitHub Desktop.
Skittles Distribution
import numpy as np
from collections import Counter
from matplotlib import pyplot as plt
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-n", "--num-trials", help="number of trials to run", type=int, default=1000)
parser.add_argument("-s", "--std", help="standard deviation of num of skittles", type=float, default=2)
parser.add_argument("-o", "--out", help="historgram out location", type=str, default="out.png")
args = parser.parse_args()
NUM_SKITTLES_MEAN = 59.2736
def get_bag(std_dist):
num_skittles = int(NUM_SKITTLES_MEAN + (np.random.standard_normal() * std_dist))
bag_raw = np.random.choice(["r","o","y","g","p"], size=num_skittles, replace=True)
bag_count = Counter(bag_raw)
bag = ";".join(sorted(["{}:{}".format(k[0],v) for k,v in bag_count.items()]))
return bag
def find_match(max_attempts=10**8, std_dist=2, stop_after_match=True, verbose=True):
all_bags = {}
for cur_num in range(max_attempts):
cur_bag = get_bag(std_dist=std_dist)
prior = all_bags.get(cur_bag, [])
all_bags[cur_bag] = prior + [cur_num]
if len(prior) > 0:
if verbose:
print("Matching pair found!")
prior_nums = ",".join([str(p) for p in prior])
print("#{} bag matches #{}".format(prior_nums, cur_num))
if stop_after_match:
return cur_num
def run_trials(num_trials, std_dist=2):
trials = [find_match(std_dist=std_dist,verbose=False) for _ in range(num_trials)]
return trials
if __name__ == "__main__":
trials = run_trials(args.num_trials, std_dist=args.std)
trials_hist=plt.hist(trials, bins=100)
plt.title('N: {} std {} mean: {:0.2f} median: {:0.2f} std: {:0.2f} 95% {:0.2f}'.format(
args.num_trials, args.std, np.mean(trials), np.median(trials), np.std(trials), np.percentile(trials,95)))
plt.gcf().savefig(args.out)
plt.clf()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment