Last active
June 15, 2021 11:21
-
-
Save gemeinl/d64c014debb5f58e4feacb57a8656ed0 to your computer and use it in GitHub Desktop.
Braindecode with sci-kit learn pipeline chaining
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/home/gemeinl/anaconda3/envs/new_braindecode/lib/python3.7/site-packages/sklearn/utils/deprecation.py:144: FutureWarning: The sklearn.metrics.scorer module is deprecated in version 0.22 and will be removed in version 0.24. The corresponding classes / functions should instead be imported from sklearn.metrics. Anything that cannot be imported from sklearn.metrics is now part of the private API.\n", | |
" warnings.warn(message, FutureWarning)\n" | |
] | |
} | |
], | |
"source": [ | |
"import torch\n", | |
"from sklearn.pipeline import Pipeline\n", | |
"from skorch.callbacks import LRScheduler\n", | |
"from skorch.helper import predefined_split\n", | |
"from sklearn.base import TransformerMixin\n", | |
"\n", | |
"from braindecode import EEGClassifier\n", | |
"from braindecode.util import set_random_seeds\n", | |
"from braindecode.models import ShallowFBCSPNet\n", | |
"from braindecode.datautil.preprocess import exponential_moving_standardize\n", | |
"from braindecode.datasets.moabb import MOABBDataset\n", | |
"from braindecode.datautil.windowers import (\n", | |
" create_windows_from_events, create_fixed_length_windows)\n", | |
"from braindecode.datautil.preprocess import (\n", | |
" MNEPreproc, NumpyPreproc, preprocess)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Preprocessor(TransformerMixin):\n", | |
" def fit(self, X, y=None):\n", | |
" return self\n", | |
" \n", | |
"\n", | |
"class EventWindower(Preprocessor):\n", | |
" def __init__(self, *args, **kwargs):\n", | |
" self.args=args\n", | |
" self.kwargs=kwargs\n", | |
" \n", | |
" def transform(self, X):\n", | |
" return create_windows_from_events(\n", | |
" concat_ds=X, *self.args, **self.kwargs)\n", | |
" \n", | |
" \n", | |
"class FixedLengthWindower(Preprocessor):\n", | |
" def __init__(self, *args, **kwargs):\n", | |
" self.args=args\n", | |
" self.kwargs=kwargs\n", | |
" \n", | |
" def transform(self, X):\n", | |
" return create_fixed_length_windows(\n", | |
" concat_ds=X, *self.args, **self.kwargs)\n", | |
"\n", | |
" \n", | |
"class MNETransformer(Preprocessor):\n", | |
" def __init__(self, fn, **kwargs):\n", | |
" self.pre = MNEPreproc(fn=fn, **kwargs)\n", | |
" \n", | |
" def transform(self, X):\n", | |
" preprocess(X, [self.pre])\n", | |
" return X\n", | |
"\n", | |
" \n", | |
"class NumpyTransformer(Preprocessor):\n", | |
" def __init__(self, fn, **kwargs):\n", | |
" self.pre = NumpyPreproc(fn=fn, **kwargs)\n", | |
" \n", | |
" def transform(self, X):\n", | |
" preprocess(X, [self.pre])\n", | |
" return X" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Known from experimental design\n", | |
"sfreq = 250 \n", | |
"n_classes = 4\n", | |
"n_chans = 22\n", | |
"original_trial_duration = 4" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Preprocessing parameters\n", | |
"low_cut_hz = 4. # low cut frequency for filtering\n", | |
"high_cut_hz = 38. # high cut frequency for filtering\n", | |
"# Parameters for exponential moving standardization\n", | |
"factor_new = 1e-3\n", | |
"init_block_size = 1000\n", | |
"trial_start_offset_seconds = -0.5\n", | |
"# Calculate the trial start offset in samples.\n", | |
"trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)\n", | |
"\n", | |
"\n", | |
"# Model parameters\n", | |
"seed = 20200220 # random seed to make results reproducible\n", | |
"input_window_samples = int(original_trial_duration * sfreq - trial_start_offset_samples)\n", | |
"\n", | |
"\n", | |
"# Training parameters\n", | |
"batch_size = 64\n", | |
"n_epochs = 4\n", | |
"# These values we found good for shallow network:\n", | |
"lr = 0.0625 * 0.01\n", | |
"weight_decay = 0\n", | |
"# For deep4 they should be:\n", | |
"# lr = 1 * 0.01\n", | |
"# weight_decay = 0.5 * 0.001" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"create a model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"cuda = torch.cuda.is_available() # check if GPU is available, if True chooses to use it\n", | |
"device = 'cuda' if cuda else 'cpu'\n", | |
"if cuda:\n", | |
" torch.backends.cudnn.benchmark = True\n", | |
"\n", | |
"# Set random seed to be able to reproduce results\n", | |
"set_random_seeds(seed=seed, cuda=cuda)\n", | |
"\n", | |
"model = ShallowFBCSPNet(\n", | |
" n_chans,\n", | |
" n_classes,\n", | |
" input_window_samples=input_window_samples,\n", | |
" final_conv_length='auto',\n", | |
")\n", | |
"\n", | |
"# Send model to GPU\n", | |
"if cuda:\n", | |
" model.cuda()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"chain all preprocessing steps as well as classifier in a pipeline" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"pipe = Pipeline([\n", | |
" (\"pick_channels\", MNETransformer(\n", | |
" fn='pick_types', \n", | |
" eeg=True, \n", | |
" meg=False, \n", | |
" stim=False)\n", | |
" ),\n", | |
" (\"convert_to_microvolts\", NumpyTransformer(\n", | |
" fn=lambda x: x * 1e6)\n", | |
" ),\n", | |
" (\"bandpass\", MNETransformer(\n", | |
" fn='filter', \n", | |
" l_freq=low_cut_hz, \n", | |
" h_freq=high_cut_hz)\n", | |
" ),\n", | |
" (\"standardize\", NumpyTransformer(\n", | |
" fn=exponential_moving_standardize, \n", | |
" factor_new=factor_new,\n", | |
" init_block_size=init_block_size)\n", | |
" ),\n", | |
" (\"create_compute_windows\", EventWindower(\n", | |
" trial_start_offset_samples=trial_start_offset_samples,\n", | |
" trial_stop_offset_samples=0, preload=True)\n", | |
" ),\n", | |
" (\"classifier\", EEGClassifier(\n", | |
" model,\n", | |
" criterion=torch.nn.NLLLoss,\n", | |
" optimizer=torch.optim.AdamW,\n", | |
" train_split=lambda X, y: (X.split(\"session\")[\"session_T\"], \n", | |
" X.split(\"session\")[\"session_E\"]),\n", | |
" optimizer__lr=lr,\n", | |
" optimizer__weight_decay=weight_decay,\n", | |
" batch_size=batch_size,\n", | |
" callbacks=[\n", | |
" \"accuracy\", \n", | |
" (\"lr_scheduler\", LRScheduler('CosineAnnealingLR', T_max=n_epochs - 1)),\n", | |
" ],\n", | |
" device=device)),\n", | |
"])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"load some data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"48 events found\n", | |
"Event IDs: [1 2 3 4]\n", | |
"48 events found\n", | |
"Event IDs: [1 2 3 4]\n", | |
"48 events found\n", | |
"Event IDs: [1 2 3 4]\n", | |
"48 events found\n", | |
"Event IDs: [1 2 3 4]\n", | |
"48 events found\n", | |
"Event IDs: [1 2 3 4]\n", | |
"48 events found\n", | |
"Event IDs: [1 2 3 4]\n", | |
"48 events found\n", | |
"Event IDs: [1 2 3 4]\n", | |
"48 events found\n", | |
"Event IDs: [1 2 3 4]\n", | |
"48 events found\n", | |
"Event IDs: [1 2 3 4]\n", | |
"48 events found\n", | |
"Event IDs: [1 2 3 4]\n", | |
"48 events found\n", | |
"Event IDs: [1 2 3 4]\n", | |
"48 events found\n", | |
"Event IDs: [1 2 3 4]\n" | |
] | |
} | |
], | |
"source": [ | |
"subject_id = 3\n", | |
"dataset = MOABBDataset(dataset_name=\"BNCI2014001\", subject_ids=[subject_id])\n", | |
"assert all([ds.raw.info['sfreq'] == sfreq for ds in dataset.datasets])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"perform fit" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Filtering raw data in 1 contiguous segment\n", | |
"Setting up band-pass filter from 4 - 38 Hz\n", | |
"\n", | |
"FIR filter parameters\n", | |
"---------------------\n", | |
"Designing a one-pass, zero-phase, non-causal bandpass filter:\n", | |
"- Windowed time-domain design (firwin) method\n", | |
"- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", | |
"- Lower passband edge: 4.00\n", | |
"- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)\n", | |
"- Upper passband edge: 38.00 Hz\n", | |
"- Upper transition bandwidth: 9.50 Hz (-6 dB cutoff frequency: 42.75 Hz)\n", | |
"- Filter length: 413 samples (1.652 sec)\n", | |
"\n", | |
"Filtering raw data in 1 contiguous segment\n", | |
"Setting up band-pass filter from 4 - 38 Hz\n", | |
"\n", | |
"FIR filter parameters\n", | |
"---------------------\n", | |
"Designing a one-pass, zero-phase, non-causal bandpass filter:\n", | |
"- Windowed time-domain design (firwin) method\n", | |
"- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", | |
"- Lower passband edge: 4.00\n", | |
"- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)\n", | |
"- Upper passband edge: 38.00 Hz\n", | |
"- Upper transition bandwidth: 9.50 Hz (-6 dB cutoff frequency: 42.75 Hz)\n", | |
"- Filter length: 413 samples (1.652 sec)\n", | |
"\n", | |
"Filtering raw data in 1 contiguous segment\n", | |
"Setting up band-pass filter from 4 - 38 Hz\n", | |
"\n", | |
"FIR filter parameters\n", | |
"---------------------\n", | |
"Designing a one-pass, zero-phase, non-causal bandpass filter:\n", | |
"- Windowed time-domain design (firwin) method\n", | |
"- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", | |
"- Lower passband edge: 4.00\n", | |
"- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)\n", | |
"- Upper passband edge: 38.00 Hz\n", | |
"- Upper transition bandwidth: 9.50 Hz (-6 dB cutoff frequency: 42.75 Hz)\n", | |
"- Filter length: 413 samples (1.652 sec)\n", | |
"\n", | |
"Filtering raw data in 1 contiguous segment\n", | |
"Setting up band-pass filter from 4 - 38 Hz\n", | |
"\n", | |
"FIR filter parameters\n", | |
"---------------------\n", | |
"Designing a one-pass, zero-phase, non-causal bandpass filter:\n", | |
"- Windowed time-domain design (firwin) method\n", | |
"- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", | |
"- Lower passband edge: 4.00\n", | |
"- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)\n", | |
"- Upper passband edge: 38.00 Hz\n", | |
"- Upper transition bandwidth: 9.50 Hz (-6 dB cutoff frequency: 42.75 Hz)\n", | |
"- Filter length: 413 samples (1.652 sec)\n", | |
"\n", | |
"Filtering raw data in 1 contiguous segment\n", | |
"Setting up band-pass filter from 4 - 38 Hz\n", | |
"\n", | |
"FIR filter parameters\n", | |
"---------------------\n", | |
"Designing a one-pass, zero-phase, non-causal bandpass filter:\n", | |
"- Windowed time-domain design (firwin) method\n", | |
"- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", | |
"- Lower passband edge: 4.00\n", | |
"- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)\n", | |
"- Upper passband edge: 38.00 Hz\n", | |
"- Upper transition bandwidth: 9.50 Hz (-6 dB cutoff frequency: 42.75 Hz)\n", | |
"- Filter length: 413 samples (1.652 sec)\n", | |
"\n", | |
"Filtering raw data in 1 contiguous segment\n", | |
"Setting up band-pass filter from 4 - 38 Hz\n", | |
"\n", | |
"FIR filter parameters\n", | |
"---------------------\n", | |
"Designing a one-pass, zero-phase, non-causal bandpass filter:\n", | |
"- Windowed time-domain design (firwin) method\n", | |
"- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", | |
"- Lower passband edge: 4.00\n", | |
"- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)\n", | |
"- Upper passband edge: 38.00 Hz\n", | |
"- Upper transition bandwidth: 9.50 Hz (-6 dB cutoff frequency: 42.75 Hz)\n", | |
"- Filter length: 413 samples (1.652 sec)\n", | |
"\n", | |
"Filtering raw data in 1 contiguous segment\n", | |
"Setting up band-pass filter from 4 - 38 Hz\n", | |
"\n", | |
"FIR filter parameters\n", | |
"---------------------\n", | |
"Designing a one-pass, zero-phase, non-causal bandpass filter:\n", | |
"- Windowed time-domain design (firwin) method\n", | |
"- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", | |
"- Lower passband edge: 4.00\n", | |
"- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)\n", | |
"- Upper passband edge: 38.00 Hz\n", | |
"- Upper transition bandwidth: 9.50 Hz (-6 dB cutoff frequency: 42.75 Hz)\n", | |
"- Filter length: 413 samples (1.652 sec)\n", | |
"\n", | |
"Filtering raw data in 1 contiguous segment\n", | |
"Setting up band-pass filter from 4 - 38 Hz\n", | |
"\n", | |
"FIR filter parameters\n", | |
"---------------------\n", | |
"Designing a one-pass, zero-phase, non-causal bandpass filter:\n", | |
"- Windowed time-domain design (firwin) method\n", | |
"- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", | |
"- Lower passband edge: 4.00\n", | |
"- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)\n", | |
"- Upper passband edge: 38.00 Hz\n", | |
"- Upper transition bandwidth: 9.50 Hz (-6 dB cutoff frequency: 42.75 Hz)\n", | |
"- Filter length: 413 samples (1.652 sec)\n", | |
"\n", | |
"Filtering raw data in 1 contiguous segment\n", | |
"Setting up band-pass filter from 4 - 38 Hz\n", | |
"\n", | |
"FIR filter parameters\n", | |
"---------------------\n", | |
"Designing a one-pass, zero-phase, non-causal bandpass filter:\n", | |
"- Windowed time-domain design (firwin) method\n", | |
"- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", | |
"- Lower passband edge: 4.00\n", | |
"- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)\n", | |
"- Upper passband edge: 38.00 Hz\n", | |
"- Upper transition bandwidth: 9.50 Hz (-6 dB cutoff frequency: 42.75 Hz)\n", | |
"- Filter length: 413 samples (1.652 sec)\n", | |
"\n", | |
"Filtering raw data in 1 contiguous segment\n", | |
"Setting up band-pass filter from 4 - 38 Hz\n", | |
"\n", | |
"FIR filter parameters\n", | |
"---------------------\n", | |
"Designing a one-pass, zero-phase, non-causal bandpass filter:\n", | |
"- Windowed time-domain design (firwin) method\n", | |
"- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", | |
"- Lower passband edge: 4.00\n", | |
"- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)\n", | |
"- Upper passband edge: 38.00 Hz\n", | |
"- Upper transition bandwidth: 9.50 Hz (-6 dB cutoff frequency: 42.75 Hz)\n", | |
"- Filter length: 413 samples (1.652 sec)\n", | |
"\n", | |
"Filtering raw data in 1 contiguous segment\n", | |
"Setting up band-pass filter from 4 - 38 Hz\n", | |
"\n", | |
"FIR filter parameters\n", | |
"---------------------\n", | |
"Designing a one-pass, zero-phase, non-causal bandpass filter:\n", | |
"- Windowed time-domain design (firwin) method\n", | |
"- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", | |
"- Lower passband edge: 4.00\n", | |
"- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)\n", | |
"- Upper passband edge: 38.00 Hz\n", | |
"- Upper transition bandwidth: 9.50 Hz (-6 dB cutoff frequency: 42.75 Hz)\n", | |
"- Filter length: 413 samples (1.652 sec)\n", | |
"\n", | |
"Filtering raw data in 1 contiguous segment\n", | |
"Setting up band-pass filter from 4 - 38 Hz\n", | |
"\n", | |
"FIR filter parameters\n", | |
"---------------------\n", | |
"Designing a one-pass, zero-phase, non-causal bandpass filter:\n", | |
"- Windowed time-domain design (firwin) method\n", | |
"- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n", | |
"- Lower passband edge: 4.00\n", | |
"- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)\n", | |
"- Upper passband edge: 38.00 Hz\n", | |
"- Upper transition bandwidth: 9.50 Hz (-6 dB cutoff frequency: 42.75 Hz)\n", | |
"- Filter length: 413 samples (1.652 sec)\n", | |
"\n", | |
"Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n", | |
"48 matching events found\n", | |
"No baseline correction applied\n", | |
"Adding metadata with 4 columns\n", | |
"0 projection items activated\n", | |
"Loading data for 48 events and 1125 original time points ...\n", | |
"0 bad epochs dropped\n", | |
"0 bad epochs dropped\n", | |
"Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n", | |
"48 matching events found\n", | |
"No baseline correction applied\n", | |
"Adding metadata with 4 columns\n", | |
"0 projection items activated\n", | |
"Loading data for 48 events and 1125 original time points ...\n", | |
"0 bad epochs dropped\n", | |
"0 bad epochs dropped\n", | |
"Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n", | |
"48 matching events found\n", | |
"No baseline correction applied\n", | |
"Adding metadata with 4 columns\n", | |
"0 projection items activated\n", | |
"Loading data for 48 events and 1125 original time points ...\n", | |
"0 bad epochs dropped\n", | |
"0 bad epochs dropped\n", | |
"Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n", | |
"48 matching events found\n", | |
"No baseline correction applied\n", | |
"Adding metadata with 4 columns\n", | |
"0 projection items activated\n", | |
"Loading data for 48 events and 1125 original time points ...\n", | |
"0 bad epochs dropped\n", | |
"0 bad epochs dropped\n", | |
"Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n", | |
"48 matching events found\n", | |
"No baseline correction applied\n", | |
"Adding metadata with 4 columns\n", | |
"0 projection items activated\n", | |
"Loading data for 48 events and 1125 original time points ...\n", | |
"0 bad epochs dropped\n", | |
"0 bad epochs dropped\n", | |
"Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n", | |
"48 matching events found\n", | |
"No baseline correction applied\n", | |
"Adding metadata with 4 columns\n", | |
"0 projection items activated\n", | |
"Loading data for 48 events and 1125 original time points ...\n", | |
"0 bad epochs dropped\n", | |
"0 bad epochs dropped\n", | |
"Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n", | |
"48 matching events found\n", | |
"No baseline correction applied\n", | |
"Adding metadata with 4 columns\n", | |
"0 projection items activated\n", | |
"Loading data for 48 events and 1125 original time points ...\n", | |
"0 bad epochs dropped\n", | |
"0 bad epochs dropped\n", | |
"Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n", | |
"48 matching events found\n", | |
"No baseline correction applied\n", | |
"Adding metadata with 4 columns\n", | |
"0 projection items activated\n", | |
"Loading data for 48 events and 1125 original time points ...\n", | |
"0 bad epochs dropped\n", | |
"0 bad epochs dropped\n", | |
"Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n", | |
"48 matching events found\n", | |
"No baseline correction applied\n", | |
"Adding metadata with 4 columns\n", | |
"0 projection items activated\n", | |
"Loading data for 48 events and 1125 original time points ...\n", | |
"0 bad epochs dropped\n", | |
"0 bad epochs dropped\n", | |
"Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n", | |
"48 matching events found\n", | |
"No baseline correction applied\n", | |
"Adding metadata with 4 columns\n", | |
"0 projection items activated\n", | |
"Loading data for 48 events and 1125 original time points ...\n", | |
"0 bad epochs dropped\n", | |
"0 bad epochs dropped\n", | |
"Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n", | |
"48 matching events found\n", | |
"No baseline correction applied\n", | |
"Adding metadata with 4 columns\n", | |
"0 projection items activated\n", | |
"Loading data for 48 events and 1125 original time points ...\n", | |
"0 bad epochs dropped\n", | |
"0 bad epochs dropped\n", | |
"Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n", | |
"48 matching events found\n", | |
"No baseline correction applied\n", | |
"Adding metadata with 4 columns\n", | |
"0 projection items activated\n", | |
"Loading data for 48 events and 1125 original time points ...\n", | |
"0 bad epochs dropped\n", | |
"0 bad epochs dropped\n", | |
" epoch train_accuracy train_loss valid_accuracy valid_loss dur\n", | |
"------- ---------------- ------------ ---------------- ------------ ------\n", | |
" 1 \u001b[36m0.2500\u001b[0m \u001b[32m1.5919\u001b[0m \u001b[35m0.2500\u001b[0m \u001b[31m6.2938\u001b[0m 1.0413\n", | |
" 2 0.2500 \u001b[32m1.1950\u001b[0m 0.2500 7.2211 0.2248\n", | |
" 3 0.2500 \u001b[32m1.0809\u001b[0m 0.2500 \u001b[31m5.8693\u001b[0m 0.2272\n", | |
" 4 \u001b[36m0.2569\u001b[0m \u001b[32m1.0008\u001b[0m \u001b[35m0.2535\u001b[0m \u001b[31m4.5076\u001b[0m 0.2266\n" | |
] | |
} | |
], | |
"source": [ | |
"pipe = pipe.fit(dataset, classifier__epochs=n_epochs)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 576x216 with 2 Axes>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"import matplotlib.pyplot as plt\n", | |
"from matplotlib.lines import Line2D\n", | |
"import pandas as pd\n", | |
"\n", | |
"# Extract loss and accuracy values for plotting from history object\n", | |
"results_columns = ['train_loss', 'valid_loss', 'train_accuracy', 'valid_accuracy']\n", | |
"df = pd.DataFrame(pipe.steps[-1][1].history[:, results_columns], columns=results_columns,\n", | |
" index=pipe.steps[-1][1].history[:, 'epoch'])\n", | |
"\n", | |
"# get percent of misclass for better visual comparison to loss\n", | |
"df = df.assign(train_misclass=100 - 100 * df.train_accuracy,\n", | |
" valid_misclass=100 - 100 * df.valid_accuracy)\n", | |
"\n", | |
"plt.style.use('seaborn')\n", | |
"fig, ax1 = plt.subplots(figsize=(8, 3))\n", | |
"df.loc[:, ['train_loss', 'valid_loss']].plot(\n", | |
" ax=ax1, style=['-', ':'], marker='o', color='tab:blue', legend=False, fontsize=14)\n", | |
"\n", | |
"ax1.tick_params(axis='y', labelcolor='tab:blue', labelsize=14)\n", | |
"ax1.set_ylabel(\"Loss\", color='tab:blue', fontsize=14)\n", | |
"\n", | |
"ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis\n", | |
"\n", | |
"df.loc[:, ['train_misclass', 'valid_misclass']].plot(\n", | |
" ax=ax2, style=['-', ':'], marker='o', color='tab:red', legend=False)\n", | |
"ax2.tick_params(axis='y', labelcolor='tab:red', labelsize=14)\n", | |
"ax2.set_ylabel(\"Misclassification Rate [%]\", color='tab:red', fontsize=14)\n", | |
"ax2.set_ylim(ax2.get_ylim()[0], 85) # make some room for legend\n", | |
"ax1.set_xlabel(\"Epoch\", fontsize=14)\n", | |
"\n", | |
"# where some data has already been plotted to ax\n", | |
"handles = []\n", | |
"handles.append(Line2D([0], [0], color='black', linewidth=1, linestyle='-', label='Train'))\n", | |
"handles.append(Line2D([0], [0], color='black', linewidth=1, linestyle=':', label='Valid'))\n", | |
"plt.legend(handles, [h.get_label() for h in handles], fontsize=14)\n", | |
"plt.tight_layout()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "new_braindecode", | |
"language": "python", | |
"name": "new_braindecode" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.6" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment