Skip to content

Instantly share code, notes, and snippets.

@robintibor
Last active February 4, 2021 14:06
Show Gist options
  • Save robintibor/62de2854c92ca55f0d2324686c745d45 to your computer and use it in GitHub Desktop.
Save robintibor/62de2854c92ca55f0d2324686c745d45 to your computer and use it in GitHub Desktop.
High Gamma Decoding and Gradient Visualization

First call:

python run_exp_and_store_amp_grads_corr.py subject_id

e.g.

python run_exp_and_store_amp_grads_corr.py 2

for subject id 2, check inside the run_exp_and_store_amp_grads_corr.py to see where results will be stored, hyperparameters etc And do this for all subjects (1-14).

Afterwards, run notebook below to produce visualizations.

Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
# Authors: Robin Schirrmeister <[email protected]>
#
# License: BSD (3-clause)
import argparse
import logging
import os.path
import sys
import numpy as np
import torch
from braindecode import EEGClassifier
from braindecode.datasets.moabb import MOABBDataset
from braindecode.datautil.preprocess import MNEPreproc, NumpyPreproc, preprocess
from braindecode.datautil.preprocess import exponential_moving_standardize
from braindecode.datautil.windowers import create_windows_from_events
from braindecode.models import Deep4Net
from braindecode.models import ShallowFBCSPNet
from braindecode.models.util import to_dense_prediction_model, get_output_shape
from braindecode.training.losses import CroppedLoss
from braindecode.util import set_random_seeds
from braindecode.visualization.gradients import compute_amplitude_gradients
from skorch.callbacks import LRScheduler
from skorch.helper import predefined_split
from torch.utils.data import Subset
log = logging.getLogger(__name__)
def load_preprocessed_data(subject_id, low_cut_hz, high_cut_hz):
assert model_name in ['deep', 'shallow']
log.info("Load dataset...")
dataset = MOABBDataset(dataset_name="Schirrmeister2017", subject_ids=[subject_id])
C_sensors = [
'FC5', 'FC1', 'FC2', 'FC6', 'C3', 'Cz', 'C4', 'CP5',
'CP1', 'CP2', 'CP6', 'FC3', 'FCz', 'FC4', 'C5', 'C1', 'C2', 'C6',
'CP3', 'CPz', 'CP4', 'FFC5h', 'FFC3h', 'FFC4h', 'FFC6h', 'FCC5h',
'FCC3h', 'FCC4h', 'FCC6h', 'CCP5h', 'CCP3h', 'CCP4h', 'CCP6h', 'CPP5h',
'CPP3h', 'CPP4h', 'CPP6h', 'FFC1h', 'FFC2h', 'FCC1h', 'FCC2h', 'CCP1h',
'CCP2h', 'CPP1h', 'CPP2h']
# Parameters for exponential moving standardization
factor_new = 1e-3
init_block_size = 1000
log.info("Preprocess dataset...")
preprocessors = [
# keep only C sensors
# MNEPreproc(fn='pick_types', eeg=True, meg=False, stim=False),
MNEPreproc(fn='pick_channels', ch_names=C_sensors, ordered=True),
# convert from volt to microvolt, directly modifying the numpy array
NumpyPreproc(fn=lambda x: x * 1e6),
NumpyPreproc(fn=lambda x: np.clip(x, -800, 800)),
MNEPreproc(fn='resample', sfreq=250),
# bandpass filter
MNEPreproc(fn='filter', l_freq=low_cut_hz, h_freq=high_cut_hz),
# exponential moving standardization
NumpyPreproc(fn=exponential_moving_standardize, factor_new=factor_new,
init_block_size=init_block_size)
]
# Transform the data
preprocess(dataset, preprocessors)
return dataset
def create_cropped_model(model_name, n_chans):
######################################################################
# Now we create the model. To enable it to be used in cropped decoding
# efficiently, we manually set the length of the final convolution layer
# to some length that makes the receptive field of the ConvNet smaller
# than ``input_window_samples`` (see ``final_conv_length=30`` in the model
# definition).
#
cuda = torch.cuda.is_available() # check if GPU is available, if True chooses to use it
device = 'cuda' if cuda else 'cpu'
if cuda:
torch.backends.cudnn.benchmark = True
seed = 20200220 # random seed to make results reproducible
# Set random seed to be able to reproduce results
set_random_seeds(seed=seed, cuda=cuda)
n_classes = 4
if model_name == 'shallow':
model = ShallowFBCSPNet(
n_chans,
n_classes,
input_window_samples=None, # no need to provide if final_conv_length given
final_conv_length=30,
)
else:
assert model_name == 'deep'
model = Deep4Net(
n_chans,
n_classes,
input_window_samples=None, # no need to provide if final_conv_length given
final_conv_length=2,
)
# Send model to GPU
if cuda:
model.cuda()
######################################################################
# And now we transform model with strides to a model that outputs dense
# prediction, so we can use it to obtain predictions for all
# crops.
#
to_dense_prediction_model(model)
return model
def cut_windows(dataset, input_window_samples, window_stride_samples):
######################################################################
# Cut the data into windows
# -------------------------
#
######################################################################
# In contrast to trialwise decoding, we have to supply an explicit window size and window stride to the
# ``create_windows_from_events`` function.
#
trial_start_offset_seconds = -0.5
# Extract sampling frequency, check that they are same in all datasets
sfreq = dataset.datasets[0].raw.info['sfreq']
assert all([ds.raw.info['sfreq'] == sfreq for ds in dataset.datasets])
# Calculate the trial start offset in samples.
trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)
# Create windows using braindecode function for this. It needs parameters to define how
# trials should be used.
windows_dataset = create_windows_from_events(
dataset,
trial_start_offset_samples=trial_start_offset_samples,
trial_stop_offset_samples=0,
window_size_samples=input_window_samples,
window_stride_samples=window_stride_samples,
drop_last_window=False,
preload=True,
)
return windows_dataset
def split_into_train_valid(windows_dataset):
######################################################################
# Split the dataset
# -----------------
#
# This code is the same as in trialwise decoding.
#
splitted = windows_dataset.split('run')
full_train_set = splitted['train']
n_split = int(np.round(0.8 * len(full_train_set)))
# ensure this is mutiple of 2 (number of windows per trial)
n_windows_per_trial = 2 # here set by hand
n_split = n_split - (n_split % n_windows_per_trial)
valid_set = Subset(full_train_set, range(n_split, len(full_train_set)))
train_set = Subset(full_train_set, range(0, n_split))
return train_set, valid_set
def run_training(model, model_name, train_set, valid_set, device, n_epochs):
if model_name == 'shallow':
# These values we found good for shallow network:
lr = 0.0625 * 0.01
weight_decay = 0
else:
assert model_name == 'deep'
# For deep4 they should be:
lr = 1 * 0.01
weight_decay = 0.5 * 0.001
batch_size = 64
clf = EEGClassifier(
model,
cropped=True,
criterion=CroppedLoss,
criterion__loss_function=torch.nn.functional.nll_loss,
optimizer=torch.optim.AdamW,
train_split=predefined_split(valid_set),
optimizer__lr=lr,
optimizer__weight_decay=weight_decay,
iterator_train__shuffle=True,
batch_size=batch_size,
callbacks=[
"accuracy", ("lr_scheduler", LRScheduler('CosineAnnealingLR', T_max=n_epochs - 1)),
],
device=device,
)
# Model training for a specified number of epochs. `y` is None as it is already supplied
# in the dataset.
clf.fit(train_set, y=None, epochs=n_epochs)
return clf
def compute_and_store_amp_grads(model, train_set, filename):
amp_grads_per_filter = compute_amplitude_gradients(model, train_set, batch_size=64)
# average across compute windows
avg_amp_grads_per_filter = np.mean(amp_grads_per_filter, axis=1)
np.save(filename, avg_amp_grads_per_filter)
def run_exp(
subject_id, low_cut_hz, high_cut_hz, n_epochs,
model_name, output_dir):
log.info(f"Load and preprocess data for subject {subject_id}...")
dataset = load_preprocessed_data(subject_id, low_cut_hz, high_cut_hz)
# Extract number of chans from dataset to create model
n_chans = dataset[0][0].shape[0]
log.info("Create cropped model...")
model = create_cropped_model(model_name, n_chans)
# Cut windows from the preprocessed data, using number of predictions
# per compute window to cut non-overlapping fully covering windows
# (except for overlap of last window to stay within trial bounds)
log.info("Cut windows from dataset ...")
input_window_samples = 1000
# To know the models’ receptive field, we calculate the shape of model
# output for a dummy input.
n_preds_per_input = get_output_shape(model, n_chans, input_window_samples)[2]
windows_dataset = cut_windows(
dataset, input_window_samples, window_stride_samples=n_preds_per_input)
# Split into train and valid, ignoring final evaluation for now
log.info("Split into train and valid...")
train_set, valid_set = split_into_train_valid(windows_dataset)
# Run actual training
log.info("Run training...")
run_training(model, model_name, train_set, valid_set, 'cuda', n_epochs)
log.info("Compute and store amplitude gradients ...")
amp_grads_filename = os.path.join(output_dir, f"{subject_id}_avg_amp_grads.npy")
compute_and_store_amp_grads(model, train_set, filename=amp_grads_filename)
log.info("... Done.")
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description="""Launch an experiment from a YAML experiment file.
Example: ./train_experiments.py configs/config.py """
)
parser.add_argument('subject_id', type=int,
help='''Run for subject id....''')
args = parser.parse_args()
subject_id = args.subject_id
low_cut_hz = None # low cut frequency for filtering
high_cut_hz = None # high cut frequency for filtering
n_epochs = 20
model_name = 'deep'
output_dir = './results/'
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
level=logging.DEBUG, stream=sys.stdout)
run_exp(
subject_id, low_cut_hz, high_cut_hz, n_epochs,
model_name, output_dir)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment