Last active
March 24, 2021 15:08
-
-
Save kingjr/59f6795e9fd9019ff3842a8af9218d79 to your computer and use it in GitHub Desktop.
2v2
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from sklearn.metrics import pairwise_distances | |
from numpy.random import permutation | |
from time import time | |
def cc_2v2(true, pred, metric="cosine"): | |
assert len(true) == len(pred) | |
ns = len(true) | |
first = permutation(ns) # first group of TR | |
second = permutation(ns) # second group of TR | |
while (first == second).any(): # check that distinct TRs in pairs | |
first[first == second] = np.random.choice((first == second).sum()) | |
correct = 0. | |
for i, j in zip(first, second): | |
# compute the 4 distances | |
r = pairwise_distances(true[[i, j]], pred[[i, j]], metric) | |
diag = np.diag(r).sum() # distances of corresponding TR | |
cross = r.sum() - diag # distance of cross TR | |
correct += 1. * (diag < cross) # comparison | |
acc = correct / ns | |
return acc | |
def jr_2v2(true, pred, metric="cosine"): | |
"""Tsonova et al 2019 https://arxiv.org/pdf/2009.08424.pdf""" | |
assert len(true) == len(pred) | |
ns = len(true) | |
first = permutation(ns) # first group of TR | |
second = permutation(ns) # second group of TR | |
while (first == second).any(): # check that distinct TRs in pairs | |
first[first == second] = np.random.choice((first == second).sum()) | |
r = pairwise_distances(true, pred) | |
s1 = r[first, first] + r[second, second] | |
s2 = r[first, second] + r[second, first] | |
acc = np.mean(1.*(s1<s2)) | |
return acc | |
n_samples, n_voxels = 1000, 20 | |
true = np.random.randn(n_samples, n_voxels) | |
snr = .1 | |
pred = snr*true + np.random.randn(n_samples, n_voxels) | |
start = time() | |
cc = cc_2v2(true, pred) | |
print('cc acc=', cc, 'time', time()-start) | |
start = time() | |
jr = jr_2v2(true, pred) | |
print('jr acc=', jr, 'time', time()-start) | |
assert jr==cc |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment