Skip to content

Instantly share code, notes, and snippets.

@robintibor
Last active July 10, 2020 08:27
Show Gist options
  • Save robintibor/a22b99dc054d0e84d57ec9c50d90aa8e to your computer and use it in GitHub Desktop.
Save robintibor/a22b99dc054d0e84d57ec9c50d90aa8e to your computer and use it in GitHub Desktop.
Decoding Attention Level with Braindecode
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"data = pd.read_csv('./user1_test10_attention.csv')\n",
"\n",
"import numpy as np\n",
"import torch\n",
"torch.backends.cudnn.benchmark = True\n",
"\n",
"window_size_samples = 512\n",
"\n",
"X = [np.array(data.iloc[i:i+window_size_samples,:-1]) \n",
" for i in range(data.shape[0] - window_size_samples)]\n",
"X = np.array(X).transpose(0,2,1).astype(np.float32)\n",
"y = np.array(data.iloc[window_size_samples:,-1]).astype(np.int64) - 1\n",
"\n",
"print(X.shape)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from braindecode.datautil import create_from_X_y\n",
"# or in prev versions: from braindecode.datasets.xy import create_from_X_y\n",
"\n",
"\n",
"n_split = 1200\n",
"train_X, train_y = X[:n_split], y[:n_split]\n",
"valid_X, valid_y = X[n_split:], y[n_split:]\n",
"means = train_X.mean(axis=(0,2), keepdims=True)\n",
"stds = train_X.std(axis=(0,2), keepdims=True)\n",
" \n",
"# standardize per channel\n",
"train_X = (train_X - means) / (stds)\n",
"valid_X = (valid_X - means) / (stds)\n",
"\n",
"\n",
"train_set = create_from_X_y(train_X, train_y, drop_last_window=False)\n",
"\n",
"valid_set = create_from_X_y(valid_X, valid_y, drop_last_window=False)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from braindecode.util import set_random_seeds\n",
"from braindecode.models import ShallowFBCSPNet\n",
"\n",
"cuda = torch.cuda.is_available() # check if GPU is available, if True chooses to use it\n",
"device = 'cuda' if cuda else 'cpu'\n",
"if cuda:\n",
" torch.backends.cudnn.benchmark = True\n",
"seed = 20200220 # random seed to make results reproducible\n",
"# Set random seed to be able to reproduce results\n",
"set_random_seeds(seed=seed, cuda=cuda)\n",
"\n",
"n_classes=5\n",
"# Extract number of chans and time steps from dataset\n",
"n_chans = train_set[0][0].shape[0]\n",
"input_window_samples = train_set[0][0].shape[1]\n",
"\n",
"model = ShallowFBCSPNet(\n",
" n_chans,\n",
" n_classes,\n",
" input_window_samples=input_window_samples,\n",
" final_conv_length='auto',\n",
")\n",
"\n",
"# Send model to GPU\n",
"if cuda:\n",
" model.cuda()\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" epoch train_accuracy train_loss valid_accuracy valid_loss dur\n",
"------- ---------------- ------------ ---------------- ------------ ------\n",
" 1 \u001b[36m0.7033\u001b[0m \u001b[32m0.4628\u001b[0m \u001b[35m0.2282\u001b[0m \u001b[31m3.0022\u001b[0m 0.8967\n",
" 2 \u001b[36m0.9333\u001b[0m \u001b[32m0.2477\u001b[0m 0.1919 9.7948 0.8895\n",
" 3 0.0167 \u001b[32m0.1764\u001b[0m 0.0000 14.5916 0.8951\n",
" 4 0.8750 0.1936 0.1919 5.9776 0.8068\n",
" 5 0.8433 \u001b[32m0.0970\u001b[0m 0.1919 12.8362 0.8077\n",
" 6 \u001b[36m0.9600\u001b[0m \u001b[32m0.0949\u001b[0m 0.1919 10.6701 0.8044\n",
" 7 \u001b[36m0.9817\u001b[0m 0.1093 0.1919 7.0251 0.8871\n",
" 8 0.9075 0.0982 0.1919 13.2528 0.8184\n",
" 9 0.9458 0.0954 0.1919 8.4267 0.8855\n",
" 10 0.9300 \u001b[32m0.0668\u001b[0m 0.1919 10.8124 0.7976\n",
" 11 0.9500 \u001b[32m0.0601\u001b[0m 0.1919 11.6038 0.8066\n",
" 12 0.9450 0.0976 0.1919 7.3305 0.8016\n",
" 13 \u001b[36m0.9858\u001b[0m 0.0736 0.1890 10.4158 0.7983\n",
" 14 0.9850 \u001b[32m0.0513\u001b[0m 0.1148 9.0304 0.7998\n",
" 15 0.3525 0.0526 0.1294 10.5825 0.7980\n",
" 16 0.9192 0.0518 0.1919 8.6944 0.8031\n",
" 17 0.9708 \u001b[32m0.0451\u001b[0m 0.1512 9.5866 0.7927\n",
" 18 0.9417 \u001b[32m0.0313\u001b[0m 0.1919 4.9543 0.7941\n",
" 19 0.9642 0.0418 \u001b[35m0.3735\u001b[0m \u001b[31m1.6894\u001b[0m 0.8016\n",
" 20 0.6983 0.0419 0.1919 2.9650 0.8049\n",
" 21 0.8667 0.0421 0.1919 13.6571 0.8049\n",
" 22 \u001b[36m0.9867\u001b[0m 0.0377 0.1570 11.6214 0.8081\n",
" 23 0.9767 0.0344 0.1555 9.7715 0.8004\n",
" 24 0.7550 \u001b[32m0.0269\u001b[0m \u001b[35m0.9840\u001b[0m \u001b[31m0.1476\u001b[0m 0.8031\n",
" 25 0.9708 \u001b[32m0.0268\u001b[0m 0.0727 12.5223 0.8024\n",
" 26 0.9808 0.0417 0.0828 11.4290 0.8188\n",
" 27 \u001b[36m0.9917\u001b[0m 0.0351 0.1512 10.7532 0.8042\n",
" 28 0.9883 0.0279 0.3459 2.4153 0.9086\n",
" 29 \u001b[36m0.9967\u001b[0m \u001b[32m0.0247\u001b[0m 0.1686 7.5120 0.9758\n",
" 30 0.9825 \u001b[32m0.0204\u001b[0m 0.1381 7.3847 0.9257\n",
" 31 0.9942 0.0217 0.1802 10.1529 0.9700\n",
" 32 0.9925 0.0233 0.1584 6.3587 0.9665\n",
" 33 \u001b[36m0.9983\u001b[0m \u001b[32m0.0150\u001b[0m 0.1890 11.0188 0.9673\n",
" 34 \u001b[36m1.0000\u001b[0m 0.0184 0.1846 7.7606 0.9717\n",
" 35 0.9983 0.0170 0.1788 10.1285 0.9677\n",
" 36 0.9992 0.0200 0.1890 7.5814 1.0036\n",
" 37 0.9992 0.0222 0.1759 8.4908 0.9762\n",
" 38 1.0000 \u001b[32m0.0136\u001b[0m 0.1730 9.0105 0.9359\n",
" 39 0.9992 0.0205 0.1817 8.3676 0.8884\n",
" 40 0.9992 0.0184 0.1730 8.9246 0.9631\n"
]
},
{
"data": {
"text/plain": [
"<class 'braindecode.classifier.EEGClassifier'>[initialized](\n",
" module_=ShallowFBCSPNet(\n",
" (ensuredims): Ensure4d()\n",
" (dimshuffle): Expression(expression=transpose_time_to_spat) \n",
" (conv_time): Conv2d(1, 40, kernel_size=(25, 1), stride=(1, 1))\n",
" (conv_spat): Conv2d(40, 40, kernel_size=(1, 8), stride=(1, 1), bias=False)\n",
" (bnorm): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv_nonlin_exp): Expression(expression=square) \n",
" (pool): AvgPool2d(kernel_size=(75, 1), stride=(15, 1), padding=0)\n",
" (pool_nonlin_exp): Expression(expression=safe_log) \n",
" (drop): Dropout(p=0.5, inplace=False)\n",
" (conv_classifier): Conv2d(40, 5, kernel_size=(28, 1), stride=(1, 1))\n",
" (softmax): LogSoftmax()\n",
" (squeeze): Expression(expression=squeeze_final_output) \n",
" ),\n",
")"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from skorch.callbacks import LRScheduler\n",
"from skorch.helper import predefined_split\n",
"\n",
"from braindecode import EEGClassifier\n",
"# These values we found good for shallow network:\n",
"lr = 0.0625 * 0.01\n",
"weight_decay = 0\n",
"\n",
"# For deep4 they should be:\n",
"# lr = 1 * 0.01\n",
"# weight_decay = 0.5 * 0.001\n",
"\n",
"batch_size = 32\n",
"n_epochs = 40\n",
"\n",
"clf = EEGClassifier(\n",
" model,\n",
" criterion=torch.nn.NLLLoss,\n",
" optimizer=torch.optim.AdamW,\n",
" train_split=predefined_split(valid_set), # using valid_set for validation\n",
" optimizer__lr=lr,\n",
" optimizer__weight_decay=weight_decay,\n",
" batch_size=batch_size,\n",
" callbacks=[\n",
" \"accuracy\", (\"lr_scheduler\", LRScheduler('CosineAnnealingLR', T_max=n_epochs - 1)),\n",
" ],\n",
" device=device,\n",
")\n",
"# Model training for a specified number of epochs. `y` is None as it is already supplied\n",
"# in the dataset.\n",
"clf.fit(train_set, y=None, epochs=n_epochs)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment