Created
March 24, 2016 03:41
-
-
Save wchargin/8a0245d629f613f74195 to your computer and use it in GitHub Desktop.
shuffling routines and automatic evaluation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Shuffling algorithms and tools for evaluating them.""" | |
import collections | |
import math | |
import random | |
def fisher_yates(xs): | |
"""Generate a uniformly random permutation of the input list.""" | |
n = len(xs) | |
for i in xrange(n): | |
idx = random.randint(i, n - 1) | |
temp = xs[idx] | |
xs[idx] = xs[i] | |
xs[i] = temp | |
return xs | |
def fisher_yates_recursive(xs): | |
"""Recursively generate a uniformly random permutation of the input list.""" | |
if not xs: | |
return xs | |
else: | |
n = len(xs) | |
idx = random.randint(0, n - 1) | |
temp = xs[idx] | |
xs[idx] = xs[0] | |
xs[0] = temp | |
xs[1:] = fisher_yates_recursive(xs[1:]) | |
return xs | |
def permutation_stderr(shuffler, n=6, trials=50000): | |
""" | |
Find the standard error of | |
the frequencies with which permutations appear | |
after shuffling a list of length `n` with the provided `shuffle` algorithm | |
over `trials` trials. | |
Compare to `random.shuffle`. | |
""" | |
input_list = range(n) | |
results = collections.defaultdict(lambda: 0) | |
for _ in xrange(trials): | |
this_list = input_list[:] | |
shuffler(this_list) | |
results[tuple(this_list)] += 1 | |
counts = results.values() | |
count_counts = {c: counts.count(c) for c in set(counts)} | |
mean_count = float(sum(c * cc | |
for (c, cc) in count_counts.iteritems())) / trials | |
variance = sum((c - mean_count) ** 2 * cc | |
for (c, cc) in count_counts.iteritems()) | |
stdev = math.sqrt(variance) | |
return stdev / trials | |
def evaluate(shuffler, reference=random.shuffle, **kwargs): | |
""" | |
Compare the behavior of the provided `shuffler` algorithm | |
to that of the provided reference shuffler `refsol`. | |
Return the quotient of their standard errors; | |
a value significantly below 1 indicates that `shuffle` may be biased. | |
Analytically, the value should never exceed 1; | |
if it does, then some combination of "you got quite lucky" | |
and "your `shuffler` is very unbiased" | |
is correct. | |
Extra kwargs are passed to `permutation_stderr`. | |
""" | |
that = permutation_stderr(reference, **kwargs) | |
this = permutation_stderr(shuffler, **kwargs) | |
return that / this | |
def main(): | |
"""Run tests against the shufflers provided in this module.""" | |
import sys | |
shufflers = [("random.shuffle", random.shuffle), | |
(u"Fisher\u2013Yates", fisher_yates), | |
(u"Fisher\u2013Yates (recursive)", fisher_yates_recursive)] | |
for (name, shuffler) in shufflers: | |
padding = 40 | |
sys.stdout.write(("Testing %s... " % name).ljust(padding)) | |
sys.stdout.flush() | |
score = evaluate(shuffler) | |
print 'score: %s.' % score | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment