Last active
June 28, 2021 15:03
-
-
Save drammock/65a9225c9b201770b81578fe0f02771a to your computer and use it in GitHub Desktop.
WIP pTFCE
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from functools import partial | |
from time import perf_counter | |
from contextlib import contextmanager | |
import numpy as np | |
from scipy.integrate import trapezoid # quad | |
# from scipy.interpolate import interp1d | |
from scipy.stats import norm, gaussian_kde | |
import matplotlib.pyplot as plt | |
import mne | |
import warnings | |
warnings.filterwarnings('error', 'Creating an ndarray from ragged nested') | |
warnings.filterwarnings('error', 'invalid value encountered in') | |
n_jobs = 4 | |
n_iter = 5 | |
verbose = False | |
save_diagnostic_plots = True | |
rng = np.random.default_rng(seed=15485863) # the one millionth prime | |
sample_data_folder = mne.datasets.sample.data_path() | |
@contextmanager | |
def timer(description: str) -> None: | |
if description: | |
print(description) | |
start = perf_counter() | |
yield | |
elapsed_time = perf_counter() - start | |
space = ' ' if description else '' | |
print(f'elapsed time{space}{description}: {elapsed_time:.4f} sec.') | |
def get_sensor_data(): | |
"""Load or compute Evoked.""" | |
print('Loading sample data') | |
sample_data_raw_file = os.path.join( | |
sample_data_folder, 'MEG', 'sample', 'sample_audvis_raw.fif') | |
raw = mne.io.read_raw_fif(sample_data_raw_file, verbose=verbose) | |
events = mne.find_events(raw, stim_channel='STI 014') | |
raw.pick(['grad']) # ditch stim, ECG, EOG | |
raw.drop_channels(raw.info['bads']) | |
event_dict = {'auditory/left': 1} | |
raw, events = raw.resample(sfreq=100, events=events, n_jobs=n_jobs) | |
evk_fname = 'ptfce-ave.fif' | |
cov_fname = 'ptfce-cov.fif' | |
try: | |
evoked = mne.read_evokeds(evk_fname)[0] | |
noise_cov = mne.read_cov(cov_fname) | |
except FileNotFoundError: | |
noise_cov = mne.cov.compute_raw_covariance(raw, n_jobs=n_jobs) | |
epochs = mne.Epochs(raw, events, event_id=event_dict, preload=False, | |
proj=True) | |
evoked = epochs.average() | |
mne.write_evokeds(evk_fname, evoked) | |
mne.write_cov(cov_fname, noise_cov) | |
return raw, evoked, noise_cov | |
def get_inverse(raw, noise_cov, subject, subjects_dir, n_jobs=1, | |
verbose=None): | |
"""Load or compute the inverse operator.""" | |
fname = 'ptfce-inv.fif' | |
try: | |
inverse = mne.minimum_norm.read_inverse_operator(fname) | |
print('Loaded inverse from disk.') | |
except FileNotFoundError: | |
src = mne.setup_source_space( | |
subject, spacing='oct6', add_dist='patch', | |
subjects_dir=subjects_dir, verbose=verbose) | |
model = mne.make_bem_model( | |
subject='sample', ico=4, subjects_dir=subjects_dir, | |
verbose=verbose) | |
bem = mne.make_bem_solution(model, verbose=verbose) | |
trans = os.path.join( | |
sample_data_folder, 'MEG', 'sample', 'sample_audvis_raw-trans.fif') | |
fwd = mne.make_forward_solution( | |
raw.info, trans=trans, src=src, bem=bem, meg=True, eeg=True, | |
mindist=0, n_jobs=n_jobs, verbose=verbose) | |
inverse = mne.minimum_norm.make_inverse_operator( | |
raw.info, fwd, noise_cov, loose=0.2, depth=0.8, verbose=verbose) | |
mne.minimum_norm.write_inverse_operator( | |
fname, inverse, verbose=verbose) | |
return inverse | |
def make_noise(raw, noise_cov, seed=None, n_jobs=1, verbose=None): | |
# instantiate random number generator | |
rng = np.random.default_rng(seed=seed) | |
# compute colorer | |
whitener, ch_names, rank, colorer = mne.cov.compute_whitener( | |
noise_cov, info=raw.info, picks=None, rank=None, scalings=None, | |
return_rank=True, pca=False, return_colorer=True, verbose=verbose) | |
# make appropriately-colored noise | |
white_noise = rng.normal(size=(raw.info['nchan'], raw.n_times)) | |
colored_noise = colorer @ white_noise | |
colored_raw = mne.io.RawArray(colored_noise, raw.info, raw.first_samp, | |
verbose=verbose) | |
# make sure it worked | |
sim_cov = mne.cov.compute_raw_covariance(colored_raw, n_jobs=n_jobs) | |
assert np.corrcoef( | |
sim_cov.data.ravel(), noise_cov.data.ravel())[0, 1] > 0.999 | |
np.testing.assert_allclose(np.linalg.norm(sim_cov.data), | |
np.linalg.norm(noise_cov.data), | |
rtol=1e-2, atol=0.) | |
return colored_raw | |
def calc_thresholds(data, n_thresh=100): | |
"""Compute pTFCE thresholds.""" | |
min_logp_thresh = 0. | |
max_logp_thresh = -1 * norm.logsf(data.max()) | |
logp_thresholds = np.linspace(min_logp_thresh, max_logp_thresh, n_thresh) | |
delta_logp_thresh = np.diff(logp_thresholds[:2])[0] | |
all_thresholds = norm.isf(np.exp(-1 * logp_thresholds)) | |
# # avoid NaNs | |
# all_thresholds[all_thresholds == -np.inf] = np.finfo( | |
# all_thresholds.dtype).min / 100 | |
return all_thresholds, delta_logp_thresh | |
def _find_clusters(data, threshold, adjacency): | |
"""Find indices of vertices that form clusters at the given threshold.""" | |
suprathresh = (data > threshold) | |
# XXX THIS IS THE TIE-IN TO EXISTING MNE-PYTHON CLUSTERING CODE XXX | |
# XXX ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓ XXX | |
clusters = mne.stats.cluster_level._get_components(suprathresh, adjacency) | |
return clusters # list of arrays of vertex numbers | |
def _get_cluster_sizes(clusters): | |
"""Get sizes of clusters (helper function).""" | |
return np.array([len(clust) for clust in clusters], dtype=int) | |
def _calc_thresholded_source_prior(threshold, noise): | |
"""Find empirically the probability of a source being suprathreshold. | |
Vectorized over thresholds. | |
Equivalent in pTFCE source code is dvox() (density of the normal | |
distribution, since they deal only with zscored images). | |
""" | |
noise = np.atleast_2d(noise.ravel()) # (1, noise.size) | |
thresh = np.atleast_2d(threshold).T # (thresh.size, 1) | |
suprathresh = (noise > thresh) # (thresh.size, noise.size) | |
n_suprathresh_src = suprathresh.sum(axis=-1) # (thresh.size,) | |
assert n_suprathresh_src.shape[0] == thresh.size | |
return n_suprathresh_src / noise.size | |
def _cluster_size_density_factory(sizes): | |
"""Find empirically the distribution (density func) of cluster sizes.""" | |
unique_sizes = np.unique(sizes) | |
if len(unique_sizes) == 0: | |
return lambda x: np.atleast_1d(np.zeros_like(x, float)) | |
elif len(unique_sizes) == 1: | |
# can't use gaussian_kde (LinAlgError) so make unimodal prob mass func: | |
return lambda x: np.atleast_1d(x == unique_sizes[0]).astype(float) | |
else: | |
return gaussian_kde(sizes) | |
def _suprathresh_density_given_cluster_size(thresholds, all_thresholds, | |
observed_cluster_size, | |
source_prior_density_func, | |
cluster_size_prior_density_func): | |
"""PDF of threshold or activation value, given an observed cluster size. | |
Equivalent in pTFCE source code is dvox.clust() | |
""" | |
numer = (source_prior_density_func(thresholds) * # p(hᵢ) | |
cluster_size_prior_density_func(thresholds, | |
observed_cluster_size)) # p(c|hᵢ) | |
y = (source_prior_density_func(all_thresholds) # p(h) | |
* cluster_size_prior_density_func(all_thresholds, | |
observed_cluster_size)) # p(c|h) | |
denom = trapezoid(x=all_thresholds, y=y) # integral | |
# func = interp1d(all_thresholds, y, kind='linear', bounds_error=False, | |
# fill_value=tuple(y[[0, -1]]), assume_sorted=True) | |
# denom = quad(func, -np.inf, np.inf)[0] | |
return numer / denom | |
def _prob_suprathresh_given_cluster_size(threshold, all_thresholds, | |
observed_cluster_size, | |
source_prior_density_func, | |
cluster_size_prior_density_func): | |
"""pvox.clust()""" | |
thresh_ix = all_thresholds.tolist().index(threshold) | |
x = all_thresholds[thresh_ix:] | |
y = _suprathresh_density_given_cluster_size( | |
x, all_thresholds, observed_cluster_size, | |
source_prior_density_func, cluster_size_prior_density_func) | |
integral = trapezoid(x=x, y=y) | |
# func = interp1d(x, y, kind='linear', bounds_error=False, | |
# fill_value=tuple(y[[0, -1]]), assume_sorted=True) | |
# integral = quad(func, x[0], np.inf)[0] | |
return integral | |
def ptfce(stc, adjacency, info, noise_cov, inverse, inverse_kw=None, | |
n_iter=n_iter, seed=None, n_jobs=1, verbose=None): | |
# arg parsing | |
if inverse_kw is None: | |
inverse_kw = inverse_kwargs | |
# instantiate random number generator | |
rng = np.random.default_rng(seed=seed) | |
# compute pTFCE thresholds | |
data = stc.data.reshape(-1) | |
all_thresholds, delta_logp_thresh = calc_thresholds(data) | |
# compute colorer | |
whitener, ch_names, rank, colorer = mne.cov.compute_whitener( | |
noise_cov, info=info, picks=None, rank=None, scalings=None, | |
return_rank=True, pca=False, return_colorer=True, verbose=verbose) | |
# make appropriately-colored noise | |
white_noise = rng.normal(size=(n_iter, info['nchan'], len(stc.times))) | |
colored_noise = colorer[np.newaxis, ...] @ white_noise | |
epochs = mne.EpochsArray(colored_noise, info, tmin=stc.tmin) | |
# make STCs from noise | |
noise_stcs = mne.minimum_norm.apply_inverse_epochs( | |
epochs, inverse, verbose=verbose, return_generator=False, **inverse_kw) | |
with timer('calculating source activation prior'): | |
all_noise = np.array([_stc.data.reshape(-1) for _stc in noise_stcs]) | |
thresholded_source_prior_density_func = partial( | |
_calc_thresholded_source_prior, noise=all_noise) # p(h) | |
if save_diagnostic_plots: | |
noise_kde = gaussian_kde(all_noise.ravel()) # p(v) | |
source_priors_at_thresh = noise_kde(all_thresholds) | |
msg = 'NaNs in source prior' | |
assert np.all(np.isfinite(source_priors_at_thresh)), msg | |
msg = 'negative density in source prior' | |
assert np.all(source_priors_at_thresh >= 0), msg | |
subtitle = f'({n_iter} noise iterations)' | |
fig, ax = plt.subplots() | |
x = np.linspace(all_noise.min(), all_noise.max(), 100) | |
y = thresholded_source_prior_density_func(threshold=x) | |
ax.plot(x, y) | |
ax.set(title=f'probability of suprathresholdness\n{subtitle}', | |
xlabel='threshold', ylabel='probability') | |
fig.savefig('source-suprathresholdness-probability.png') | |
plt.close(fig) | |
fig, ax = plt.subplots() | |
y = noise_kde(x) | |
fig, ax = plt.subplots() | |
ax.plot(x, y) | |
ax.set(title=f'source activation density\n{subtitle}', | |
xlabel='activation', ylabel='density') | |
fig.savefig('source-activation-density.png') | |
plt.close(fig) | |
del all_noise | |
with timer('finding clusters in noise simulations'): | |
all_clusters = list() | |
for iter_ix, noise_stc in enumerate(noise_stcs): | |
print(f'iteration {iter_ix}, threshold ', end='', flush=True) | |
noise = noise_stc.data.reshape(-1) | |
this_clusters = list() | |
for thresh_ix, threshold in enumerate(all_thresholds): | |
# progress bar | |
if not thresh_ix % 5: | |
print(f'{thresh_ix} ', end='', flush=True) | |
# compute cluster size prior | |
clust = _find_clusters(noise, threshold, adjacency) | |
this_clusters.append(clust) | |
print() | |
all_clusters.append(this_clusters) | |
with timer('calculating cluster size distribution from noise'): | |
# pool obs across epochs & thresholds → total prob of each cluster size | |
sizes = _get_cluster_sizes([clust for _iter in all_clusters | |
for thresh in _iter for clust in thresh]) | |
cluster_size_density_func = _cluster_size_density_factory(sizes) | |
# XXX would fitting a particular distribution be more appropriate here? | |
# (scipy.special.gammaincc is a reasonable candidate...) | |
# Or is there some way to make it a probability _mass_ function and | |
# fill in the gaps of sizes not observed in the noise sims? | |
# (maybe interp1d, and then evaluate at all the integers?) | |
sizes_at_thresh = list() | |
for thresh_ix in range(len(all_thresholds)): | |
# estimate prob. density of cluster size at each threshold: p(c|h) | |
clusts_at_thresh = [_iter[thresh_ix] for _iter in all_clusters] | |
_sizes_at_thresh = _get_cluster_sizes( | |
[clust for _iter in clusts_at_thresh for clust in _iter]) | |
sizes_at_thresh.append(_sizes_at_thresh) | |
def cluster_size_prior_density_func(thresholds, observed_cluster_size): | |
"""PDF of cluster size, given threshold. | |
Equivalent in pTFCE source code is dclust() which is derived from the | |
Euler Characteristic Density of a gaussian field of given dimension. | |
""" | |
this_thresholds = np.array(thresholds) | |
thresh_ixs = np.nonzero(np.in1d(all_thresholds, this_thresholds))[0] | |
# noise_cluster_sizes = [ | |
# sizes_at_thresh[this_ix] for this_ix in thresh_ixs] | |
# return np.array([ | |
# _cluster_size_density_factory(this_sizes)(observed_cluster_size)[0] | |
# for this_sizes in noise_cluster_sizes]) | |
# clearer version of above: | |
densities = list() | |
for thresh_ix in thresh_ixs: | |
noise_cluster_sizes = sizes_at_thresh[thresh_ix] | |
density_func = _cluster_size_density_factory(noise_cluster_sizes) | |
density = density_func(observed_cluster_size)[0] | |
densities.append(density) | |
return np.array(densities) | |
if save_diagnostic_plots: | |
x = np.unique(sizes) | |
y = cluster_size_density_func(x) | |
fig, ax = plt.subplots() | |
ax.semilogx(x, y) | |
ax.set(title=f'cluster size density across all thresholds\n{subtitle}', | |
xlabel='cluster size', ylabel='density') | |
fig.savefig('cluster-size-density.png') | |
plt.close(fig) | |
fig, ax = plt.subplots() | |
ax.plot(all_thresholds, source_priors_at_thresh) | |
ax.set(title='prior probability of source suprathresholdness' | |
f'\n{subtitle}', | |
xlabel='threshold', ylabel='probability') | |
fig.savefig('prior.png') | |
plt.close(fig) | |
# apply to the real data | |
with timer('finding clusters in real data'): | |
print('threshold number: ', end='', flush=True) | |
unaggregated_probs = np.ones( | |
(len(all_thresholds), *data.shape), dtype=float) | |
all_data_clusters_by_thresh = list() | |
all_data_cluster_sizes_by_thresh = list() | |
for thresh_ix, threshold in enumerate(all_thresholds): | |
# progress bar | |
if not thresh_ix % 5: | |
print(f'{thresh_ix} ', end='', flush=True) | |
# find clusters in data STC | |
data_clusters = _find_clusters(data, threshold, adjacency) | |
data_cluster_sizes = _get_cluster_sizes(data_clusters) | |
all_data_clusters_by_thresh.append(data_clusters) | |
all_data_cluster_sizes_by_thresh.append(data_cluster_sizes) | |
uniq_data_cluster_sizes = np.unique(data_cluster_sizes) | |
# compute unaggregated probs. (the call to | |
# _prob_suprathresh_given_cluster_size is slow, so do it only once | |
# for each unique cluster size) | |
uniq_data_cluster_probs = { | |
size: _prob_suprathresh_given_cluster_size( | |
threshold, all_thresholds, size, | |
thresholded_source_prior_density_func, | |
cluster_size_prior_density_func) | |
for size in uniq_data_cluster_sizes} | |
# prepare prob array that will zip with clusters | |
data_cluster_probs = np.array( | |
[uniq_data_cluster_probs[size] for size in data_cluster_sizes]) | |
# assign probs to vertices in thresh-appropriate slice of big array | |
for clust, prob in zip(data_clusters, data_cluster_probs): | |
# make sure we're not overwriting anything | |
assert np.all(unaggregated_probs[thresh_ix][clust] == 1.) | |
unaggregated_probs[thresh_ix][clust] = prob | |
print() | |
with timer('aggregating and adjusting probabilities'): | |
# S(x) = ∑ᵢ -log(P(V ≥ hᵢ|cᵢ)) at voxel position x (equation 10) | |
_neglogp = np.sum(-1 * np.log(unaggregated_probs), axis=0) | |
# (sqrt(Δk * (8S(x) + Δk)) - Δk) / 2 (equation 9) | |
_adjust = np.sqrt(delta_logp_thresh | |
* (8 * _neglogp + delta_logp_thresh) | |
- delta_logp_thresh) / 2 | |
# neglogp → regular p-values | |
_ptfce = np.exp(-1 * _adjust) | |
# reshape and return as STC | |
stc_ptfce = stc.copy() | |
ptfce_data = _ptfce.reshape(stc.data.shape) | |
stc_ptfce.data = ptfce_data | |
return (stc_ptfce, noise_stcs, all_thresholds, unaggregated_probs, | |
thresholded_source_prior_density_func, | |
sizes, cluster_size_density_func, | |
all_data_clusters_by_thresh, all_data_cluster_sizes_by_thresh) | |
# # # # # # # # # # # # | |
# ACTUALLY RUN STUFF # | |
# # # # # # # # # # # # | |
# get the sensor data | |
raw, evoked, noise_cov = get_sensor_data() | |
# get the inverse operator | |
print('Creating inverse operator') | |
snr = 3. | |
lambda2 = 1. / snr ** 2 | |
inverse_kwargs = dict(lambda2=lambda2, method='dSPM', pick_ori=None, | |
use_cps=True) | |
subject = 'sample' | |
subjects_dir = os.path.join(sample_data_folder, 'subjects') | |
inverse_operator = get_inverse( | |
raw, noise_cov, subject, subjects_dir, n_jobs=n_jobs, verbose=verbose) | |
src_adjacency = mne.spatial_src_adjacency(inverse_operator['src']) | |
adjacency = None | |
# make STC from data | |
print('Creating STC from data') | |
stc = mne.minimum_norm.apply_inverse( | |
evoked, inverse_operator, verbose=verbose, **inverse_kwargs) | |
# expand adjacency to temporal dimension | |
adjacency = mne.stats.combine_adjacency(len(stc.times), src_adjacency) | |
# compute pTFCE | |
with timer('running pTFCE'): | |
(stc_ptfce, noise_stcs, all_thresholds, unaggregated_probs, | |
thresholded_source_prior_density_func, sizes, | |
cluster_size_density_func, | |
all_data_clusters_by_thresh, all_data_cluster_sizes_by_thresh | |
) = ptfce( | |
stc, adjacency, raw.info, noise_cov, inverse_operator, n_jobs=n_jobs, | |
seed=rng) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment