Last active
April 1, 2019 20:41
-
-
Save larsoner/17be3f92ae8cc7ed08bdf90f98ca60f0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
import numpy as np | |
from numpy.testing import assert_allclose | |
from scipy import linalg | |
from scipy.stats import ortho_group | |
def _multi_corr(x, y, rescale=True): | |
"""Compute correlations between terms in a rotation-invariant way.""" | |
assert x.ndim == 2 | |
assert x.shape == y.shape | |
# Remove mean | |
x = x - x.mean(axis=-1, keepdims=True) | |
y = y - y.mean(axis=-1, keepdims=True) | |
# How much they jointly explain | |
_, S_x, V_x = linalg.svd(x, full_matrices=False) | |
_, S_y, V_y = linalg.svd(y, full_matrices=False) | |
# Make x and y each have unit explained variance | |
for sing, V in ((S_x, V_x), (S_y, V_y)): | |
# Discard component directions that are tiny (consider as | |
# numerical noise, e.g. if using a sphere head model) | |
if sing.any(): | |
if rescale: | |
sing[:] = np.where(sing > sing[0] * 1e-5, 1., 0) | |
sing /= np.sqrt((sing * sing).sum()) | |
V *= sing[:, np.newaxis] | |
R = linalg.svdvals(np.dot(V_x, V_y.T)).sum() | |
return np.minimum(R, 1.) # rounding errors | |
n_times = 1000 | |
rng = np.random.RandomState(0) | |
X, Y, Z = np.eye(3) | |
np.random.seed(0) | |
# Make two random matrices that reorient activation | |
ori_rand = np.concatenate(ortho_group.rvs(3, 2, rng), axis=0) | |
# And allow random strengths for each source | |
strength_rand = rng.randn(6) | |
# Ensure rand oris have some, but not too much, of each direction | |
for ori_ in ori_rand: | |
for card in (X, Y, Z): | |
angle = np.rad2deg(np.arccos(np.abs(np.dot(ori_, card)))) | |
assert 10 < angle < 88 | |
def _orthogonalize(a, b): | |
a -= np.dot(a, b) / np.dot(b, b) * b | |
# Generate three signals, and three other signals orthogonal to those | |
signals, signals_orth = rng.randn(2, 3, n_times) | |
for this_s in (signals, signals_orth): | |
for si, s in enumerate(this_s): | |
for other in this_s[si + 1:]: | |
_orthogonalize(s, other) | |
assert_allclose(np.dot(s, other), 0., atol=1e-6) | |
for s_o in signals_orth: | |
# Orthogonalize s_o relative to each member of signals | |
for s in signals: | |
_orthogonalize(s_o, s) | |
# Ensure that _multi_corr does the same thing as corrcoef for 1D data | |
corrs = np.corrcoef(s_o, signals)[0, 1:] | |
assert_allclose(corrs, 0., atol=2e-3) # s_0 ⟂ signals | |
corrs_orth = np.corrcoef(s_o, signals_orth)[0, 1:] | |
for other, these_corrs in ((signals, corrs), | |
(signals_orth, corrs_orth)): | |
for si, s in enumerate(other): | |
corr = _multi_corr(s[np.newaxis], s_o[np.newaxis]) | |
assert_allclose(corr, np.abs(these_corrs[si]), atol=1e-7) | |
# Ensure we have rotation invariance of our measure | |
atol = 2e-3 | |
for a, b, corr in ((signals, signals, 1.), | |
(signals_orth, signals_orth, 1.), | |
(signals, signals_orth, 0.)): | |
# Create signals that are scaled and rotated versions of originals | |
for rescale in (True, False): | |
x = np.dot(ori_rand[:3].T, strength_rand[0] * a) | |
y = np.dot(ori_rand[3:].T, strength_rand[1] * b) | |
mc = _multi_corr(x, y, rescale) | |
assert_allclose(mc, corr, atol=atol) | |
x = np.dot(np.array([X, Y, Z]).T, a) | |
y = np.dot(np.array([Y, Z, -X]).T, b) | |
mc = _multi_corr(x, y, rescale) | |
assert_allclose(mc, corr, atol=atol) | |
x = np.dot(ori_rand[:3].T, strength_rand[:3, np.newaxis] * a) | |
y = np.dot(ori_rand[3:].T, strength_rand[3:, np.newaxis] * b) | |
mc = _multi_corr(x, y, rescale) | |
if corr == 1. and not rescale: | |
# If we apply different scalings to each source and then rotate, | |
# the correlation coefficient should go down if we don't normalize | |
# the components | |
assert 0.1 < mc < 0.9 | |
else: | |
assert_allclose(mc, corr, atol=atol) | |
x = np.dot(np.array([X, Y, Z]).T, a) | |
y = np.dot(np.array([Y, Z, -X]).T, b) | |
# Since these do not mix the signals at all, "rescale" does not matter | |
mc = _multi_corr(x, y, rescale) | |
assert_allclose(mc, corr, atol=atol) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment