Created
June 11, 2021 13:39
-
-
Save robintibor/9a6fdc36db7f23bf6ba9927e50646e94 to your computer and use it in GitHub Desktop.
Temporal Filters as Convolutions - different options
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "3e618e81", | |
"metadata": {}, | |
"source": [ | |
"# Temporal Filters Through Convolutions - different options" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "0d534d03", | |
"metadata": {}, | |
"source": [ | |
"## Some data loading to have example data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "0eac5fa0", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from braindecode.datasets.moabb import MOABBDataset\n", | |
"import mne\n", | |
"\n", | |
"subject_id = 3\n", | |
"dataset = MOABBDataset(dataset_name=\"BNCI2014001\", subject_ids=[subject_id])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "bed120cf", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/home/schirrmr/braindecode/code/braindecode/braindecode/datautil/preprocess.py:87: UserWarning: MNEPreproc is deprecated. Use Preprocessor with `apply_on_array=False` instead.\n", | |
" warn('MNEPreproc is deprecated. Use Preprocessor with '\n", | |
"/home/schirrmr/braindecode/code/braindecode/braindecode/datautil/preprocess.py:105: UserWarning: NumpyPreproc is deprecated. Use Preprocessor with `apply_on_array=True` instead.\n", | |
" warn('NumpyPreproc is deprecated. Use Preprocessor with '\n" | |
] | |
} | |
], | |
"source": [ | |
"from braindecode.datautil.preprocess import exponential_moving_standardize\n", | |
"from braindecode.datautil.preprocess import MNEPreproc, NumpyPreproc, preprocess\n", | |
"\n", | |
"low_cut_hz = 4. # low cut frequency for filtering\n", | |
"high_cut_hz = 38. # high cut frequency for filtering\n", | |
"# Parameters for exponential moving standardization\n", | |
"factor_new = 1e-3\n", | |
"init_block_size = 1000\n", | |
"\n", | |
"preprocessors = [\n", | |
" # keep only EEG sensors\n", | |
" MNEPreproc(fn='pick_types', eeg=True, meg=False, stim=False),\n", | |
" # convert from volt to microvolt, directly modifying the numpy array\n", | |
" NumpyPreproc(fn=lambda x: x * 1e6),\n", | |
" # bandpass filter\n", | |
" MNEPreproc(fn='filter', l_freq=low_cut_hz, h_freq=high_cut_hz),\n", | |
" # exponential moving standardization\n", | |
" NumpyPreproc(fn=exponential_moving_standardize, factor_new=factor_new,\n", | |
" init_block_size=init_block_size)\n", | |
"]\n", | |
"\n", | |
"# Transform the data\n", | |
"preprocess(dataset, preprocessors)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "f04ae50e", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"from braindecode.datautil.windowers import create_windows_from_events\n", | |
"\n", | |
"trial_start_offset_seconds = -0.5\n", | |
"# Extract sampling frequency, check that they are same in all datasets\n", | |
"sfreq = dataset.datasets[0].raw.info['sfreq']\n", | |
"assert all([ds.raw.info['sfreq'] == sfreq for ds in dataset.datasets])\n", | |
"# Calculate the trial start offset in samples.\n", | |
"trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)\n", | |
"\n", | |
"# Create windows using braindecode function for this. It needs parameters to define how\n", | |
"# trials should be used.\n", | |
"windows_dataset = create_windows_from_events(\n", | |
" dataset,\n", | |
" trial_start_offset_samples=trial_start_offset_samples,\n", | |
" trial_stop_offset_samples=0,\n", | |
" preload=True,\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "d93b4e7b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"from skorch.utils import to_tensor\n", | |
"# small batch of 12 examples\n", | |
"X = torch.stack([to_tensor(windows_dataset[i][0], 'cpu') for i in range(12)], dim=0)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"id": "db80a711", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"torch.Size([12, 22, 1125])" | |
] | |
}, | |
"execution_count": 16, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"X.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"id": "dee8a2e0", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"n_chans = X.shape[1]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "e3395576", | |
"metadata": {}, | |
"source": [ | |
"### Regular Conv1D will do all_channels to all_channels" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"id": "dbba163e", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from torch import nn\n", | |
"n_time_kernel = 9\n", | |
"regular_conv = nn.Conv1d(n_chans, n_chans, n_time_kernel)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"id": "c033c373", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"torch.Size([22, 22, 9])" | |
] | |
}, | |
"execution_count": 25, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"regular_conv.weight.shape \n", | |
"# -> each of 22 outputs uses all 22 inputs thats why [22...22,9]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"id": "9915527d", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"torch.Size([12, 22, 1117])" | |
] | |
}, | |
"execution_count": 27, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"regular_conv(X).shape" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "ffedc91c", | |
"metadata": {}, | |
"source": [ | |
"## Grouped Conv to make different filters for each channel" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"id": "b6b202a5", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"grouped_conv = nn.Conv1d(n_chans, n_chans*4, n_time_kernel, groups=n_chans)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 31, | |
"id": "b0504179", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"torch.Size([88, 1, 9])" | |
] | |
}, | |
"execution_count": 31, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"grouped_conv.weight.shape\n", | |
"# -> each channel is only convolved with its own filters, in this case\n", | |
"# we have 4 different filters for each channel, 22*4=88 in total" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 30, | |
"id": "2d8d33be", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"torch.Size([12, 88, 1117])" | |
] | |
}, | |
"execution_count": 30, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"grouped_conv(X).shape\n", | |
"# now 22*4 output channels" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "fce0d966", | |
"metadata": {}, | |
"source": [ | |
"## Reshapes to apply same filters *independently* to each channel\n", | |
"\n", | |
"This is what we do in Deep4/Shallow" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 33, | |
"id": "cdbbb3a4", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from braindecode.models.modules import Expression" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 47, | |
"id": "c6c83be8", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"torch.Size([12, 1, 1125, 22])" | |
] | |
}, | |
"execution_count": 47, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"put_channels_into_image_dim = Expression(lambda x: x.unsqueeze(-1).transpose(1,3))\n", | |
"# create a new \"image\" dimension to put EEG channels inside\n", | |
"put_channels_into_image_dim(X).shape\n", | |
"# now #examples x 1 x #timesteps x #EEG channels" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 51, | |
"id": "391c4af7", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"torch.Size([4, 1, 9, 1])" | |
] | |
}, | |
"execution_count": 51, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"conv_image = nn.Conv2d(1,4,(n_time_kernel,1)) \n", | |
"# so this conv applies the same 4 temporal filters independently to each EEG channel\n", | |
"conv_image.weight.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 49, | |
"id": "e02640cd", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"torch.Size([12, 4, 1117, 22])" | |
] | |
}, | |
"execution_count": 49, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"conv_image(put_channels_into_image_dim(X)).shape\n", | |
"# now you have: #examples x #temporal filters x #timesteps x #EEG channels" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 44, | |
"id": "ab116ad8", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Could could also merge EEGchannels+temporal filters back in channel dimension and put everything\n", | |
"# into one sequential:\n", | |
"\n", | |
"conv_independent_filters = nn.Sequential(\n", | |
" put_channels_into_image_dim,\n", | |
" conv_image,\n", | |
" Expression(lambda x: x.transpose(2,3).reshape(x.shape[0], x.shape[1]*x.shape[3], x.shape[2]))\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 50, | |
"id": "727d0b0c", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"torch.Size([12, 88, 1117])" | |
] | |
}, | |
"execution_count": 50, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"conv_independent_filters(X).shape\n", | |
"# so now #examples x (#channels * #temporal filters) x #timesteps" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment