|
# Authors: Robin Schirrmeister <[email protected]> |
|
# |
|
# License: BSD (3-clause) |
|
|
|
|
|
import argparse |
|
import logging |
|
import os.path |
|
import sys |
|
|
|
import numpy as np |
|
import torch |
|
from braindecode import EEGClassifier |
|
from braindecode.datasets.moabb import MOABBDataset |
|
from braindecode.datautil.preprocess import MNEPreproc, NumpyPreproc, preprocess |
|
from braindecode.datautil.preprocess import exponential_moving_standardize |
|
from braindecode.datautil.windowers import create_windows_from_events |
|
from braindecode.models import Deep4Net |
|
from braindecode.models import ShallowFBCSPNet |
|
from braindecode.models.util import to_dense_prediction_model, get_output_shape |
|
from braindecode.training.losses import CroppedLoss |
|
from braindecode.util import set_random_seeds |
|
from braindecode.visualization.gradients import compute_amplitude_gradients |
|
from skorch.callbacks import LRScheduler |
|
from skorch.helper import predefined_split |
|
from torch.utils.data import Subset |
|
|
|
log = logging.getLogger(__name__) |
|
|
|
|
|
def load_preprocessed_data(subject_id, low_cut_hz, high_cut_hz): |
|
assert model_name in ['deep', 'shallow'] |
|
log.info("Load dataset...") |
|
dataset = MOABBDataset(dataset_name="Schirrmeister2017", subject_ids=[subject_id]) |
|
C_sensors = [ |
|
'FC5', 'FC1', 'FC2', 'FC6', 'C3', 'Cz', 'C4', 'CP5', |
|
'CP1', 'CP2', 'CP6', 'FC3', 'FCz', 'FC4', 'C5', 'C1', 'C2', 'C6', |
|
'CP3', 'CPz', 'CP4', 'FFC5h', 'FFC3h', 'FFC4h', 'FFC6h', 'FCC5h', |
|
'FCC3h', 'FCC4h', 'FCC6h', 'CCP5h', 'CCP3h', 'CCP4h', 'CCP6h', 'CPP5h', |
|
'CPP3h', 'CPP4h', 'CPP6h', 'FFC1h', 'FFC2h', 'FCC1h', 'FCC2h', 'CCP1h', |
|
'CCP2h', 'CPP1h', 'CPP2h'] |
|
# Parameters for exponential moving standardization |
|
factor_new = 1e-3 |
|
init_block_size = 1000 |
|
|
|
log.info("Preprocess dataset...") |
|
preprocessors = [ |
|
# keep only C sensors |
|
# MNEPreproc(fn='pick_types', eeg=True, meg=False, stim=False), |
|
MNEPreproc(fn='pick_channels', ch_names=C_sensors, ordered=True), |
|
# convert from volt to microvolt, directly modifying the numpy array |
|
NumpyPreproc(fn=lambda x: x * 1e6), |
|
NumpyPreproc(fn=lambda x: np.clip(x, -800, 800)), |
|
MNEPreproc(fn='resample', sfreq=250), |
|
# bandpass filter |
|
MNEPreproc(fn='filter', l_freq=low_cut_hz, h_freq=high_cut_hz), |
|
# exponential moving standardization |
|
NumpyPreproc(fn=exponential_moving_standardize, factor_new=factor_new, |
|
init_block_size=init_block_size) |
|
] |
|
|
|
# Transform the data |
|
preprocess(dataset, preprocessors) |
|
return dataset |
|
|
|
|
|
def create_cropped_model(model_name, n_chans): |
|
###################################################################### |
|
# Now we create the model. To enable it to be used in cropped decoding |
|
# efficiently, we manually set the length of the final convolution layer |
|
# to some length that makes the receptive field of the ConvNet smaller |
|
# than ``input_window_samples`` (see ``final_conv_length=30`` in the model |
|
# definition). |
|
# |
|
|
|
cuda = torch.cuda.is_available() # check if GPU is available, if True chooses to use it |
|
device = 'cuda' if cuda else 'cpu' |
|
if cuda: |
|
torch.backends.cudnn.benchmark = True |
|
seed = 20200220 # random seed to make results reproducible |
|
# Set random seed to be able to reproduce results |
|
set_random_seeds(seed=seed, cuda=cuda) |
|
|
|
n_classes = 4 |
|
|
|
if model_name == 'shallow': |
|
model = ShallowFBCSPNet( |
|
n_chans, |
|
n_classes, |
|
input_window_samples=None, # no need to provide if final_conv_length given |
|
final_conv_length=30, |
|
) |
|
else: |
|
assert model_name == 'deep' |
|
model = Deep4Net( |
|
n_chans, |
|
n_classes, |
|
input_window_samples=None, # no need to provide if final_conv_length given |
|
final_conv_length=2, |
|
) |
|
|
|
# Send model to GPU |
|
if cuda: |
|
model.cuda() |
|
|
|
###################################################################### |
|
# And now we transform model with strides to a model that outputs dense |
|
# prediction, so we can use it to obtain predictions for all |
|
# crops. |
|
# |
|
|
|
to_dense_prediction_model(model) |
|
return model |
|
|
|
|
|
def cut_windows(dataset, input_window_samples, window_stride_samples): |
|
###################################################################### |
|
# Cut the data into windows |
|
# ------------------------- |
|
# |
|
|
|
###################################################################### |
|
# In contrast to trialwise decoding, we have to supply an explicit window size and window stride to the |
|
# ``create_windows_from_events`` function. |
|
# |
|
|
|
trial_start_offset_seconds = -0.5 |
|
# Extract sampling frequency, check that they are same in all datasets |
|
sfreq = dataset.datasets[0].raw.info['sfreq'] |
|
assert all([ds.raw.info['sfreq'] == sfreq for ds in dataset.datasets]) |
|
|
|
# Calculate the trial start offset in samples. |
|
trial_start_offset_samples = int(trial_start_offset_seconds * sfreq) |
|
|
|
# Create windows using braindecode function for this. It needs parameters to define how |
|
# trials should be used. |
|
windows_dataset = create_windows_from_events( |
|
dataset, |
|
trial_start_offset_samples=trial_start_offset_samples, |
|
trial_stop_offset_samples=0, |
|
window_size_samples=input_window_samples, |
|
window_stride_samples=window_stride_samples, |
|
drop_last_window=False, |
|
preload=True, |
|
) |
|
return windows_dataset |
|
|
|
|
|
def split_into_train_valid(windows_dataset): |
|
###################################################################### |
|
# Split the dataset |
|
# ----------------- |
|
# |
|
# This code is the same as in trialwise decoding. |
|
# |
|
|
|
splitted = windows_dataset.split('run') |
|
full_train_set = splitted['train'] |
|
|
|
n_split = int(np.round(0.8 * len(full_train_set))) |
|
# ensure this is mutiple of 2 (number of windows per trial) |
|
n_windows_per_trial = 2 # here set by hand |
|
n_split = n_split - (n_split % n_windows_per_trial) |
|
valid_set = Subset(full_train_set, range(n_split, len(full_train_set))) |
|
train_set = Subset(full_train_set, range(0, n_split)) |
|
return train_set, valid_set |
|
|
|
|
|
def run_training(model, model_name, train_set, valid_set, device, n_epochs): |
|
if model_name == 'shallow': |
|
# These values we found good for shallow network: |
|
lr = 0.0625 * 0.01 |
|
weight_decay = 0 |
|
else: |
|
assert model_name == 'deep' |
|
# For deep4 they should be: |
|
lr = 1 * 0.01 |
|
weight_decay = 0.5 * 0.001 |
|
|
|
batch_size = 64 |
|
|
|
clf = EEGClassifier( |
|
model, |
|
cropped=True, |
|
criterion=CroppedLoss, |
|
criterion__loss_function=torch.nn.functional.nll_loss, |
|
optimizer=torch.optim.AdamW, |
|
train_split=predefined_split(valid_set), |
|
optimizer__lr=lr, |
|
optimizer__weight_decay=weight_decay, |
|
iterator_train__shuffle=True, |
|
batch_size=batch_size, |
|
callbacks=[ |
|
"accuracy", ("lr_scheduler", LRScheduler('CosineAnnealingLR', T_max=n_epochs - 1)), |
|
], |
|
device=device, |
|
) |
|
# Model training for a specified number of epochs. `y` is None as it is already supplied |
|
# in the dataset. |
|
clf.fit(train_set, y=None, epochs=n_epochs) |
|
return clf |
|
|
|
|
|
def compute_and_store_amp_grads(model, train_set, filename): |
|
amp_grads_per_filter = compute_amplitude_gradients(model, train_set, batch_size=64) |
|
# average across compute windows |
|
avg_amp_grads_per_filter = np.mean(amp_grads_per_filter, axis=1) |
|
np.save(filename, avg_amp_grads_per_filter) |
|
|
|
|
|
def run_exp( |
|
subject_id, low_cut_hz, high_cut_hz, n_epochs, |
|
model_name, output_dir): |
|
log.info(f"Load and preprocess data for subject {subject_id}...") |
|
dataset = load_preprocessed_data(subject_id, low_cut_hz, high_cut_hz) |
|
|
|
# Extract number of chans from dataset to create model |
|
n_chans = dataset[0][0].shape[0] |
|
log.info("Create cropped model...") |
|
model = create_cropped_model(model_name, n_chans) |
|
|
|
# Cut windows from the preprocessed data, using number of predictions |
|
# per compute window to cut non-overlapping fully covering windows |
|
# (except for overlap of last window to stay within trial bounds) |
|
log.info("Cut windows from dataset ...") |
|
input_window_samples = 1000 |
|
# To know the models’ receptive field, we calculate the shape of model |
|
# output for a dummy input. |
|
n_preds_per_input = get_output_shape(model, n_chans, input_window_samples)[2] |
|
windows_dataset = cut_windows( |
|
dataset, input_window_samples, window_stride_samples=n_preds_per_input) |
|
|
|
# Split into train and valid, ignoring final evaluation for now |
|
log.info("Split into train and valid...") |
|
train_set, valid_set = split_into_train_valid(windows_dataset) |
|
|
|
# Run actual training |
|
log.info("Run training...") |
|
run_training(model, model_name, train_set, valid_set, 'cuda', n_epochs) |
|
|
|
log.info("Compute and store amplitude gradients ...") |
|
amp_grads_filename = os.path.join(output_dir, f"{subject_id}_avg_amp_grads.npy") |
|
compute_and_store_amp_grads(model, train_set, filename=amp_grads_filename) |
|
log.info("... Done.") |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser( |
|
description="""Launch an experiment from a YAML experiment file. |
|
Example: ./train_experiments.py configs/config.py """ |
|
) |
|
parser.add_argument('subject_id', type=int, |
|
help='''Run for subject id....''') |
|
args = parser.parse_args() |
|
subject_id = args.subject_id |
|
low_cut_hz = None # low cut frequency for filtering |
|
high_cut_hz = None # high cut frequency for filtering |
|
n_epochs = 20 |
|
model_name = 'deep' |
|
output_dir = './results/' |
|
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s', |
|
level=logging.DEBUG, stream=sys.stdout) |
|
run_exp( |
|
subject_id, low_cut_hz, high_cut_hz, n_epochs, |
|
model_name, output_dir) |