Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save kingjr/7b4aa44438781e138fcc to your computer and use it in GitHub Desktop.
Save kingjr/7b4aa44438781e138fcc to your computer and use it in GitHub Desktop.
Set of temporal generalization examples using MNE-python & sklearn
"""
========================================================
Decoding sensor space data with over-time generalization
========================================================
This example runs the analysis computed in:
Jean-Remi King, Alexandre Gramfort, Aaron Schurger, Lionel Naccache
and Stanislas Dehaene, "Two distinct dynamic modes subtend the detection of
unexpected sounds", PLOS ONE, 2013
The idea is to learn at one time instant and assess if the decoder
can predict accurately over time.
"""
print(__doc__)
# Authors: Alexandre Gramfort <[email protected]>
# Denis Engemann <[email protected]>
# Jean-Remi King <[email protected]>
#
# License: BSD (3-clause)
import numpy as np
import matplotlib.pyplot as plt
import mne
from mne.datasets import sample
from mne.decoding import time_generalization
from mne.utils import create_slices
from mne.fixes import partial
data_path = sample.data_path()
#
# Load and filter data, set up epochs
raw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif'
events_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw-eve.fif'
# read file
raw = mne.io.Raw(raw_fname, preload=True)
# pick MEG channels only
picks = mne.pick_types(raw.info, meg=True, exclude='bads')
# band pass filtering signals: time generalization is here applied with an
# evoked signals.
raw.filter(1, 30, method='iir')
# get events
events = mne.read_events(events_fname)
event_id = {'AudL': 1, 'VisL': 3, 'AudR': 2, 'VisR': 4}
event_id_train = ['AudL', 'VisL']
event_id_gen = ['AudR', 'VisR']
tmin, tmax = -0.1, 0.5
# Read epochs
decim = 3 # decimate to make the example faster to run
epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=True,
picks=picks, baseline=None, preload=True,
reject=dict(mag=5e-12), decim=decim)
# select training condition
epochs_list = [epochs[k] for k in event_id_train]
mne.epochs.equalize_epoch_counts(epochs_list)
# select generalization condition
epochs_list_gen = [epochs[k] for k in event_id_gen]
mne.epochs.equalize_epoch_counts(epochs_list_gen)
#
# -----------------------------
# 1. Classic decoding over time
# -----------------------------
# In this scenario, each classifier is trained and tested at a single time
# point. This results in the classic decoding score across time
#
# 1.1. Setup sliding window parameters
train_slices = partial(create_slices, length=2)
# 1.2. Run decoding
# Compute Area Under the Curver (AUC) Receiver Operator Curve (ROC) score
# of time generalization. A perfect decoding would lead to AUCs of 1.
# Chance level is at 0.5.
# The default classifier is a linear SVM (C=1) after feature scaling.
results = time_generalization(epochs_list, train_slices=train_slices,
relative_test_slice=True,
clf=None, cv=5, scoring="roc_auc",
shuffle=True, n_jobs=1)
# 1.3 Retrieve results
scores = results['scores']
# Note that time corresponds to the *start* of the classifier. Larger window
# length imply that later time points will be used by each classifier.
train_times = 1e3 * results['train_times']
# 1.4 Visualize
# Vizualize
def plot_decode(ax, scores, time, event_id):
ax.plot(time, np.diag(scores), label="Classif. score")
ax.axhline(0.5, color='k', linestyle='--', label="Chance level")
ax.axvline(0, color='r', label='Stim onset')
ax.set_ylim(0, 1)
ax.set_xlabel('Time (ms)')
ax.set_ylabel('ROC classification score')
ax.set_title('Decoding (%s vs. %s)' % tuple(event_id))
ax.legend(loc='best')
fig, ax = plt.subplots(1, 1)
plot_decode(ax, scores, train_times, event_id_train)
mne.viz.tight_layout(fig=fig)
plt.show()
# --------------------------------------------------------------
# 2. Generalization over time + generalization across conditions
# --------------------------------------------------------------
# In this scenario, each classifier is trained in a particular time point, and
# subsquently tested on its ability to generalize across other time points.
#
# Note that each classifier can also be tested on its ability to generalize to
# a second dataset. Specifically here, the classifiers are trained on
# epochs_list (left audio versus left visual), and generalized to
# epochs_list_gen (right audio versus right visual)
# The sliding window parameters can be modified with create_slice(). Here,
# the classifiers are non-overlapping and use 2 consecutive time samples.
train_slices = partial(create_slices, length=2, step=2)
# Run main script
results = time_generalization(epochs_list, train_slices=train_slices,
epochs_list_gen=epochs_list_gen,
relative_test_slice=False, n_jobs=-1)
train_times = 1e3 * results['train_times']
test_times = 1e3 * results['test_times']
scores = results['scores']
scores_gen = results['scores_gen']
# Vizualize
def plot_time_gen(ax, plt, scores, train_time, test_time, event_id):
im = ax.imshow(scores, interpolation='nearest', origin='lower',
extent=[test_times[0], test_times[-1],
train_times[0], train_times[-1]],
vmin=0., vmax=1.)
ax.set_xlabel('Times Test (ms)')
ax.set_ylabel('Times Train (ms)')
ax.set_title('Time generalization (%s vs. %s)' % tuple(event_id))
ax.axvline(0, color='k')
ax.axhline(0, color='k')
plt.colorbar(im, ax=ax)
fig, ax = plt.subplots(2, 2, figsize=(12, 8))
ax1, ax2, ax3, ax4 = ax.T.flatten()
# plot time generalization for (cross-validated) training set
plot_decode(ax1, scores, train_times, event_id_train)
plot_time_gen(ax3, plt, scores, train_times, test_times, event_id_train)
# plot time generalization for generalization dataset
plot_decode(ax2, scores_gen, train_times, event_id_train)
plot_time_gen(ax4, plt, scores_gen, train_times, test_times, event_id_gen)
mne.viz.tight_layout(fig=fig)
plt.show()
# -----------------------------------
# 3. Advanced temporal generalization
# -----------------------------------
# Other temporal generalization scenario can be implemented by modifying the
# training and testing slices.
fig, ax = plt.subplots(1, 3, figsize=(18, 4))
ax1, ax2, ax3 = ax.T.flatten()
# Generalize around diagonal
test_slices = create_slices(-5, 5) # generalize +/- around train sample
results = time_generalization(epochs_list, test_slices=test_slices,
relative_test_slice=True, n_jobs=1)
train_times = 1e3 * results['train_times']
scores = results['scores']
plot_time_gen(ax1, plt, scores, train_times, train_times, event_id_train)
# Generalize between specific time points
train_slices = create_slices(10, 15)
test_slices = [create_slices(0, 30) for _ in range(len(train_slices))]
results = time_generalization(epochs_list, train_slices=train_slices,
test_slices=test_slices,
relative_test_slice=False, n_jobs=1)
train_times = 1e3 * results['train_times']
test_times = 1e3 * results['test_times']
scores = results['scores']
plot_time_gen(ax2, plt, scores, train_times, test_times, event_id_train)
# Generalize at specific time points
train_slices = create_slices(0, 31)
test_slices = [create_slices(10, 15) for _ in range(len(train_slices))]
results = time_generalization(epochs_list, test_slices=test_slices,
train_slices=train_slices,
relative_test_slice=False, n_jobs=1)
train_times = 1e3 * results['train_times']
test_times = 1e3 * results['test_times']
scores = results['scores']
plot_time_gen(ax3, plt, scores, train_times, test_times, event_id_train)
mne.viz.tight_layout(fig=fig)
plt.show()
# --------------------------------------
# 4. Use different parallelization modes
# --------------------------------------
import time
# Re-create heavier dataset
picks = mne.pick_types(raw.info, meg='grad', stim=False, ecg=False,
eog=False, exclude='bads')
epochs = mne.Epochs(raw, events, event_id, -.5, 1.5, picks=picks,
baseline=(None, 0), preload=True)
epochs_list = [epochs[k] for k in event_id_train]
def timeit(train_slices, cv, parallel):
test_slices = [create_slices(0, len(train_slices), step=1,
length=train_slices[0].stop -
train_slices[0].start)] * len(slices)
t = time.time()
results = time_generalization(epochs_list, train_slices=train_slices,
test_slices=test_slices, n_jobs=-1,
parallel_across=parallel, cv=cv)
print(time.time() - t)
return results
train_slices = create_slices(0, 200, length=50, step=1)
timeit(train_slices, 2, 'folds')
timeit(train_slices, 2, 'time_samples')
slices = create_slices(10, 20)
timeit(train_slices, 6, 'folds')
timeit(train_slices, 6, 'time_samples')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment