Created
October 24, 2023 18:08
-
-
Save robintibor/6a07c3f8b736778cf0f9051449878712 to your computer and use it in GitHub Desktop.
BCIC IV 2a and HGD code
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.backends.cudnn as cudnn | |
import torch | |
from hyperoptim.parse import cartesian_dict_of_lists_product, \ | |
product_of_list_of_lists_of_dicts | |
import logging | |
import time | |
import os | |
os.sys.path.insert(0, '/home/schirrmr/code/invertible-reimplement/') | |
logging.basicConfig(format='%(asctime)s | %(levelname)s : %(message)s') | |
log = logging.getLogger(__name__) | |
log.setLevel('INFO') | |
def get_templates(): | |
return {} | |
def get_grid_param_list(): | |
dictlistprod = cartesian_dict_of_lists_product | |
save_params = [ | |
{ | |
'save_folder': '/home/schirrmr/data/exps/braindecode/bcic-iv-2a-reproduction/trial-fixed-preproc/', | |
}, | |
] | |
debug_params = [{ | |
'debug': False, | |
}] | |
data_params = dictlistprod({ | |
'subject_id': range(1, 10), | |
'low_cut_hz': [0, 4], | |
}) | |
train_params = dictlistprod({ | |
'n_epochs': [800], | |
'cropped_or_trial': ['cropped'] | |
}) | |
random_params = dictlistprod({ | |
'seed': range(0, 3), | |
}) | |
model_params = dictlistprod({ | |
'model_name': ['shallow', 'deep'], | |
}) | |
grid_params = product_of_list_of_lists_of_dicts([ | |
save_params, | |
data_params, | |
train_params, | |
debug_params, | |
random_params, | |
model_params, | |
]) | |
return grid_params | |
def sample_config_params(rng, params): | |
return params | |
def run( | |
ex, | |
subject_id, | |
debug, | |
low_cut_hz, | |
n_epochs, | |
seed, | |
model_name, | |
cropped_or_trial, | |
): | |
kwargs = locals() | |
kwargs.pop('ex') | |
if not debug: | |
log.setLevel('INFO') | |
if debug: | |
kwargs['n_epochs'] = 3 | |
kwargs.pop('debug') | |
kwargs.pop('cropped_or_trial') | |
assert cropped_or_trial in ['cropped', 'trial'] | |
#file_obs = ex.observers[0] | |
#output_dir = file_obs.dir | |
#kwargs['output_dir'] = output_dir | |
torch.backends.cudnn.benchmark = True | |
import sys | |
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s', | |
level=logging.DEBUG, stream=sys.stdout) | |
start_time = time.time() | |
ex.info['finished'] = False | |
from braindecode.experiments.bcic_iv_2a.run import run_exp_cropped, run_exp_trial | |
if cropped_or_trial == 'cropped': | |
clf = run_exp_cropped(**kwargs) | |
else: | |
assert cropped_or_trial == 'trial' | |
clf = run_exp_trial(**kwargs) | |
end_time = time.time() | |
run_time = end_time - start_time | |
ex.info['finished'] = True | |
ignore_keys = [ | |
'batches', 'epoch', 'train_batch_count', 'valid_batch_count', | |
'train_loss_best', | |
'valid_loss_best', 'train_trial_accuracy_best', | |
'valid_trial_accuracy_best'] | |
results = dict([(key, val) for key, val in clf.history[-1].items() if | |
key not in ignore_keys]) | |
for key, val in results.items(): | |
ex.info[key] = float(val) | |
ex.info['runtime'] = run_time |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Cropped Decoding on BCIC IV 2a Competition Set with skorch and moabb. | |
===================================================================== | |
""" | |
# Authors: Maciej Sliwowski <[email protected]> | |
# Robin Tibor Schirrmeister <[email protected]> | |
# Lukas Gemein <[email protected]> | |
# Hubert Banville <[email protected]> | |
# | |
# License: BSD-3 | |
from braindecode.datautil.preprocess import preprocess, MNEPreproc, NumpyPreproc | |
from braindecode.datautil.preprocess import exponential_moving_standardize | |
from braindecode.util import set_random_seeds | |
from braindecode.models.util import to_dense_prediction_model, get_output_shape | |
from braindecode.models.shallow_fbcsp import ShallowFBCSPNet | |
from braindecode.models.deep4 import Deep4Net | |
from braindecode.training.losses import CroppedNLLLoss | |
from braindecode.datasets import MOABBDataset | |
from braindecode.classifier import EEGClassifier | |
from braindecode.datautil.windowers import create_windows_from_events | |
from functools import partial | |
import numpy as np | |
import torch | |
import mne | |
from skorch.callbacks import LRScheduler | |
mne.set_log_level('ERROR') | |
def run_exp_cropped(subject_id, model_name, low_cut_hz, n_epochs, seed): | |
assert model_name in ['shallow', 'deep'] | |
high_cut_hz = 38. | |
trial_start_offset_seconds = -0.5 | |
input_window_samples = 1000 | |
batch_size = 64 | |
factor_new = 1e-3 | |
init_block_size = 1000 | |
cuda = torch.cuda.is_available() | |
device = 'cuda' if cuda else 'cpu' | |
if cuda: | |
torch.backends.cudnn.benchmark = True | |
n_classes = 4 | |
n_chans = 22 | |
set_random_seeds(seed=seed, cuda=cuda) | |
if model_name == "shallow": | |
model = ShallowFBCSPNet( | |
n_chans, | |
n_classes, | |
input_window_samples=input_window_samples, | |
final_conv_length=30, | |
) | |
lr = 0.0625 * 0.01 | |
weight_decay = 0 | |
elif model_name == "deep": | |
model = Deep4Net( | |
n_chans, | |
n_classes, | |
input_window_samples=input_window_samples, | |
final_conv_length=2, | |
) | |
lr = 1 * 0.01 | |
weight_decay = 0.5 * 0.001 | |
if cuda: | |
model.cuda() | |
to_dense_prediction_model(model) | |
n_preds_per_input = get_output_shape(model, n_chans, input_window_samples)[2] | |
dataset = MOABBDataset(dataset_name="BNCI2014001", subject_ids=[subject_id]) | |
preprocessors = [ | |
# keep only EEG sensors | |
MNEPreproc(fn='pick_types', eeg=True, meg=False, stim=False), | |
# convert from volt to microvolt, directly modifying the numpy array | |
NumpyPreproc(fn=lambda x: x * 1e6), | |
# bandpass filter | |
MNEPreproc(fn='filter', l_freq=low_cut_hz, h_freq=high_cut_hz), | |
NumpyPreproc(fn=exponential_moving_standardize, | |
factor_new=factor_new, | |
init_block_size=init_block_size) | |
] | |
# Preprocess the data | |
preprocess(dataset, preprocessors) | |
sfreqs = [ds.raw.info['sfreq'] for ds in dataset.datasets] | |
assert len(np.unique(sfreqs)) == 1 | |
trial_start_offset_samples = int(trial_start_offset_seconds * sfreqs[0]) | |
windows_dataset = create_windows_from_events( | |
dataset, | |
trial_start_offset_samples=trial_start_offset_samples, | |
trial_stop_offset_samples=0, | |
window_size_samples=input_window_samples, | |
window_stride_samples=n_preds_per_input, | |
drop_last_window=False, | |
preload=True, | |
) | |
class TrainTestBCICIV2aSplit(object): | |
def __call__(self, dataset, y, **kwargs): | |
splitted = dataset.split('session') | |
return splitted['session_T'], splitted['session_E'] | |
clf = EEGClassifier( | |
model, | |
cropped=True, | |
criterion=CroppedNLLLoss, | |
optimizer=torch.optim.AdamW, | |
train_split=TrainTestBCICIV2aSplit(), | |
optimizer__lr=lr, | |
optimizer__weight_decay=weight_decay, | |
iterator_train__shuffle=True, | |
batch_size=batch_size, | |
callbacks=[ | |
"accuracy", | |
# seems n_epochs -1 leads to desired behavior of lr=0 after end of training? | |
("lr_scheduler", | |
LRScheduler('CosineAnnealingLR', T_max=n_epochs - 1)), | |
], | |
device=device, | |
) | |
clf.fit(windows_dataset, y=None, epochs=n_epochs) | |
return clf | |
def run_exp_trial(subject_id, model_name, low_cut_hz, n_epochs, seed): | |
high_cut_hz = 38. | |
trial_start_offset_seconds = -0.5 | |
input_window_samples = 1125 | |
batch_size = 64 | |
factor_new = 1e-3 | |
init_block_size = 1000 | |
cuda = torch.cuda.is_available() | |
device = 'cuda' if cuda else 'cpu' | |
if cuda: | |
torch.backends.cudnn.benchmark = True | |
n_classes = 4 | |
n_chans = 22 | |
set_random_seeds(seed=seed, cuda=cuda) | |
if model_name == "shallow": | |
model = ShallowFBCSPNet( | |
n_chans, | |
n_classes, | |
input_window_samples=input_window_samples, | |
final_conv_length='auto', | |
) | |
lr = 0.0625 * 0.01 | |
weight_decay = 0 | |
elif model_name == "deep": | |
model = Deep4Net( | |
n_chans, | |
n_classes, | |
input_window_samples=input_window_samples, | |
final_conv_length='auto', | |
) | |
lr = 1 * 0.01 | |
weight_decay = 0.5 * 0.001 | |
if cuda: | |
model.cuda() | |
dataset = MOABBDataset(dataset_name="BNCI2014001", subject_ids=[subject_id]) | |
preprocessors = [ | |
# keep only EEG sensors | |
MNEPreproc(fn='pick_types', eeg=True, meg=False, stim=False), | |
# convert from volt to microvolt, directly modifying the numpy array | |
NumpyPreproc(fn=lambda x: x * 1e6), | |
# bandpass filter | |
MNEPreproc(fn='filter', l_freq=low_cut_hz, h_freq=high_cut_hz), | |
NumpyPreproc(fn=exponential_moving_standardize, | |
factor_new=factor_new, | |
init_block_size=init_block_size) | |
] | |
# Transform the data | |
preprocess(dataset, preprocessors) | |
sfreqs = [ds.raw.info['sfreq'] for ds in dataset.datasets] | |
assert len(np.unique(sfreqs)) == 1 | |
trial_start_offset_samples = int(trial_start_offset_seconds * sfreqs[0]) | |
windows_dataset = create_windows_from_events( | |
dataset, | |
trial_start_offset_samples=trial_start_offset_samples, | |
trial_stop_offset_samples=0, | |
window_size_samples=input_window_samples, | |
window_stride_samples=input_window_samples, | |
drop_last_window=False, | |
preload=True, | |
) | |
class TrainTestBCICIV2aSplit(object): | |
def __call__(self, dataset, y, **kwargs): | |
splitted = dataset.split('session') | |
return splitted['session_T'], splitted['session_E'] | |
clf = EEGClassifier( | |
model, | |
cropped=False, | |
criterion=torch.nn.NLLLoss, | |
optimizer=torch.optim.AdamW, | |
train_split=TrainTestBCICIV2aSplit(), | |
optimizer__lr=lr, | |
optimizer__weight_decay=weight_decay, | |
iterator_train__shuffle=True, | |
batch_size=batch_size, | |
callbacks=[ | |
"accuracy", | |
# seems n_epochs -1 leads to desired behavior of lr=0 after end of training? | |
("lr_scheduler", | |
LRScheduler('CosineAnnealingLR', T_max=n_epochs - 1)), | |
], | |
device=device, | |
) | |
clf.fit(windows_dataset, y=None, epochs=n_epochs) | |
return clf |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.backends.cudnn as cudnn | |
import torch | |
from hyperoptim.parse import cartesian_dict_of_lists_product, \ | |
product_of_list_of_lists_of_dicts | |
import logging | |
import time | |
import os | |
logging.basicConfig(format='%(asctime)s | %(levelname)s : %(message)s') | |
log = logging.getLogger(__name__) | |
log.setLevel('INFO') | |
def get_templates(): | |
return {} | |
def get_grid_param_list(): | |
dictlistprod = cartesian_dict_of_lists_product | |
save_params = [ | |
{ | |
'save_folder': '/home/schirrmr/data/exps/braindecode/hgd-decoding-without-gamma/', | |
}, | |
] | |
debug_params = [{ | |
'debug': False, | |
}] | |
data_params = dictlistprod({ | |
'subject_id': range(1, 15), | |
'low_cut_hz': [0, 4], | |
'high_cut_hz': [None, 45], | |
'exponential_moving_fn': ['standardize'],#standardize'],#'demean',#demean', | |
'only_C_sensors': [True], | |
'do_common_average_reference': [False],#False | |
'use_final_eval': [True],#False | |
}) | |
train_params = dictlistprod({ | |
'n_epochs': [800], | |
}) | |
random_params = dictlistprod({ | |
'seed': range(5),#range(0, 3), | |
}) | |
model_params = dictlistprod({ | |
'model_name': ['deep', 'shallow'],#'shallow', | |
}) | |
store_params = dictlistprod({ | |
'save_amp_grads': [False], | |
'save_model': [False], | |
}) | |
moabb_params = dictlistprod({ | |
'new_moabb': [True], | |
}) | |
grid_params = product_of_list_of_lists_of_dicts([ | |
save_params, | |
data_params, | |
train_params, | |
debug_params, | |
random_params, | |
model_params, | |
store_params, | |
moabb_params, | |
]) | |
return grid_params | |
def sample_config_params(rng, params): | |
return params | |
def run( | |
ex, | |
subject_id, | |
low_cut_hz, | |
high_cut_hz, | |
exponential_moving_fn, | |
n_epochs, | |
model_name, | |
seed, | |
debug, | |
only_C_sensors, | |
do_common_average_reference, | |
use_final_eval, | |
save_amp_grads, | |
save_model, | |
new_moabb, | |
): | |
kwargs = locals() | |
kwargs.pop('ex') | |
new_moabb = kwargs.pop('new_moabb') | |
if new_moabb: | |
os.sys.path.insert(0, '/home/schirrmr/braindecode/code/moabb/') | |
if not debug: | |
log.setLevel('INFO') | |
if debug: | |
kwargs['n_epochs'] = 3 | |
file_obs = ex.observers[0] | |
output_dir = file_obs.dir | |
kwargs['output_dir'] = output_dir | |
torch.backends.cudnn.benchmark = True | |
import sys | |
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s', | |
level=logging.DEBUG, stream=sys.stdout) | |
start_time = time.time() | |
ex.info['finished'] = False | |
from braindecode.experiments.hgd.run import run_exp | |
clf = run_exp(**kwargs) | |
end_time = time.time() | |
run_time = end_time - start_time | |
ex.info['finished'] = True | |
ignore_keys = [ | |
'batches', 'epoch', 'train_batch_count', 'valid_batch_count', | |
'train_loss_best', | |
'valid_loss_best', 'train_trial_accuracy_best', | |
'valid_trial_accuracy_best'] | |
results = dict([(key, val) for key, val in clf.history[-1].items() if | |
key not in ignore_keys]) | |
for key, val in results.items(): | |
ex.info[key] = float(val) | |
ex.info['runtime'] = run_time |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Authors: Robin Schirrmeister <[email protected]> | |
# | |
# License: BSD (3-clause) | |
import argparse | |
import logging | |
import os.path | |
import sys | |
from functools import partial | |
import numpy as np | |
import torch | |
from torch import nn | |
from braindecode import EEGClassifier | |
from braindecode.datasets.moabb import MOABBDataset | |
from braindecode.preprocessing.preprocess import Preprocessor, preprocess | |
from braindecode.preprocessing.preprocess import exponential_moving_standardize | |
from braindecode.preprocessing.preprocess import exponential_moving_demean | |
from braindecode.preprocessing.windowers import create_windows_from_events | |
from braindecode.models import Deep4Net, EEGResNet | |
from braindecode.models import ShallowFBCSPNet | |
from braindecode.models.util import to_dense_prediction_model, get_output_shape | |
from braindecode.training.losses import CroppedLoss | |
from braindecode.util import set_random_seeds | |
from braindecode.visualization.gradients import compute_amplitude_gradients | |
from skorch.callbacks import LRScheduler | |
from skorch.helper import predefined_split | |
from torch.utils.data import Subset | |
log = logging.getLogger(__name__) | |
def load_preprocessed_data(subject_id, low_cut_hz, high_cut_hz, exponential_moving_fn, | |
only_C_sensors, do_common_average_reference, set_name): | |
log.info("Load dataset...") | |
if set_name == 'hgd': | |
dataset = MOABBDataset(dataset_name="Schirrmeister2017", subject_ids=[subject_id]) | |
else: | |
assert set_name == "bcic_iv_2a" | |
dataset = MOABBDataset(dataset_name="BNCI2014001", subject_ids=[subject_id]) | |
C_sensors = [ | |
'FC5', 'FC1', 'FC2', 'FC6', 'C3', 'Cz', 'C4', 'CP5', | |
'CP1', 'CP2', 'CP6', 'FC3', 'FCz', 'FC4', 'C5', 'C1', 'C2', 'C6', | |
'CP3', 'CPz', 'CP4', 'FFC5h', 'FFC3h', 'FFC4h', 'FFC6h', 'FCC5h', | |
'FCC3h', 'FCC4h', 'FCC6h', 'CCP5h', 'CCP3h', 'CCP4h', 'CCP6h', 'CPP5h', | |
'CPP3h', 'CPP4h', 'CPP6h', 'FFC1h', 'FFC2h', 'FCC1h', 'FCC2h', 'CCP1h', | |
'CCP2h', 'CPP1h', 'CPP2h'] | |
EEG_sensors = ['Fp1', 'Fp2', 'Fpz', 'F7', 'F3', 'Fz', 'F4', 'F8', | |
'FC5', 'FC1', 'FC2', 'FC6', 'M1', 'T7', 'C3', 'Cz', 'C4', 'T8', 'M2', | |
'CP5', 'CP1', 'CP2', 'CP6', 'P7', 'P3', 'Pz', 'P4', 'P8', 'POz', 'O1', | |
'Oz', 'O2', 'AF7', 'AF3', 'AF4', 'AF8', 'F5', 'F1', 'F2', 'F6', 'FC3', | |
'FCz', 'FC4', 'C5', 'C1', 'C2', 'C6', 'CP3', 'CPz', 'CP4', 'P5', 'P1', | |
'P2', 'P6', 'PO5', 'PO3', 'PO4', 'PO6', 'FT7', 'FT8', 'TP7', 'TP8', | |
'PO7', 'PO8', 'FT9', 'FT10', 'TPP9h', 'TPP10h', 'PO9', 'PO10', 'P9', | |
'P10', 'AFF1', 'AFz', 'AFF2', 'FFC5h', 'FFC3h', 'FFC4h', 'FFC6h', 'FCC5h', | |
'FCC3h', 'FCC4h', 'FCC6h', 'CCP5h', 'CCP3h', 'CCP4h', 'CCP6h', 'CPP5h', | |
'CPP3h', 'CPP4h', 'CPP6h', 'PPO1', 'PPO2', 'I1', 'Iz', 'I2', 'AFp3h', | |
'AFp4h', 'AFF5h', 'AFF6h', 'FFT7h', 'FFC1h', 'FFC2h', 'FFT8h', 'FTT9h', | |
'FTT7h', 'FCC1h', 'FCC2h', 'FTT8h', 'FTT10h', 'TTP7h', 'CCP1h', 'CCP2h', | |
'TTP8h', 'TPP7h', 'CPP1h', 'CPP2h', 'TPP8h', 'PPO9h', 'PPO5h', 'PPO6h', | |
'PPO10h', 'POO9h', 'POO3h', 'POO4h', 'POO10h', 'OI1h', 'OI2h'] | |
if only_C_sensors: | |
sensor_names = C_sensors | |
else: | |
sensor_names = EEG_sensors | |
# Parameters for exponential moving standardization | |
factor_new = 1e-3 | |
init_block_size = 1000 | |
log.info("Preprocess dataset...") | |
moving_fn ={'standardize': exponential_moving_standardize, | |
'demean': exponential_moving_demean}[exponential_moving_fn] | |
preprocessors = [ | |
# keep only C sensors | |
Preprocessor(fn='load_data'), | |
] | |
if set_name == "hgd": | |
preprocessors.append(Preprocessor(fn='pick_channels', ch_names=sensor_names, ordered=True)) | |
else: | |
assert set_name == 'bcic_iv_2a' | |
preprocessors.append(Preprocessor("pick_types", eeg=True, meg=False, stim=False)) # Keep EEG sensors | |
preprocessors.append(Preprocessor(fn=lambda x: x * 1e6, apply_on_array=True)) | |
preprocessors.append(Preprocessor(fn=lambda x: np.clip(x, -800, 800), apply_on_array=True)) | |
if do_common_average_reference: | |
preprocessors.append(Preprocessor(fn='set_eeg_reference', ref_channels='average'),) | |
preprocessors.extend([ | |
Preprocessor(fn='resample', sfreq=250), | |
# bandpass filter | |
Preprocessor(fn='filter', l_freq=low_cut_hz, h_freq=high_cut_hz), | |
# exponential moving standardization | |
Preprocessor(fn=moving_fn, factor_new=factor_new, | |
init_block_size=init_block_size, apply_on_array=True), | |
]) | |
# Transform the data | |
preprocess(dataset, preprocessors) | |
return dataset | |
def create_cropped_model(model_name, n_chans, resnet_init_a): | |
###################################################################### | |
# Now we create the model. To enable it to be used in cropped decoding | |
# efficiently, we manually set the length of the final convolution layer | |
# to some length that makes the receptive field of the ConvNet smaller | |
# than ``input_window_samples`` (see ``final_conv_length=30`` in the model | |
# definition). | |
# | |
cuda = torch.cuda.is_available() # check if GPU is available, if True chooses to use it | |
device = 'cuda' if cuda else 'cpu' | |
if cuda: | |
torch.backends.cudnn.benchmark = True | |
seed = 20200220 # random seed to make results reproducible | |
# Set random seed to be able to reproduce results | |
set_random_seeds(seed=seed, cuda=cuda) | |
n_classes = 4 | |
if model_name == 'shallow': | |
model = ShallowFBCSPNet( | |
n_chans, | |
n_classes, | |
input_window_samples=None, # no need to provide if final_conv_length given | |
final_conv_length=30, | |
) | |
elif model_name == 'resnet': | |
model = EEGResNet( | |
n_chans, | |
n_classes, | |
input_window_samples=None, # no need to provide if final_conv_length given | |
n_first_filters=48, | |
final_pool_length=10, | |
conv_weight_init_fn=partial(nn.init.kaiming_normal_, a=resnet_init_a)) | |
else: | |
assert model_name == 'deep' | |
model = Deep4Net( | |
n_chans, | |
n_classes, | |
input_window_samples=None, # no need to provide if final_conv_length given | |
final_conv_length=2, | |
) | |
# Send model to GPU | |
if cuda: | |
model.cuda() | |
###################################################################### | |
# And now we transform model with strides to a model that outputs dense | |
# prediction, so we can use it to obtain predictions for all | |
# crops. | |
# | |
if model_name in ["shallow", "deep"]: | |
to_dense_prediction_model(model) | |
return model | |
def cut_windows(dataset, input_window_samples, window_stride_samples): | |
###################################################################### | |
# Cut the data into windows | |
# ------------------------- | |
# | |
###################################################################### | |
# In contrast to trialwise decoding, we have to supply an explicit window size and window stride to the | |
# ``create_windows_from_events`` function. | |
# | |
trial_start_offset_seconds = -0.5 | |
# Extract sampling frequency, check that they are same in all datasets | |
sfreq = dataset.datasets[0].raw.info['sfreq'] | |
assert all([ds.raw.info['sfreq'] == sfreq for ds in dataset.datasets]) | |
# Calculate the trial start offset in samples. | |
trial_start_offset_samples = int(trial_start_offset_seconds * sfreq) | |
# Create windows using braindecode function for this. It needs parameters to define how | |
# trials should be used. | |
windows_dataset = create_windows_from_events( | |
dataset, | |
trial_start_offset_samples=trial_start_offset_samples, | |
trial_stop_offset_samples=0, | |
window_size_samples=input_window_samples, | |
window_stride_samples=window_stride_samples, | |
drop_last_window=False, | |
preload=True, | |
mapping={'left_hand': 0, 'right_hand': 1, 'feet': 2, 'rest': 3}, | |
) | |
return windows_dataset | |
def split_into_train_valid(windows_dataset, use_final_eval): | |
###################################################################### | |
# Split the dataset | |
# ----------------- | |
# | |
# This code is the same as in trialwise decoding. | |
# | |
if sum(windows_dataset.description.session == 'session_T') > 0: | |
# BCIC IV 2a case | |
splitted = windows_dataset.split("session") | |
train_key = 'session_T' | |
test_key = 'session_E' | |
else: | |
splitted = windows_dataset.split('run') | |
train_key = 'train' | |
test_key = 'test' | |
if use_final_eval: | |
train_set = splitted[train_key] | |
valid_set = splitted[test_key] | |
else: | |
full_train_set = splitted[train_key] | |
n_split = int(np.round(0.8 * len(full_train_set))) | |
# ensure this is multiple of 2 (number of windows per trial) | |
n_windows_per_trial = 2 # here set by hand | |
n_split = n_split - (n_split % n_windows_per_trial) | |
valid_set = Subset(full_train_set, range(n_split, len(full_train_set))) | |
train_set = Subset(full_train_set, range(0, n_split)) | |
return train_set, valid_set | |
def run_training(model, model_name, train_set, valid_set, device, n_epochs, resnet_lr, | |
resnet_weight_decay): | |
assert model_name in ['deep', 'shallow', 'resnet'] | |
if model_name == 'shallow': | |
# These values we found good for shallow network: | |
lr = 0.0625 * 0.01 | |
weight_decay = 0 | |
elif model_name == 'resnet': | |
# Guessing here | |
# For deep4 they should be: | |
lr = resnet_lr | |
weight_decay = resnet_weight_decay | |
else: | |
assert model_name == 'deep' | |
# For deep4 they should be: | |
lr = 1 * 0.01 | |
weight_decay = 0.5 * 0.001 | |
batch_size = 64 | |
clf = EEGClassifier( | |
model, | |
cropped=True, | |
criterion=CroppedLoss, | |
criterion__loss_function=torch.nn.functional.nll_loss, | |
optimizer=torch.optim.AdamW, | |
train_split=predefined_split(valid_set), | |
optimizer__lr=lr, | |
optimizer__weight_decay=weight_decay, | |
iterator_train__shuffle=True, | |
batch_size=batch_size, | |
callbacks=[ | |
"accuracy", ("lr_scheduler", LRScheduler('CosineAnnealingLR', T_max=n_epochs - 1)), | |
], | |
device=device, | |
) | |
# Model training for a specified number of epochs. `y` is None as it is already supplied | |
# in the dataset. | |
clf.fit(train_set, y=None, epochs=n_epochs) | |
return clf | |
def compute_and_store_amp_grads(model, train_set, filename): | |
amp_grads_per_filter = compute_amplitude_gradients(model, train_set, batch_size=64) | |
# average across compute windows | |
avg_amp_grads_per_filter = np.mean(amp_grads_per_filter, axis=1) | |
np.save(filename, avg_amp_grads_per_filter) | |
def run_exp( | |
seed, | |
subject_id, | |
low_cut_hz, | |
high_cut_hz, | |
exponential_moving_fn, | |
n_epochs, | |
model_name, | |
output_dir, | |
only_C_sensors, | |
do_common_average_reference, | |
use_final_eval, | |
save_amp_grads, | |
save_model, | |
resnet_lr, | |
resnet_weight_decay, | |
resnet_init_a, | |
debug, | |
set_name): | |
assert model_name in ['deep', 'shallow', 'resnet'] | |
set_random_seeds(seed, True) | |
log.info(f"Load and preprocess data for subject {subject_id}...") | |
dataset = load_preprocessed_data(subject_id, low_cut_hz, high_cut_hz, | |
exponential_moving_fn=exponential_moving_fn, | |
only_C_sensors=only_C_sensors, | |
do_common_average_reference=do_common_average_reference, | |
set_name=set_name, | |
) | |
# Extract number of chans from dataset to create model | |
n_chans = dataset[0][0].shape[0] | |
log.info("Create cropped model...") | |
model = create_cropped_model(model_name, n_chans, resnet_init_a) | |
# Cut windows from the preprocessed data, using number of predictions | |
# per compute window to cut non-overlapping fully covering windows | |
# (except for overlap of last window to stay within trial bounds) | |
log.info("Cut windows from dataset ...") | |
input_window_samples = 1000 | |
# To know the models’ receptive field, we calculate the shape of model | |
# output for a dummy input. | |
n_preds_per_input = get_output_shape(model, n_chans, input_window_samples)[2] | |
windows_dataset = cut_windows( | |
dataset, input_window_samples, window_stride_samples=n_preds_per_input) | |
# Split into train and valid, ignoring final evaluation for now | |
log.info("Split into train and valid...") | |
train_set, valid_set = split_into_train_valid(windows_dataset, use_final_eval=use_final_eval) | |
# Run actual training | |
log.info("Run training...") | |
clf = run_training(model, model_name, train_set, valid_set, 'cuda', n_epochs, resnet_lr, | |
resnet_weight_decay) | |
if save_amp_grads: | |
log.info("Compute and store amplitude gradients ...") | |
amp_grads_filename = os.path.join(output_dir, f"{subject_id}_avg_amp_grads.npy") | |
compute_and_store_amp_grads(model, train_set, filename=amp_grads_filename) | |
if (not debug) and (save_model): | |
log.info("Save model ...") | |
# save model | |
torch.save(model, os.path.join(output_dir, f"model.pth")) | |
log.info("... Done.") | |
return clf | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser( | |
description="""Launch an experiment from a YAML experiment file. | |
Example: ./train_experiments.py configs/config.py """ | |
) | |
parser.add_argument('subject_id', type=int, | |
help='''Run for subject id....''') | |
args = parser.parse_args() | |
subject_id = args.subject_id | |
low_cut_hz = None # low cut frequency for filtering | |
high_cut_hz = None # high cut frequency for filtering | |
n_epochs = 20 | |
model_name = 'deep' | |
output_dir = './results/' | |
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s', | |
level=logging.DEBUG, stream=sys.stdout) | |
run_exp( | |
subject_id, low_cut_hz, high_cut_hz, n_epochs, | |
model_name, output_dir) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment