Skip to content

Instantly share code, notes, and snippets.

@dengemann
Created December 29, 2020 11:45
Show Gist options
  • Save dengemann/558cbed8e43e81b9fa9f51d136297e04 to your computer and use it in GitHub Desktop.
Save dengemann/558cbed8e43e81b9fa9f51d136297e04 to your computer and use it in GitHub Desktop.
# License: BSD (3-clause)
# Author: Denis A. Engemann <[email protected]>
# Based on :
import platform
import psutil
import datetime
from time import time
import os
import numpy as np
import pandas as pd
import mne
from mne.preprocessing import ICA
from autoreject import AutoReject
from mne.stats import spatio_temporal_cluster_test
from mne.channels import find_ch_adjacency
data_path = mne.datasets.sample.data_path()
sample_data_raw_file = os.path.join(data_path, 'MEG', 'sample',
'sample_audvis_raw.fif')
class MnePipelineBench(object):
def __init__(self, n_jobs=4, n_runs=5):
self.n_jobs = n_jobs
self.n_runs = n_runs
def load_data(self):
raw = mne.io.read_raw_fif(sample_data_raw_file, preload=True)
self.picks = mne.pick_types(raw.info, meg='mag', eeg=False)
self.raw = raw
return self
def resample(self):
self.raw.resample(150)
def filter(self):
self.raw.filter(l_freq=1, h_freq=40, n_jobs=self.n_jobs, picks=self.picks)
return self
def ica(self):
ica = ICA(n_components=102, random_state=97)
ica.fit(self.raw, picks=self.picks)
ica.exclude = []
eog_indices, eog_scores = ica.find_bads_eog(self.raw)
ica.exclude += eog_indices
ecg_indices, ecg_scores = ica.find_bads_ecg(self.raw, method='correlation',
threshold='auto')
ica.exclude += ecg_indices
ica.apply(self.raw)
return self
def make_epochs(self):
event_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw-eve.fif'
self.event_id = {'Aud/L': 1, 'Aud/R': 2, 'Vis/L': 3, 'Vis/R': 4}
tmin = -0.2
tmax = 0.5
events = mne.read_events(event_fname)
epochs = mne.Epochs(self.raw, events, self.event_id, tmin, tmax,
picks=self.picks,
baseline=None, reject=None, preload=True)
epochs.equalize_event_counts(self.event_id)
self.epochs = epochs
return self
def autoreject(self):
ar = AutoReject(n_jobs=self.n_jobs)
self.epochs_clean = ar.fit_transform(self.epochs)
self.ar = ar
return self
def permutation_clustering(self):
adjacency, ch_names = find_ch_adjacency(self.epochs_clean.info,
ch_type='mag')
X = [self.epochs_clean[k].get_data() for k in self.event_id]
X = [np.transpose(x, (0, 2, 1)) for x in X]
threshold = 3.0
p_accept = 0.01
cluster_stats = spatio_temporal_cluster_test(
X, n_permutations=1000,
threshold=threshold, tail=1,
n_jobs=self.n_jobs, buffer_size=None,
adjacency=adjacency)
def _bench(self, function):
tt = time()
function()
delta = time() - tt
return delta
def run(self):
steps = (
'load_data',
'resample',
'filter',
'ica',
'make_epochs',
'autoreject',
'permutation_clustering'
)
results = list()
for run in range(self.n_runs):
for step in steps:
print(f"run: {run}, {step}")
this_bench = self._bench(getattr(self, step))
results.append(
{"run": run, "step": step, "time": this_bench}
)
results_df = pd.DataFrame(results)
results_df['ram'] = str(
round(psutil.virtual_memory().total / (1024.0 **3))) + " GB"
results_df['arch'] = platform.machine()
results_df['version'] = platform.platform()
results_df['py'] = platform.python_version()
date = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f")
results_df.to_csv('mne/bench-%s.csv' % date[:-7])
self.results = results_df
return self
MnePipelineBench().run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment