Created
November 12, 2019 21:54
-
-
Save olooney/036f7cade5fa40e8cce14668b3be27a4 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from scipy.optimize import linear_sum_assignment | |
import numpy as np | |
def maximize_trace(X): | |
""" | |
Maximize the trace of a square matrix using only row and | |
column permutations. Furthermore, sort the trace | |
in desending order so that largest value ends | |
up the upper left and the smallest in the lower right. | |
One practical use is with non-deterministic clustering | |
algorithms like k-means or GMM; this can be used | |
to bring the two cluster assignments into maximal | |
aggreement for ease of comparison. | |
Returns: | |
(X_result, row_permutation, column_permutation) | |
""" | |
# Use the Hungarian algorithm to bring the largest | |
# values to the diagonal. | |
row_order, column_order = linear_sum_assignment(-X) | |
X_optimal = X_random[row_order, :][:, column_order] | |
# Sort the diagonal. | |
diag = np.diag(X_optimal) | |
diag_order = np.argsort(diag)[::-1] | |
X_result = X_optimal[diag_order, :][:, diag_order] | |
# also return the permutations used | |
row_permutation = np.arange(X.shape[0])[row_order][diag_order] | |
column_permutation = np.arange(X.shape[1])[column_order][diag_order] | |
return X_result, row_permutation, column_permutation | |
def test_maximize_trace(): | |
X = np.array([[20, 1, 1, 1], | |
[13, 15, 3, 1], | |
[ 2, 7, 12, 4], | |
[ 2, 2, 5, 10]]) | |
N = 100 | |
n_correct = 0 | |
for i in range(N): | |
X_random = X[np.random.permutation(range(X.shape[0])), :][:, np.random.permutation(range(X.shape[1]))] | |
X_result, rows, columns = maximize_trace(X_random) | |
correct = (np.all(X == X_result)) | |
if correct: | |
n_correct += 1 | |
else: | |
print("input:", X_random) | |
print("incorrect result:", X_result) | |
print(f"Final Score: {n_correct}/{N}") | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment