Created
March 25, 2020 18:14
-
-
Save larsoner/a6110c5afb890d78517a070be3cccabe to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import time | |
from datetime import datetime, timezone, timedelta | |
import mne | |
import numpy as np | |
import h5py | |
import pandas as pd | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
import pyedflib | |
mne.set_log_level('CRITICAL') | |
plt.ion() | |
def _stamp_to_dt(utc_stamp): | |
"""Convert timestamp to datetime object in Windows-friendly way.""" | |
# The min on windows is 86400 | |
stamp = [int(s) for s in utc_stamp] | |
if len(stamp) == 1: # In case there is no microseconds information | |
stamp.append(0) | |
return (datetime.fromtimestamp(0, tz=timezone.utc) + | |
timedelta(0, stamp[0], stamp[1])) # day, sec, μs | |
def write_mne_edf(mne_raw, fname, picks=None, tmin=0, tmax=None, | |
overwrite=False): | |
"""Modified from https://gist.github.com/skjerns/bc660ef59dca0dbd53f00ed38c42f6be | |
Saves the raw content of an MNE.io.Raw and its subclasses to | |
a file using the EDF+/BDF filetype | |
pyEDFlib is used to save the raw contents of the RawArray to disk | |
Parameters | |
---------- | |
mne_raw : mne.io.Raw | |
An object with super class mne.io.Raw that contains the data | |
to save | |
fname : string | |
File name of the new dataset. This has to be a new filename | |
unless data have been preloaded. Filenames should end with .edf | |
picks : array-like of int | None | |
Indices of channels to include. If None all channels are kept. | |
tmin : float | None | |
Time in seconds of first sample to save. If None first sample | |
is used. | |
tmax : float | None | |
Time in seconds of last sample to save. If None last sample | |
is used. | |
overwrite : bool | |
If True, the destination file (if it exists) will be overwritten. | |
If False (default), an error will be raised if the file exists. | |
""" | |
if not issubclass(type(mne_raw), mne.io.BaseRaw): | |
raise TypeError('Must be mne.io.Raw type') | |
if not overwrite and os.path.exists(fname): | |
raise OSError('File already exists. No overwrite.') | |
# static settings | |
if os.path.splitext(fname)[-1] == '.edf': | |
file_type = pyedflib.FILETYPE_EDFPLUS | |
dmin, dmax = -32768, 32767 | |
else: | |
file_type = pyedflib.FILETYPE_BDFPLUS | |
dmin, dmax = -8388608, 8388607 | |
sfreq = mne_raw.info['sfreq'] | |
# date = _stamp_to_dt(mne_raw.info['meas_date']) | |
date = mne_raw.info['meas_date'] | |
# date = date.strftime('%d %b %Y %H:%M:%S') | |
first_sample = int(sfreq*tmin) | |
last_sample = int(sfreq*tmax) if tmax is not None else None | |
# convert data | |
channels = mne_raw.get_data(picks, | |
start = first_sample, | |
stop = last_sample) | |
# convert to microvolts to scale up precision | |
channels *= 1e6 | |
# set conversion parameters | |
n_channels = len(channels) | |
# create channel from this | |
try: | |
f = pyedflib.EdfWriter(fname, | |
n_channels=n_channels, | |
file_type=file_type) | |
channel_info = [] | |
ch_idx = range(n_channels) if picks is None else picks | |
# keys = list(mne_raw._orig_units.keys()) | |
for i in ch_idx: | |
try: | |
ch_dict = {'label': mne_raw.ch_names[i], | |
'dimension': 'uV', # mne_raw._orig_units[keys[i]], | |
'sample_rate': mne_raw._raw_extras[0]['n_samps'][i], | |
'physical_min': mne_raw._raw_extras[0]['physical_min'][i], | |
'physical_max': mne_raw._raw_extras[0]['physical_max'][i], | |
'digital_min': mne_raw._raw_extras[0]['digital_min'][i], | |
'digital_max': mne_raw._raw_extras[0]['digital_max'][i], | |
'transducer': '', | |
'prefilter': ''} | |
except: | |
ch_dict = {'label': mne_raw.ch_names[i], | |
'dimension': 'uV', # mne_raw._orig_units[keys[i]], | |
'sample_rate': sfreq, | |
'physical_min': channels.min(), | |
'physical_max': channels.max(), | |
'digital_min': dmin, | |
'digital_max': dmax, | |
'transducer': '', | |
'prefilter': ''} | |
channel_info.append(ch_dict) | |
# f.setPatientCode(mne_raw._raw_extras[0]['subject_info']['id']) | |
# f.setPatientName(mne_raw._raw_extras[0]['subject_info']['name']) | |
# f.setTechnician('mne-gist-save-edf-skjerns') | |
f.setSignalHeaders(channel_info) | |
f.setStartdatetime(date) | |
f.writeSamples(channels) | |
except Exception as e: | |
print(e) | |
return False | |
finally: | |
f.close() | |
return True | |
def create_mne_raw(n_channels, n_times, sfreq, savedir=None, save=True): | |
"""Create mne.io.RawArray with fake data and save as .fif, .edf and .hdf5. | |
""" | |
data = np.random.rand(n_channels, n_times) | |
ch_names = [f'ch{i}' for i in range(n_channels)] | |
ch_types = ['eeg'] * n_channels | |
info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types) | |
info['meas_date'] = datetime.now(tz=timezone.utc) | |
raw = mne.io.RawArray(data, info) | |
raw_fname = os.path.join(savedir, 'fake_eeg_raw.fif') | |
edf_fname = raw_fname.replace('_raw.fif', '.edf') | |
h5_fname = raw_fname.replace('_raw.fif', '.hdf5') | |
if save: | |
# Save .fif | |
raw.save(raw_fname, overwrite=True) | |
# Save .edf | |
write_mne_edf(raw, edf_fname, picks=None, tmin=0, tmax=None, | |
overwrite=True) | |
# Save .hdf5 | |
with h5py.File(h5_fname, 'w') as f: | |
f.create_dataset('fake_raw', dtype='f16', data=raw.get_data()) | |
return raw, raw_fname, edf_fname, h5_fname | |
def raw_to_epochs(raw, win_len_s, win_overlap_s, preload=False): | |
"""Extract epochs from mne.io.Raw. | |
""" | |
events = mne.make_fixed_length_events( | |
raw, id=1, start=0, stop=None, duration=win_len_s, first_samp=True, | |
overlap=win_overlap_s) | |
md_columns = ['subject', 'session', 'run', 'age', 'label'] | |
metadata = pd.DataFrame( | |
np.zeros((events.shape[0], len(md_columns))), columns=md_columns) | |
tmax = win_len_s - 1. / raw.info['sfreq'] | |
epochs = mne.Epochs( | |
raw, events, event_id=None, tmin=0, tmax=tmax, baseline=None, | |
preload=preload, metadata=metadata) | |
start_end_inds = np.vstack( | |
(events[:, 0], events[:, 0] + int(win_len_s * raw.info['sfreq']))).T | |
return epochs, start_end_inds | |
def test(): | |
savedir = '.' | |
# Fake data parameters | |
n_channels = 8 | |
sfreq = 256 | |
n_times = int(8.5 * 60 * 60 * sfreq) | |
win_len_s = 30 | |
win_overlap_s = 0 | |
raw, fif_fname, edf_fname, hdf5_fname = create_mne_raw( | |
n_channels, n_times, sfreq, savedir=savedir, save=False) | |
# Create epochs from fif | |
raw_fif = mne.io.read_raw_fif(fif_fname, preload=False, verbose=None) | |
epochs_fif, start_end_inds = raw_to_epochs( | |
raw_fif, win_len_s, win_overlap_s, preload=False) | |
epochs_fif = epochs_fif.drop_bad(reject=None, flat=None) | |
n_events = epochs_fif.events.shape[0] | |
# Create epochs from edf | |
raw_edf = mne.io.read_raw_edf(edf_fname, preload=False, verbose=None) | |
epochs_edf, start_end_inds = raw_to_epochs( | |
raw_edf, win_len_s, win_overlap_s, preload=False) | |
epochs_edf = epochs_edf.drop_bad(reject=None, flat=None) | |
times = list() | |
### FIF ### | |
for i in range(n_events): | |
start = time.time() | |
epochs_fif.get_data(item=i)[0] | |
times.append( | |
{'dur': time.time() - start, 'format': 'fif', 'idx': i}) | |
# """ | |
### EDF ### | |
for i in range(n_events): | |
start = time.time() | |
epochs_edf.get_data(item=i)[0] | |
times.append( | |
{'dur': time.time() - start, 'format': 'edf', 'idx': i}) | |
### HDF5 ### | |
hf = h5py.File(hdf5_fname, 'r') | |
for i in range(n_events): | |
start = time.time() | |
hf['fake_raw'][:, start_end_inds[i, 0]:start_end_inds[i, 1]] | |
times.append( | |
{'dur': time.time() - start, 'format': 'hdf5', 'idx': i}) | |
hf.close() | |
# Make sure all methods return the same thing | |
x1 = epochs_fif.get_data() | |
x2 = epochs_edf.get_data() | |
assert np.allclose(x1, x2, atol=2e-5) | |
hf = h5py.File(hdf5_fname, 'r') | |
for i in range(n_events): | |
assert np.allclose( | |
epochs_fif.get_data(item=i), | |
hf['fake_raw'][:, start_end_inds[i, 0]:start_end_inds[i, 1]], | |
atol=1e-10) | |
hf.close() | |
# """ | |
durations_df = pd.DataFrame(times) | |
sns.lineplot(data=durations_df, x='idx', y='dur', hue='format') | |
plt.ylim(0, 0.0025) | |
if __name__ == '__main__': | |
test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment