Skip to content

Instantly share code, notes, and snippets.

@larsoner
Created April 22, 2019 12:09
Show Gist options
  • Save larsoner/8566361d3e011c9ce34c1d321f0f4ef1 to your computer and use it in GitHub Desktop.
Save larsoner/8566361d3e011c9ce34c1d321f0f4ef1 to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
"""
Simulate the sample dataset
===========================
Here we use :func:`mne.simulation.simulate_raw` to simulate the sample dataset
and then do source localization on the result.
"""
import os.path as op
import numpy as np
import mne
data_path = mne.datasets.sample.data_path()
subject = 'sample'
meg_path = op.join(data_path, 'MEG', subject)
subjects_dir = op.join(data_path, 'subjects')
raw = mne.io.read_raw_fif(op.join(meg_path, 'sample_audvis_raw.fif'))
raw.set_eeg_reference(projection=True).crop(0, 60) # for speed
forward = mne.read_forward_solution(op.join(
meg_path, 'sample_audvis-meg-eeg-oct-6-fwd.fif'))
# Our standard sample IDs
event_id = {'auditory/left': 1, 'auditory/right': 2, 'visual/left': 3,
'visual/right': 4, 'smiley': 5, 'button': 32}
# Events from the experiment
events = mne.find_events(raw)
##############################################################################
# Set up simulation
# -----------------
# Make a dict that maps conditions to activation strengths within aparc.a2009s
# labels:
activations = {
'auditory/left':
[('G_temp_sup-G_T_transv-lh', 100), # label, activation (nAm)
('G_temp_sup-G_T_transv-rh', 200)],
'auditory/right':
[('G_temp_sup-G_T_transv-lh', 200),
('G_temp_sup-G_T_transv-rh', 100)],
'visual/left':
[('S_calcarine-lh', 100),
('S_calcarine-rh', 200)],
'visual/right':
[('S_calcarine-lh', 200),
('S_calcarine-rh', 100)],
}
##############################################################################
# Set up the simulation generator
# -------------------------------
# Eventually this sort of thing could be handled inside MNE, but for now:
def stc_gen(activations, annot='aparc.a2009s'):
"""Generate STCs based on surface labels."""
print(' Loading labels')
# Load the labels
src = forward['src']
vertices = [s['vertno'] for s in src]
n_vertices = sum(len(v) for v in vertices)
offsets = [0, len(vertices[0])]
# Load the necessary labels
label_names = sorted(set(activation[0]
for activation_list in activations.values()
for activation in activation_list))
dotter = np.zeros((n_vertices, len(label_names)))
used = np.zeros(n_vertices, bool)
for li, name in enumerate(label_names):
label = mne.read_labels_from_annot(
subject, annot, subjects_dir=subjects_dir, regexp=name,
verbose=False)
if len(label) != 1:
raise RuntimeError('Ambiguous label, found %d != 1: %s'
% (len(label), name,))
label = label[0]
hidx = dict(lh=0, rh=1)[label.hemi]
these_vertices = label.vertices[np.in1d(label.vertices,
src[hidx]['vertno'])]
label = mne.Label(these_vertices, hemi=label.hemi)
label.values[:] = 1
flip = mne.label_sign_flip(label, src)
lidx = np.where(np.in1d(src[hidx]['vertno'], [label.vertices]))[0]
lidx += offsets[hidx]
dotter[lidx, li] = flip / len(flip) # normalize
used[lidx] = True
# For efficiency, restrict our actually used vertices
vertices = [vertices[0][used[:offsets[1]]], vertices[1][used[offsets[1]:]]]
n_vertices = sum(len(v) for v in vertices)
del offsets
dotter = dotter[used]
assert dotter.shape[0] == n_vertices
print(' Reduced from %d sources to %d for %d labels'
% (len(used), used.sum(), len(label_names)))
# Actually do the generation
count = 0
tstep = 1. / raw.info['sfreq']
assert events[0, 0] > raw.first_samp
rev_id = {val: key for key, val in event_id.items()}
# one signal for all for now
n_active = int(round(0.2 * raw.info['sfreq']))
signal = 1e-9 * np.hanning(n_active)
while True:
if count == 0:
n_samp = events[0, 0] - raw.first_samp
id_ = 0
else:
if count == len(events):
n_samp = len(raw.times) - (events[-1, 0] - raw.first_samp)
else:
assert 0 < count < len(events)
n_samp = events[count, 0] - events[count - 1, 0]
id_ = events[count - 1, 2]
assert n_samp > 0
data = np.zeros((len(label_names), n_samp))
if id_ not in rev_id or rev_id[id_] not in activations:
print(' Ignoring event %d / %d (unknown id: %d)'
% (count, len(events), id_))
else:
key = rev_id[id_]
print(' Generating STC for event %d / %d (type %s)'
% (count, len(events), key))
for name, amp in activations[key]:
row_idx = label_names.index(name)
time_sl = slice(0, min(len(signal), data.shape[1]))
data[row_idx, time_sl] = amp * signal[time_sl]
# Project label activations into source space
data = np.dot(dotter, data)
# Yield the SourceEstimate
yield (mne.SourceEstimate(data, vertices, 0, tstep), id_)
count += 1
##############################################################################
# Run the simulation
# ------------------
# Compute a noise covariance to add noise to the data
noise_epochs = mne.Epochs(raw, events, tmax=0)
noise_epochs.info['bads'] = []
noise_cov = mne.compute_covariance(noise_epochs)
raw_sim = mne.simulation.simulate_raw(
raw, stc_gen(activations), None, None, None, noise_cov,
blink=True, ecg=True, forward=forward, random_state=0, verbose=True)
##############################################################################
# Source localize
# ---------------
method, lambda2 = 'dSPM', 1. / 9.
epochs = mne.Epochs(raw_sim, events, event_id)
inv = mne.minimum_norm.make_inverse_operator(epochs.info, forward, noise_cov)
stc_aud = mne.minimum_norm.apply_inverse(
epochs['auditory/left'].average(), inv, lambda2, method)
stc_vis = mne.minimum_norm.apply_inverse(
epochs['visual/right'].average(), inv, lambda2, method)
stc_diff = stc_aud - stc_vis
brain = stc_diff.plot(subjects_dir=subjects_dir, initial_time=0.1)
@larsoner
Copy link
Author

Screen Shot 2019-03-18 at 22 39 05

Screen Shot 2019-03-18 at 22 39 10

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment