Created
August 28, 2014 00:03
-
-
Save berdario/7b7232cee64bafb3a938 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function | |
from numpy.random import binomial | |
from random import sample | |
from functools import partial | |
from operator import add, itemgetter | |
from itertools import takewhile | |
# Problem: | |
# You are given 10 opaque bags containing large numbers of red balls | |
# and blue balls in unspecified proportions. The balls are identical | |
# except for their color. Propose an algorithm for picking 100 balls | |
# from any of the bags that tries to maximize the number of blue balls | |
# picked. Of course each bag is identifiable but picking a ball from a | |
# given bag is random. | |
# by intuition, a good approach could be to go through all the bags/buckets | |
# and keep track of the number of blue balls (1s) extracted before stumbling | |
# upon a red ball (a 0), and then prioritize the "best" buckets | |
# this is what the extractor function does: | |
def extractor(buckets, sort_fn, scale_fn): | |
i, step = 0, 0 | |
bucket_scores = {bucket: (0,0) for bucket in buckets} | |
while i < 100: | |
for bucket, score in sort_fn(bucket_scores)[:scale_fn(step)]: | |
if i >= 100: | |
break | |
tally = 0 | |
result = bucket() | |
i += 1 | |
tally += result | |
while result and i < 100: | |
result = bucket() | |
i += 1 | |
tally += result | |
bucket_scores[bucket] = map(add, score, (tally, -1)) | |
step += 1 | |
return sum(tally for tally, _ in bucket_scores.values()) | |
# sort_fn is used to prioritize, and scale_fn is used to drop the worst buckets | |
def alg4(buckets): | |
i = 0 | |
bucket_scores = {bucket: 0 for bucket in buckets} | |
while i < 100: | |
for bucket in buckets: | |
if i >= 100: | |
break | |
tally = 0 | |
result = bucket() | |
i += 1 | |
tally += result | |
while result and i < 100: | |
result = bucket() | |
i += 1 | |
tally += result | |
bucket_scores[bucket] += tally | |
if len(buckets)==2: | |
buckets = [max(bucket_scores.items(), key=itemgetter(1))[0]] | |
elif len(buckets)>2 and sum(bucket_scores.values()) > 5: | |
top2 = sorted(bucket_scores.items(), key=itemgetter(1), reverse=True)[:2] | |
buckets = [bucket for bucket, _ in top2] | |
return sum(bucket_scores.values()) | |
def scaling(factors): | |
def scale_fn(n): | |
return list(takewhile(lambda (i,_): i<=n, factors))[-1][1] | |
return scale_fn | |
def alg1(buckets): | |
l = len(buckets) | |
def sort_fn(bucket_scores): | |
return sorted(bucket_scores.items(), key=itemgetter(1), reverse=True) | |
return extractor(buckets, sort_fn, scaling([(0, l), (2, l//2)])) | |
def alg1b(buckets): | |
l = len(buckets) | |
def sort_fn(bucket_scores): | |
return sorted(bucket_scores.items(), key=itemgetter(1), reverse=True) | |
return extractor(buckets, sort_fn, scaling([(0, l), (2, 1)])) | |
def alg1c(buckets): | |
l = len(buckets) | |
def sort_fn(bucket_scores): | |
return sorted(bucket_scores.items(), key=itemgetter(1), reverse=True) | |
return extractor(buckets, sort_fn, scaling([(0, l), (1, 2)])) | |
def alg1d(buckets): | |
l = len(buckets) | |
def sort_fn(bucket_scores): | |
return sorted(bucket_scores.items(), key=itemgetter(1), reverse=True) | |
return extractor(buckets, sort_fn, scaling([(0, l), (1, 2), (2, 1)])) | |
def alg1e(buckets): | |
l = len(buckets) | |
def sort_fn(bucket_scores): | |
return sorted(bucket_scores.items(), key=itemgetter(1), reverse=True) | |
return extractor(buckets, sort_fn, scaling([(0, l), (1, 1)])) | |
def alg1f(buckets): | |
l = len(buckets) | |
def sort_fn(bucket_scores): | |
return sorted(bucket_scores.items(), key=itemgetter(1), reverse=True) | |
return extractor(buckets, sort_fn, scaling([(0, l), (2, 2), (3, 1)])) | |
def alg2(buckets): | |
l = len(buckets) | |
def sort_fn(bucket_scores): | |
return sorted(bucket_scores.items(), key=lambda x: (x[1][1], x[1][0]), reverse=True) | |
return extractor(buckets, sort_fn, scaling([(0, l), (2, l//2)])) | |
def alg3(buckets): | |
l = len(buckets) | |
def sort_fn(bucket_scores): | |
return sorted(bucket_scores.items(), key=itemgetter(1), reverse=True) | |
return extractor(buckets, sort_fn, scaling([(0, l), (2, l//2), (3, l//4)])) | |
N = 1000 | |
def get_buckets(ps): | |
return [partial(binomial, 1, p) for p in sample(ps, len(ps))] | |
positives = partial(get_buckets, [0.6, 0.7, 0.75, 0.8, 0.9]*2) | |
negatives = partial(get_buckets, [0.4, 0.3, 0.25, 0.2, 0.1]*2) | |
mixed = partial(get_buckets, [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.25, 0.2, 0.1]) | |
#print('alg1d pos', sum(alg1d(positives()) for _ in range(N))) | |
#print('alg1f pos', sum(alg1f(positives()) for _ in range(N))) | |
#print('alg4 pos', sum(alg4(positives()) for _ in range(N))) | |
#print('alg1d neg', sum(alg1d(negatives()) for _ in range(N))) | |
#print('alg1f neg', sum(alg1f(negatives()) for _ in range(N))) | |
#print('alg4 neg', sum(alg4(negatives()) for _ in range(N))) | |
print('alg1d mix', sum(alg1d(mixed()) for _ in range(N))) | |
#print('alg1f mix', sum(alg1f(mixed()) for _ in range(N))) | |
#print('alg4 mix', sum(alg4(mixed()) for _ in range(N))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment