-
-
Save iso-p/d34b62ab8e397970883b07163330fef2 to your computer and use it in GitHub Desktop.
Braindecode on CHB-MIT Scalp EEG Database
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"nbsphinx": "hidden" | |
}, | |
"outputs": [], | |
"source": [ | |
"%load_ext autoreload\n", | |
"%autoreload 2" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# CHB MIT Scalp Seizure Dataset" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"1. First download dataset from https://www.physionet.org/pn6/chbmit/ (for this notebook only chb01 is needed)\n", | |
"2. Change the `base_path` in the code below" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Load Data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import glob\n", | |
"import os.path\n", | |
"subject_id = 1\n", | |
"base_path = \"/data/schirrmr/physionet.org/physiobank/database/chbmit/\"\n", | |
"edf_file_names = sorted(glob.glob(os.path.join(base_path, \"chb{:02d}/*.edf\".format(subject_id))))\n", | |
"summary_file = os.path.join(base_path, \"chb{:02d}/chb{:02d}-summary.txt\".format(subject_id, subject_id))\n", | |
"\n", | |
"\n", | |
"summary_content = open(summary_file,'r').read()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import re\n", | |
"import mne\n", | |
"import numpy as np\n", | |
"def extract_data_and_labels(edf_filename, summary_text):\n", | |
" folder, basename = os.path.split(edf_filename)\n", | |
" \n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n", | |
" X = edf.get_data().astype(np.float32) * 1e6 # to mV\n", | |
" y = np.zeros(X.shape[1], dtype=np.int64)\n", | |
" i_text_start = summary_text.index(basename)\n", | |
"\n", | |
" if 'File Name' in summary_text[i_text_start:]:\n", | |
" i_text_stop = summary_text.index('File Name', i_text_start)\n", | |
" else:\n", | |
" i_text_stop = len(summary_text)\n", | |
" assert i_text_stop > i_text_start\n", | |
"\n", | |
" file_text = summary_text[i_text_start:i_text_stop]\n", | |
" if 'Seizure Start' in file_text:\n", | |
" start_sec = int(re.search(r\"Seizure Start Time: ([0-9]*) seconds\", file_text).group(1))\n", | |
" end_sec = int(re.search(r\"Seizure End Time: ([0-9]*) seconds\", file_text).group(1))\n", | |
" i_seizure_start = int(round(start_sec * edf.info['sfreq']))\n", | |
" i_seizure_stop = int(round((end_sec + 1) * edf.info['sfreq']))\n", | |
" y[i_seizure_start:i_seizure_stop] = 1\n", | |
" assert X.shape[1] == len(y)\n", | |
" return X,y" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting edf Parameters from /data/schirrmr/physionet.org/physiobank/database/chbmit/chb01/chb01_01.edf...\n", | |
"EDF file detected\n", | |
"Setting channel info structure...\n", | |
"Created Raw.info structure...\n", | |
"Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
"Ready.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"<ipython-input-3-ae37fa0d04aa>:7: RuntimeWarning: Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting edf Parameters from /data/schirrmr/physionet.org/physiobank/database/chbmit/chb01/chb01_02.edf...\n", | |
"EDF file detected\n", | |
"Setting channel info structure...\n", | |
"Created Raw.info structure...\n", | |
"Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
"Ready.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"<ipython-input-3-ae37fa0d04aa>:7: RuntimeWarning: Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting edf Parameters from /data/schirrmr/physionet.org/physiobank/database/chbmit/chb01/chb01_03.edf...\n", | |
"EDF file detected\n", | |
"Setting channel info structure...\n", | |
"Created Raw.info structure...\n", | |
"Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
"Ready.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"<ipython-input-3-ae37fa0d04aa>:7: RuntimeWarning: Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting edf Parameters from /data/schirrmr/physionet.org/physiobank/database/chbmit/chb01/chb01_04.edf...\n", | |
"EDF file detected\n", | |
"Setting channel info structure...\n", | |
"Created Raw.info structure...\n", | |
"Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
"Ready.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"<ipython-input-3-ae37fa0d04aa>:7: RuntimeWarning: Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting edf Parameters from /data/schirrmr/physionet.org/physiobank/database/chbmit/chb01/chb01_05.edf...\n", | |
"EDF file detected\n", | |
"Setting channel info structure...\n", | |
"Created Raw.info structure...\n", | |
"Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
"Ready.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"<ipython-input-3-ae37fa0d04aa>:7: RuntimeWarning: Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting edf Parameters from /data/schirrmr/physionet.org/physiobank/database/chbmit/chb01/chb01_06.edf...\n", | |
"EDF file detected\n", | |
"Setting channel info structure...\n", | |
"Created Raw.info structure...\n", | |
"Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
"Ready.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"<ipython-input-3-ae37fa0d04aa>:7: RuntimeWarning: Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting edf Parameters from /data/schirrmr/physionet.org/physiobank/database/chbmit/chb01/chb01_07.edf...\n", | |
"EDF file detected\n", | |
"Setting channel info structure...\n", | |
"Created Raw.info structure...\n", | |
"Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
"Ready.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"<ipython-input-3-ae37fa0d04aa>:7: RuntimeWarning: Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting edf Parameters from /data/schirrmr/physionet.org/physiobank/database/chbmit/chb01/chb01_08.edf...\n", | |
"EDF file detected\n", | |
"Setting channel info structure...\n", | |
"Created Raw.info structure...\n", | |
"Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
"Ready.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"<ipython-input-3-ae37fa0d04aa>:7: RuntimeWarning: Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting edf Parameters from /data/schirrmr/physionet.org/physiobank/database/chbmit/chb01/chb01_09.edf...\n", | |
"EDF file detected\n", | |
"Setting channel info structure...\n", | |
"Created Raw.info structure...\n", | |
"Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
"Ready.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"<ipython-input-3-ae37fa0d04aa>:7: RuntimeWarning: Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting edf Parameters from /data/schirrmr/physionet.org/physiobank/database/chbmit/chb01/chb01_10.edf...\n", | |
"EDF file detected\n", | |
"Setting channel info structure...\n", | |
"Created Raw.info structure...\n", | |
"Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
"Ready.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"<ipython-input-3-ae37fa0d04aa>:7: RuntimeWarning: Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting edf Parameters from /data/schirrmr/physionet.org/physiobank/database/chbmit/chb01/chb01_11.edf...\n", | |
"EDF file detected\n", | |
"Setting channel info structure...\n", | |
"Created Raw.info structure...\n", | |
"Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
"Ready.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"<ipython-input-3-ae37fa0d04aa>:7: RuntimeWarning: Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting edf Parameters from /data/schirrmr/physionet.org/physiobank/database/chbmit/chb01/chb01_12.edf...\n", | |
"EDF file detected\n", | |
"Setting channel info structure...\n", | |
"Created Raw.info structure...\n", | |
"Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
"Ready.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"<ipython-input-3-ae37fa0d04aa>:7: RuntimeWarning: Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting edf Parameters from /data/schirrmr/physionet.org/physiobank/database/chbmit/chb01/chb01_13.edf...\n", | |
"EDF file detected\n", | |
"Setting channel info structure...\n", | |
"Created Raw.info structure...\n", | |
"Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
"Ready.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"<ipython-input-3-ae37fa0d04aa>:7: RuntimeWarning: Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting edf Parameters from /data/schirrmr/physionet.org/physiobank/database/chbmit/chb01/chb01_14.edf...\n", | |
"EDF file detected\n", | |
"Setting channel info structure...\n", | |
"Created Raw.info structure...\n", | |
"Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
"Ready.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"<ipython-input-3-ae37fa0d04aa>:7: RuntimeWarning: Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting edf Parameters from /data/schirrmr/physionet.org/physiobank/database/chbmit/chb01/chb01_15.edf...\n", | |
"EDF file detected\n", | |
"Setting channel info structure...\n", | |
"Created Raw.info structure...\n", | |
"Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
"Ready.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"<ipython-input-3-ae37fa0d04aa>:7: RuntimeWarning: Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting edf Parameters from /data/schirrmr/physionet.org/physiobank/database/chbmit/chb01/chb01_16.edf...\n", | |
"EDF file detected\n", | |
"Setting channel info structure...\n", | |
"Created Raw.info structure...\n", | |
"Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
"Ready.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"<ipython-input-3-ae37fa0d04aa>:7: RuntimeWarning: Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting edf Parameters from /data/schirrmr/physionet.org/physiobank/database/chbmit/chb01/chb01_17.edf...\n", | |
"EDF file detected\n", | |
"Setting channel info structure...\n", | |
"Created Raw.info structure...\n", | |
"Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
"Ready.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"<ipython-input-3-ae37fa0d04aa>:7: RuntimeWarning: Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting edf Parameters from /data/schirrmr/physionet.org/physiobank/database/chbmit/chb01/chb01_18.edf...\n", | |
"EDF file detected\n", | |
"Setting channel info structure...\n", | |
"Created Raw.info structure...\n", | |
"Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
"Ready.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"<ipython-input-3-ae37fa0d04aa>:7: RuntimeWarning: Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting edf Parameters from /data/schirrmr/physionet.org/physiobank/database/chbmit/chb01/chb01_19.edf...\n", | |
"EDF file detected\n", | |
"Setting channel info structure...\n", | |
"Created Raw.info structure...\n", | |
"Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
"Ready.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"<ipython-input-3-ae37fa0d04aa>:7: RuntimeWarning: Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting edf Parameters from /data/schirrmr/physionet.org/physiobank/database/chbmit/chb01/chb01_20.edf...\n", | |
"EDF file detected\n", | |
"Setting channel info structure...\n", | |
"Created Raw.info structure...\n", | |
"Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
"Ready.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"<ipython-input-3-ae37fa0d04aa>:7: RuntimeWarning: Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting edf Parameters from /data/schirrmr/physionet.org/physiobank/database/chbmit/chb01/chb01_21.edf...\n", | |
"EDF file detected\n", | |
"Setting channel info structure...\n", | |
"Created Raw.info structure...\n", | |
"Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
"Ready.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"<ipython-input-3-ae37fa0d04aa>:7: RuntimeWarning: Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting edf Parameters from /data/schirrmr/physionet.org/physiobank/database/chbmit/chb01/chb01_22.edf...\n", | |
"EDF file detected\n", | |
"Setting channel info structure...\n", | |
"Created Raw.info structure...\n", | |
"Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
"Ready.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"<ipython-input-3-ae37fa0d04aa>:7: RuntimeWarning: Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting edf Parameters from /data/schirrmr/physionet.org/physiobank/database/chbmit/chb01/chb01_23.edf...\n", | |
"EDF file detected\n", | |
"Setting channel info structure...\n", | |
"Created Raw.info structure...\n", | |
"Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
"Ready.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"<ipython-input-3-ae37fa0d04aa>:7: RuntimeWarning: Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting edf Parameters from /data/schirrmr/physionet.org/physiobank/database/chbmit/chb01/chb01_24.edf...\n", | |
"EDF file detected\n", | |
"Setting channel info structure...\n", | |
"Created Raw.info structure...\n", | |
"Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
"Ready.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"<ipython-input-3-ae37fa0d04aa>:7: RuntimeWarning: Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting edf Parameters from /data/schirrmr/physionet.org/physiobank/database/chbmit/chb01/chb01_25.edf...\n", | |
"EDF file detected\n", | |
"Setting channel info structure...\n", | |
"Created Raw.info structure...\n", | |
"Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
"Ready.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"<ipython-input-3-ae37fa0d04aa>:7: RuntimeWarning: Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting edf Parameters from /data/schirrmr/physionet.org/physiobank/database/chbmit/chb01/chb01_26.edf...\n", | |
"EDF file detected\n", | |
"Setting channel info structure...\n", | |
"Created Raw.info structure...\n", | |
"Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
"Ready.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"<ipython-input-3-ae37fa0d04aa>:7: RuntimeWarning: Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting edf Parameters from /data/schirrmr/physionet.org/physiobank/database/chbmit/chb01/chb01_27.edf...\n", | |
"EDF file detected\n", | |
"Setting channel info structure...\n", | |
"Created Raw.info structure...\n", | |
"Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
"Ready.\n", | |
"Extracting edf Parameters from /data/schirrmr/physionet.org/physiobank/database/chbmit/chb01/chb01_29.edf...\n", | |
"EDF file detected\n", | |
"Setting channel info structure...\n", | |
"Created Raw.info structure...\n", | |
"Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
"Ready.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"<ipython-input-3-ae37fa0d04aa>:7: RuntimeWarning: Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n", | |
"<ipython-input-3-ae37fa0d04aa>:7: RuntimeWarning: Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting edf Parameters from /data/schirrmr/physionet.org/physiobank/database/chbmit/chb01/chb01_30.edf...\n", | |
"EDF file detected\n", | |
"Setting channel info structure...\n", | |
"Created Raw.info structure...\n", | |
"Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
"Ready.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"<ipython-input-3-ae37fa0d04aa>:7: RuntimeWarning: Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting edf Parameters from /data/schirrmr/physionet.org/physiobank/database/chbmit/chb01/chb01_31.edf...\n", | |
"EDF file detected\n", | |
"Setting channel info structure...\n", | |
"Created Raw.info structure...\n", | |
"Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
"Ready.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"<ipython-input-3-ae37fa0d04aa>:7: RuntimeWarning: Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting edf Parameters from /data/schirrmr/physionet.org/physiobank/database/chbmit/chb01/chb01_32.edf...\n", | |
"EDF file detected\n", | |
"Setting channel info structure...\n", | |
"Created Raw.info structure...\n", | |
"Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
"Ready.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"<ipython-input-3-ae37fa0d04aa>:7: RuntimeWarning: Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting edf Parameters from /data/schirrmr/physionet.org/physiobank/database/chbmit/chb01/chb01_33.edf...\n", | |
"EDF file detected\n", | |
"Setting channel info structure...\n", | |
"Created Raw.info structure...\n", | |
"Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
"Ready.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"<ipython-input-3-ae37fa0d04aa>:7: RuntimeWarning: Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting edf Parameters from /data/schirrmr/physionet.org/physiobank/database/chbmit/chb01/chb01_34.edf...\n", | |
"EDF file detected\n", | |
"Setting channel info structure...\n", | |
"Created Raw.info structure...\n", | |
"Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
"Ready.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"<ipython-input-3-ae37fa0d04aa>:7: RuntimeWarning: Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting edf Parameters from /data/schirrmr/physionet.org/physiobank/database/chbmit/chb01/chb01_36.edf...\n", | |
"EDF file detected\n", | |
"Setting channel info structure...\n", | |
"Created Raw.info structure...\n", | |
"Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
"Ready.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"<ipython-input-3-ae37fa0d04aa>:7: RuntimeWarning: Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting edf Parameters from /data/schirrmr/physionet.org/physiobank/database/chbmit/chb01/chb01_37.edf...\n", | |
"EDF file detected\n", | |
"Setting channel info structure...\n", | |
"Created Raw.info structure...\n", | |
"Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
"Ready.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"<ipython-input-3-ae37fa0d04aa>:7: RuntimeWarning: Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting edf Parameters from /data/schirrmr/physionet.org/physiobank/database/chbmit/chb01/chb01_38.edf...\n", | |
"EDF file detected\n", | |
"Setting channel info structure...\n", | |
"Created Raw.info structure...\n", | |
"Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
"Ready.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"<ipython-input-3-ae37fa0d04aa>:7: RuntimeWarning: Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting edf Parameters from /data/schirrmr/physionet.org/physiobank/database/chbmit/chb01/chb01_39.edf...\n", | |
"EDF file detected\n", | |
"Setting channel info structure...\n", | |
"Created Raw.info structure...\n", | |
"Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
"Ready.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"<ipython-input-3-ae37fa0d04aa>:7: RuntimeWarning: Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting edf Parameters from /data/schirrmr/physionet.org/physiobank/database/chbmit/chb01/chb01_40.edf...\n", | |
"EDF file detected\n", | |
"Setting channel info structure...\n", | |
"Created Raw.info structure...\n", | |
"Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
"Ready.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"<ipython-input-3-ae37fa0d04aa>:7: RuntimeWarning: Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting edf Parameters from /data/schirrmr/physionet.org/physiobank/database/chbmit/chb01/chb01_41.edf...\n", | |
"EDF file detected\n", | |
"Setting channel info structure...\n", | |
"Created Raw.info structure...\n", | |
"Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
"Ready.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"<ipython-input-3-ae37fa0d04aa>:7: RuntimeWarning: Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting edf Parameters from /data/schirrmr/physionet.org/physiobank/database/chbmit/chb01/chb01_42.edf...\n", | |
"EDF file detected\n", | |
"Setting channel info structure...\n", | |
"Created Raw.info structure...\n", | |
"Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
"Ready.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"<ipython-input-3-ae37fa0d04aa>:7: RuntimeWarning: Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting edf Parameters from /data/schirrmr/physionet.org/physiobank/database/chbmit/chb01/chb01_43.edf...\n", | |
"EDF file detected\n", | |
"Setting channel info structure...\n", | |
"Created Raw.info structure...\n", | |
"Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
"Ready.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"<ipython-input-3-ae37fa0d04aa>:7: RuntimeWarning: Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Extracting edf Parameters from /data/schirrmr/physionet.org/physiobank/database/chbmit/chb01/chb01_46.edf...\n", | |
"EDF file detected\n", | |
"Setting channel info structure...\n", | |
"Created Raw.info structure...\n", | |
"Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
"Ready.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"<ipython-input-3-ae37fa0d04aa>:7: RuntimeWarning: Channel names are not unique, found duplicates for: {'T8-P8'}. Applying running numbers for duplicates.\n", | |
" edf = mne.io.read_raw_edf(edf_filename,stim_channel=None)\n" | |
] | |
} | |
], | |
"source": [ | |
"all_X = []\n", | |
"all_y = []\n", | |
"for edf_file_name in edf_file_names:\n", | |
" X, y = extract_data_and_labels(edf_file_name, summary_content)\n", | |
" all_X.append(X)\n", | |
" all_y.append(y)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"dummy_X = all_X#[x[:,:2000] for x in all_X]\n", | |
"\n", | |
"dummy_y = all_y#[y[:2000] for y in all_y]\n", | |
"\n", | |
"from braindecode.datautil.signal_target import SignalAndTarget\n", | |
"from braindecode.datautil.splitters import split_into_two_sets\n", | |
"whole_set = SignalAndTarget(dummy_X, dummy_y)\n", | |
"train_set, test_set = split_into_two_sets(whole_set,0.5)\n", | |
"train_set, valid_set = split_into_two_sets(train_set, 0.7)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Create model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from braindecode.models.shallow_fbcsp import ShallowFBCSPNet\n", | |
"from braindecode.models.deep4 import Deep4Net\n", | |
"from torch import nn\n", | |
"from braindecode.torch_ext.util import set_random_seeds\n", | |
"from braindecode.models.util import to_dense_prediction_model\n", | |
"\n", | |
"# Set if you want to use GPU\n", | |
"# You can also use torch.cuda.is_available() to determine if cuda is available on your machine.\n", | |
"cuda = True\n", | |
"set_random_seeds(seed=20170629, cuda=cuda)\n", | |
"\n", | |
"# This will determine how many crops are processed in parallel\n", | |
"input_time_length = 1200\n", | |
"n_classes = 2\n", | |
"in_chans = train_set.X[0].shape[0]\n", | |
"# final_conv_length determines the size of the receptive field of the ConvNet\n", | |
"#model = ShallowFBCSPNet(in_chans=in_chans, n_classes=n_classes, input_time_length=input_time_length,\n", | |
"# final_conv_length=12).create_network()\n", | |
"model = Deep4Net(in_chans=in_chans, n_classes=n_classes, input_time_length=input_time_length,\n", | |
" final_conv_length=2, stride_before_pool=True).create_network()\n", | |
"to_dense_prediction_model(model)\n", | |
"\n", | |
"if cuda:\n", | |
" model.cuda()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"from torch import optim\n", | |
"\n", | |
"optimizer = optim.Adam(model.parameters())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"519 predictions per input/trial\n" | |
] | |
} | |
], | |
"source": [ | |
"from braindecode.torch_ext.util import np_to_var\n", | |
"# determine output size\n", | |
"test_input = np_to_var(np.ones((2, in_chans, input_time_length, 1), dtype=np.float32))\n", | |
"if cuda:\n", | |
" test_input = test_input.cuda()\n", | |
"out = model(test_input)\n", | |
"n_preds_per_input = out.cpu().data.numpy().shape[2]\n", | |
"print(\"{:d} predictions per input/trial\".format(n_preds_per_input))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Monitor that computes multiple metrics\n", | |
"\n", | |
"You can adapt it to compute other metrics you are interested in.." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"from braindecode.experiments.monitors import compute_preds_per_trial_for_set\n", | |
"\n", | |
"from sklearn.metrics import roc_auc_score\n", | |
"\n", | |
"\n", | |
"class SeizureMonitor(object):\n", | |
" \"\"\"\n", | |
" Compute trialwise misclasses from predictions for crops.\n", | |
" \n", | |
" Parameters\n", | |
" ----------\n", | |
" input_time_length: int\n", | |
" Temporal length of one input to the model.\n", | |
" \"\"\"\n", | |
" def __init__(self, input_time_length=None):\n", | |
" self.input_time_length = input_time_length\n", | |
"\n", | |
" def monitor_epoch(self,):\n", | |
" return\n", | |
"\n", | |
" def monitor_set(self, setname, all_preds, all_losses,\n", | |
" all_batch_sizes, all_targets, dataset):\n", | |
" \"\"\"Assuming one hot encoding for now\"\"\"\n", | |
" assert self.input_time_length is not None, \"Need to know input time length...\"\n", | |
" # this will be timeseries of predictions\n", | |
" # for each trial\n", | |
" preds_per_trial = compute_preds_per_trial_for_set(all_preds, self.input_time_length, \n", | |
" dataset)\n", | |
" seizure_preds = []\n", | |
" all_preds = []\n", | |
" all_y = []\n", | |
" for i_trial in range(len(preds_per_trial)):\n", | |
" this_y = dataset.y[i_trial]\n", | |
" this_preds = preds_per_trial[i_trial]\n", | |
" this_preds = np.exp(this_preds[1])\n", | |
" n_missing_preds = len(this_y) - len(this_preds)\n", | |
" this_preds = np.concatenate((np.zeros(n_missing_preds, dtype=this_preds.dtype),\n", | |
" this_preds))\n", | |
" all_preds.extend(this_preds)\n", | |
" all_y.extend(this_y)\n", | |
" if np.any(this_y == 1):\n", | |
" seizure_preds.append(this_preds[this_y == 1])\n", | |
" if len(seizure_preds) > 0:\n", | |
" max_seiz_preds = np.array([np.max(p) for p in seizure_preds])\n", | |
" sensitivity = np.mean(max_seiz_preds > 0.5)\n", | |
" else:\n", | |
" sensitivity = np.nan\n", | |
" \n", | |
" sensitivity_name = \"{:s}_sensitivity\".format(setname)\n", | |
" if len(np.unique(all_y)) > 1:\n", | |
" auc = roc_auc_score(all_y, all_preds)\n", | |
" else:\n", | |
" auc = np.nan\n", | |
" auc_name = \"{:s}_auc\".format(setname)\n", | |
" return {sensitivity_name: float(sensitivity),\n", | |
" auc_name: float(auc)}" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Setup Experiment and Run" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"from braindecode.torch_ext.losses import log_categorical_crossentropy\n", | |
"from braindecode.experiments.experiment import Experiment\n", | |
"from braindecode.datautil.iterators import CropsFromTrialsIterator\n", | |
"from braindecode.experiments.monitors import RuntimeMonitor, LossMonitor, CroppedTrialMisclassMonitor, MisclassMonitor\n", | |
"from braindecode.experiments.stopcriteria import MaxEpochs\n", | |
"import torch.nn.functional as F\n", | |
"import torch as th\n", | |
"from braindecode.torch_ext.modules import Expression\n", | |
"# Iterator is used to iterate over datasets both for training\n", | |
"# and evaluation\n", | |
"iterator = CropsFromTrialsIterator(batch_size=32,input_time_length=input_time_length,\n", | |
" n_preds_per_input=n_preds_per_input)\n", | |
"\n", | |
"# Loss function takes predictions as they come out of the network and the targets\n", | |
"# and returns a loss\n", | |
"loss_function = lambda preds, targets: log_categorical_crossentropy(preds, targets)\n", | |
"# Could be used to apply some constraint on the models, then should be object\n", | |
"# with apply method that accepts a module\n", | |
"model_constraint = None\n", | |
"# Monitors log the training progress\n", | |
"monitors = [LossMonitor(), MisclassMonitor(col_suffix='misclass'),\n", | |
" SeizureMonitor(input_time_length),\n", | |
" RuntimeMonitor(),]\n", | |
"# Stop criterion determines when the first stop happens\n", | |
"stop_criterion = MaxEpochs(5)\n", | |
"exp = Experiment(model, train_set, valid_set, test_set, iterator, loss_function, optimizer, model_constraint,\n", | |
" monitors, stop_criterion, remember_best_column='valid_misclass',\n", | |
" run_after_early_stop=True, batch_modifier=None, cuda=cuda)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"2017-10-03 17:25:58,068 INFO : Run until first stop...\n", | |
"2017-10-03 17:27:30,965 INFO : Epoch 0\n", | |
"2017-10-03 17:27:30,966 INFO : train_loss 0.74135\n", | |
"2017-10-03 17:27:30,967 INFO : valid_loss 1.46873\n", | |
"2017-10-03 17:27:30,968 INFO : test_loss 0.57399\n", | |
"2017-10-03 17:27:30,969 INFO : train_misclass 0.03220\n", | |
"2017-10-03 17:27:30,970 INFO : valid_misclass 0.01954\n", | |
"2017-10-03 17:27:30,971 INFO : test_misclass 0.02823\n", | |
"2017-10-03 17:27:30,972 INFO : train_sensitivity 1.00000\n", | |
"2017-10-03 17:27:30,973 INFO : train_auc 0.14261\n", | |
"2017-10-03 17:27:30,974 INFO : valid_sensitivity 1.00000\n", | |
"2017-10-03 17:27:30,975 INFO : valid_auc 0.11537\n", | |
"2017-10-03 17:27:30,976 INFO : test_sensitivity 1.00000\n", | |
"2017-10-03 17:27:30,977 INFO : test_auc 0.14126\n", | |
"2017-10-03 17:27:30,978 INFO : runtime 0.00000\n", | |
"2017-10-03 17:27:30,979 INFO : \n", | |
"2017-10-03 17:27:30,986 INFO : New best valid_misclass: 0.019545\n", | |
"2017-10-03 17:27:30,987 INFO : \n", | |
"2017-10-03 17:28:51,310 INFO : Time only for training updates: 75.69s\n", | |
"2017-10-03 17:30:23,441 INFO : Epoch 1\n", | |
"2017-10-03 17:30:23,443 INFO : train_loss 0.00781\n", | |
"2017-10-03 17:30:23,444 INFO : valid_loss 0.07895\n", | |
"2017-10-03 17:30:23,445 INFO : test_loss 0.00743\n", | |
"2017-10-03 17:30:23,445 INFO : train_misclass 0.00143\n", | |
"2017-10-03 17:30:23,446 INFO : valid_misclass 0.01001\n", | |
"2017-10-03 17:30:23,447 INFO : test_misclass 0.00126\n", | |
"2017-10-03 17:30:23,448 INFO : train_sensitivity 1.00000\n", | |
"2017-10-03 17:30:23,449 INFO : train_auc 0.92865\n", | |
"2017-10-03 17:30:23,449 INFO : valid_sensitivity 1.00000\n", | |
"2017-10-03 17:30:23,450 INFO : valid_auc 0.83478\n", | |
"2017-10-03 17:30:23,451 INFO : test_sensitivity 1.00000\n", | |
"2017-10-03 17:30:23,452 INFO : test_auc 0.92489\n", | |
"2017-10-03 17:30:23,453 INFO : runtime 173.24244\n", | |
"2017-10-03 17:30:23,453 INFO : \n", | |
"2017-10-03 17:30:23,458 INFO : New best valid_misclass: 0.010006\n", | |
"2017-10-03 17:30:23,459 INFO : \n", | |
"2017-10-03 17:31:44,107 INFO : Time only for training updates: 76.02s\n", | |
"2017-10-03 17:33:16,295 INFO : Epoch 2\n", | |
"2017-10-03 17:33:16,297 INFO : train_loss 0.00128\n", | |
"2017-10-03 17:33:16,298 INFO : valid_loss 0.04533\n", | |
"2017-10-03 17:33:16,299 INFO : test_loss 0.00424\n", | |
"2017-10-03 17:33:16,300 INFO : train_misclass 0.00021\n", | |
"2017-10-03 17:33:16,300 INFO : valid_misclass 0.00579\n", | |
"2017-10-03 17:33:16,301 INFO : test_misclass 0.00115\n", | |
"2017-10-03 17:33:16,302 INFO : train_sensitivity 1.00000\n", | |
"2017-10-03 17:33:16,303 INFO : train_auc 0.99708\n", | |
"2017-10-03 17:33:16,304 INFO : valid_sensitivity 1.00000\n", | |
"2017-10-03 17:33:16,304 INFO : valid_auc 0.95073\n", | |
"2017-10-03 17:33:16,305 INFO : test_sensitivity 1.00000\n", | |
"2017-10-03 17:33:16,306 INFO : test_auc 0.98772\n", | |
"2017-10-03 17:33:16,307 INFO : runtime 172.79666\n", | |
"2017-10-03 17:33:16,308 INFO : \n", | |
"2017-10-03 17:33:16,312 INFO : New best valid_misclass: 0.005794\n", | |
"2017-10-03 17:33:16,313 INFO : \n", | |
"2017-10-03 17:34:36,847 INFO : Time only for training updates: 75.93s\n", | |
"2017-10-03 17:36:09,110 INFO : Epoch 3\n", | |
"2017-10-03 17:36:09,112 INFO : train_loss 0.00071\n", | |
"2017-10-03 17:36:09,113 INFO : valid_loss 0.03611\n", | |
"2017-10-03 17:36:09,114 INFO : test_loss 0.00392\n", | |
"2017-10-03 17:36:09,114 INFO : train_misclass 0.00021\n", | |
"2017-10-03 17:36:09,115 INFO : valid_misclass 0.00533\n", | |
"2017-10-03 17:36:09,116 INFO : test_misclass 0.00100\n", | |
"2017-10-03 17:36:09,117 INFO : train_sensitivity 1.00000\n", | |
"2017-10-03 17:36:09,118 INFO : train_auc 0.99982\n", | |
"2017-10-03 17:36:09,119 INFO : valid_sensitivity 1.00000\n", | |
"2017-10-03 17:36:09,119 INFO : valid_auc 0.95862\n", | |
"2017-10-03 17:36:09,120 INFO : test_sensitivity 1.00000\n", | |
"2017-10-03 17:36:09,121 INFO : test_auc 0.99435\n", | |
"2017-10-03 17:36:09,122 INFO : runtime 172.74023\n", | |
"2017-10-03 17:36:09,123 INFO : \n", | |
"2017-10-03 17:36:09,126 INFO : New best valid_misclass: 0.005333\n", | |
"2017-10-03 17:36:09,127 INFO : \n", | |
"2017-10-03 17:37:29,685 INFO : Time only for training updates: 75.95s\n", | |
"2017-10-03 17:39:01,783 INFO : Epoch 4\n", | |
"2017-10-03 17:39:01,784 INFO : train_loss 0.00078\n", | |
"2017-10-03 17:39:01,785 INFO : valid_loss 0.05229\n", | |
"2017-10-03 17:39:01,786 INFO : test_loss 0.00317\n", | |
"2017-10-03 17:39:01,787 INFO : train_misclass 0.00019\n", | |
"2017-10-03 17:39:01,788 INFO : valid_misclass 0.00564\n", | |
"2017-10-03 17:39:01,789 INFO : test_misclass 0.00079\n", | |
"2017-10-03 17:39:01,789 INFO : train_sensitivity 1.00000\n", | |
"2017-10-03 17:39:01,790 INFO : train_auc 0.99982\n", | |
"2017-10-03 17:39:01,791 INFO : valid_sensitivity 1.00000\n", | |
"2017-10-03 17:39:01,792 INFO : valid_auc 0.95368\n", | |
"2017-10-03 17:39:01,793 INFO : test_sensitivity 1.00000\n", | |
"2017-10-03 17:39:01,793 INFO : test_auc 0.99381\n", | |
"2017-10-03 17:39:01,794 INFO : runtime 172.83778\n", | |
"2017-10-03 17:39:01,795 INFO : \n", | |
"2017-10-03 17:40:22,336 INFO : Time only for training updates: 75.95s\n", | |
"2017-10-03 17:41:54,442 INFO : Epoch 5\n", | |
"2017-10-03 17:41:54,443 INFO : train_loss 0.00117\n", | |
"2017-10-03 17:41:54,444 INFO : valid_loss 0.06295\n", | |
"2017-10-03 17:41:54,445 INFO : test_loss 0.00552\n", | |
"2017-10-03 17:41:54,446 INFO : train_misclass 0.00019\n", | |
"2017-10-03 17:41:54,447 INFO : valid_misclass 0.00617\n", | |
"2017-10-03 17:41:54,447 INFO : test_misclass 0.00123\n", | |
"2017-10-03 17:41:54,448 INFO : train_sensitivity 1.00000\n", | |
"2017-10-03 17:41:54,449 INFO : train_auc 0.99911\n", | |
"2017-10-03 17:41:54,450 INFO : valid_sensitivity 1.00000\n", | |
"2017-10-03 17:41:54,451 INFO : valid_auc 0.88697\n", | |
"2017-10-03 17:41:54,451 INFO : test_sensitivity 1.00000\n", | |
"2017-10-03 17:41:54,452 INFO : test_auc 0.98754\n", | |
"2017-10-03 17:41:54,453 INFO : runtime 172.65101\n", | |
"2017-10-03 17:41:54,454 INFO : \n", | |
"2017-10-03 17:41:54,455 INFO : Setup for second stop...\n", | |
"2017-10-03 17:41:54,459 INFO : Train loss to reach 0.00071\n", | |
"2017-10-03 17:41:54,459 INFO : Run until second stop...\n", | |
"2017-10-03 17:43:39,735 INFO : Epoch 4\n", | |
"2017-10-03 17:43:39,737 INFO : train_loss 0.01051\n", | |
"2017-10-03 17:43:39,738 INFO : valid_loss 0.03611\n", | |
"2017-10-03 17:43:39,739 INFO : test_loss 0.00392\n", | |
"2017-10-03 17:43:39,739 INFO : train_misclass 0.00163\n", | |
"2017-10-03 17:43:39,740 INFO : valid_misclass 0.00533\n", | |
"2017-10-03 17:43:39,741 INFO : test_misclass 0.00100\n", | |
"2017-10-03 17:43:39,742 INFO : train_sensitivity 1.00000\n", | |
"2017-10-03 17:43:39,743 INFO : train_auc 0.96786\n", | |
"2017-10-03 17:43:39,743 INFO : valid_sensitivity 1.00000\n", | |
"2017-10-03 17:43:39,744 INFO : valid_auc 0.95862\n", | |
"2017-10-03 17:43:39,745 INFO : test_sensitivity 1.00000\n", | |
"2017-10-03 17:43:39,746 INFO : test_auc 0.99435\n", | |
"2017-10-03 17:43:39,747 INFO : runtime 92.12272\n", | |
"2017-10-03 17:43:39,747 INFO : \n", | |
"2017-10-03 17:45:31,098 INFO : Time only for training updates: 104.97s\n", | |
"2017-10-03 17:47:16,342 INFO : Epoch 5\n", | |
"2017-10-03 17:47:16,344 INFO : train_loss 0.00552\n", | |
"2017-10-03 17:47:16,345 INFO : valid_loss 0.01261\n", | |
"2017-10-03 17:47:16,345 INFO : test_loss 0.00146\n", | |
"2017-10-03 17:47:16,346 INFO : train_misclass 0.00161\n", | |
"2017-10-03 17:47:16,347 INFO : valid_misclass 0.00357\n", | |
"2017-10-03 17:47:16,348 INFO : test_misclass 0.00034\n", | |
"2017-10-03 17:47:16,349 INFO : train_sensitivity 1.00000\n", | |
"2017-10-03 17:47:16,349 INFO : train_auc 0.99764\n", | |
"2017-10-03 17:47:16,350 INFO : valid_sensitivity 1.00000\n", | |
"2017-10-03 17:47:16,351 INFO : valid_auc 0.99753\n", | |
"2017-10-03 17:47:16,352 INFO : test_sensitivity 1.00000\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"2017-10-03 17:47:16,353 INFO : test_auc 0.98500\n", | |
"2017-10-03 17:47:16,353 INFO : runtime 216.63896\n", | |
"2017-10-03 17:47:16,354 INFO : \n", | |
"2017-10-03 17:49:07,741 INFO : Time only for training updates: 105.01s\n", | |
"2017-10-03 17:50:52,914 INFO : Epoch 6\n", | |
"2017-10-03 17:50:52,916 INFO : train_loss 0.00131\n", | |
"2017-10-03 17:50:52,917 INFO : valid_loss 0.00319\n", | |
"2017-10-03 17:50:52,918 INFO : test_loss 0.00525\n", | |
"2017-10-03 17:50:52,919 INFO : train_misclass 0.00031\n", | |
"2017-10-03 17:50:52,919 INFO : valid_misclass 0.00060\n", | |
"2017-10-03 17:50:52,920 INFO : test_misclass 0.00228\n", | |
"2017-10-03 17:50:52,921 INFO : train_sensitivity 1.00000\n", | |
"2017-10-03 17:50:52,922 INFO : train_auc 0.99926\n", | |
"2017-10-03 17:50:52,923 INFO : valid_sensitivity 1.00000\n", | |
"2017-10-03 17:50:52,923 INFO : valid_auc 0.99910\n", | |
"2017-10-03 17:50:52,924 INFO : test_sensitivity 1.00000\n", | |
"2017-10-03 17:50:52,925 INFO : test_auc 0.99784\n", | |
"2017-10-03 17:50:52,926 INFO : runtime 216.64343\n", | |
"2017-10-03 17:50:52,927 INFO : \n" | |
] | |
} | |
], | |
"source": [ | |
"# need to setup python logging before to be able to see anything\n", | |
"import logging\n", | |
"import sys\n", | |
"logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',\n", | |
" level=logging.DEBUG, stream=sys.stdout)\n", | |
"exp.run()" | |
] | |
} | |
], | |
"metadata": { | |
"celltoolbar": "Edit Metadata", | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment