Created
March 17, 2016 22:18
-
-
Save pursuingpareto/b15f1197d96b1a2bbc48 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def score(bracket, results, filt, | |
teams_remaining, blacklist, round_num=0): | |
""" | |
Recursively calculates the score of a prediction | |
bracket against a results bracket. | |
- bracket | |
A bitstring representing a prediction bracket. For | |
a 64 game tournament this would be 63 bits. | |
- results | |
A bitstring representing the actual outcome of | |
a tournament. | |
- filt | |
With the exception of the first round in a tournament, | |
its not possible to score a round by just comparing | |
the bits in bracket to the bits in results. For example, | |
correctly predicting the championship game requires not | |
only the correct bit for that game, but also the correct | |
prediction for all the games the tournament winner had | |
won before the final round. | |
The filt parameter is a one time pre-computed bitstring | |
used to indicate which games in a round must be | |
correctly predicted in order to correctly predict | |
successive games. For a 64 game tournament filt would | |
contain 62 bits. | |
- teams_remaining | |
This is a recursive function where each call | |
represents another tournament round. | |
teams_remaining gives the number of teams left in the | |
tournament as of this function call. | |
- blacklist | |
This parameter is a sequence of N bits where N is | |
the number of games in the current round. It uses | |
the accuracy of predictions from previous rounds to | |
"remember" which games are possible to correctly | |
predict. When calling the score function initially | |
these bits should all be set to 1. | |
- round_num | |
A number representing the current round. For a 64 team | |
tournament this would take the values 0,1,2,3,4, and 5 | |
""" | |
# First check if there is a winner | |
if teams_remaining == 1 : | |
return 0 | |
# compute constants for round | |
# round_mask is a bitstring with all bits set to 0 | |
# except the bits corresponding to the current round | |
num_games = teams_remaining / 2 | |
round_mask = 2 ** num_games - 1 | |
# the current round is encoded in the num_games | |
# least significant bits. Likewise for results | |
# and filter | |
round_predictions = bracket & round_mask | |
bracket = bracket >> num_games | |
round_results = results & round_mask | |
results = results >> num_games | |
round_filter = filt & round_mask | |
filt = filt >> num_games | |
# The overlap between the prediction bits and the | |
# results bits is calculated by XORing the two and | |
# then flipping the bits remaining. | |
overlap = ~(round_predictions ^ round_results) | |
# In all rounds except the first, overlap will tend | |
# to overestimate a bracket's correctness. This is | |
# corrected by ANDing the overlap with the blacklist | |
scores = overlap & blacklist | |
# the points for this round are calculated by counting | |
# the number of 1s in the scores bitstring and then | |
# multiplying by 2 ^ round_num (this multiplication | |
# is used to weigh predictions in later rounds more | |
# heavily than earlier rounds) | |
points = popcount(scores) << round_num | |
# with the points calculated we can now use the | |
# pre-computed filter to figure out which of these | |
# predictions may impact future predictions | |
relevant_scores = scores & round_filter | |
# For each pair of games in this round, look for a 1 | |
# in either of the bits to compute the blacklist | |
# for the next round. | |
even_bits, odd_bits = get_odds_and_evens(relevant_scores) | |
blacklist = even_bits | odd_bits | |
# recursively call score function with updated params | |
return points + score(bracket, results, filt, | |
teams_remaining / 2, blacklist, round_num + 1) | |
def get_odds_and_evens(bits): | |
""" | |
Separates the even and odd bits by repeatedly | |
shuffling smaller segments of a bitstring. | |
""" | |
tmp = (bits ^ (bits >> 1)) & 0x22222222; | |
bits ^= (tmp ^ (tmp << 1)); | |
tmp = (bits ^ (bits >> 2)) & 0x0c0c0c0c; | |
bits ^= (tmp ^ (tmp << 2)); | |
tmp = (bits ^ (bits >> 4)) & 0x00f000f0; | |
bits ^= (tmp ^ (tmp << 4)); | |
tmp = (bits ^ (bits >> 8)) & 0x0000ff00; | |
bits ^= (tmp ^ (tmp << 8)); | |
evens = bits >> 16 | |
odds = bits % 0x10000 | |
return evens, odds | |
def popcount(x): | |
""" | |
Counts the number of 1s in a bitstring. | |
""" | |
x -= (x >> 1) & 0x5555555555555555 | |
x = (x & 0x3333333333333333) + ((x >> 2) & 0x3333333333333333) | |
x = (x + (x >> 4)) & 0x0f0f0f0f0f0f0f0f | |
return ((x * 0x0101010101010101) & 0xffffffffffffffff ) >> 56 | |
# this function can be slow since it would only be called once | |
# when the tournament is over. | |
def make_mask_binary(results, teams_remaining): | |
future_rounds = results & (2 ** (teams_remaining/2) -1) | |
mask = "" | |
for res in bin(future_rounds)[2:]: | |
if res == "0": | |
mask += "01" | |
else: | |
mask += "10" | |
mask = mask[::-1] | |
return int(mask, 2) | |
def make_test(bracket, results, N=64): | |
blacklist = int("1" * (N/2), 2) | |
filt = make_mask_binary(results, N) | |
return blacklist, filt | |
def test(): | |
results = int("1" * 63, 2) | |
bracket = int("1" * 63, 2) | |
blacklist, filt = make_test(bracket, results) | |
assert score(bracket, results, filt, 64, blacklist) == 192 | |
bracket = int("0" * 31 + "1" * 32, 2) | |
blacklist, filt = make_test(bracket, results) | |
assert score(bracket, results, filt, 64, blacklist) == 32 | |
bracket = int("1" * 15 + "0" * 16 + "1" * 32, 2) | |
blacklist, filt = make_test(bracket, results) | |
assert score(bracket, results, filt, 64, blacklist) == 32 | |
bracket = int("1" * 15 + "01" * 8 + "1" * 32, 2) | |
blacklist, filt = make_test(bracket, results) | |
assert score(bracket, results, filt, 64, blacklist) == (192 - 2*8) | |
bracket = int("1" * 15 + "10" * 8 + "1" * 32, 2) | |
blacklist, filt = make_test(bracket, results) | |
assert score(bracket, results, filt, 64, blacklist) == (32 + 2*8) | |
for _ in range(10): | |
s = "" | |
for i in range(63): | |
s += random.choice(["1", "0"]) | |
bracket = int(s, 2) | |
results = bracket | |
assert score(bracket, results, filt, 64, blacklist) == 192 | |
print "tests pass" | |
test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment