Created
July 10, 2020 08:31
-
-
Save robintibor/59280edaa48030183536af4464f9b94a to your computer and use it in GitHub Desktop.
Decoding Attention Level with Braindecode smaller Windows
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 39, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"(2272, 8, 128)\n" | |
] | |
} | |
], | |
"source": [ | |
"import pandas as pd\n", | |
"\n", | |
"data = pd.read_csv('./user1_test10_attention.csv')\n", | |
"\n", | |
"import numpy as np\n", | |
"import torch\n", | |
"torch.backends.cudnn.benchmark = True\n", | |
"\n", | |
"window_size_samples = 128\n", | |
"\n", | |
"X = [np.array(data.iloc[i:i+window_size_samples,:-1]) \n", | |
" for i in range(data.shape[0] - window_size_samples)]\n", | |
"X = np.array(X).transpose(0,2,1).astype(np.float32)\n", | |
"y = np.array(data.iloc[window_size_samples:,-1]).astype(np.int64) - 1\n", | |
"\n", | |
"print(X.shape)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 40, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from braindecode.datautil import create_from_X_y\n", | |
"# or in prev versions: from braindecode.datasets.xy import create_from_X_y\n", | |
"\n", | |
"\n", | |
"n_split = 1600\n", | |
"train_X, train_y = X[:n_split], y[:n_split]\n", | |
"valid_X, valid_y = X[n_split:], y[n_split:]\n", | |
"means = train_X.mean(axis=(0,2), keepdims=True)\n", | |
"stds = train_X.std(axis=(0,2), keepdims=True)\n", | |
" \n", | |
"# standardize per channel\n", | |
"train_X = (train_X - means) / (stds)\n", | |
"valid_X = (valid_X - means) / (stds)\n", | |
"\n", | |
"\n", | |
"train_set = create_from_X_y(train_X, train_y, drop_last_window=False)\n", | |
"\n", | |
"valid_set = create_from_X_y(valid_X, valid_y, drop_last_window=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 41, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"from braindecode.util import set_random_seeds\n", | |
"from braindecode.models import ShallowFBCSPNet\n", | |
"\n", | |
"cuda = torch.cuda.is_available() # check if GPU is available, if True chooses to use it\n", | |
"device = 'cuda' if cuda else 'cpu'\n", | |
"if cuda:\n", | |
" torch.backends.cudnn.benchmark = True\n", | |
"seed = 20200220 # random seed to make results reproducible\n", | |
"# Set random seed to be able to reproduce results\n", | |
"set_random_seeds(seed=seed, cuda=cuda)\n", | |
"\n", | |
"n_classes=5\n", | |
"# Extract number of chans and time steps from dataset\n", | |
"n_chans = train_set[0][0].shape[0]\n", | |
"input_window_samples = train_set[0][0].shape[1]\n", | |
"\n", | |
"model = ShallowFBCSPNet(\n", | |
" n_chans,\n", | |
" n_classes,\n", | |
" input_window_samples=input_window_samples,\n", | |
" final_conv_length='auto',\n", | |
")\n", | |
"\n", | |
"# Send model to GPU\n", | |
"if cuda:\n", | |
" model.cuda()\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 42, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" epoch train_accuracy train_loss valid_accuracy valid_loss dur\n", | |
"------- ---------------- ------------ ---------------- ------------ ------\n", | |
" 1 \u001b[36m0.5519\u001b[0m \u001b[32m1.1509\u001b[0m \u001b[35m0.0000\u001b[0m \u001b[31m2.2398\u001b[0m 1.3503\n", | |
" 2 \u001b[36m0.6606\u001b[0m \u001b[32m0.7179\u001b[0m \u001b[35m0.1726\u001b[0m 4.3311 1.3391\n", | |
" 3 \u001b[36m0.9113\u001b[0m \u001b[32m0.4896\u001b[0m 0.1726 2.9743 1.3597\n", | |
" 4 \u001b[36m0.9581\u001b[0m \u001b[32m0.3229\u001b[0m 0.1726 4.4045 1.3775\n", | |
" 5 0.9300 \u001b[32m0.2481\u001b[0m 0.1726 4.0017 1.3307\n", | |
" 6 \u001b[36m0.9750\u001b[0m \u001b[32m0.2348\u001b[0m 0.1726 5.6199 1.3246\n", | |
" 7 \u001b[36m0.9881\u001b[0m \u001b[32m0.1922\u001b[0m 0.1726 6.0158 1.3421\n", | |
" 8 0.9400 \u001b[32m0.1526\u001b[0m 0.1726 7.7374 1.2859\n", | |
" 9 0.9738 \u001b[32m0.1486\u001b[0m 0.1726 4.7569 1.1396\n", | |
" 10 0.9675 \u001b[32m0.1331\u001b[0m 0.1726 5.1696 1.1235\n", | |
" 11 0.8912 \u001b[32m0.1180\u001b[0m 0.1726 7.8588 1.1291\n", | |
" 12 0.9487 0.1250 \u001b[35m0.3839\u001b[0m \u001b[31m1.3711\u001b[0m 1.0218\n", | |
" 13 0.9775 0.1919 0.1726 3.1013 1.0330\n", | |
" 14 \u001b[36m0.9975\u001b[0m \u001b[32m0.1103\u001b[0m 0.1726 4.3662 1.3610\n", | |
" 15 0.9956 \u001b[32m0.1028\u001b[0m 0.1726 4.7099 1.3303\n", | |
" 16 0.9794 \u001b[32m0.0814\u001b[0m 0.1726 6.6700 1.3671\n", | |
" 17 0.9888 0.0982 0.2827 4.3589 1.3759\n", | |
" 18 0.9862 \u001b[32m0.0740\u001b[0m 0.1726 6.2744 1.3433\n", | |
" 19 0.9831 0.0955 0.1726 6.5324 1.3806\n", | |
" 20 0.9788 \u001b[32m0.0713\u001b[0m 0.1726 7.6627 1.3642\n", | |
" 21 0.9956 \u001b[32m0.0658\u001b[0m 0.2783 4.2936 1.3517\n", | |
" 22 0.9806 \u001b[32m0.0624\u001b[0m 0.1726 7.2627 1.3614\n", | |
" 23 0.9925 \u001b[32m0.0518\u001b[0m 0.1726 5.6753 1.3646\n", | |
" 24 0.9875 0.0534 0.3363 4.2460 1.3827\n", | |
" 25 0.9800 0.0572 0.1726 7.7427 0.9306\n", | |
" 26 0.9862 0.0651 0.2500 6.2331 0.9330\n", | |
" 27 0.9962 0.0600 0.2768 4.6223 0.9309\n", | |
" 28 0.9744 0.0559 0.2976 2.9818 0.9232\n", | |
" 29 0.9894 0.0612 0.2619 6.2250 0.9244\n", | |
" 30 0.9900 0.0623 0.2158 6.8295 0.9265\n", | |
" 31 0.9881 0.0578 0.3333 4.4365 1.0223\n", | |
" 32 0.9956 \u001b[32m0.0495\u001b[0m 0.2798 4.9459 0.9264\n", | |
" 33 0.9956 0.0525 0.2902 6.0256 1.0238\n", | |
" 34 0.9975 0.0507 0.3110 4.6143 0.9232\n", | |
" 35 0.9900 \u001b[32m0.0462\u001b[0m 0.2857 6.0690 0.9318\n", | |
" 36 0.9931 0.0610 0.2991 5.7356 0.9115\n", | |
" 37 0.9938 0.0516 0.2976 5.8834 0.9155\n", | |
" 38 0.9944 0.0504 0.2976 5.8595 0.9123\n", | |
" 39 0.9931 0.0680 0.2946 5.9841 0.9106\n", | |
" 40 \u001b[36m0.9981\u001b[0m 0.0526 0.3140 5.2729 0.9140\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"<class 'braindecode.classifier.EEGClassifier'>[initialized](\n", | |
" module_=ShallowFBCSPNet(\n", | |
" (ensuredims): Ensure4d()\n", | |
" (dimshuffle): Expression(expression=transpose_time_to_spat) \n", | |
" (conv_time): Conv2d(1, 40, kernel_size=(25, 1), stride=(1, 1))\n", | |
" (conv_spat): Conv2d(40, 40, kernel_size=(1, 8), stride=(1, 1), bias=False)\n", | |
" (bnorm): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (conv_nonlin_exp): Expression(expression=square) \n", | |
" (pool): AvgPool2d(kernel_size=(75, 1), stride=(15, 1), padding=0)\n", | |
" (pool_nonlin_exp): Expression(expression=safe_log) \n", | |
" (drop): Dropout(p=0.5, inplace=False)\n", | |
" (conv_classifier): Conv2d(40, 5, kernel_size=(2, 1), stride=(1, 1))\n", | |
" (softmax): LogSoftmax()\n", | |
" (squeeze): Expression(expression=squeeze_final_output) \n", | |
" ),\n", | |
")" | |
] | |
}, | |
"execution_count": 42, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from skorch.callbacks import LRScheduler\n", | |
"from skorch.helper import predefined_split\n", | |
"\n", | |
"from braindecode import EEGClassifier\n", | |
"# These values we found good for shallow network:\n", | |
"lr = 0.0625 * 0.01\n", | |
"weight_decay = 0\n", | |
"\n", | |
"# For deep4 they should be:\n", | |
"# lr = 1 * 0.01\n", | |
"# weight_decay = 0.5 * 0.001\n", | |
"\n", | |
"batch_size = 32\n", | |
"n_epochs = 40\n", | |
"\n", | |
"clf = EEGClassifier(\n", | |
" model,\n", | |
" criterion=torch.nn.NLLLoss,\n", | |
" optimizer=torch.optim.AdamW,\n", | |
" train_split=predefined_split(valid_set), # using valid_set for validation\n", | |
" optimizer__lr=lr,\n", | |
" optimizer__weight_decay=weight_decay,\n", | |
" batch_size=batch_size,\n", | |
" callbacks=[\n", | |
" \"accuracy\", (\"lr_scheduler\", LRScheduler('CosineAnnealingLR', T_max=n_epochs - 1)),\n", | |
" ],\n", | |
" device=device,\n", | |
")\n", | |
"# Model training for a specified number of epochs. `y` is None as it is already supplied\n", | |
"# in the dataset.\n", | |
"clf.fit(train_set, y=None, epochs=n_epochs)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment