Skip to content

Instantly share code, notes, and snippets.

@kflu
Created May 27, 2012 05:49
Show Gist options
  • Save kflu/2802311 to your computer and use it in GitHub Desktop.
Save kflu/2802311 to your computer and use it in GitHub Desktop.
Balanced partition
""" Balanced partition
You have a set of n integers each in the range 0 ... K. Partition these
integers into two subsets such that you minimize |S1 - S2|, where S1 and S2
denote the sums of the elements in each of the two subsets.
http://people.csail.mit.edu/bdean/6.046/dp/
"""
class BalancedPartition:
def __init__(self, A):
self.A = A
self.n = len(A)
self.S = sum(A)
self.table = [[None for i in xrange(self.S/2 + 1)]
for j in xrange(self.n)]
for s in xrange(self.S/2 + 1):
# M(0,s) = 1 iff s == A[0]
self.table[0][s] = 1 if s == self.A[0] else 0
for i in xrange(self.n):
self.table[i][0] = 1 # Empty set can sum up to S=0
self.trace = [[0 for i in xrange(self.S/2 + 1)]
for j in xrange(self.n)]
def M(self, i, s):
"""Returns if a subset of A_1...A_i can sum up to s."""
if s < 0: return 0
if self.table[i][s] != None: return self.table[i][s]
assert i >= 1
if self.M(i-1, s) == 1:
self.table[i][s] = 1
return 1
if self.M(i-1, s-self.A[i]) == 1:
self.table[i][s] = 1
self.trace[i][s] = 1
# import rpdb2; rpdb2.start_embedded_debugger("1")
return 1
return 0
def solve(self):
max_sum = None
for x in xrange(self.S/2, -1, -1): # [S/2..0]
if self.M(self.n - 1, x) == 1:
max_sum = x
break
assert max_sum != None
i, s, has = self.n - 1, max_sum, []
while i > 0 and s > 0:
assert self.table[i][s] == 1
if self.trace[i][s] == 1:
has.append(i)
s -= self.A[i]
i -= 1
if s != 0:
assert s == self.A[0]
has.append(0)
# return (difference, partition_1, partition_2)
return (self.S - 2 * max_sum,
set(has),
set(xrange(self.n)) - set(has))
# =======================
# TESTS
# =======================
def test_simple():
result = BalancedPartition([1]).solve()
assert result[0] == 1
assert set([0]) in result
assert set([]) in result
def test_1():
result = BalancedPartition([1,2,3]).solve()
assert result[0] == 0
assert set([0,1]) in result
assert set([2]) in result
def test_2():
result = BalancedPartition([0,0,0]).solve()
assert result[0] == 0
def test_3():
A = [8,2,4,0,1]
result = BalancedPartition(A).solve()
assert result[0] == 1
P1, P2 = result[1:]
assert P1 & P2 == set()
assert P1 | P2 == set(xrange(len(A)))
assert abs(sum((A[i] for i in P1)) - sum((A[i] for i in P2))) == result[0]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment