Skip to content

Instantly share code, notes, and snippets.

@leopd
Created July 18, 2013 17:12
Show Gist options
  • Save leopd/6031089 to your computer and use it in GitHub Desktop.
Save leopd/6031089 to your computer and use it in GitHub Desktop.
An example of a dynamic programming solution to the problem of counting the number of nxn matrices that have exactly n/2 zeroes and n/2 ones in each row and each column. Produces OEIS A058527.
import copy
import math
import logging
def balance_cnt(n):
"""Considering all nxn matrices consisting entirely of 0s and 1s,
return the number of matrices which have exactly n/2 ones in each column and each row.
Solution using dynamic programming as outlined here:
https://en.wikipedia.org/wiki/Dynamic_programming#A_type_of_balanced_0.E2.80.931_matrix
Produces answers in this sequence (for n/2) https://oeis.org/A058527
"""
assert n%2 == 0
memo={}
def validcntmemo(A, pre=""):
key = str(A)
if memo.has_key(key):
return memo[key]
ans = validcnt(A, pre)
memo[key] = ans
return ans
def validcnt(A, pre=""):
"""A is an array of n tuples (a,b), one tuple per column s.t.
a := num 0s left to place in this column.
b := num 1s left to place in this column.
Return number of valid solutions from here down.
"""
logging.debug("%sCalled with %s" % (pre,str(A)))
left = sum(A[0])
assert len(A) == n
for tup in A:
assert sum(tup) == left
if tup[0] < 0 or tup[1] < 0:
# If any are negative, there are no valid solutions
return 0
if left == 0:
return 0
if left == 1:
ones = sum([ x[1] for x in A ])
if ones == n/2:
logging.debug("%sFOUND a solution" % pre)
return 1
else:
logging.debug("%sNo solutions because %d ones here" % (pre,ones))
return 0
# more than 1 left. Try all combinations and sum them up
tot = 0
for bits in xrange(2**n):
logging.debug("%s Modifying %s with bits %d" % (pre,str(A),bits))
AA = copy.deepcopy(A) # copy to recurse
bitshere = 0
for bitnum in range(n):
bit = 2**bitnum
if bits & bit:
# decrement 1
one = 1
else:
# decrement 0
one = 0
bitshere += one
AA[bitnum][one] -= 1
if bitshere == n/2:
logging.debug("%s recursing to %s" % (pre,str(AA)))
tot += validcntmemo(AA," "+pre)
else:
logging.debug("%s not a valid row -- %s has %d ones" % (pre, str(AA), bitshere))
pass
logging.debug("%sAnswer for %s is %d" % (pre,str(A),tot))
return tot
# initiate recursion.
A = [ [n/2,n/2] for i in range(n) ]
return validcntmemo(A)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment