Created
February 15, 2018 14:13
-
-
Save dokato/7a997b2a94a0ec6384a5fd0e91e45f8b to your computer and use it in GitHub Desktop.
Find closest orthogonal matrix
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
def find_closest_orthogonal_matrix(A): | |
''' | |
Find closest orthogonal matrix to *A* using iterative method. | |
Bases on the code from REMOVE_SOURCE_LEAKAGE function from OSL Matlab package. | |
Args: | |
A (numpy.array): array shaped k, n, where k is number of channels, n - data points | |
Returns: | |
L (numpy.array): orthogonalized matrix with amplitudes preserved | |
Reading: | |
Colclough GL et al., A symmetric multivariate leakage correction for MEG connectomes., | |
Neuroimage. 2015 Aug 15;117:439-48. doi: 10.1016/j.neuroimage.2015.03.071 | |
''' | |
# | |
MAX_ITER = 2000 | |
TOLERANCE = np.max((1, np.max(A.shape) * np.linalg.svd(A.T, False, False)[0])) * np.finfo(A.dtype).eps# TODO | |
reldiff = lambda a,b: 2*abs(a-b) / (abs(a)+abs(b)) | |
convergence = lambda rho, prev_rho: reldiff(rho, prev_rho) <= TOLERANCE | |
A_b = A.conj() | |
d = np.sqrt(np.sum(A*A_b,axis=1)) | |
rhos = np.zeros(MAX_ITER) | |
for i in range(MAX_ITER): | |
scA = A.T * d | |
u, s, vh = np.linalg.svd(scA, False) | |
V = np.dot(u, vh) | |
# TODO check is rank is full | |
d = np.sum(A_b*V.T, axis=1) | |
L = (V * d).T | |
E = A-L | |
rhos[i] = np.sqrt(np.sum(E*E.conj())) | |
if i > 0 and convergence(rhos[i], rhos[i - 1]): | |
break | |
return L | |
if __name__ == '__main__': | |
# data simulation | |
a = np.random.randn(3,100) | |
a[1,:] += a[0,:] | |
a[2,:] += 0.2*a[1,:] | |
a[2,:] += np.sin(np.linspace(0,2*np.pi,100)*2*np.pi*5) | |
# run | |
print(np.corrcoef(a)) | |
L = find_closest_orthogonal_matrix(a) | |
print(np.corrcoef(L)) | |
# if you want to see the changes | |
# import matplotlib.pyplot as plt | |
# plt.plot(a.T);plt.figure();plt.plot(L.T);plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Your function seems to result in pretty similar corrections as mine. A quick np.allclose returns True.