Last active
March 7, 2019 16:11
-
-
Save gemeinl/e5f7615fafd65b4ab54c77be2c7df52c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "import logging\nlogger = logging.getLogger()\nlogger.setLevel(logging.DEBUG)", | |
| "execution_count": 1, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "data_path = \"/media/Data1/LukasG/KayGan/train_data.jblib\"\nfake_data_path = \"/media/Data1/LukasG/KayGan/train_data_fake.jblib\"\nfs = 250\nchannels = ['Fp1','Fp2','F7','F3','Fz','F4','F8','T7','C3','Cz','C4','T8','P7',\n 'P3','Pz','P4','P8','O1','O2','M1','M2']\nleft_hand = 0\nright_hand = 1", | |
| "execution_count": 96, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "import sys", | |
| "execution_count": 3, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "sys.path.insert(1, \"/home/lukasg/code/braindecode/\")", | |
| "execution_count": 4, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "from brainfeatures.data_set.abstract_data_set import DataSet", | |
| "execution_count": 5, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "import pandas as pd", | |
| "execution_count": 6, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "import joblib", | |
| "execution_count": 7, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "class KayGan(DataSet):\n def __init__(self, path, subset, channel_names):\n with open(path, 'rb') as fo: \n self.loaded = joblib.load(fo)[subset]\n self.fs = 250\n self.subset = subset\n \n def __getitem__(self, idx):\n x = self.loaded.X[idx]\n return pd.DataFrame(x, index=channels), self.fs, self.loaded.y[idx]\n \n def __len__(self):\n return len(self.loaded.X)", | |
| "execution_count": 8, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "def load_train_valid_test_set(data_path, channel_names):\n train_set = KayGan(\n path=data_path,\n subset=\"train\",\n channel_names=channels\n )\n valid_set = KayGan(\n path=data_path,\n subset=\"valid\",\n channel_names=channels\n )\n test_set = KayGan(\n path=data_path,\n subset=\"test\",\n channel_names=channels\n )\n return train_set, valid_set, test_set", | |
| "execution_count": 71, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "train_set, valid_set, test_set = load_train_valid_test_set(data_path, channels)", | |
| "execution_count": 72, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "len(train_set), len(valid_set), len(test_set)", | |
| "execution_count": 73, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "execution_count": 73, | |
| "data": { | |
| "text/plain": "(358, 90, 80)" | |
| }, | |
| "metadata": {} | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "fake_train_set, fake_valid_set, fake_test_set = load_train_valid_test_set(fake_data_path, channels)", | |
| "execution_count": 98, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "len(fake_train_set), len(fake_valid_set), len(fake_test_set)", | |
| "execution_count": 99, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "execution_count": 99, | |
| "data": { | |
| "text/plain": "(358, 90, 80)" | |
| }, | |
| "metadata": {} | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "x, fs, y = train_set[0]", | |
| "execution_count": 14, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "x.shape, fs, y", | |
| "execution_count": 15, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "execution_count": 15, | |
| "data": { | |
| "text/plain": "((21, 384), 250, 1)" | |
| }, | |
| "metadata": {} | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "x.shape[-1]/fs", | |
| "execution_count": 16, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "execution_count": 16, | |
| "data": { | |
| "text/plain": "1.536" | |
| }, | |
| "metadata": {} | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "from brainfeatures.feature_generation.generate_features import default_feature_generation_params", | |
| "execution_count": 74, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "default_feature_generation_params.update({\"epoch_duration_s\": x.shape[-1]/fs})", | |
| "execution_count": 75, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "from brainfeatures.experiment.experiment import Experiment", | |
| "execution_count": 26, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "def generate_features(list_of_data_sets):\n feats = []\n for ds in list_of_data_sets:\n exp = Experiment(\n devel_set=ds,\n estimator=None,\n preproc_function=None,\n feature_generation_params=default_feature_generation_params,\n n_jobs=50\n )\n exp.run()\n feats.append(pd.concat(exp._features[\"devel\"], ignore_index=True))\n return feats", | |
| "execution_count": 76, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "train_feats, valid_feats, test_feats = generate_features([train_set, valid_set, test_set])", | |
| "execution_count": 77, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": "INFO:root:Started on 2019-03-07 at 16:32:13.629218\nINFO:root:Loading devel (preprocessed)\nINFO:root:Generating features (devel)\nINFO:root:Finished on 2019-03-07 at 16:32:36.524810.\nINFO:root:Started on 2019-03-07 at 16:32:37.095011\nINFO:root:Loading devel (preprocessed)\nINFO:root:Generating features (devel)\nINFO:root:Finished on 2019-03-07 at 16:32:41.971529.\nINFO:root:Started on 2019-03-07 at 16:32:42.064338\nINFO:root:Loading devel (preprocessed)\nINFO:root:Generating features (devel)\nINFO:root:Finished on 2019-03-07 at 16:32:47.496739.\n", | |
| "name": "stderr" | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "train_feats.shape, valid_feats.shape, test_feats.shape", | |
| "execution_count": 78, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "execution_count": 78, | |
| "data": { | |
| "text/plain": "((358, 8463), (90, 8463), (80, 8463))" | |
| }, | |
| "metadata": {} | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "train_feats[:3]", | |
| "execution_count": 181, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "execution_count": 181, | |
| "data": { | |
| "text/plain": " cwt_bounded_variation_a7_Fp1 cwt_bounded_variation_a7_Fp2 \\\n0 12.491957 14.566269 \n1 49.169237 18.202797 \n2 9.294037 7.581290 \n\n cwt_bounded_variation_a7_F7 cwt_bounded_variation_a7_F3 \\\n0 23.488305 13.358394 \n1 16.328832 36.492521 \n2 33.104700 18.012985 \n\n cwt_bounded_variation_a7_Fz cwt_bounded_variation_a7_F4 \\\n0 10.170716 20.240715 \n1 13.102106 14.456780 \n2 9.382596 10.728926 \n\n cwt_bounded_variation_a7_F8 cwt_bounded_variation_a7_T7 \\\n0 48.457450 72.134396 \n1 50.510341 56.999656 \n2 13.532158 84.702942 \n\n cwt_bounded_variation_a7_C3 cwt_bounded_variation_a7_Cz \\\n0 17.169753 12.602528 \n1 17.496603 28.061434 \n2 8.270255 8.210823 \n\n ... time_zero-crossing-derivative_T8 \\\n0 ... 95.0 \n1 ... 92.0 \n2 ... 94.0 \n\n time_zero-crossing-derivative_P7 time_zero-crossing-derivative_P3 \\\n0 87.0 100.0 \n1 93.0 97.0 \n2 94.0 96.0 \n\n time_zero-crossing-derivative_Pz time_zero-crossing-derivative_P4 \\\n0 102.0 95.0 \n1 92.0 99.0 \n2 100.0 91.0 \n\n time_zero-crossing-derivative_P8 time_zero-crossing-derivative_O1 \\\n0 93.0 101.0 \n1 89.0 100.0 \n2 90.0 95.0 \n\n time_zero-crossing-derivative_O2 time_zero-crossing-derivative_M1 \\\n0 98.0 83.0 \n1 98.0 87.0 \n2 94.0 102.0 \n\n time_zero-crossing-derivative_M2 \n0 94.0 \n1 88.0 \n2 92.0 \n\n[3 rows x 8463 columns]", | |
| "text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>cwt_bounded_variation_a7_Fp1</th>\n <th>cwt_bounded_variation_a7_Fp2</th>\n <th>cwt_bounded_variation_a7_F7</th>\n <th>cwt_bounded_variation_a7_F3</th>\n <th>cwt_bounded_variation_a7_Fz</th>\n <th>cwt_bounded_variation_a7_F4</th>\n <th>cwt_bounded_variation_a7_F8</th>\n <th>cwt_bounded_variation_a7_T7</th>\n <th>cwt_bounded_variation_a7_C3</th>\n <th>cwt_bounded_variation_a7_Cz</th>\n <th>...</th>\n <th>time_zero-crossing-derivative_T8</th>\n <th>time_zero-crossing-derivative_P7</th>\n <th>time_zero-crossing-derivative_P3</th>\n <th>time_zero-crossing-derivative_Pz</th>\n <th>time_zero-crossing-derivative_P4</th>\n <th>time_zero-crossing-derivative_P8</th>\n <th>time_zero-crossing-derivative_O1</th>\n <th>time_zero-crossing-derivative_O2</th>\n <th>time_zero-crossing-derivative_M1</th>\n <th>time_zero-crossing-derivative_M2</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>12.491957</td>\n <td>14.566269</td>\n <td>23.488305</td>\n <td>13.358394</td>\n <td>10.170716</td>\n <td>20.240715</td>\n <td>48.457450</td>\n <td>72.134396</td>\n <td>17.169753</td>\n <td>12.602528</td>\n <td>...</td>\n <td>95.0</td>\n <td>87.0</td>\n <td>100.0</td>\n <td>102.0</td>\n <td>95.0</td>\n <td>93.0</td>\n <td>101.0</td>\n <td>98.0</td>\n <td>83.0</td>\n <td>94.0</td>\n </tr>\n <tr>\n <th>1</th>\n <td>49.169237</td>\n <td>18.202797</td>\n <td>16.328832</td>\n <td>36.492521</td>\n <td>13.102106</td>\n <td>14.456780</td>\n <td>50.510341</td>\n <td>56.999656</td>\n <td>17.496603</td>\n <td>28.061434</td>\n <td>...</td>\n <td>92.0</td>\n <td>93.0</td>\n <td>97.0</td>\n <td>92.0</td>\n <td>99.0</td>\n <td>89.0</td>\n <td>100.0</td>\n <td>98.0</td>\n <td>87.0</td>\n <td>88.0</td>\n </tr>\n <tr>\n <th>2</th>\n <td>9.294037</td>\n <td>7.581290</td>\n <td>33.104700</td>\n <td>18.012985</td>\n <td>9.382596</td>\n <td>10.728926</td>\n <td>13.532158</td>\n <td>84.702942</td>\n <td>8.270255</td>\n <td>8.210823</td>\n <td>...</td>\n <td>94.0</td>\n <td>94.0</td>\n <td>96.0</td>\n <td>100.0</td>\n <td>91.0</td>\n <td>90.0</td>\n <td>95.0</td>\n <td>94.0</td>\n <td>102.0</td>\n <td>92.0</td>\n </tr>\n </tbody>\n</table>\n<p>3 rows × 8463 columns</p>\n</div>" | |
| }, | |
| "metadata": {} | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "fake_train_feats, fake_valid_feats, fake_test_feats = generate_features([fake_train_set, fake_valid_set, \n fake_test_set])", | |
| "execution_count": 100, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": "INFO:root:Started on 2019-03-07 at 16:37:41.840862\nINFO:root:Loading devel (preprocessed)\nINFO:root:Generating features (devel)\nINFO:root:Finished on 2019-03-07 at 16:37:59.955030.\nINFO:root:Started on 2019-03-07 at 16:38:00.470538\nINFO:root:Loading devel (preprocessed)\nINFO:root:Generating features (devel)\nINFO:root:Finished on 2019-03-07 at 16:38:05.723335.\nINFO:root:Started on 2019-03-07 at 16:38:05.823948\nINFO:root:Loading devel (preprocessed)\nINFO:root:Generating features (devel)\nINFO:root:Finished on 2019-03-07 at 16:38:11.449096.\n", | |
| "name": "stderr" | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "fake_train_feats.shape, fake_valid_feats.shape, fake_test_feats.shape", | |
| "execution_count": 102, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "execution_count": 102, | |
| "data": { | |
| "text/plain": "((358, 8463), (90, 8463), (80, 8463))" | |
| }, | |
| "metadata": {} | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "from sklearn.ensemble import RandomForestClassifier", | |
| "execution_count": 173, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "from sklearn.metrics import accuracy_score", | |
| "execution_count": 174, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "def evaluate(clf, train_feats, train_y, valid_feats, valid_y, test_feats, test_y):\n clf.fit(train_feats, train_y)\n \n train_preds = clf.predict(train_feats)\n train_acc = accuracy_score(train_y, train_preds)\n \n valid_preds = clf.predict(valid_feats)\n valid_acc = accuracy_score(valid_y, valid_preds)\n \n test_preds = clf.predict(test_feats)\n test_acc = accuracy_score(test_y, test_preds)\n return {\"train\": train_acc, \"valid\": valid_acc, \"test\": test_acc}", | |
| "execution_count": 175, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": {}, | |
| "cell_type": "markdown", | |
| "source": "real data" | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "evaluate(RandomForestClassifier(n_estimators=100, random_state=0),\n train_feats, [y for x ,fs, y in train_set], \n valid_feats, [y for x, fs, y in valid_set], \n test_feats, [y for x, fs, y in test_set])", | |
| "execution_count": 176, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "execution_count": 176, | |
| "data": { | |
| "text/plain": "{'train': 1.0, 'valid': 0.6111111111111112, 'test': 0.625}" | |
| }, | |
| "metadata": {} | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": {}, | |
| "cell_type": "markdown", | |
| "source": "fake data" | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "evaluate(RandomForestClassifier(n_estimators=100, random_state=0),\n fake_train_feats, [y for x ,fs, y in fake_train_set], \n fake_valid_feats, [y for x, fs, y in fake_valid_set], \n fake_test_feats, [y for x, fs, y in fake_test_set])", | |
| "execution_count": 177, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "execution_count": 177, | |
| "data": { | |
| "text/plain": "{'train': 1.0, 'valid': 0.7, 'test': 0.675}" | |
| }, | |
| "metadata": {} | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": {}, | |
| "cell_type": "markdown", | |
| "source": "train on real data, predict fake data" | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "evaluate(RandomForestClassifier(n_estimators=100, random_state=0),\n train_feats, [y for x ,fs, y in train_set], \n fake_valid_feats, [y for x, fs, y in fake_valid_set], \n fake_test_feats, [y for x, fs, y in fake_test_set])", | |
| "execution_count": 178, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "execution_count": 178, | |
| "data": { | |
| "text/plain": "{'train': 1.0, 'valid': 0.6111111111111112, 'test': 0.5875}" | |
| }, | |
| "metadata": {} | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": {}, | |
| "cell_type": "markdown", | |
| "source": "train on fake data, predict real data" | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "evaluate(RandomForestClassifier(n_estimators=100, random_state=0),\n fake_train_feats, [y for x ,fs, y in fake_train_set], \n valid_feats, [y for x, fs, y in valid_set], \n test_feats, [y for x, fs, y in test_set])", | |
| "execution_count": 179, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "execution_count": 179, | |
| "data": { | |
| "text/plain": "{'train': 1.0, 'valid': 0.6111111111111112, 'test': 0.5375}" | |
| }, | |
| "metadata": {} | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "", | |
| "execution_count": null, | |
| "outputs": [] | |
| } | |
| ], | |
| "metadata": { | |
| "language_info": { | |
| "pygments_lexer": "ipython3", | |
| "codemirror_mode": { | |
| "version": 3, | |
| "name": "ipython" | |
| }, | |
| "version": "3.6.7", | |
| "nbconvert_exporter": "python", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "file_extension": ".py" | |
| }, | |
| "kernelspec": { | |
| "name": "brainfeatures", | |
| "display_name": "brainfeatures", | |
| "language": "python" | |
| }, | |
| "gist_id": "e5f7615fafd65b4ab54c77be2c7df52c" | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment