Skip to content

Instantly share code, notes, and snippets.

@gemeinl
Last active March 7, 2019 16:11
Show Gist options
  • Select an option

  • Save gemeinl/e5f7615fafd65b4ab54c77be2c7df52c to your computer and use it in GitHub Desktop.

Select an option

Save gemeinl/e5f7615fafd65b4ab54c77be2c7df52c to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "import logging\nlogger = logging.getLogger()\nlogger.setLevel(logging.DEBUG)",
"execution_count": 1,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "data_path = \"/media/Data1/LukasG/KayGan/train_data.jblib\"\nfake_data_path = \"/media/Data1/LukasG/KayGan/train_data_fake.jblib\"\nfs = 250\nchannels = ['Fp1','Fp2','F7','F3','Fz','F4','F8','T7','C3','Cz','C4','T8','P7',\n 'P3','Pz','P4','P8','O1','O2','M1','M2']\nleft_hand = 0\nright_hand = 1",
"execution_count": 96,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "import sys",
"execution_count": 3,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "sys.path.insert(1, \"/home/lukasg/code/braindecode/\")",
"execution_count": 4,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "from brainfeatures.data_set.abstract_data_set import DataSet",
"execution_count": 5,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "import pandas as pd",
"execution_count": 6,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "import joblib",
"execution_count": 7,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "class KayGan(DataSet):\n def __init__(self, path, subset, channel_names):\n with open(path, 'rb') as fo: \n self.loaded = joblib.load(fo)[subset]\n self.fs = 250\n self.subset = subset\n \n def __getitem__(self, idx):\n x = self.loaded.X[idx]\n return pd.DataFrame(x, index=channels), self.fs, self.loaded.y[idx]\n \n def __len__(self):\n return len(self.loaded.X)",
"execution_count": 8,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "def load_train_valid_test_set(data_path, channel_names):\n train_set = KayGan(\n path=data_path,\n subset=\"train\",\n channel_names=channels\n )\n valid_set = KayGan(\n path=data_path,\n subset=\"valid\",\n channel_names=channels\n )\n test_set = KayGan(\n path=data_path,\n subset=\"test\",\n channel_names=channels\n )\n return train_set, valid_set, test_set",
"execution_count": 71,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "train_set, valid_set, test_set = load_train_valid_test_set(data_path, channels)",
"execution_count": 72,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "len(train_set), len(valid_set), len(test_set)",
"execution_count": 73,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 73,
"data": {
"text/plain": "(358, 90, 80)"
},
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "fake_train_set, fake_valid_set, fake_test_set = load_train_valid_test_set(fake_data_path, channels)",
"execution_count": 98,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "len(fake_train_set), len(fake_valid_set), len(fake_test_set)",
"execution_count": 99,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 99,
"data": {
"text/plain": "(358, 90, 80)"
},
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "x, fs, y = train_set[0]",
"execution_count": 14,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "x.shape, fs, y",
"execution_count": 15,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 15,
"data": {
"text/plain": "((21, 384), 250, 1)"
},
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "x.shape[-1]/fs",
"execution_count": 16,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 16,
"data": {
"text/plain": "1.536"
},
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "from brainfeatures.feature_generation.generate_features import default_feature_generation_params",
"execution_count": 74,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "default_feature_generation_params.update({\"epoch_duration_s\": x.shape[-1]/fs})",
"execution_count": 75,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "from brainfeatures.experiment.experiment import Experiment",
"execution_count": 26,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "def generate_features(list_of_data_sets):\n feats = []\n for ds in list_of_data_sets:\n exp = Experiment(\n devel_set=ds,\n estimator=None,\n preproc_function=None,\n feature_generation_params=default_feature_generation_params,\n n_jobs=50\n )\n exp.run()\n feats.append(pd.concat(exp._features[\"devel\"], ignore_index=True))\n return feats",
"execution_count": 76,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "train_feats, valid_feats, test_feats = generate_features([train_set, valid_set, test_set])",
"execution_count": 77,
"outputs": [
{
"output_type": "stream",
"text": "INFO:root:Started on 2019-03-07 at 16:32:13.629218\nINFO:root:Loading devel (preprocessed)\nINFO:root:Generating features (devel)\nINFO:root:Finished on 2019-03-07 at 16:32:36.524810.\nINFO:root:Started on 2019-03-07 at 16:32:37.095011\nINFO:root:Loading devel (preprocessed)\nINFO:root:Generating features (devel)\nINFO:root:Finished on 2019-03-07 at 16:32:41.971529.\nINFO:root:Started on 2019-03-07 at 16:32:42.064338\nINFO:root:Loading devel (preprocessed)\nINFO:root:Generating features (devel)\nINFO:root:Finished on 2019-03-07 at 16:32:47.496739.\n",
"name": "stderr"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "train_feats.shape, valid_feats.shape, test_feats.shape",
"execution_count": 78,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 78,
"data": {
"text/plain": "((358, 8463), (90, 8463), (80, 8463))"
},
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "train_feats[:3]",
"execution_count": 181,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 181,
"data": {
"text/plain": " cwt_bounded_variation_a7_Fp1 cwt_bounded_variation_a7_Fp2 \\\n0 12.491957 14.566269 \n1 49.169237 18.202797 \n2 9.294037 7.581290 \n\n cwt_bounded_variation_a7_F7 cwt_bounded_variation_a7_F3 \\\n0 23.488305 13.358394 \n1 16.328832 36.492521 \n2 33.104700 18.012985 \n\n cwt_bounded_variation_a7_Fz cwt_bounded_variation_a7_F4 \\\n0 10.170716 20.240715 \n1 13.102106 14.456780 \n2 9.382596 10.728926 \n\n cwt_bounded_variation_a7_F8 cwt_bounded_variation_a7_T7 \\\n0 48.457450 72.134396 \n1 50.510341 56.999656 \n2 13.532158 84.702942 \n\n cwt_bounded_variation_a7_C3 cwt_bounded_variation_a7_Cz \\\n0 17.169753 12.602528 \n1 17.496603 28.061434 \n2 8.270255 8.210823 \n\n ... time_zero-crossing-derivative_T8 \\\n0 ... 95.0 \n1 ... 92.0 \n2 ... 94.0 \n\n time_zero-crossing-derivative_P7 time_zero-crossing-derivative_P3 \\\n0 87.0 100.0 \n1 93.0 97.0 \n2 94.0 96.0 \n\n time_zero-crossing-derivative_Pz time_zero-crossing-derivative_P4 \\\n0 102.0 95.0 \n1 92.0 99.0 \n2 100.0 91.0 \n\n time_zero-crossing-derivative_P8 time_zero-crossing-derivative_O1 \\\n0 93.0 101.0 \n1 89.0 100.0 \n2 90.0 95.0 \n\n time_zero-crossing-derivative_O2 time_zero-crossing-derivative_M1 \\\n0 98.0 83.0 \n1 98.0 87.0 \n2 94.0 102.0 \n\n time_zero-crossing-derivative_M2 \n0 94.0 \n1 88.0 \n2 92.0 \n\n[3 rows x 8463 columns]",
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>cwt_bounded_variation_a7_Fp1</th>\n <th>cwt_bounded_variation_a7_Fp2</th>\n <th>cwt_bounded_variation_a7_F7</th>\n <th>cwt_bounded_variation_a7_F3</th>\n <th>cwt_bounded_variation_a7_Fz</th>\n <th>cwt_bounded_variation_a7_F4</th>\n <th>cwt_bounded_variation_a7_F8</th>\n <th>cwt_bounded_variation_a7_T7</th>\n <th>cwt_bounded_variation_a7_C3</th>\n <th>cwt_bounded_variation_a7_Cz</th>\n <th>...</th>\n <th>time_zero-crossing-derivative_T8</th>\n <th>time_zero-crossing-derivative_P7</th>\n <th>time_zero-crossing-derivative_P3</th>\n <th>time_zero-crossing-derivative_Pz</th>\n <th>time_zero-crossing-derivative_P4</th>\n <th>time_zero-crossing-derivative_P8</th>\n <th>time_zero-crossing-derivative_O1</th>\n <th>time_zero-crossing-derivative_O2</th>\n <th>time_zero-crossing-derivative_M1</th>\n <th>time_zero-crossing-derivative_M2</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>12.491957</td>\n <td>14.566269</td>\n <td>23.488305</td>\n <td>13.358394</td>\n <td>10.170716</td>\n <td>20.240715</td>\n <td>48.457450</td>\n <td>72.134396</td>\n <td>17.169753</td>\n <td>12.602528</td>\n <td>...</td>\n <td>95.0</td>\n <td>87.0</td>\n <td>100.0</td>\n <td>102.0</td>\n <td>95.0</td>\n <td>93.0</td>\n <td>101.0</td>\n <td>98.0</td>\n <td>83.0</td>\n <td>94.0</td>\n </tr>\n <tr>\n <th>1</th>\n <td>49.169237</td>\n <td>18.202797</td>\n <td>16.328832</td>\n <td>36.492521</td>\n <td>13.102106</td>\n <td>14.456780</td>\n <td>50.510341</td>\n <td>56.999656</td>\n <td>17.496603</td>\n <td>28.061434</td>\n <td>...</td>\n <td>92.0</td>\n <td>93.0</td>\n <td>97.0</td>\n <td>92.0</td>\n <td>99.0</td>\n <td>89.0</td>\n <td>100.0</td>\n <td>98.0</td>\n <td>87.0</td>\n <td>88.0</td>\n </tr>\n <tr>\n <th>2</th>\n <td>9.294037</td>\n <td>7.581290</td>\n <td>33.104700</td>\n <td>18.012985</td>\n <td>9.382596</td>\n <td>10.728926</td>\n <td>13.532158</td>\n <td>84.702942</td>\n <td>8.270255</td>\n <td>8.210823</td>\n <td>...</td>\n <td>94.0</td>\n <td>94.0</td>\n <td>96.0</td>\n <td>100.0</td>\n <td>91.0</td>\n <td>90.0</td>\n <td>95.0</td>\n <td>94.0</td>\n <td>102.0</td>\n <td>92.0</td>\n </tr>\n </tbody>\n</table>\n<p>3 rows × 8463 columns</p>\n</div>"
},
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "fake_train_feats, fake_valid_feats, fake_test_feats = generate_features([fake_train_set, fake_valid_set, \n fake_test_set])",
"execution_count": 100,
"outputs": [
{
"output_type": "stream",
"text": "INFO:root:Started on 2019-03-07 at 16:37:41.840862\nINFO:root:Loading devel (preprocessed)\nINFO:root:Generating features (devel)\nINFO:root:Finished on 2019-03-07 at 16:37:59.955030.\nINFO:root:Started on 2019-03-07 at 16:38:00.470538\nINFO:root:Loading devel (preprocessed)\nINFO:root:Generating features (devel)\nINFO:root:Finished on 2019-03-07 at 16:38:05.723335.\nINFO:root:Started on 2019-03-07 at 16:38:05.823948\nINFO:root:Loading devel (preprocessed)\nINFO:root:Generating features (devel)\nINFO:root:Finished on 2019-03-07 at 16:38:11.449096.\n",
"name": "stderr"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "fake_train_feats.shape, fake_valid_feats.shape, fake_test_feats.shape",
"execution_count": 102,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 102,
"data": {
"text/plain": "((358, 8463), (90, 8463), (80, 8463))"
},
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "from sklearn.ensemble import RandomForestClassifier",
"execution_count": 173,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "from sklearn.metrics import accuracy_score",
"execution_count": 174,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "def evaluate(clf, train_feats, train_y, valid_feats, valid_y, test_feats, test_y):\n clf.fit(train_feats, train_y)\n \n train_preds = clf.predict(train_feats)\n train_acc = accuracy_score(train_y, train_preds)\n \n valid_preds = clf.predict(valid_feats)\n valid_acc = accuracy_score(valid_y, valid_preds)\n \n test_preds = clf.predict(test_feats)\n test_acc = accuracy_score(test_y, test_preds)\n return {\"train\": train_acc, \"valid\": valid_acc, \"test\": test_acc}",
"execution_count": 175,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "real data"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "evaluate(RandomForestClassifier(n_estimators=100, random_state=0),\n train_feats, [y for x ,fs, y in train_set], \n valid_feats, [y for x, fs, y in valid_set], \n test_feats, [y for x, fs, y in test_set])",
"execution_count": 176,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 176,
"data": {
"text/plain": "{'train': 1.0, 'valid': 0.6111111111111112, 'test': 0.625}"
},
"metadata": {}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "fake data"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "evaluate(RandomForestClassifier(n_estimators=100, random_state=0),\n fake_train_feats, [y for x ,fs, y in fake_train_set], \n fake_valid_feats, [y for x, fs, y in fake_valid_set], \n fake_test_feats, [y for x, fs, y in fake_test_set])",
"execution_count": 177,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 177,
"data": {
"text/plain": "{'train': 1.0, 'valid': 0.7, 'test': 0.675}"
},
"metadata": {}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "train on real data, predict fake data"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "evaluate(RandomForestClassifier(n_estimators=100, random_state=0),\n train_feats, [y for x ,fs, y in train_set], \n fake_valid_feats, [y for x, fs, y in fake_valid_set], \n fake_test_feats, [y for x, fs, y in fake_test_set])",
"execution_count": 178,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 178,
"data": {
"text/plain": "{'train': 1.0, 'valid': 0.6111111111111112, 'test': 0.5875}"
},
"metadata": {}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "train on fake data, predict real data"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "evaluate(RandomForestClassifier(n_estimators=100, random_state=0),\n fake_train_feats, [y for x ,fs, y in fake_train_set], \n valid_feats, [y for x, fs, y in valid_set], \n test_feats, [y for x, fs, y in test_set])",
"execution_count": 179,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 179,
"data": {
"text/plain": "{'train': 1.0, 'valid': 0.6111111111111112, 'test': 0.5375}"
},
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
}
],
"metadata": {
"language_info": {
"pygments_lexer": "ipython3",
"codemirror_mode": {
"version": 3,
"name": "ipython"
},
"version": "3.6.7",
"nbconvert_exporter": "python",
"mimetype": "text/x-python",
"name": "python",
"file_extension": ".py"
},
"kernelspec": {
"name": "brainfeatures",
"display_name": "brainfeatures",
"language": "python"
},
"gist_id": "e5f7615fafd65b4ab54c77be2c7df52c"
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment