Created
December 29, 2020 11:45
-
-
Save dengemann/558cbed8e43e81b9fa9f51d136297e04 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# License: BSD (3-clause) | |
# Author: Denis A. Engemann <[email protected]> | |
# Based on : | |
import platform | |
import psutil | |
import datetime | |
from time import time | |
import os | |
import numpy as np | |
import pandas as pd | |
import mne | |
from mne.preprocessing import ICA | |
from autoreject import AutoReject | |
from mne.stats import spatio_temporal_cluster_test | |
from mne.channels import find_ch_adjacency | |
data_path = mne.datasets.sample.data_path() | |
sample_data_raw_file = os.path.join(data_path, 'MEG', 'sample', | |
'sample_audvis_raw.fif') | |
class MnePipelineBench(object): | |
def __init__(self, n_jobs=4, n_runs=5): | |
self.n_jobs = n_jobs | |
self.n_runs = n_runs | |
def load_data(self): | |
raw = mne.io.read_raw_fif(sample_data_raw_file, preload=True) | |
self.picks = mne.pick_types(raw.info, meg='mag', eeg=False) | |
self.raw = raw | |
return self | |
def resample(self): | |
self.raw.resample(150) | |
def filter(self): | |
self.raw.filter(l_freq=1, h_freq=40, n_jobs=self.n_jobs, picks=self.picks) | |
return self | |
def ica(self): | |
ica = ICA(n_components=102, random_state=97) | |
ica.fit(self.raw, picks=self.picks) | |
ica.exclude = [] | |
eog_indices, eog_scores = ica.find_bads_eog(self.raw) | |
ica.exclude += eog_indices | |
ecg_indices, ecg_scores = ica.find_bads_ecg(self.raw, method='correlation', | |
threshold='auto') | |
ica.exclude += ecg_indices | |
ica.apply(self.raw) | |
return self | |
def make_epochs(self): | |
event_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw-eve.fif' | |
self.event_id = {'Aud/L': 1, 'Aud/R': 2, 'Vis/L': 3, 'Vis/R': 4} | |
tmin = -0.2 | |
tmax = 0.5 | |
events = mne.read_events(event_fname) | |
epochs = mne.Epochs(self.raw, events, self.event_id, tmin, tmax, | |
picks=self.picks, | |
baseline=None, reject=None, preload=True) | |
epochs.equalize_event_counts(self.event_id) | |
self.epochs = epochs | |
return self | |
def autoreject(self): | |
ar = AutoReject(n_jobs=self.n_jobs) | |
self.epochs_clean = ar.fit_transform(self.epochs) | |
self.ar = ar | |
return self | |
def permutation_clustering(self): | |
adjacency, ch_names = find_ch_adjacency(self.epochs_clean.info, | |
ch_type='mag') | |
X = [self.epochs_clean[k].get_data() for k in self.event_id] | |
X = [np.transpose(x, (0, 2, 1)) for x in X] | |
threshold = 3.0 | |
p_accept = 0.01 | |
cluster_stats = spatio_temporal_cluster_test( | |
X, n_permutations=1000, | |
threshold=threshold, tail=1, | |
n_jobs=self.n_jobs, buffer_size=None, | |
adjacency=adjacency) | |
def _bench(self, function): | |
tt = time() | |
function() | |
delta = time() - tt | |
return delta | |
def run(self): | |
steps = ( | |
'load_data', | |
'resample', | |
'filter', | |
'ica', | |
'make_epochs', | |
'autoreject', | |
'permutation_clustering' | |
) | |
results = list() | |
for run in range(self.n_runs): | |
for step in steps: | |
print(f"run: {run}, {step}") | |
this_bench = self._bench(getattr(self, step)) | |
results.append( | |
{"run": run, "step": step, "time": this_bench} | |
) | |
results_df = pd.DataFrame(results) | |
results_df['ram'] = str( | |
round(psutil.virtual_memory().total / (1024.0 **3))) + " GB" | |
results_df['arch'] = platform.machine() | |
results_df['version'] = platform.platform() | |
results_df['py'] = platform.python_version() | |
date = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f") | |
results_df.to_csv('mne/bench-%s.csv' % date[:-7]) | |
self.results = results_df | |
return self | |
MnePipelineBench().run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment