Skip to content

Instantly share code, notes, and snippets.

@papaeye
Created September 22, 2014 18:41
Show Gist options
  • Save papaeye/30dc1eba80b422230cb4 to your computer and use it in GitHub Desktop.
Save papaeye/30dc1eba80b422230cb4 to your computer and use it in GitHub Desktop.
from itertools import repeat
class KuhnMunkres(object):
def __init__(self, weights):
self.weights = weights
self.rows = rows = len(weights)
self.cols = cols = len(weights[0])
self.X = range(rows)
self.Y = range(cols)
self.lx = [max(row) for row in weights]
self.ly = [0] * cols
self.matchx = [None] * rows
self.matchy = [None] * cols
self.S = []
self.T = []
self.slacky = []
def solve(self):
for root in self.X:
self.prepare(root)
self.match()
return self.matchx
def prepare(self, root):
self.S[:] = repeat(False, self.rows)
self.S[root] = True
self.T[:] = repeat(None, self.cols)
self.slacky[:] = [[self.slack(root, y), root] for y in self.Y]
def slack(self, x, y):
return self.lx[x] + self.ly[y] - self.weights[x][y]
def match(self):
while True:
(delta, x), y = min((self.slacky[y], y)
for y, parent in enumerate(self.T)
if parent is None)
if delta > 0:
self.update_labels(delta)
self.T[y] = x
if self.matchy[y] is None:
self.augment_match(y)
break
x = self.matchy[y]
self.S[x] = True
for y, v in enumerate(self.T):
if v is None:
s = self.slack(x, y)
if s < self.slacky[y][0]:
self.slacky[y][:] = s, x
def update_labels(self, delta):
for x, v in enumerate(self.S):
if v is True:
self.lx[x] -= delta
for y, v in enumerate(self.T):
if v is not None:
self.ly[y] += delta
else:
self.slacky[y][0] -= delta
def augment_match(self, y):
while y is not None:
x = self.T[y]
z = self.matchx[x]
self.matchx[x] = y
self.matchy[y] = x
y = z
def kuhn_munkres(weights):
if len(weights) <= len(weights[0]):
km = KuhnMunkres(weights)
return enumerate(km.solve())
else:
weights = list(zip(*weights))
km = KuhnMunkres(weights)
return map(reversed, enumerate(km.solve()))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment