Skip to content

Instantly share code, notes, and snippets.

@larsoner
Created July 20, 2020 14:05
Show Gist options
  • Save larsoner/da24fbc58c39ffd9c09e278aa0a26c32 to your computer and use it in GitHub Desktop.
Save larsoner/da24fbc58c39ffd9c09e278aa0a26c32 to your computer and use it in GitHub Desktop.
import numpy as np
import mne
raws = list()
# kiloword
epo = mne.read_epochs(
mne.datasets.kiloword.data_path() + '/kword_metadata-epo.fif')
epo.pick_types(meg=False, eeg=True)
# XXX this / 1000. is a bug with kiloword, should be in meters!
# Also the info['dig'] should not be empty...
montage = mne.channels.make_dig_montage(
ch_pos={ch['ch_name']: ch['loc'][:3] / 1000. for ch in epo.info['chs']})
epo.set_montage(montage)
raws.append(epo)
# LIMO
epo = mne.datasets.limo.load_data(subject=1)
raws.append(epo)
# EEGBCI
subject = 1
runs = [6, 10, 14]
raw_fnames = mne.datasets.eegbci.load_data(subject, runs)
raw = mne.concatenate_raws([
mne.io.read_raw_edf(f, preload=True) for f in raw_fnames])
mne.datasets.eegbci.standardize(raw) # set channel names
montage = mne.channels.make_standard_montage('standard_1005')
raw.set_montage(montage)
raws.append(raw)
del raw
# Sample
raw = mne.io.read_raw_fif(
mne.datasets.sample.data_path() + '/MEG/sample/sample_audvis_raw.fif')
raws.append(raw)
# Run it
misses = [('spline', (None, 1e-5))] + \
[('MNE', ('tsvd', x)) for x in [0.5, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5]] + \
[('MNE', ('tikhonov', x)) for x in [1e-0, 0.5, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5]] # noqa: E501
corr_meds = np.empty((len(raws), len(misses)))
n = 6
print('| ' + ' | '.join(f'{miss[1][1]:0.0e}'.center(n) for miss in misses) + ' |') # noqa: E501
print('|' + '|'.join(['-' * (n + 2)] * len(misses)) + '|')
for ri, raw in enumerate(raws):
raw.load_data().pick_types(meg=False, eeg=True)
raw.set_eeg_reference(projection=True).apply_proj()
corrs = np.empty((len(misses), len(raw.ch_names)))
for mi, miss in enumerate(misses):
for ii in range(len(raw.ch_names)):
this_raw = raw.copy()
this_raw.info['bads'] = [this_raw.ch_names[ii]]
old = this_raw.get_data(picks=[ii])[..., 0, :].ravel()
this_raw.interpolate_bads(
verbose=False, method=dict(eeg=miss[0]), miss=miss[1])
new = this_raw.get_data(picks=[ii])[..., 0, :].ravel()
corrs[mi, ii] = np.abs(np.dot(old, new)) / (
np.linalg.norm(old) * np.linalg.norm(new))
corr_meds[ri] = np.median(corrs, axis=1)
print('| ' + ' | '.join(f'{c:0.4f}' for c in corr_meds[ri]) + ' |')
print('|' + '|'.join(['-' * (n + 2)] * len(misses)) + '|')
print('| ' + ' | '.join(f'{c:0.4f}' for c in np.mean(corr_meds, axis=0)) + ' |') # noqa: E501
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment