Skip to content

Instantly share code, notes, and snippets.

@sylvchev
Forked from gemeinl/to_scikit_learn_API.ipynb
Created June 15, 2021 11:21
Show Gist options
  • Save sylvchev/270f1377a997d06a56cf49bea36ed013 to your computer and use it in GitHub Desktop.
Save sylvchev/270f1377a997d06a56cf49bea36ed013 to your computer and use it in GitHub Desktop.
Braindecode with sci-kit learn pipeline chaining
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/gemeinl/anaconda3/envs/new_braindecode/lib/python3.7/site-packages/sklearn/utils/deprecation.py:144: FutureWarning: The sklearn.metrics.scorer module is deprecated in version 0.22 and will be removed in version 0.24. The corresponding classes / functions should instead be imported from sklearn.metrics. Anything that cannot be imported from sklearn.metrics is now part of the private API.\n",
" warnings.warn(message, FutureWarning)\n"
]
}
],
"source": [
"import torch\n",
"from sklearn.pipeline import Pipeline\n",
"from skorch.callbacks import LRScheduler\n",
"from skorch.helper import predefined_split\n",
"from sklearn.base import TransformerMixin\n",
"\n",
"from braindecode import EEGClassifier\n",
"from braindecode.util import set_random_seeds\n",
"from braindecode.models import ShallowFBCSPNet\n",
"from braindecode.datautil.preprocess import exponential_moving_standardize\n",
"from braindecode.datasets.moabb import MOABBDataset\n",
"from braindecode.datautil.windowers import (\n",
" create_windows_from_events, create_fixed_length_windows)\n",
"from braindecode.datautil.preprocess import (\n",
" MNEPreproc, NumpyPreproc, preprocess)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"class Preprocessor(TransformerMixin):\n",
" def fit(self, X, y=None):\n",
" return self\n",
" \n",
"\n",
"class EventWindower(Preprocessor):\n",
" def __init__(self, *args, **kwargs):\n",
" self.args=args\n",
" self.kwargs=kwargs\n",
" \n",
" def transform(self, X):\n",
" return create_windows_from_events(\n",
" concat_ds=X, *self.args, **self.kwargs)\n",
" \n",
" \n",
"class FixedLengthWindower(Preprocessor):\n",
" def __init__(self, *args, **kwargs):\n",
" self.args=args\n",
" self.kwargs=kwargs\n",
" \n",
" def transform(self, X):\n",
" return create_fixed_length_windows(\n",
" concat_ds=X, *self.args, **self.kwargs)\n",
"\n",
" \n",
"class MNETransformer(Preprocessor):\n",
" def __init__(self, fn, **kwargs):\n",
" self.pre = MNEPreproc(fn=fn, **kwargs)\n",
" \n",
" def transform(self, X):\n",
" preprocess(X, [self.pre])\n",
" return X\n",
"\n",
" \n",
"class NumpyTransformer(Preprocessor):\n",
" def __init__(self, fn, **kwargs):\n",
" self.pre = NumpyPreproc(fn=fn, **kwargs)\n",
" \n",
" def transform(self, X):\n",
" preprocess(X, [self.pre])\n",
" return X"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Known from experimental design\n",
"sfreq = 250 \n",
"n_classes = 4\n",
"n_chans = 22\n",
"original_trial_duration = 4"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# Preprocessing parameters\n",
"low_cut_hz = 4. # low cut frequency for filtering\n",
"high_cut_hz = 38. # high cut frequency for filtering\n",
"# Parameters for exponential moving standardization\n",
"factor_new = 1e-3\n",
"init_block_size = 1000\n",
"trial_start_offset_seconds = -0.5\n",
"# Calculate the trial start offset in samples.\n",
"trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)\n",
"\n",
"\n",
"# Model parameters\n",
"seed = 20200220 # random seed to make results reproducible\n",
"input_window_samples = int(original_trial_duration * sfreq - trial_start_offset_samples)\n",
"\n",
"\n",
"# Training parameters\n",
"batch_size = 64\n",
"n_epochs = 4\n",
"# These values we found good for shallow network:\n",
"lr = 0.0625 * 0.01\n",
"weight_decay = 0\n",
"# For deep4 they should be:\n",
"# lr = 1 * 0.01\n",
"# weight_decay = 0.5 * 0.001"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"create a model"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"cuda = torch.cuda.is_available() # check if GPU is available, if True chooses to use it\n",
"device = 'cuda' if cuda else 'cpu'\n",
"if cuda:\n",
" torch.backends.cudnn.benchmark = True\n",
"\n",
"# Set random seed to be able to reproduce results\n",
"set_random_seeds(seed=seed, cuda=cuda)\n",
"\n",
"model = ShallowFBCSPNet(\n",
" n_chans,\n",
" n_classes,\n",
" input_window_samples=input_window_samples,\n",
" final_conv_length='auto',\n",
")\n",
"\n",
"# Send model to GPU\n",
"if cuda:\n",
" model.cuda()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"chain all preprocessing steps as well as classifier in a pipeline"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"pipe = Pipeline([\n",
" (\"pick_channels\", MNETransformer(\n",
" fn='pick_types', \n",
" eeg=True, \n",
" meg=False, \n",
" stim=False)\n",
" ),\n",
" (\"convert_to_microvolts\", NumpyTransformer(\n",
" fn=lambda x: x * 1e6)\n",
" ),\n",
" (\"bandpass\", MNETransformer(\n",
" fn='filter', \n",
" l_freq=low_cut_hz, \n",
" h_freq=high_cut_hz)\n",
" ),\n",
" (\"standardize\", NumpyTransformer(\n",
" fn=exponential_moving_standardize, \n",
" factor_new=factor_new,\n",
" init_block_size=init_block_size)\n",
" ),\n",
" (\"create_compute_windows\", EventWindower(\n",
" trial_start_offset_samples=trial_start_offset_samples,\n",
" trial_stop_offset_samples=0, preload=True)\n",
" ),\n",
" (\"classifier\", EEGClassifier(\n",
" model,\n",
" criterion=torch.nn.NLLLoss,\n",
" optimizer=torch.optim.AdamW,\n",
" train_split=lambda X, y: (X.split(\"session\")[\"session_T\"], \n",
" X.split(\"session\")[\"session_E\"]),\n",
" optimizer__lr=lr,\n",
" optimizer__weight_decay=weight_decay,\n",
" batch_size=batch_size,\n",
" callbacks=[\n",
" \"accuracy\", \n",
" (\"lr_scheduler\", LRScheduler('CosineAnnealingLR', T_max=n_epochs - 1)),\n",
" ],\n",
" device=device)),\n",
"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"load some data"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"48 events found\n",
"Event IDs: [1 2 3 4]\n",
"48 events found\n",
"Event IDs: [1 2 3 4]\n",
"48 events found\n",
"Event IDs: [1 2 3 4]\n",
"48 events found\n",
"Event IDs: [1 2 3 4]\n",
"48 events found\n",
"Event IDs: [1 2 3 4]\n",
"48 events found\n",
"Event IDs: [1 2 3 4]\n",
"48 events found\n",
"Event IDs: [1 2 3 4]\n",
"48 events found\n",
"Event IDs: [1 2 3 4]\n",
"48 events found\n",
"Event IDs: [1 2 3 4]\n",
"48 events found\n",
"Event IDs: [1 2 3 4]\n",
"48 events found\n",
"Event IDs: [1 2 3 4]\n",
"48 events found\n",
"Event IDs: [1 2 3 4]\n"
]
}
],
"source": [
"subject_id = 3\n",
"dataset = MOABBDataset(dataset_name=\"BNCI2014001\", subject_ids=[subject_id])\n",
"assert all([ds.raw.info['sfreq'] == sfreq for ds in dataset.datasets])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"perform fit"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Filtering raw data in 1 contiguous segment\n",
"Setting up band-pass filter from 4 - 38 Hz\n",
"\n",
"FIR filter parameters\n",
"---------------------\n",
"Designing a one-pass, zero-phase, non-causal bandpass filter:\n",
"- Windowed time-domain design (firwin) method\n",
"- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n",
"- Lower passband edge: 4.00\n",
"- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)\n",
"- Upper passband edge: 38.00 Hz\n",
"- Upper transition bandwidth: 9.50 Hz (-6 dB cutoff frequency: 42.75 Hz)\n",
"- Filter length: 413 samples (1.652 sec)\n",
"\n",
"Filtering raw data in 1 contiguous segment\n",
"Setting up band-pass filter from 4 - 38 Hz\n",
"\n",
"FIR filter parameters\n",
"---------------------\n",
"Designing a one-pass, zero-phase, non-causal bandpass filter:\n",
"- Windowed time-domain design (firwin) method\n",
"- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n",
"- Lower passband edge: 4.00\n",
"- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)\n",
"- Upper passband edge: 38.00 Hz\n",
"- Upper transition bandwidth: 9.50 Hz (-6 dB cutoff frequency: 42.75 Hz)\n",
"- Filter length: 413 samples (1.652 sec)\n",
"\n",
"Filtering raw data in 1 contiguous segment\n",
"Setting up band-pass filter from 4 - 38 Hz\n",
"\n",
"FIR filter parameters\n",
"---------------------\n",
"Designing a one-pass, zero-phase, non-causal bandpass filter:\n",
"- Windowed time-domain design (firwin) method\n",
"- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n",
"- Lower passband edge: 4.00\n",
"- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)\n",
"- Upper passband edge: 38.00 Hz\n",
"- Upper transition bandwidth: 9.50 Hz (-6 dB cutoff frequency: 42.75 Hz)\n",
"- Filter length: 413 samples (1.652 sec)\n",
"\n",
"Filtering raw data in 1 contiguous segment\n",
"Setting up band-pass filter from 4 - 38 Hz\n",
"\n",
"FIR filter parameters\n",
"---------------------\n",
"Designing a one-pass, zero-phase, non-causal bandpass filter:\n",
"- Windowed time-domain design (firwin) method\n",
"- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n",
"- Lower passband edge: 4.00\n",
"- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)\n",
"- Upper passband edge: 38.00 Hz\n",
"- Upper transition bandwidth: 9.50 Hz (-6 dB cutoff frequency: 42.75 Hz)\n",
"- Filter length: 413 samples (1.652 sec)\n",
"\n",
"Filtering raw data in 1 contiguous segment\n",
"Setting up band-pass filter from 4 - 38 Hz\n",
"\n",
"FIR filter parameters\n",
"---------------------\n",
"Designing a one-pass, zero-phase, non-causal bandpass filter:\n",
"- Windowed time-domain design (firwin) method\n",
"- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n",
"- Lower passband edge: 4.00\n",
"- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)\n",
"- Upper passband edge: 38.00 Hz\n",
"- Upper transition bandwidth: 9.50 Hz (-6 dB cutoff frequency: 42.75 Hz)\n",
"- Filter length: 413 samples (1.652 sec)\n",
"\n",
"Filtering raw data in 1 contiguous segment\n",
"Setting up band-pass filter from 4 - 38 Hz\n",
"\n",
"FIR filter parameters\n",
"---------------------\n",
"Designing a one-pass, zero-phase, non-causal bandpass filter:\n",
"- Windowed time-domain design (firwin) method\n",
"- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n",
"- Lower passband edge: 4.00\n",
"- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)\n",
"- Upper passband edge: 38.00 Hz\n",
"- Upper transition bandwidth: 9.50 Hz (-6 dB cutoff frequency: 42.75 Hz)\n",
"- Filter length: 413 samples (1.652 sec)\n",
"\n",
"Filtering raw data in 1 contiguous segment\n",
"Setting up band-pass filter from 4 - 38 Hz\n",
"\n",
"FIR filter parameters\n",
"---------------------\n",
"Designing a one-pass, zero-phase, non-causal bandpass filter:\n",
"- Windowed time-domain design (firwin) method\n",
"- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n",
"- Lower passband edge: 4.00\n",
"- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)\n",
"- Upper passband edge: 38.00 Hz\n",
"- Upper transition bandwidth: 9.50 Hz (-6 dB cutoff frequency: 42.75 Hz)\n",
"- Filter length: 413 samples (1.652 sec)\n",
"\n",
"Filtering raw data in 1 contiguous segment\n",
"Setting up band-pass filter from 4 - 38 Hz\n",
"\n",
"FIR filter parameters\n",
"---------------------\n",
"Designing a one-pass, zero-phase, non-causal bandpass filter:\n",
"- Windowed time-domain design (firwin) method\n",
"- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n",
"- Lower passband edge: 4.00\n",
"- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)\n",
"- Upper passband edge: 38.00 Hz\n",
"- Upper transition bandwidth: 9.50 Hz (-6 dB cutoff frequency: 42.75 Hz)\n",
"- Filter length: 413 samples (1.652 sec)\n",
"\n",
"Filtering raw data in 1 contiguous segment\n",
"Setting up band-pass filter from 4 - 38 Hz\n",
"\n",
"FIR filter parameters\n",
"---------------------\n",
"Designing a one-pass, zero-phase, non-causal bandpass filter:\n",
"- Windowed time-domain design (firwin) method\n",
"- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n",
"- Lower passband edge: 4.00\n",
"- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)\n",
"- Upper passband edge: 38.00 Hz\n",
"- Upper transition bandwidth: 9.50 Hz (-6 dB cutoff frequency: 42.75 Hz)\n",
"- Filter length: 413 samples (1.652 sec)\n",
"\n",
"Filtering raw data in 1 contiguous segment\n",
"Setting up band-pass filter from 4 - 38 Hz\n",
"\n",
"FIR filter parameters\n",
"---------------------\n",
"Designing a one-pass, zero-phase, non-causal bandpass filter:\n",
"- Windowed time-domain design (firwin) method\n",
"- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n",
"- Lower passband edge: 4.00\n",
"- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)\n",
"- Upper passband edge: 38.00 Hz\n",
"- Upper transition bandwidth: 9.50 Hz (-6 dB cutoff frequency: 42.75 Hz)\n",
"- Filter length: 413 samples (1.652 sec)\n",
"\n",
"Filtering raw data in 1 contiguous segment\n",
"Setting up band-pass filter from 4 - 38 Hz\n",
"\n",
"FIR filter parameters\n",
"---------------------\n",
"Designing a one-pass, zero-phase, non-causal bandpass filter:\n",
"- Windowed time-domain design (firwin) method\n",
"- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n",
"- Lower passband edge: 4.00\n",
"- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)\n",
"- Upper passband edge: 38.00 Hz\n",
"- Upper transition bandwidth: 9.50 Hz (-6 dB cutoff frequency: 42.75 Hz)\n",
"- Filter length: 413 samples (1.652 sec)\n",
"\n",
"Filtering raw data in 1 contiguous segment\n",
"Setting up band-pass filter from 4 - 38 Hz\n",
"\n",
"FIR filter parameters\n",
"---------------------\n",
"Designing a one-pass, zero-phase, non-causal bandpass filter:\n",
"- Windowed time-domain design (firwin) method\n",
"- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation\n",
"- Lower passband edge: 4.00\n",
"- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)\n",
"- Upper passband edge: 38.00 Hz\n",
"- Upper transition bandwidth: 9.50 Hz (-6 dB cutoff frequency: 42.75 Hz)\n",
"- Filter length: 413 samples (1.652 sec)\n",
"\n",
"Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
"48 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 48 events and 1125 original time points ...\n",
"0 bad epochs dropped\n",
"0 bad epochs dropped\n",
"Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
"48 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 48 events and 1125 original time points ...\n",
"0 bad epochs dropped\n",
"0 bad epochs dropped\n",
"Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
"48 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 48 events and 1125 original time points ...\n",
"0 bad epochs dropped\n",
"0 bad epochs dropped\n",
"Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
"48 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 48 events and 1125 original time points ...\n",
"0 bad epochs dropped\n",
"0 bad epochs dropped\n",
"Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
"48 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 48 events and 1125 original time points ...\n",
"0 bad epochs dropped\n",
"0 bad epochs dropped\n",
"Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
"48 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 48 events and 1125 original time points ...\n",
"0 bad epochs dropped\n",
"0 bad epochs dropped\n",
"Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
"48 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 48 events and 1125 original time points ...\n",
"0 bad epochs dropped\n",
"0 bad epochs dropped\n",
"Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
"48 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 48 events and 1125 original time points ...\n",
"0 bad epochs dropped\n",
"0 bad epochs dropped\n",
"Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
"48 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 48 events and 1125 original time points ...\n",
"0 bad epochs dropped\n",
"0 bad epochs dropped\n",
"Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
"48 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 48 events and 1125 original time points ...\n",
"0 bad epochs dropped\n",
"0 bad epochs dropped\n",
"Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
"48 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 48 events and 1125 original time points ...\n",
"0 bad epochs dropped\n",
"0 bad epochs dropped\n",
"Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
"48 matching events found\n",
"No baseline correction applied\n",
"Adding metadata with 4 columns\n",
"0 projection items activated\n",
"Loading data for 48 events and 1125 original time points ...\n",
"0 bad epochs dropped\n",
"0 bad epochs dropped\n",
" epoch train_accuracy train_loss valid_accuracy valid_loss dur\n",
"------- ---------------- ------------ ---------------- ------------ ------\n",
" 1 \u001b[36m0.2500\u001b[0m \u001b[32m1.5919\u001b[0m \u001b[35m0.2500\u001b[0m \u001b[31m6.2938\u001b[0m 1.0413\n",
" 2 0.2500 \u001b[32m1.1950\u001b[0m 0.2500 7.2211 0.2248\n",
" 3 0.2500 \u001b[32m1.0809\u001b[0m 0.2500 \u001b[31m5.8693\u001b[0m 0.2272\n",
" 4 \u001b[36m0.2569\u001b[0m \u001b[32m1.0008\u001b[0m \u001b[35m0.2535\u001b[0m \u001b[31m4.5076\u001b[0m 0.2266\n"
]
}
],
"source": [
"pipe = pipe.fit(dataset, classifier__epochs=n_epochs)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 576x216 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"from matplotlib.lines import Line2D\n",
"import pandas as pd\n",
"\n",
"# Extract loss and accuracy values for plotting from history object\n",
"results_columns = ['train_loss', 'valid_loss', 'train_accuracy', 'valid_accuracy']\n",
"df = pd.DataFrame(pipe.steps[-1][1].history[:, results_columns], columns=results_columns,\n",
" index=pipe.steps[-1][1].history[:, 'epoch'])\n",
"\n",
"# get percent of misclass for better visual comparison to loss\n",
"df = df.assign(train_misclass=100 - 100 * df.train_accuracy,\n",
" valid_misclass=100 - 100 * df.valid_accuracy)\n",
"\n",
"plt.style.use('seaborn')\n",
"fig, ax1 = plt.subplots(figsize=(8, 3))\n",
"df.loc[:, ['train_loss', 'valid_loss']].plot(\n",
" ax=ax1, style=['-', ':'], marker='o', color='tab:blue', legend=False, fontsize=14)\n",
"\n",
"ax1.tick_params(axis='y', labelcolor='tab:blue', labelsize=14)\n",
"ax1.set_ylabel(\"Loss\", color='tab:blue', fontsize=14)\n",
"\n",
"ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis\n",
"\n",
"df.loc[:, ['train_misclass', 'valid_misclass']].plot(\n",
" ax=ax2, style=['-', ':'], marker='o', color='tab:red', legend=False)\n",
"ax2.tick_params(axis='y', labelcolor='tab:red', labelsize=14)\n",
"ax2.set_ylabel(\"Misclassification Rate [%]\", color='tab:red', fontsize=14)\n",
"ax2.set_ylim(ax2.get_ylim()[0], 85) # make some room for legend\n",
"ax1.set_xlabel(\"Epoch\", fontsize=14)\n",
"\n",
"# where some data has already been plotted to ax\n",
"handles = []\n",
"handles.append(Line2D([0], [0], color='black', linewidth=1, linestyle='-', label='Train'))\n",
"handles.append(Line2D([0], [0], color='black', linewidth=1, linestyle=':', label='Valid'))\n",
"plt.legend(handles, [h.get_label() for h in handles], fontsize=14)\n",
"plt.tight_layout()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "new_braindecode",
"language": "python",
"name": "new_braindecode"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment