Skip to content

Instantly share code, notes, and snippets.

@papaeye
Created September 22, 2014 18:36
Show Gist options
  • Save papaeye/cf029a358ca460be8922 to your computer and use it in GitHub Desktop.
Save papaeye/cf029a358ca460be8922 to your computer and use it in GitHub Desktop.
from functools import partial
from itertools import product
STARRED = 1
PRIMED = 2
def kuhn_munkres(costs):
C, rows, cols, rotated = _step0(costs)
M = [[None] * (cols + 1) for _ in range(rows + 1)]
step = _step1
while step is not None:
step = step(C, M, rows, cols)
indices = ((i, j) for i, j in _iter_indices(rows, cols)
if M[i][j] == STARRED)
if rotated:
return map(reversed, indices)
return indices
def _step0(costs):
rows, cols = len(costs), len(costs[0])
if rows <= cols:
C = [row[:] for row in costs]
rotated = False
else:
C = [list(x) for x in zip(*costs)]
rows, cols = cols, rows
rotated = True
return C, rows, cols, rotated
def _step1(C, M, rows, cols):
for row in C:
min_v = min(row)
for j in range(cols):
row[j] -= min_v
return _step2
def _step2(C, M, rows, cols):
for i, j in _iter_indices(rows, cols):
if C[i][j] == 0 and not M[i][cols] and not M[rows][j]:
M[i][j] = STARRED
M[i][cols] = True
M[rows][j] = True
for i in range(rows):
M[i][cols] = False
for j in range(cols):
M[rows][j] = False
return _step3
def _step3(C, M, rows, cols):
for i, j in _iter_indices(rows, cols):
if M[i][j] == STARRED:
M[rows][j] = True
if M[rows].count(True) < rows:
return _step4
def _step4(C, M, rows, cols):
for i, j in _iter_indices(rows, cols):
if C[i][j] or M[i][cols] or M[rows][j]:
continue
M[i][j] = PRIMED
k = _find_in_row(M, i, STARRED)
if k is None:
return partial(_step5, i=i, j=j)
M[i][cols] = True
M[rows][k] = False
return _step6
def _step5(C, M, rows, cols, i, j):
path = [(i, j)]
while True:
i = _find_in_column(M, j, STARRED)
if i is None:
break
path.append((i, j))
j = _find_in_row(M, i, PRIMED)
path.append((i, j))
# convert path
for i, j in path:
if M[i][j] == STARRED:
M[i][j] = None
else:
M[i][j] = STARRED
# clear covers
for i in range(rows):
M[i][cols] = False
for j in range(cols):
M[rows][j] = False
# erase primes
for i, j in _iter_indices(rows, cols):
if M[i][j] == PRIMED:
M[i][j] = None
return _step3
def _step6(C, M, rows, cols):
min_cost = min(C[i][j] for i, j in _iter_indices(rows, cols)
if not M[i][cols] and not M[rows][j])
for i in range(rows):
if M[i][cols]:
for j in range(cols):
C[i][j] += min_cost
for j in range(cols):
if not M[rows][j]:
for i in range(rows):
C[i][j] -= min_cost
return _step4
def _iter_indices(rows, cols):
return product(range(rows), range(cols))
def _find_in_column(matrix, j, value):
for i, row in enumerate(matrix):
if row[j] == value:
return i
def _find_in_row(matrix, i, value):
for j, v in enumerate(matrix[i]):
if v == value:
return j
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment