Created
May 27, 2012 05:49
-
-
Save kflu/2802311 to your computer and use it in GitHub Desktop.
Balanced partition
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" Balanced partition | |
You have a set of n integers each in the range 0 ... K. Partition these | |
integers into two subsets such that you minimize |S1 - S2|, where S1 and S2 | |
denote the sums of the elements in each of the two subsets. | |
http://people.csail.mit.edu/bdean/6.046/dp/ | |
""" | |
class BalancedPartition: | |
def __init__(self, A): | |
self.A = A | |
self.n = len(A) | |
self.S = sum(A) | |
self.table = [[None for i in xrange(self.S/2 + 1)] | |
for j in xrange(self.n)] | |
for s in xrange(self.S/2 + 1): | |
# M(0,s) = 1 iff s == A[0] | |
self.table[0][s] = 1 if s == self.A[0] else 0 | |
for i in xrange(self.n): | |
self.table[i][0] = 1 # Empty set can sum up to S=0 | |
self.trace = [[0 for i in xrange(self.S/2 + 1)] | |
for j in xrange(self.n)] | |
def M(self, i, s): | |
"""Returns if a subset of A_1...A_i can sum up to s.""" | |
if s < 0: return 0 | |
if self.table[i][s] != None: return self.table[i][s] | |
assert i >= 1 | |
if self.M(i-1, s) == 1: | |
self.table[i][s] = 1 | |
return 1 | |
if self.M(i-1, s-self.A[i]) == 1: | |
self.table[i][s] = 1 | |
self.trace[i][s] = 1 | |
# import rpdb2; rpdb2.start_embedded_debugger("1") | |
return 1 | |
return 0 | |
def solve(self): | |
max_sum = None | |
for x in xrange(self.S/2, -1, -1): # [S/2..0] | |
if self.M(self.n - 1, x) == 1: | |
max_sum = x | |
break | |
assert max_sum != None | |
i, s, has = self.n - 1, max_sum, [] | |
while i > 0 and s > 0: | |
assert self.table[i][s] == 1 | |
if self.trace[i][s] == 1: | |
has.append(i) | |
s -= self.A[i] | |
i -= 1 | |
if s != 0: | |
assert s == self.A[0] | |
has.append(0) | |
# return (difference, partition_1, partition_2) | |
return (self.S - 2 * max_sum, | |
set(has), | |
set(xrange(self.n)) - set(has)) | |
# ======================= | |
# TESTS | |
# ======================= | |
def test_simple(): | |
result = BalancedPartition([1]).solve() | |
assert result[0] == 1 | |
assert set([0]) in result | |
assert set([]) in result | |
def test_1(): | |
result = BalancedPartition([1,2,3]).solve() | |
assert result[0] == 0 | |
assert set([0,1]) in result | |
assert set([2]) in result | |
def test_2(): | |
result = BalancedPartition([0,0,0]).solve() | |
assert result[0] == 0 | |
def test_3(): | |
A = [8,2,4,0,1] | |
result = BalancedPartition(A).solve() | |
assert result[0] == 1 | |
P1, P2 = result[1:] | |
assert P1 & P2 == set() | |
assert P1 | P2 == set(xrange(len(A))) | |
assert abs(sum((A[i] for i in P1)) - sum((A[i] for i in P2))) == result[0] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment