# Authors: Robin Schirrmeister <[email protected]> |
# |
# License: BSD (3-clause) |
import argparse |
import logging |
import os.path |
import sys |
import numpy as np |
import torch |
from braindecode import EEGClassifier |
from braindecode.datasets.moabb import MOABBDataset |
from braindecode.datautil.preprocess import MNEPreproc, NumpyPreproc, preprocess |
from braindecode.datautil.preprocess import exponential_moving_standardize |
from braindecode.datautil.windowers import create_windows_from_events |
from braindecode.models import Deep4Net |
from braindecode.models import ShallowFBCSPNet |
from braindecode.models.util import to_dense_prediction_model, get_output_shape |
from braindecode.training.losses import CroppedLoss |
from braindecode.util import set_random_seeds |
from braindecode.visualization.gradients import compute_amplitude_gradients |
from skorch.callbacks import LRScheduler |
from skorch.helper import predefined_split |
from torch.utils.data import Subset |
log = logging.getLogger(__name__) |
def load_preprocessed_data(subject_id, low_cut_hz, high_cut_hz): |
assert model_name in ['deep', 'shallow'] |
log.info("Load dataset...") |
dataset = MOABBDataset(dataset_name="Schirrmeister2017", subject_ids=[subject_id]) |
C_sensors = [ |
'FC5', 'FC1', 'FC2', 'FC6', 'C3', 'Cz', 'C4', 'CP5', |
'CP1', 'CP2', 'CP6', 'FC3', 'FCz', 'FC4', 'C5', 'C1', 'C2', 'C6', |
'CP3', 'CPz', 'CP4', 'FFC5h', 'FFC3h', 'FFC4h', 'FFC6h', 'FCC5h', |
'FCC3h', 'FCC4h', 'FCC6h', 'CCP5h', 'CCP3h', 'CCP4h', 'CCP6h', 'CPP5h', |
'CPP3h', 'CPP4h', 'CPP6h', 'FFC1h', 'FFC2h', 'FCC1h', 'FCC2h', 'CCP1h', |
'CCP2h', 'CPP1h', 'CPP2h'] |
# Parameters for exponential moving standardization |
factor_new = 1e-3 |
init_block_size = 1000 |
log.info("Preprocess dataset...") |
preprocessors = [ |
# keep only C sensors |
# MNEPreproc(fn='pick_types', eeg=True, meg=False, stim=False), |
MNEPreproc(fn='pick_channels', ch_names=C_sensors, ordered=True), |
# convert from volt to microvolt, directly modifying the numpy array |
NumpyPreproc(fn=lambda x: x * 1e6), |
NumpyPreproc(fn=lambda x: np.clip(x, -800, 800)), |
MNEPreproc(fn='resample', sfreq=250), |
# bandpass filter |
MNEPreproc(fn='filter', l_freq=low_cut_hz, h_freq=high_cut_hz), |
# exponential moving standardization |
NumpyPreproc(fn=exponential_moving_standardize, factor_new=factor_new, |
init_block_size=init_block_size) |
] |
# Transform the data |
preprocess(dataset, preprocessors) |
return dataset |
def create_cropped_model(model_name, n_chans): |
###################################################################### |
# Now we create the model. To enable it to be used in cropped decoding |
# efficiently, we manually set the length of the final convolution layer |
# to some length that makes the receptive field of the ConvNet smaller |
# than ``input_window_samples`` (see ``final_conv_length=30`` in the model |
# definition). |
# |
cuda = torch.cuda.is_available() # check if GPU is available, if True chooses to use it |
device = 'cuda' if cuda else 'cpu' |
if cuda: |
torch.backends.cudnn.benchmark = True |
seed = 20200220 # random seed to make results reproducible |
# Set random seed to be able to reproduce results |
set_random_seeds(seed=seed, cuda=cuda) |
n_classes = 4 |
if model_name == 'shallow': |
model = ShallowFBCSPNet( |
n_chans, |
n_classes, |
input_window_samples=None, # no need to provide if final_conv_length given |
final_conv_length=30, |
) |
else: |
assert model_name == 'deep' |
model = Deep4Net( |
n_chans, |
n_classes, |
input_window_samples=None, # no need to provide if final_conv_length given |
final_conv_length=2, |
) |
# Send model to GPU |
if cuda: |
model.cuda() |
###################################################################### |
# And now we transform model with strides to a model that outputs dense |
# prediction, so we can use it to obtain predictions for all |
# crops. |
# |
to_dense_prediction_model(model) |
return model |
def cut_windows(dataset, input_window_samples, window_stride_samples): |
###################################################################### |
# Cut the data into windows |
# ------------------------- |
# |
###################################################################### |
# In contrast to trialwise decoding, we have to supply an explicit window size and window stride to the |
# ``create_windows_from_events`` function. |
# |
trial_start_offset_seconds = -0.5 |
# Extract sampling frequency, check that they are same in all datasets |
sfreq = dataset.datasets[0].raw.info['sfreq'] |
assert all([ds.raw.info['sfreq'] == sfreq for ds in dataset.datasets]) |
# Calculate the trial start offset in samples. |
trial_start_offset_samples = int(trial_start_offset_seconds * sfreq) |
# Create windows using braindecode function for this. It needs parameters to define how |
# trials should be used. |
windows_dataset = create_windows_from_events( |
dataset, |
trial_start_offset_samples=trial_start_offset_samples, |
trial_stop_offset_samples=0, |
window_size_samples=input_window_samples, |
window_stride_samples=window_stride_samples, |
drop_last_window=False, |
preload=True, |
) |
return windows_dataset |
def split_into_train_valid(windows_dataset): |
###################################################################### |
# Split the dataset |
# ----------------- |
# |
# This code is the same as in trialwise decoding. |
# |
splitted = windows_dataset.split('run') |
full_train_set = splitted['train'] |
n_split = int(np.round(0.8 * len(full_train_set))) |
# ensure this is mutiple of 2 (number of windows per trial) |
n_windows_per_trial = 2 # here set by hand |
n_split = n_split - (n_split % n_windows_per_trial) |
valid_set = Subset(full_train_set, range(n_split, len(full_train_set))) |
train_set = Subset(full_train_set, range(0, n_split)) |
return train_set, valid_set |
def run_training(model, model_name, train_set, valid_set, device, n_epochs): |
if model_name == 'shallow': |
# These values we found good for shallow network: |
lr = 0.0625 * 0.01 |
weight_decay = 0 |
else: |
assert model_name == 'deep' |
# For deep4 they should be: |
lr = 1 * 0.01 |
weight_decay = 0.5 * 0.001 |
batch_size = 64 |
clf = EEGClassifier( |
model, |
cropped=True, |
criterion=CroppedLoss, |
criterion__loss_function=torch.nn.functional.nll_loss, |
optimizer=torch.optim.AdamW, |
train_split=predefined_split(valid_set), |
optimizer__lr=lr, |
optimizer__weight_decay=weight_decay, |
iterator_train__shuffle=True, |
batch_size=batch_size, |
callbacks=[ |
"accuracy", ("lr_scheduler", LRScheduler('CosineAnnealingLR', T_max=n_epochs - 1)), |
], |
device=device, |
) |
# Model training for a specified number of epochs. `y` is None as it is already supplied |
# in the dataset. |
clf.fit(train_set, y=None, epochs=n_epochs) |
return clf |
def compute_and_store_amp_grads(model, train_set, filename): |
amp_grads_per_filter = compute_amplitude_gradients(model, train_set, batch_size=64) |
# average across compute windows |
avg_amp_grads_per_filter = np.mean(amp_grads_per_filter, axis=1) |
np.save(filename, avg_amp_grads_per_filter) |
def run_exp( |
subject_id, low_cut_hz, high_cut_hz, n_epochs, |
model_name, output_dir): |
log.info(f"Load and preprocess data for subject {subject_id}...") |
dataset = load_preprocessed_data(subject_id, low_cut_hz, high_cut_hz) |
# Extract number of chans from dataset to create model |
n_chans = dataset[0][0].shape[0] |
log.info("Create cropped model...") |
model = create_cropped_model(model_name, n_chans) |
# Cut windows from the preprocessed data, using number of predictions |
# per compute window to cut non-overlapping fully covering windows |
# (except for overlap of last window to stay within trial bounds) |
log.info("Cut windows from dataset ...") |
input_window_samples = 1000 |
# To know the models’ receptive field, we calculate the shape of model |
# output for a dummy input. |
n_preds_per_input = get_output_shape(model, n_chans, input_window_samples)[2] |
windows_dataset = cut_windows( |
dataset, input_window_samples, window_stride_samples=n_preds_per_input) |
# Split into train and valid, ignoring final evaluation for now |
log.info("Split into train and valid...") |
train_set, valid_set = split_into_train_valid(windows_dataset) |
# Run actual training |
log.info("Run training...") |
run_training(model, model_name, train_set, valid_set, 'cuda', n_epochs) |
log.info("Compute and store amplitude gradients ...") |
amp_grads_filename = os.path.join(output_dir, f"{subject_id}_avg_amp_grads.npy") |
compute_and_store_amp_grads(model, train_set, filename=amp_grads_filename) |
log.info("... Done.") |
if __name__ == '__main__': |
parser = argparse.ArgumentParser( |
description="""Launch an experiment from a YAML experiment file. |
Example: ./train_experiments.py configs/config.py """ |
) |
parser.add_argument('subject_id', type=int, |
help='''Run for subject id....''') |
args = parser.parse_args() |
subject_id = args.subject_id |
low_cut_hz = None # low cut frequency for filtering |
high_cut_hz = None # high cut frequency for filtering |
n_epochs = 20 |
model_name = 'deep' |
output_dir = './results/' |
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s', |
level=logging.DEBUG, stream=sys.stdout) |
run_exp( |
subject_id, low_cut_hz, high_cut_hz, n_epochs, |
model_name, output_dir) |