Skip to content

Instantly share code, notes, and snippets.

@robintibor
Created June 11, 2021 13:39
Show Gist options
  • Save robintibor/9a6fdc36db7f23bf6ba9927e50646e94 to your computer and use it in GitHub Desktop.
Save robintibor/9a6fdc36db7f23bf6ba9927e50646e94 to your computer and use it in GitHub Desktop.
Temporal Filters as Convolutions - different options
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "3e618e81",
"metadata": {},
"source": [
"# Temporal Filters Through Convolutions - different options"
]
},
{
"cell_type": "markdown",
"id": "0d534d03",
"metadata": {},
"source": [
"## Some data loading to have example data"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "0eac5fa0",
"metadata": {},
"outputs": [],
"source": [
"from braindecode.datasets.moabb import MOABBDataset\n",
"import mne\n",
"\n",
"subject_id = 3\n",
"dataset = MOABBDataset(dataset_name=\"BNCI2014001\", subject_ids=[subject_id])"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "bed120cf",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/schirrmr/braindecode/code/braindecode/braindecode/datautil/preprocess.py:87: UserWarning: MNEPreproc is deprecated. Use Preprocessor with `apply_on_array=False` instead.\n",
" warn('MNEPreproc is deprecated. Use Preprocessor with '\n",
"/home/schirrmr/braindecode/code/braindecode/braindecode/datautil/preprocess.py:105: UserWarning: NumpyPreproc is deprecated. Use Preprocessor with `apply_on_array=True` instead.\n",
" warn('NumpyPreproc is deprecated. Use Preprocessor with '\n"
]
}
],
"source": [
"from braindecode.datautil.preprocess import exponential_moving_standardize\n",
"from braindecode.datautil.preprocess import MNEPreproc, NumpyPreproc, preprocess\n",
"\n",
"low_cut_hz = 4. # low cut frequency for filtering\n",
"high_cut_hz = 38. # high cut frequency for filtering\n",
"# Parameters for exponential moving standardization\n",
"factor_new = 1e-3\n",
"init_block_size = 1000\n",
"\n",
"preprocessors = [\n",
" # keep only EEG sensors\n",
" MNEPreproc(fn='pick_types', eeg=True, meg=False, stim=False),\n",
" # convert from volt to microvolt, directly modifying the numpy array\n",
" NumpyPreproc(fn=lambda x: x * 1e6),\n",
" # bandpass filter\n",
" MNEPreproc(fn='filter', l_freq=low_cut_hz, h_freq=high_cut_hz),\n",
" # exponential moving standardization\n",
" NumpyPreproc(fn=exponential_moving_standardize, factor_new=factor_new,\n",
" init_block_size=init_block_size)\n",
"]\n",
"\n",
"# Transform the data\n",
"preprocess(dataset, preprocessors)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "f04ae50e",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from braindecode.datautil.windowers import create_windows_from_events\n",
"\n",
"trial_start_offset_seconds = -0.5\n",
"# Extract sampling frequency, check that they are same in all datasets\n",
"sfreq = dataset.datasets[0].raw.info['sfreq']\n",
"assert all([ds.raw.info['sfreq'] == sfreq for ds in dataset.datasets])\n",
"# Calculate the trial start offset in samples.\n",
"trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)\n",
"\n",
"# Create windows using braindecode function for this. It needs parameters to define how\n",
"# trials should be used.\n",
"windows_dataset = create_windows_from_events(\n",
" dataset,\n",
" trial_start_offset_samples=trial_start_offset_samples,\n",
" trial_stop_offset_samples=0,\n",
" preload=True,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "d93b4e7b",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from skorch.utils import to_tensor\n",
"# small batch of 12 examples\n",
"X = torch.stack([to_tensor(windows_dataset[i][0], 'cpu') for i in range(12)], dim=0)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "db80a711",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([12, 22, 1125])"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X.shape"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "dee8a2e0",
"metadata": {},
"outputs": [],
"source": [
"n_chans = X.shape[1]"
]
},
{
"cell_type": "markdown",
"id": "e3395576",
"metadata": {},
"source": [
"### Regular Conv1D will do all_channels to all_channels"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "dbba163e",
"metadata": {},
"outputs": [],
"source": [
"from torch import nn\n",
"n_time_kernel = 9\n",
"regular_conv = nn.Conv1d(n_chans, n_chans, n_time_kernel)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "c033c373",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([22, 22, 9])"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"regular_conv.weight.shape \n",
"# -> each of 22 outputs uses all 22 inputs thats why [22...22,9]"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "9915527d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([12, 22, 1117])"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"regular_conv(X).shape"
]
},
{
"cell_type": "markdown",
"id": "ffedc91c",
"metadata": {},
"source": [
"## Grouped Conv to make different filters for each channel"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "b6b202a5",
"metadata": {},
"outputs": [],
"source": [
"grouped_conv = nn.Conv1d(n_chans, n_chans*4, n_time_kernel, groups=n_chans)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "b0504179",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([88, 1, 9])"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"grouped_conv.weight.shape\n",
"# -> each channel is only convolved with its own filters, in this case\n",
"# we have 4 different filters for each channel, 22*4=88 in total"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "2d8d33be",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([12, 88, 1117])"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"grouped_conv(X).shape\n",
"# now 22*4 output channels"
]
},
{
"cell_type": "markdown",
"id": "fce0d966",
"metadata": {},
"source": [
"## Reshapes to apply same filters *independently* to each channel\n",
"\n",
"This is what we do in Deep4/Shallow"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "cdbbb3a4",
"metadata": {},
"outputs": [],
"source": [
"from braindecode.models.modules import Expression"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "c6c83be8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([12, 1, 1125, 22])"
]
},
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"put_channels_into_image_dim = Expression(lambda x: x.unsqueeze(-1).transpose(1,3))\n",
"# create a new \"image\" dimension to put EEG channels inside\n",
"put_channels_into_image_dim(X).shape\n",
"# now #examples x 1 x #timesteps x #EEG channels"
]
},
{
"cell_type": "code",
"execution_count": 51,
"id": "391c4af7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([4, 1, 9, 1])"
]
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"conv_image = nn.Conv2d(1,4,(n_time_kernel,1)) \n",
"# so this conv applies the same 4 temporal filters independently to each EEG channel\n",
"conv_image.weight.shape"
]
},
{
"cell_type": "code",
"execution_count": 49,
"id": "e02640cd",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([12, 4, 1117, 22])"
]
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"conv_image(put_channels_into_image_dim(X)).shape\n",
"# now you have: #examples x #temporal filters x #timesteps x #EEG channels"
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "ab116ad8",
"metadata": {},
"outputs": [],
"source": [
"# Could could also merge EEGchannels+temporal filters back in channel dimension and put everything\n",
"# into one sequential:\n",
"\n",
"conv_independent_filters = nn.Sequential(\n",
" put_channels_into_image_dim,\n",
" conv_image,\n",
" Expression(lambda x: x.transpose(2,3).reshape(x.shape[0], x.shape[1]*x.shape[3], x.shape[2]))\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 50,
"id": "727d0b0c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([12, 88, 1117])"
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"conv_independent_filters(X).shape\n",
"# so now #examples x (#channels * #temporal filters) x #timesteps"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment