Created
June 16, 2015 18:35
-
-
Save lebedov/9fa8b5a02a0e764cd40c to your computer and use it in GitHub Desktop.
Find permutation of matrix that maximizes its trace using the Munkres algorithm.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
""" | |
Find permutation of matrix that maximizes its trace using the Munkres algorithm. | |
Reference | |
--------- | |
https://stat.ethz.ch/pipermail/r-help/2010-April/236664.html | |
""" | |
import itertools | |
import sys | |
import munkres | |
import numpy as np | |
def permute_cols(a, inds): | |
""" | |
Permutes the columns of matrix `a` given | |
a list of tuples `inds` whose elements `(from, to)` describe how columns | |
should be permuted. | |
""" | |
p = np.zeros_like(a) | |
for i in inds: | |
p[i] = 1 | |
return np.dot(a, p) | |
def maximize_trace(a): | |
""" | |
Maximize trace by minimizing the Frobenius norm of | |
`np.dot(p, a)-np.eye(a.shape[0])`, where `a` is square and | |
`p` is a permutation matrix. Returns permuted version of `a` with | |
maximal trace. | |
""" | |
assert a.shape[0] == a.shape[1] | |
d = np.zeros_like(a) | |
n = a.shape[0] | |
b = np.eye(n, dtype=int) | |
for i, j in itertools.product(xrange(n), xrange(n)): | |
d[j, i] = sum((b[j, :]-a[i, :])**2) | |
m = munkres.Munkres() | |
inds = m.compute(d) | |
return permute_cols(a, inds) | |
if __name__ == '__main__': | |
n = 6 | |
a = np.random.randint(0, 10, n**2).reshape(n, n) | |
print 'original: ' | |
print a | |
print 'trace: %d' % a.trace() | |
ap = maximize_trace(a) | |
print 'permuted: ' | |
print ap | |
print 'trace: %d' % ap.trace() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment