Last active
May 28, 2021 14:46
-
-
Save drammock/d39b693213ddf51b82e1c18cc5d60d7f to your computer and use it in GitHub Desktop.
Testing to see if we recover the expected covariance from simulated noise signals
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from collections import namedtuple | |
import numpy as np | |
from scipy.stats import norm, zscore | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
import seaborn as sns | |
import mne | |
n_jobs = 4 | |
verbose = False | |
rng = np.random.default_rng(seed=15485863) # the one millionth prime | |
# plt.ion() | |
def make_noise(raw, noise_cov, seed=None, n_jobs=1, verbose=None): | |
# instantiate random number generator | |
rng = np.random.default_rng(seed=seed) | |
# compute colorer | |
whitener, ch_names, rank, colorer = mne.cov.compute_whitener( | |
noise_cov, info=raw.info, picks=None, rank=None, scalings=None, | |
return_rank=True, pca=False, return_colorer=True, verbose=verbose) | |
# make appropriately-colored noise | |
white_noise = rng.normal(size=(raw.info['nchan'], raw.n_times)) | |
colored_noise = colorer @ white_noise | |
colored_raw = mne.io.RawArray(colored_noise, raw.info, raw.first_samp, | |
verbose=verbose) | |
# make sure it worked | |
sim_cov = mne.cov.compute_raw_covariance(colored_raw, n_jobs=n_jobs) | |
assert np.corrcoef( | |
sim_cov.data.ravel(), noise_cov.data.ravel())[0, 1] > 0.999 | |
np.testing.assert_allclose(np.linalg.norm(sim_cov.data), | |
np.linalg.norm(noise_cov.data), | |
rtol=1e-2, atol=0.) | |
return colored_raw | |
def get_inverse(raw, noise_cov, subject, subjects_dir, n_jobs=1, | |
verbose=None): | |
src = mne.setup_source_space( | |
subject, spacing='oct6', add_dist='patch', subjects_dir=subjects_dir, | |
verbose=verbose) | |
model = mne.make_bem_model( | |
subject='sample', ico=4, subjects_dir=subjects_dir, verbose=verbose) | |
bem = mne.make_bem_solution(model, verbose=verbose) | |
trans = os.path.join( | |
sample_data_folder, 'MEG', 'sample', 'sample_audvis_raw-trans.fif') | |
fwd = mne.make_forward_solution( | |
raw.info, trans=trans, src=src, bem=bem, meg=True, eeg=True, mindist=0, | |
n_jobs=n_jobs, verbose=verbose) | |
# make inverse and apply | |
inverse_operator = mne.minimum_norm.make_inverse_operator( | |
raw.info, fwd, noise_cov, loose=0.2, depth=0.8, verbose=verbose) | |
return inverse_operator | |
def ptfce_null(X, adjacency, | |
# tail, exclude, out_type, check_disjoint, buffer_size, # TODO | |
): | |
X = zscore(X) | |
thresholds = compute_thresholds(X) | |
# prepare our data container | |
entries = ['threshold', 'n_suprathresh_sources', 'n_clusters', 'clusters', | |
'cluster_sizes', 'mean_cluster_size', # 'cluster_size_pdf', | |
'cluster_size_pvals', 'prob_suprathresh'] | |
Thresh = namedtuple('Thresh', entries) | |
# compute empirical distributions | |
results = list() | |
for ix, thresh in enumerate(thresholds): | |
write_log = not (ix % 10) | |
suprathresh = (X >= thresh).astype(int) | |
n_suprathresh_sources = suprathresh.sum() | |
# quit when we're above the highest data point | |
if n_suprathresh_sources == 0: | |
break | |
# find clusters | |
if write_log: | |
ms = f'clustering threshold {ix:>2}: {round(thresh, 3):< 6.3f}... ' | |
print(ms, end='') | |
clusters = mne.stats.cluster_level._get_components(suprathresh, | |
adjacency) | |
sizes = np.array([len(clust) for clust in clusters]) | |
if write_log: | |
print(f'{len(clusters):>6} clusters; largest: {max(sizes):>6}') | |
expected_mean_cluster_size = n_suprathresh_sources / len(clusters) | |
# compute p-values of attested cluster sizes at this threshold | |
cluster_size_pvals = cluster_pval(sizes, expected_mean_cluster_size) | |
# store the various numbers we care about | |
this_thresh = Thresh( | |
thresh, n_suprathresh_sources, len(clusters), clusters, sizes, | |
expected_mean_cluster_size, cluster_size_pvals, | |
prob_suprathresh=n_suprathresh_sources / X.size) | |
results.append(this_thresh) | |
return results | |
def _get_lambda_h(expected_mean_cluster_size, ndim): | |
"""Rate param. of exponential distr. of clust. sizes at given threshold.""" | |
from scipy.special import gamma | |
return (expected_mean_cluster_size / gamma(1 + ndim / 2)) ** (-2 / ndim) | |
def cluster_pval(observed_cluster_size, expected_mean_cluster_size, ndim=2): | |
"""Compute p-value of observed cluster size (Spisák eq. 14).""" | |
lambda_h = _get_lambda_h(expected_mean_cluster_size, ndim) | |
_x = observed_cluster_size ** (2/3) | |
return _x * np.exp(-1 * lambda_h * _x) | |
def compute_thresholds(data, n_thresh=100): | |
"""Compute pTFCE thresholds.""" | |
min_logp_thresh = 0. | |
max_logp_thresh = norm.logsf(data.max()) | |
logp_thresholds = np.linspace(min_logp_thresh, max_logp_thresh, n_thresh) | |
return norm.isf(np.exp(logp_thresholds)) | |
def _cluster_size_pdf(data, expected_mean_cluster_size, ndim=2): | |
"""Compute distribution of cluster sizes (follows an exponential dist.).""" | |
# TODO: not sure if we need this function? | |
lambda_h = _get_lambda_h(expected_mean_cluster_size, ndim) | |
cluster_sizes = np.arange(data.size) + 1 | |
return (2 * lambda_h * np.exp(-1 * lambda_h * cluster_sizes) | |
/ 3 * cluster_sizes ** (1/3)) | |
# load sample data | |
print('Loading sample data') | |
sample_data_folder = mne.datasets.sample.data_path() | |
sample_data_raw_file = os.path.join(sample_data_folder, 'MEG', 'sample', | |
'sample_audvis_raw.fif') | |
raw = mne.io.read_raw_fif(sample_data_raw_file, preload=True, verbose=verbose) | |
raw.pick(['grad']) # ditch stim, ECG, EOG | |
raw.drop_channels(raw.info['bads']) | |
raw.apply_proj() | |
# get inverse operator | |
print('Creating inverse operator') | |
snr = 3. | |
lambda2 = 1. / snr ** 2 | |
subject = 'sample' | |
subjects_dir = os.path.join(sample_data_folder, 'subjects') | |
noise_cov = mne.cov.compute_raw_covariance(raw, n_jobs=n_jobs) | |
inverse_operator = get_inverse( | |
raw, noise_cov, subject, subjects_dir, n_jobs=n_jobs, verbose=verbose) | |
src_adjacency = mne.spatial_src_adjacency(inverse_operator['src']) | |
adjacency = None | |
all_results = pd.DataFrame() | |
os.makedirs('figures', exist_ok=True) | |
os.makedirs('simulation-data', exist_ok=True) | |
def do_pairplot(dataframe, vars_to_keep, fname, ci=None, alpha=1): | |
names = iter(vars_to_keep) | |
def diagfunc(x, **kwargs): | |
"""Write variable names on the diagonal axes of a PairGrid.""" | |
ax = plt.gca() | |
ax.annotate(next(names), xy=(0.5, 0.5), xycoords='axes fraction', | |
ha='center', va='center', size='large') | |
ax.set_axis_off() # this line doesn't work | |
g = sns.PairGrid(dataframe, vars=vars_to_keep, diag_sharey=False, | |
palette='muted') | |
g.map_offdiag(sns.lineplot, ci=ci, alpha=alpha, sort=False) | |
g.map_diag(diagfunc) | |
g.fig.set_size_inches(12, 12) | |
g.fig.subplots_adjust(left=0.08) | |
g.fig.savefig(fname) | |
# loop to test convergence of the null | |
for ix in range(100): | |
print(f'iteration {ix}') | |
# make cov-matched noise data, then crop for clustering speed | |
colored_raw = make_noise(raw, noise_cov, seed=rng, n_jobs=n_jobs, | |
verbose=verbose) | |
colored_raw.crop(tmax=0.1) | |
# make STC from noise raw | |
stc = mne.minimum_norm.apply_inverse_raw( | |
colored_raw, inverse_operator, lambda2, method='dSPM', pick_ori=None, | |
verbose=verbose) | |
# expand adjacency to temporal dimension | |
if adjacency is None: | |
adjacency = mne.stats.combine_adjacency(len(stc.times), src_adjacency) | |
# construct data array | |
X = stc.data.reshape(-1) | |
# compute null distributions | |
results = ptfce_null(X, adjacency) | |
# make dataframe | |
vars_to_keep = ['threshold', 'n_suprathresh_sources', 'n_clusters', | |
'mean_cluster_size', 'prob_suprathresh'] | |
df = pd.DataFrame(results)[vars_to_keep] | |
df['iteration'] = ix | |
all_results = pd.concat([all_results, df]) | |
# plot | |
fname = f'figures/iter_{ix:03}.png' | |
do_pairplot(df, vars_to_keep, fname) | |
# save simulations | |
all_results.to_csv('simulation-data/simulations.csv') | |
fname = 'figures/summary.png' | |
do_pairplot(all_results, vars_to_keep, fname, ci='sd', alpha=0.3) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment