Skip to content

Instantly share code, notes, and snippets.

@robintibor
Last active January 15, 2024 10:16
Show Gist options
  • Save robintibor/6a95a85088e651392c1bb4f912d1528e to your computer and use it in GitHub Desktop.
Save robintibor/6a95a85088e651392c1bb4f912d1528e to your computer and use it in GitHub Desktop.
HGD Reproduction HBM 2017
import torch.backends.cudnn as cudnn
import torch
from hyperoptim.parse import cartesian_dict_of_lists_product, \
product_of_list_of_lists_of_dicts
import logging
import time
import os
os.sys.path.insert(0, '/home/schirrmr/code/invertible-reimplement/')
logging.basicConfig(format='%(asctime)s | %(levelname)s : %(message)s')
log = logging.getLogger(__name__)
log.setLevel('INFO')
def get_templates():
return {}
def get_grid_param_list():
dictlistprod = cartesian_dict_of_lists_product
save_params = [
{
'save_folder': '/home/schirrmr/data/exps/braindecode/hgd-decoding/',
},
]
debug_params = [{
'debug': False,
}]
data_params = dictlistprod({
'subject_id': range(1, 15),
'low_cut_hz': [0, 4],
'high_cut_hz': [None],
'exponential_moving_fn': ['standardize', 'demean'],#standardize'],#'demean',#demean',
'only_C_sensors': [True],
'do_common_average_reference': [True],
'use_final_eval': [False],#False
})
train_params = dictlistprod({
'n_epochs': [800],
})
random_params = dictlistprod({
'seed': range(0,3),#range(0, 3),
})
model_params = dictlistprod({
'model_name': ['deep', 'shallow'],#'shallow',
})
store_params = dictlistprod({
'save_amp_grads': [False],
'save_model': [False],
})
grid_params = product_of_list_of_lists_of_dicts([
save_params,
data_params,
train_params,
debug_params,
random_params,
model_params,
store_params,
])
return grid_params
def sample_config_params(rng, params):
return params
def run(
ex,
subject_id,
low_cut_hz,
high_cut_hz,
exponential_moving_fn,
n_epochs,
model_name,
seed,
debug,
only_C_sensors,
do_common_average_reference,
use_final_eval,
save_amp_grads,
save_model,
):
kwargs = locals()
kwargs.pop('ex')
if not debug:
log.setLevel('INFO')
if debug:
kwargs['n_epochs'] = 3
file_obs = ex.observers[0]
output_dir = file_obs.dir
kwargs['output_dir'] = output_dir
torch.backends.cudnn.benchmark = True
import sys
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
level=logging.DEBUG, stream=sys.stdout)
start_time = time.time()
ex.info['finished'] = False
from braindecode.experiments.hgd.run import run_exp
clf = run_exp(**kwargs)
end_time = time.time()
run_time = end_time - start_time
ex.info['finished'] = True
ignore_keys = [
'batches', 'epoch', 'train_batch_count', 'valid_batch_count',
'train_loss_best',
'valid_loss_best', 'train_trial_accuracy_best',
'valid_trial_accuracy_best']
results = dict([(key, val) for key, val in clf.history[-1].items() if
key not in ignore_keys])
for key, val in results.items():
ex.info[key] = float(val)
ex.info['runtime'] = run_time
# Authors: Robin Schirrmeister <[email protected]>
#
# License: BSD (3-clause)
import argparse
import logging
import os.path
import sys
from functools import partial
import numpy as np
import torch
from torch import nn
from braindecode import EEGClassifier
from braindecode.datasets.moabb import MOABBDataset
from braindecode.preprocessing.preprocess import Preprocessor, preprocess
from braindecode.preprocessing.preprocess import exponential_moving_standardize
from braindecode.preprocessing.preprocess import exponential_moving_demean
from braindecode.preprocessing.windowers import create_windows_from_events
from braindecode.models import Deep4Net, EEGResNet
from braindecode.models import ShallowFBCSPNet
from braindecode.models.util import to_dense_prediction_model, get_output_shape
from braindecode.training.losses import CroppedLoss
from braindecode.util import set_random_seeds
from braindecode.visualization.gradients import compute_amplitude_gradients
from skorch.callbacks import LRScheduler
from skorch.helper import predefined_split
from torch.utils.data import Subset
log = logging.getLogger(__name__)
def load_preprocessed_data(subject_id, low_cut_hz, high_cut_hz, exponential_moving_fn,
only_C_sensors, do_common_average_reference, set_name):
log.info("Load dataset...")
if set_name == 'hgd':
dataset = MOABBDataset(dataset_name="Schirrmeister2017", subject_ids=[subject_id])
else:
assert set_name == "bcic_iv_2a"
dataset = MOABBDataset(dataset_name="BNCI2014001", subject_ids=[subject_id])
C_sensors = [
'FC5', 'FC1', 'FC2', 'FC6', 'C3', 'Cz', 'C4', 'CP5',
'CP1', 'CP2', 'CP6', 'FC3', 'FCz', 'FC4', 'C5', 'C1', 'C2', 'C6',
'CP3', 'CPz', 'CP4', 'FFC5h', 'FFC3h', 'FFC4h', 'FFC6h', 'FCC5h',
'FCC3h', 'FCC4h', 'FCC6h', 'CCP5h', 'CCP3h', 'CCP4h', 'CCP6h', 'CPP5h',
'CPP3h', 'CPP4h', 'CPP6h', 'FFC1h', 'FFC2h', 'FCC1h', 'FCC2h', 'CCP1h',
'CCP2h', 'CPP1h', 'CPP2h']
EEG_sensors = ['Fp1', 'Fp2', 'Fpz', 'F7', 'F3', 'Fz', 'F4', 'F8',
'FC5', 'FC1', 'FC2', 'FC6', 'M1', 'T7', 'C3', 'Cz', 'C4', 'T8', 'M2',
'CP5', 'CP1', 'CP2', 'CP6', 'P7', 'P3', 'Pz', 'P4', 'P8', 'POz', 'O1',
'Oz', 'O2', 'AF7', 'AF3', 'AF4', 'AF8', 'F5', 'F1', 'F2', 'F6', 'FC3',
'FCz', 'FC4', 'C5', 'C1', 'C2', 'C6', 'CP3', 'CPz', 'CP4', 'P5', 'P1',
'P2', 'P6', 'PO5', 'PO3', 'PO4', 'PO6', 'FT7', 'FT8', 'TP7', 'TP8',
'PO7', 'PO8', 'FT9', 'FT10', 'TPP9h', 'TPP10h', 'PO9', 'PO10', 'P9',
'P10', 'AFF1', 'AFz', 'AFF2', 'FFC5h', 'FFC3h', 'FFC4h', 'FFC6h', 'FCC5h',
'FCC3h', 'FCC4h', 'FCC6h', 'CCP5h', 'CCP3h', 'CCP4h', 'CCP6h', 'CPP5h',
'CPP3h', 'CPP4h', 'CPP6h', 'PPO1', 'PPO2', 'I1', 'Iz', 'I2', 'AFp3h',
'AFp4h', 'AFF5h', 'AFF6h', 'FFT7h', 'FFC1h', 'FFC2h', 'FFT8h', 'FTT9h',
'FTT7h', 'FCC1h', 'FCC2h', 'FTT8h', 'FTT10h', 'TTP7h', 'CCP1h', 'CCP2h',
'TTP8h', 'TPP7h', 'CPP1h', 'CPP2h', 'TPP8h', 'PPO9h', 'PPO5h', 'PPO6h',
'PPO10h', 'POO9h', 'POO3h', 'POO4h', 'POO10h', 'OI1h', 'OI2h']
if only_C_sensors:
sensor_names = C_sensors
else:
sensor_names = EEG_sensors
# Parameters for exponential moving standardization
factor_new = 1e-3
init_block_size = 1000
log.info("Preprocess dataset...")
moving_fn ={'standardize': exponential_moving_standardize,
'demean': exponential_moving_demean}[exponential_moving_fn]
preprocessors = [
# keep only C sensors
Preprocessor(fn='load_data'),
]
if set_name == "hgd":
preprocessors.append(Preprocessor(fn='pick_channels', ch_names=sensor_names, ordered=True))
else:
assert set_name == 'bcic_iv_2a'
preprocessors.append(Preprocessor("pick_types", eeg=True, meg=False, stim=False)) # Keep EEG sensors
preprocessors.append(Preprocessor(fn=lambda x: x * 1e6, apply_on_array=True))
preprocessors.append(Preprocessor(fn=lambda x: np.clip(x, -800, 800), apply_on_array=True))
if do_common_average_reference:
preprocessors.append(Preprocessor(fn='set_eeg_reference', ref_channels='average'),)
preprocessors.extend([
Preprocessor(fn='resample', sfreq=250),
# bandpass filter
Preprocessor(fn='filter', l_freq=low_cut_hz, h_freq=high_cut_hz),
# exponential moving standardization
Preprocessor(fn=moving_fn, factor_new=factor_new,
init_block_size=init_block_size, apply_on_array=True),
])
# Transform the data
preprocess(dataset, preprocessors)
return dataset
def create_cropped_model(model_name, n_chans, resnet_init_a):
######################################################################
# Now we create the model. To enable it to be used in cropped decoding
# efficiently, we manually set the length of the final convolution layer
# to some length that makes the receptive field of the ConvNet smaller
# than ``input_window_samples`` (see ``final_conv_length=30`` in the model
# definition).
#
cuda = torch.cuda.is_available() # check if GPU is available, if True chooses to use it
device = 'cuda' if cuda else 'cpu'
if cuda:
torch.backends.cudnn.benchmark = True
seed = 20200220 # random seed to make results reproducible
# Set random seed to be able to reproduce results
set_random_seeds(seed=seed, cuda=cuda)
n_classes = 4
if model_name == 'shallow':
model = ShallowFBCSPNet(
n_chans,
n_classes,
input_window_samples=None, # no need to provide if final_conv_length given
final_conv_length=30,
)
elif model_name == 'resnet':
model = EEGResNet(
n_chans,
n_classes,
input_window_samples=None, # no need to provide if final_conv_length given
n_first_filters=48,
final_pool_length=10,
conv_weight_init_fn=partial(nn.init.kaiming_normal_, a=resnet_init_a))
else:
assert model_name == 'deep'
model = Deep4Net(
n_chans,
n_classes,
input_window_samples=None, # no need to provide if final_conv_length given
final_conv_length=2,
)
# Send model to GPU
if cuda:
model.cuda()
######################################################################
# And now we transform model with strides to a model that outputs dense
# prediction, so we can use it to obtain predictions for all
# crops.
#
if model_name in ["shallow", "deep"]:
to_dense_prediction_model(model)
return model
def cut_windows(dataset, input_window_samples, window_stride_samples):
######################################################################
# Cut the data into windows
# -------------------------
#
######################################################################
# In contrast to trialwise decoding, we have to supply an explicit window size and window stride to the
# ``create_windows_from_events`` function.
#
trial_start_offset_seconds = -0.5
# Extract sampling frequency, check that they are same in all datasets
sfreq = dataset.datasets[0].raw.info['sfreq']
assert all([ds.raw.info['sfreq'] == sfreq for ds in dataset.datasets])
# Calculate the trial start offset in samples.
trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)
# Create windows using braindecode function for this. It needs parameters to define how
# trials should be used.
windows_dataset = create_windows_from_events(
dataset,
trial_start_offset_samples=trial_start_offset_samples,
trial_stop_offset_samples=0,
window_size_samples=input_window_samples,
window_stride_samples=window_stride_samples,
drop_last_window=False,
preload=True,
mapping={'left_hand': 0, 'right_hand': 1, 'feet': 2, 'rest': 3},
)
return windows_dataset
def split_into_train_valid(windows_dataset, use_final_eval):
######################################################################
# Split the dataset
# -----------------
#
# This code is the same as in trialwise decoding.
#
print("description", windows_dataset.description)
if sum(windows_dataset.description.session == 'session_T') > 0:
# BCIC IV 2a case
splitted = windows_dataset.split("session")
train_key = 'session_T'
test_key = 'session_E'
else:
splitted = windows_dataset.split('run')
train_key = '0train'
test_key = '1test'
print("splitted", splitted)
if use_final_eval:
train_set = splitted[train_key]
valid_set = splitted[test_key]
else:
full_train_set = splitted[train_key]
n_split = int(np.round(0.8 * len(full_train_set)))
# ensure this is multiple of 2 (number of windows per trial)
n_windows_per_trial = 2 # here set by hand
n_split = n_split - (n_split % n_windows_per_trial)
valid_set = Subset(full_train_set, range(n_split, len(full_train_set)))
train_set = Subset(full_train_set, range(0, n_split))
return train_set, valid_set
def run_training(model, model_name, train_set, valid_set, device, n_epochs, resnet_lr,
resnet_weight_decay, drop_channel_prob):
assert model_name in ['deep', 'shallow', 'resnet']
if model_name == 'shallow':
# These values we found good for shallow network:
lr = 0.0625 * 0.01
weight_decay = 0
elif model_name == 'resnet':
# Guessing here
# For deep4 they should be:
lr = resnet_lr
weight_decay = resnet_weight_decay
else:
assert model_name == 'deep'
# For deep4 they should be:
lr = 1 * 0.01
weight_decay = 0.5 * 0.001
batch_size = 64
from braindecode.augmentation import AugmentedDataLoader, ChannelsDropout
transforms = [ChannelsDropout(1, drop_channel_prob)]
clf = EEGClassifier(
model,
cropped=True,
criterion=CroppedLoss,
criterion__loss_function=torch.nn.functional.nll_loss,
iterator_train=AugmentedDataLoader,
iterator_train__transforms=transforms, # This sets the augmentations to use
optimizer=torch.optim.AdamW,
train_split=predefined_split(valid_set),
optimizer__lr=lr,
optimizer__weight_decay=weight_decay,
iterator_train__shuffle=True,
batch_size=batch_size,
callbacks=[
"accuracy", ("lr_scheduler", LRScheduler('CosineAnnealingLR', T_max=n_epochs - 1)),
],
device=device,
classes=["right", "left", "rest", "feet"],
)
# Model training for a specified number of epochs. `y` is None as it is already supplied
# in the dataset.
clf.fit(train_set, y=None, epochs=n_epochs)
return clf
def compute_and_store_amp_grads(model, train_set, filename):
amp_grads_per_filter = compute_amplitude_gradients(model, train_set, batch_size=64)
# average across compute windows
avg_amp_grads_per_filter = np.mean(amp_grads_per_filter, axis=1)
np.save(filename, avg_amp_grads_per_filter)
def run_exp(
seed,
subject_id,
low_cut_hz,
high_cut_hz,
exponential_moving_fn,
n_epochs,
model_name,
output_dir,
only_C_sensors,
do_common_average_reference,
use_final_eval,
save_amp_grads,
save_model,
resnet_lr,
resnet_weight_decay,
resnet_init_a,
debug,
set_name,
drop_channel_prob):
assert model_name in ['deep', 'shallow', 'resnet']
set_random_seeds(seed, True)
log.info(f"Load and preprocess data for subject {subject_id}...")
dataset = load_preprocessed_data(subject_id, low_cut_hz, high_cut_hz,
exponential_moving_fn=exponential_moving_fn,
only_C_sensors=only_C_sensors,
do_common_average_reference=do_common_average_reference,
set_name=set_name,
)
# Extract number of chans from dataset to create model
n_chans = dataset[0][0].shape[0]
log.info("Create cropped model...")
model = create_cropped_model(model_name, n_chans, resnet_init_a)
# Cut windows from the preprocessed data, using number of predictions
# per compute window to cut non-overlapping fully covering windows
# (except for overlap of last window to stay within trial bounds)
log.info("Cut windows from dataset ...")
input_window_samples = 1000
# To know the models’ receptive field, we calculate the shape of model
# output for a dummy input.
n_preds_per_input = get_output_shape(model, n_chans, input_window_samples)[2]
windows_dataset = cut_windows(
dataset, input_window_samples, window_stride_samples=n_preds_per_input)
# Split into train and valid, ignoring final evaluation for now
log.info("Split into train and valid...")
train_set, valid_set = split_into_train_valid(windows_dataset, use_final_eval=use_final_eval)
# Run actual training
log.info("Run training...")
clf = run_training(model, model_name, train_set, valid_set, 'cuda', n_epochs, resnet_lr,
resnet_weight_decay, drop_channel_prob)
if save_amp_grads:
log.info("Compute and store amplitude gradients ...")
amp_grads_filename = os.path.join(output_dir, f"{subject_id}_avg_amp_grads.npy")
compute_and_store_amp_grads(model, train_set, filename=amp_grads_filename)
if (not debug) and (save_model):
log.info("Save model ...")
# save model
torch.save(model, os.path.join(output_dir, f"model.pth"))
log.info("... Done.")
return clf
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description="""Launch an experiment from a YAML experiment file.
Example: ./train_experiments.py configs/config.py """
)
parser.add_argument('subject_id', type=int,
help='''Run for subject id....''')
args = parser.parse_args()
seed = 0
subject_id = args.subject_id
low_cut_hz = None # low cut frequency for filtering
high_cut_hz = None # high cut frequency for filtering
n_epochs = 800
model_name = 'deep'
output_dir = './results/'
exponential_moving_fn = "standardize"
only_C_sensors = True
do_common_average_reference = False
use_final_eval = True
save_amp_grads = False
save_model = False
resnet_lr = 1e-3
resnet_init_a = 1
resnet_weight_decay = 1e-5
debug = False
set_name = "hgd"
drop_channel_prob = 0.
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
level=logging.DEBUG, stream=sys.stdout)
run_exp(
seed,
subject_id,
low_cut_hz,
high_cut_hz,
exponential_moving_fn,
n_epochs,
model_name,
output_dir,
only_C_sensors,
do_common_average_reference,
use_final_eval,
save_amp_grads,
save_model,
resnet_lr,
resnet_weight_decay,
resnet_init_a,
debug,
set_name,
drop_channel_prob)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment