Created
November 29, 2018 10:04
-
-
Save gemeinl/da26b86c5f48b52cdba95118c85c2d08 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "from mne.io import read_raw_edf\nfrom glob import glob\nimport numpy as np\nimport h5py\nimport json\nimport os\nimport re", | |
"execution_count": 1, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "/home/gemeinl/venvs/auto-eeg-diag/lib/python3.5/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n from ._conv import register_converters as _register_converters\n", | |
"name": "stderr" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "def natural_key(string):\n \"\"\" provides a human-like sorting key of a string \"\"\"\n p = r'(\\d+)'\n key = [int(t) if t.isdigit() else None for t in re.split(p, string)]\n return key", | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "def parse_age_and_gender_from_edf_header(file_path):\n \"\"\" parse sex and age of patient from the patient_id in the header of the\n edf file\n :param file_path: path of the recording\n :return: gender (M, X, F) and age of patient\n \"\"\"\n assert os.path.exists(file_path), \"file not found {}\".format(file_path)\n f = open(file_path, 'rb')\n content = f.read(88)\n f.close()\n patient_id = content[8:88].decode('ascii')\n [age] = re.findall(\"Age:(\\d+)\", patient_id)\n [gender] = re.findall(\"\\s(\\w)\\s\", patient_id)\n return int(age), gender", | |
"execution_count": 3, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "def numpy_load(path):\n \"\"\" load a numpy file. make sure it exists \"\"\"\n assert os.path.exists(path), \"file not found {}\".format(path)\n x = np.load(path)\n if len(x.shape) == 2:\n (xdim, ydim) = x.shape\n if xdim > ydim:\n x = x.T\n return x.astype(np.float64)", | |
"execution_count": 4, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "def h5_load(path):\n \"\"\" load signals from h5 \"\"\"\n assert os.path.exists(path), \"file not found {}\".format(path)\n f = h5py.File(path, \"r\")\n x = f[\"signals\"][:]\n f.close()\n if len(x.shape) == 2:\n xdim, ydim = x.shape\n if xdim > ydim:\n x = x.T\n return x.astype(np.float64)", | |
"execution_count": 5, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "def json_load(path):\n \"\"\" load sth from json file \"\"\"\n assert os.path.exists(path), \"file not found {}\".format(path)\n with open(path, \"r\") as json_file:\n loaded = json.load(json_file)\n return loaded", | |
"execution_count": 6, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "def mne_load_signals_and_fs_from_edf(file_, wanted_chs, ch_name_pattern=None,\n factor=1e6):\n \"\"\" read an edf file, pick channels, scale with factor and return signals\n as well as sampling frequency \"\"\"\n assert os.path.exists(file_), \"file not found {}\".format(file_)\n raw = read_raw_edf(file_, verbose=\"error\")\n fs = raw.info[\"sfreq\"]\n raw = raw.load_data()\n if ch_name_pattern is not None:\n chs = [ch_name_pattern.format(wanted_elec) for wanted_elec in wanted_chs]\n else:\n chs = wanted_chs\n raw = raw.reorder_channels(chs)\n # achieves two things: asserts that channels are sorted and picked\n # channels are in same order\n assert raw.ch_names == sorted(chs), \\\n \"actual channel names: {}, wanted channels names: {}\".format(\n ', '.join(raw.ch_names), ', '.join(chs))\n\n signals = raw.get_data()\n if factor:\n signals = signals * factor\n return signals, fs", | |
"execution_count": 7, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "def property_in_path(curr_path, property_):\n tokens = curr_path.split(\"/\")\n return property_ in tokens", | |
"execution_count": 8, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "def replace_extension(path, new_extension):\n assert new_extension.startswith(\".\")\n old_exension = os.path.splitext(path)[1]\n path = path.replace(old_exension, new_extension)\n return path", | |
"execution_count": 9, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "# check whether this can be replaced by natural key\ndef _session_key(string):\n \"\"\" sort the file name by session \"\"\"\n p = r'(s\\d*)_'\n return re.findall(p, string)\n\n\ndef _time_key(file_name):\n \"\"\" provides a time-based sorting key \"\"\"\n # the splits are specific to tuh abnormal eeg data set\n splits = file_name.split('/')\n p = r'(\\d{4}_\\d{2}_\\d{2})'\n [date] = re.findall(p, splits[-2])\n date_id = [int(token) for token in date.split('_')]\n recording_id = natural_key(splits[-1])\n session_id = _session_key(splits[-2])\n return date_id + session_id + recording_id\n\n\ndef _read_all_file_names(path, extension, key=\"time\"):\n \"\"\" read all files with specified extension from given path\n :param path: parent directory holding the files directly or in subdirectories\n :param extension: the type of the file, e.g. '.txt' or '.edf'\n :param key: the sorting of the files. natural e.g. 1, 2, 12, 21 (machine 1, 12, 2, 21) or by time since this is\n important for cv. time is specified in the edf file names\n \"\"\"\n assert key in [\"natural\", \"time\"], \"unknown sorting key\"\n file_paths = glob(path + '**/*' + extension, recursive=True)\n if key == \"time\":\n sorting_key = _time_key\n else:\n sorting_key = natural_key\n file_names = sorted(file_paths, key=sorting_key)\n\n assert len(file_names) > 0, \\\n \"something went wrong. Found no {} files in {}\".format(\n extension, path)\n return file_names\n\n\nclass TuhAbnormal():\n \"\"\"tuh abnormal data set. file names are given as\"\"\"\n # v2.0.0/edf/eval/abnormal/01_tcp_ar/007/00000768/s003_2012_04_06/00000768_s003_t000.edf\n def __init__(self, data_path, extension, subset=\"train\", channels=sorted([\n 'A1', 'A2', 'C3', 'C4', 'CZ', 'F3', 'F4', 'F7', 'F8', 'FP1', 'FP2',\n 'FZ', 'O1', 'O2', 'P3', 'P4', 'PZ', 'T3', 'T4', 'T5', 'T6']),\n key=\"time\", n_recordings=None, target=\"pathological\",\n max_recording_mins=None, ch_name_pattern=\"EEG {}-REF\"):\n super(TuhAbnormal, self).__init__()\n self.max_recording_mins = max_recording_mins\n self.ch_name_pattern = ch_name_pattern\n self.n_recordings = n_recordings\n self.extension = extension\n self.data_path = data_path\n self.channels = channels\n self.target = target\n self.subset = subset\n self.key = key\n\n self.pathologicals = []\n self.file_names = []\n self.genders = []\n self.targets = []\n self.sfreqs = []\n self.ages = []\n\n if self.subset == \"eval\":\n assert self.max_recording_mins is None, \"do not reject eval recordings\"\n\n def load(self):\n # read all file names in path with given extension sorted by key\n self.file_names = _read_all_file_names(\n self.data_path, self.extension, self.key)\n\n # prune this file names to train or eval subset\n self.file_names = [file_name for file_name in self.file_names\n if self.subset in file_name.split('/')]\n\n n_picked_recs = 0\n files_to_delete = []\n for file_name in self.file_names:\n if self.n_recordings is not None:\n if n_picked_recs == self.n_recordings:\n break\n\n # if this is raw version of data set, reject too long recordings\n if self.extension == \".edf\":\n if self.max_recording_mins is not None:\n # reject recordings that are too long\n rejected, duration = reject_too_long_recording(\n file_name, self.max_recording_mins)\n if rejected:\n files_to_delete.append(file_name)\n continue\n n_picked_recs += 1\n\n assert self.target in [\"pathological\", \"age\", \"gender\"], \\\n \"unknown target {}\".format(self.target)\n assert self.extension in [\".edf\", \".npy\", \".h5\"], \\\n \"unknown file format {}\".format(self.extension)\n if self.extension == \".edf\":\n # get pathological status, age and gender for edf file\n pathological = property_in_path(file_name, \"abnormal\")\n age, gender = parse_age_and_gender_from_edf_header(file_name)\n else:\n # load info json file of clean recording\n # get pathological status, age, gender and sfreq for clean file\n new_file_name = replace_extension(file_name, \".json\")\n info = json_load(new_file_name)\n pathological = info[\"pathological\"]\n age = int(info[\"age\"])\n gender = info[\"gender\"]\n self.sfreqs.append(info[\"sfreq\"])\n\n targets = {\"pathological\": pathological, \"age\": age, \"gender\": gender}\n self.targets.append(targets[self.target])\n self.ages.append(age)\n self.genders.append(gender)\n self.pathologicals.append(pathological)\n\n if self.max_recording_mins is not None:\n # prune list of all file names to n_recordings\n for file_name in files_to_delete:\n self.file_names.remove(file_name)\n\n if self.n_recordings is not None:\n self.file_names = self.file_names[:self.n_recordings]\n\n assert len(self.file_names) == len(self.targets), \"lengths differ\"\n if self.n_recordings is not None:\n assert len(self.file_names) == self.n_recordings, \\\n \"less recordings picked than desired\"\n assert len(np.intersect1d(self.file_names, files_to_delete)) == 0, \\\n \"deleting unwanted file names failed\"\n\n def __getitem__(self, index):\n file_ = self.file_names[index]\n label = self.targets[index]\n if self.extension == \".edf\":\n signals, sfreq = mne_load_signals_and_fs_from_edf(\n file_, self.channels, self.ch_name_pattern)\n return signals, sfreq, label\n elif self.extension == \".npy\":\n data = numpy_load(file_)\n return data, self.sfreqs[index], label\n elif self.extension == \".h5\":\n data = h5_load(file_)\n return data, self.sfreqs[index], label\n\n def __len__(self):\n return len(self.file_names)", | |
"execution_count": 10, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "data_path = \"/data/schirrmr/gemeinl/tuh-abnormal-eeg/raw/v2.0.0/edf/train/\"", | |
"execution_count": 11, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "d = TuhAbnormal(data_path, extension=\".edf\")", | |
"execution_count": 12, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "d.load()", | |
"execution_count": 13, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "len(d)", | |
"execution_count": 14, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"execution_count": 14, | |
"data": { | |
"text/plain": "2717" | |
}, | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "for i, (x, fs, y) in enumerate(d):\n if i == 3:\n break\n print(x.shape, fs, y)", | |
"execution_count": 15, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "(21, 296750) 250.0 False\n(21, 752250) 250.0 False\n(21, 312000) 250.0 True\n", | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "", | |
"execution_count": null, | |
"outputs": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"name": "auto-eeg-diag", | |
"display_name": "auto-eeg-diag", | |
"language": "python" | |
}, | |
"language_info": { | |
"pygments_lexer": "ipython3", | |
"mimetype": "text/x-python", | |
"version": "3.5.2", | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"nbconvert_exporter": "python", | |
"name": "python" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment