Last active
July 10, 2020 08:27
-
-
Save robintibor/a22b99dc054d0e84d57ec9c50d90aa8e to your computer and use it in GitHub Desktop.
Decoding Attention Level with Braindecode
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import pandas as pd\n", | |
"\n", | |
"data = pd.read_csv('./user1_test10_attention.csv')\n", | |
"\n", | |
"import numpy as np\n", | |
"import torch\n", | |
"torch.backends.cudnn.benchmark = True\n", | |
"\n", | |
"window_size_samples = 512\n", | |
"\n", | |
"X = [np.array(data.iloc[i:i+window_size_samples,:-1]) \n", | |
" for i in range(data.shape[0] - window_size_samples)]\n", | |
"X = np.array(X).transpose(0,2,1).astype(np.float32)\n", | |
"y = np.array(data.iloc[window_size_samples:,-1]).astype(np.int64) - 1\n", | |
"\n", | |
"print(X.shape)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from braindecode.datautil import create_from_X_y\n", | |
"# or in prev versions: from braindecode.datasets.xy import create_from_X_y\n", | |
"\n", | |
"\n", | |
"n_split = 1200\n", | |
"train_X, train_y = X[:n_split], y[:n_split]\n", | |
"valid_X, valid_y = X[n_split:], y[n_split:]\n", | |
"means = train_X.mean(axis=(0,2), keepdims=True)\n", | |
"stds = train_X.std(axis=(0,2), keepdims=True)\n", | |
" \n", | |
"# standardize per channel\n", | |
"train_X = (train_X - means) / (stds)\n", | |
"valid_X = (valid_X - means) / (stds)\n", | |
"\n", | |
"\n", | |
"train_set = create_from_X_y(train_X, train_y, drop_last_window=False)\n", | |
"\n", | |
"valid_set = create_from_X_y(valid_X, valid_y, drop_last_window=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 34, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"from braindecode.util import set_random_seeds\n", | |
"from braindecode.models import ShallowFBCSPNet\n", | |
"\n", | |
"cuda = torch.cuda.is_available() # check if GPU is available, if True chooses to use it\n", | |
"device = 'cuda' if cuda else 'cpu'\n", | |
"if cuda:\n", | |
" torch.backends.cudnn.benchmark = True\n", | |
"seed = 20200220 # random seed to make results reproducible\n", | |
"# Set random seed to be able to reproduce results\n", | |
"set_random_seeds(seed=seed, cuda=cuda)\n", | |
"\n", | |
"n_classes=5\n", | |
"# Extract number of chans and time steps from dataset\n", | |
"n_chans = train_set[0][0].shape[0]\n", | |
"input_window_samples = train_set[0][0].shape[1]\n", | |
"\n", | |
"model = ShallowFBCSPNet(\n", | |
" n_chans,\n", | |
" n_classes,\n", | |
" input_window_samples=input_window_samples,\n", | |
" final_conv_length='auto',\n", | |
")\n", | |
"\n", | |
"# Send model to GPU\n", | |
"if cuda:\n", | |
" model.cuda()\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 35, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" epoch train_accuracy train_loss valid_accuracy valid_loss dur\n", | |
"------- ---------------- ------------ ---------------- ------------ ------\n", | |
" 1 \u001b[36m0.7033\u001b[0m \u001b[32m0.4628\u001b[0m \u001b[35m0.2282\u001b[0m \u001b[31m3.0022\u001b[0m 0.8967\n", | |
" 2 \u001b[36m0.9333\u001b[0m \u001b[32m0.2477\u001b[0m 0.1919 9.7948 0.8895\n", | |
" 3 0.0167 \u001b[32m0.1764\u001b[0m 0.0000 14.5916 0.8951\n", | |
" 4 0.8750 0.1936 0.1919 5.9776 0.8068\n", | |
" 5 0.8433 \u001b[32m0.0970\u001b[0m 0.1919 12.8362 0.8077\n", | |
" 6 \u001b[36m0.9600\u001b[0m \u001b[32m0.0949\u001b[0m 0.1919 10.6701 0.8044\n", | |
" 7 \u001b[36m0.9817\u001b[0m 0.1093 0.1919 7.0251 0.8871\n", | |
" 8 0.9075 0.0982 0.1919 13.2528 0.8184\n", | |
" 9 0.9458 0.0954 0.1919 8.4267 0.8855\n", | |
" 10 0.9300 \u001b[32m0.0668\u001b[0m 0.1919 10.8124 0.7976\n", | |
" 11 0.9500 \u001b[32m0.0601\u001b[0m 0.1919 11.6038 0.8066\n", | |
" 12 0.9450 0.0976 0.1919 7.3305 0.8016\n", | |
" 13 \u001b[36m0.9858\u001b[0m 0.0736 0.1890 10.4158 0.7983\n", | |
" 14 0.9850 \u001b[32m0.0513\u001b[0m 0.1148 9.0304 0.7998\n", | |
" 15 0.3525 0.0526 0.1294 10.5825 0.7980\n", | |
" 16 0.9192 0.0518 0.1919 8.6944 0.8031\n", | |
" 17 0.9708 \u001b[32m0.0451\u001b[0m 0.1512 9.5866 0.7927\n", | |
" 18 0.9417 \u001b[32m0.0313\u001b[0m 0.1919 4.9543 0.7941\n", | |
" 19 0.9642 0.0418 \u001b[35m0.3735\u001b[0m \u001b[31m1.6894\u001b[0m 0.8016\n", | |
" 20 0.6983 0.0419 0.1919 2.9650 0.8049\n", | |
" 21 0.8667 0.0421 0.1919 13.6571 0.8049\n", | |
" 22 \u001b[36m0.9867\u001b[0m 0.0377 0.1570 11.6214 0.8081\n", | |
" 23 0.9767 0.0344 0.1555 9.7715 0.8004\n", | |
" 24 0.7550 \u001b[32m0.0269\u001b[0m \u001b[35m0.9840\u001b[0m \u001b[31m0.1476\u001b[0m 0.8031\n", | |
" 25 0.9708 \u001b[32m0.0268\u001b[0m 0.0727 12.5223 0.8024\n", | |
" 26 0.9808 0.0417 0.0828 11.4290 0.8188\n", | |
" 27 \u001b[36m0.9917\u001b[0m 0.0351 0.1512 10.7532 0.8042\n", | |
" 28 0.9883 0.0279 0.3459 2.4153 0.9086\n", | |
" 29 \u001b[36m0.9967\u001b[0m \u001b[32m0.0247\u001b[0m 0.1686 7.5120 0.9758\n", | |
" 30 0.9825 \u001b[32m0.0204\u001b[0m 0.1381 7.3847 0.9257\n", | |
" 31 0.9942 0.0217 0.1802 10.1529 0.9700\n", | |
" 32 0.9925 0.0233 0.1584 6.3587 0.9665\n", | |
" 33 \u001b[36m0.9983\u001b[0m \u001b[32m0.0150\u001b[0m 0.1890 11.0188 0.9673\n", | |
" 34 \u001b[36m1.0000\u001b[0m 0.0184 0.1846 7.7606 0.9717\n", | |
" 35 0.9983 0.0170 0.1788 10.1285 0.9677\n", | |
" 36 0.9992 0.0200 0.1890 7.5814 1.0036\n", | |
" 37 0.9992 0.0222 0.1759 8.4908 0.9762\n", | |
" 38 1.0000 \u001b[32m0.0136\u001b[0m 0.1730 9.0105 0.9359\n", | |
" 39 0.9992 0.0205 0.1817 8.3676 0.8884\n", | |
" 40 0.9992 0.0184 0.1730 8.9246 0.9631\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"<class 'braindecode.classifier.EEGClassifier'>[initialized](\n", | |
" module_=ShallowFBCSPNet(\n", | |
" (ensuredims): Ensure4d()\n", | |
" (dimshuffle): Expression(expression=transpose_time_to_spat) \n", | |
" (conv_time): Conv2d(1, 40, kernel_size=(25, 1), stride=(1, 1))\n", | |
" (conv_spat): Conv2d(40, 40, kernel_size=(1, 8), stride=(1, 1), bias=False)\n", | |
" (bnorm): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (conv_nonlin_exp): Expression(expression=square) \n", | |
" (pool): AvgPool2d(kernel_size=(75, 1), stride=(15, 1), padding=0)\n", | |
" (pool_nonlin_exp): Expression(expression=safe_log) \n", | |
" (drop): Dropout(p=0.5, inplace=False)\n", | |
" (conv_classifier): Conv2d(40, 5, kernel_size=(28, 1), stride=(1, 1))\n", | |
" (softmax): LogSoftmax()\n", | |
" (squeeze): Expression(expression=squeeze_final_output) \n", | |
" ),\n", | |
")" | |
] | |
}, | |
"execution_count": 35, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from skorch.callbacks import LRScheduler\n", | |
"from skorch.helper import predefined_split\n", | |
"\n", | |
"from braindecode import EEGClassifier\n", | |
"# These values we found good for shallow network:\n", | |
"lr = 0.0625 * 0.01\n", | |
"weight_decay = 0\n", | |
"\n", | |
"# For deep4 they should be:\n", | |
"# lr = 1 * 0.01\n", | |
"# weight_decay = 0.5 * 0.001\n", | |
"\n", | |
"batch_size = 32\n", | |
"n_epochs = 40\n", | |
"\n", | |
"clf = EEGClassifier(\n", | |
" model,\n", | |
" criterion=torch.nn.NLLLoss,\n", | |
" optimizer=torch.optim.AdamW,\n", | |
" train_split=predefined_split(valid_set), # using valid_set for validation\n", | |
" optimizer__lr=lr,\n", | |
" optimizer__weight_decay=weight_decay,\n", | |
" batch_size=batch_size,\n", | |
" callbacks=[\n", | |
" \"accuracy\", (\"lr_scheduler\", LRScheduler('CosineAnnealingLR', T_max=n_epochs - 1)),\n", | |
" ],\n", | |
" device=device,\n", | |
")\n", | |
"# Model training for a specified number of epochs. `y` is None as it is already supplied\n", | |
"# in the dataset.\n", | |
"clf.fit(train_set, y=None, epochs=n_epochs)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment