Created
March 20, 2018 16:38
-
-
Save xire-/00beccec7e99e69a13df7fccae40d253 to your computer and use it in GitHub Desktop.
payment solver
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# -*- coding: utf8 -*- | |
from collections import defaultdict | |
from itertools import permutations | |
class PaymentSolver: | |
def __init__(self, people): | |
# if X owe Y N euro, | |
# then Y owe X -N euro | |
# (x,y) -> float | |
self.x_owe_y = defaultdict(float) | |
self.people = people | |
def __getitem__(self, idx): | |
x,y = idx | |
assert x in self.people | |
assert y in self.people | |
assert x != y | |
if x > y: | |
return - self.x_owe_y[y,x] | |
else: | |
return self.x_owe_y[x,y] | |
def __setitem__(self, idx, value): | |
x,y = idx | |
assert x in self.people | |
assert y in self.people | |
assert x != y | |
if x > y: | |
x,y = y,x | |
value = -value | |
self.x_owe_y[x,y] = value | |
def __str__(self): | |
ris = list() | |
for p1 in sorted(self.people): | |
ris.append(f'___{p1:_<30s}') | |
tot = 0. | |
for p2 in sorted(self.people): | |
if p1 == p2: continue | |
owed = self[p1,p2] | |
tot += owed | |
if owed > 0: | |
ris.append(f'owe {p2} {owed:.2f} EUR') | |
ris.append('---') | |
ris.append(f'TOTAL OWED: {tot:.2f} EUR') | |
ris.append('') | |
return '\n'.join(ris) | |
def paid_for(self, who, total, group): | |
""" | |
"who" paid for "group", "total" euro | |
""" | |
assert who in self.people | |
assert all(p in self.people for p in group) | |
total_each = total / len(group) | |
for p in group: | |
# skip self | |
if p == who: continue | |
self[p,who] += total_each | |
def split_money(self, who, total, group): | |
""" | |
*who* split *total* euro with *group* | |
it's equivalent to *group* owing negative money to *who* | |
""" | |
self.paid_for(who,-total, group) | |
def _simplify_triangle(self): | |
''' | |
se: | |
A -> C : x | |
A -> B : y | |
B -> C : z | |
diventa: | |
A -> C : x + min(y,z) | |
A -> B : y - min(y,z) | |
B -> C : z - min(y,z) | |
''' | |
while True: | |
fixpoint = True | |
for A,B,C in permutations(sorted(self.people), 3): | |
x = self[A,C] | |
y = self[A,B] | |
z = self[B,C] | |
if x >= 0 and y > 0 and z > 0: | |
fixpoint = False | |
if y < z: | |
self[A,C] = x + y | |
self[A,B] = 0.0 | |
self[B,C] = z - y | |
else: | |
self[A,C] = x + z | |
self[A,B] = y - z | |
self[B,C] = 0.0 | |
if fixpoint: break | |
def _simplify_cross(self): | |
''' | |
se: | |
A -> C: x | |
A -> D: y | |
B -> C: z | |
B -> D: w | |
deventa: | |
A -> C: x + min(y,z) | |
A -> D: y - min(y,z) | |
B -> C: z - min(y,z) | |
B -> D: w + min(y,z) | |
''' | |
while True: | |
fixpoint = True | |
for A,B,C,D in permutations(sorted(self.people), 4): | |
x = self[A,C] | |
y = self[A,D] | |
z = self[B,C] | |
w = self[B,D] | |
if x > 0 and y > 0 and z > 0 and w > 0: | |
fixpoint = False | |
if y < z: | |
self[A,C] = x + y | |
self[A,D] = 0.0 | |
self[B,C] = z - y | |
self[B,D] = w + y | |
else: | |
self[A,C] = x + z | |
self[A,D] = y - z | |
self[B,C] = 0.0 | |
self[B,D] = w + z | |
if fixpoint: break | |
def simplify_payments(self): | |
self._simplify_triangle() | |
self._simplify_cross() | |
# ================================================================================ | |
def main(): | |
people = { | |
'Marco', | |
'Palma', | |
'Mauro', | |
'Gaspa', | |
'Cesco', | |
'Benve', | |
'Lorenzo', | |
} | |
s = PaymentSolver(people) | |
s.paid_for(who='Lorenzo', total=45.00, group=people) | |
s.paid_for(who='Gaspa', total=12.00, group={'Gaspa', 'Cesco', 'Benve'}) | |
s.paid_for(who='Marco', total=7.20, group=people-{'Palma'}) | |
# set a single debt | |
s['Cesco', 'Gaspa'] = 0.0 | |
# increase a single debt | |
s['Benve', 'Gaspa'] += 10.0 | |
s.simplify_payments() | |
print(s) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment