Skip to content

Instantly share code, notes, and snippets.

@majiang
Created May 12, 2012 13:26
Show Gist options
  • Select an option

  • Save majiang/2666505 to your computer and use it in GitHub Desktop.

Select an option

Save majiang/2666505 to your computer and use it in GitHub Desktop.
hungarian method O(n^4)
def hungarian(cost):
'''solves assignment method on a bipartite graph.
description: http://en.wikipedia.org/wiki/Hungarian_algorithm#The_algorithm_in_terms_of_bipartite_graphs
S: n worker vertices
T: n job vertices
input:
cost[i][j] = cost of edges from S[i] to T[j]
output:
cost.tropdet
'''
n = len(cost)
orientation = [[1 for j in range(n)] for i in range(n)] # orientation[i][j] = {1: S[i] -> T[j]; -1: S[i] <- T[j]}
potential_S = [0 for i in range(n)] # potential on S
potential_T = [0 for j in range(n)] # potential on T
while True:
uncovered_S = [True for i in range(n)]
uncovered_T = [True for j in range(n)]
for i in range(n):
for j in range(n):
if orientation[i][j] == -1:
uncovered_S[i] = False
uncovered_T[j] = False
if not(any(uncovered_S) or any(uncovered_T)):
break
istight = [[potential_S[i] + potential_T[j] == cost[i][j] for j in range(n)] for i in range(n)]
Q = [(0, i) for i in range(n) if uncovered_S[i]]
reachable_S = [ui for ui in uncovered_S]
reachable_T = [False for j in range(n)]
while Q:
t = Q.pop(0)
if t[0] == 0: # reached S[i]
i = t[1]
for j in range(n): # look for edges S -> T which is tight
if orientation[i][j] == 1 and istight[i][j]:
if not reachable_T[j]:
reachable_T[j] = (0, i)
Q.append((1, j))
else: # reached T[j]
j = t[1]
for i in range(n): # look for edges S <- T which is tight
if orientation[i][j] == -1 and istight[i][j]:
if not reachable_S[i]:
reachable_S[i] = (1, j)
Q.append((0, i))
# breadth-first search ends.
for j in range(n):
if uncovered_T[j] and reachable_T[j]:
# back track
dummy, i = reachable_T[j]
edges = [(i, j)]
while not uncovered_S[i]:
dummy, j = reachable_S[i]
edges.append((i, j))
dummy, i = reachable_T[j]
edges.append((i, j))
for (i, j) in edges:
orientation[i][j] *= -1
break
else:
delta = min(
cost[i][j] - potential_S[i] - potential_T[j]
for j in range(n) for i in range(n)
if reachable_S[i] and not reachable_T[j]
)
if delta <= 0:
raise ValueError('delta <= 0')
for i in range(n):
if reachable_S[i]:
potential_S[i] += delta
for j in range(n):
if reachable_T[j]:
potential_T[j] -= delta
total_cost = 0
for i in range(n):
for j in range(n):
if orientation[i][j] == -1:
total_cost += cost[i][j]
#print 'worker %d does job %d' % (i, j)
return total_cost
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment