Skip to content

Instantly share code, notes, and snippets.

@Synthetica9
Last active August 25, 2018 13:30
Show Gist options
  • Save Synthetica9/14ab958b1d54f7575e76a7fbd0926509 to your computer and use it in GitHub Desktop.
Save Synthetica9/14ab958b1d54f7575e76a7fbd0926509 to your computer and use it in GitHub Desktop.
#! /usr/bin/env nix-shell
#! nix-shell -i python -p "with python3Packages; [python hypothesis]"
from hypothesis import *
from hypothesis.strategies import *
from heapq import *
from itertools import count
from string import ascii_uppercase
class InitalSortedStack(object):
def __init__(self, src=None):
if src is None:
src = []
src = list(src)
src = sorted(src)
self.content = src
def push(self, obj):
self.content.append(obj)
def pop(self):
return self.content.pop()
def empty(self):
return not self.content
def __repr__(self):
return repr(self.content)
class MaxHeap(InitalSortedStack):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.content = [-x for x in self.content]
heapify(self.content)
def push(self, obj):
heappush(self.content, -obj)
def pop(self):
return -heappop(self.content)
def __repr__(self):
return repr([-x for x in self.content])
@composite
def zero_sum_lists(draw, *args, **kwargs):
l = draw(lists(integers(), *args, **kwargs))
s = sum(l)
if sum(l) != 0 or draw(booleans()): # we shrink towards False
l.append(-s)
assert(sum(l) == 0)
return l
def algo(datastructure):
def func(data, verbose=False):
assert(sum(data) == 0)
pos = datastructure( x for x in data if x > 0)
neg = datastructure(-x for x in data if x < 0)
for i in count(0):
if verbose:
print("+:", pos, "-:", neg)
if pos.empty() or neg.empty():
assert pos.empty() and neg.empty()
break
x = pos.pop() - neg.pop()
if x > 0:
pos.push(x)
elif x < 0:
neg.push(-x)
if verbose:
print("Transactions:", i)
return i
return func
naive = algo(InitalSortedStack)
smart = algo(MaxHeap)
if __name__ == '__main__':
criterion = lambda x: smart(x) < naive(x)
l = find(zero_sum_lists(), criterion)
l = sorted(l, reverse=True)
for (n, x) in zip(ascii_uppercase, l):
print(f'{n}: {x:+}')
print()
print("Naive:")
naive(l, verbose=True)
print()
print("Smart:")
smart(l, verbose=True)
n = 0
for i in range(10000):
l = zero_sum_lists().example()
n += criterion(l)
print(f'\nFound relevant in {round(100*n/i, 2)}% of randomly generated lists')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment