Created
July 20, 2020 14:05
-
-
Save larsoner/da24fbc58c39ffd9c09e278aa0a26c32 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import mne | |
raws = list() | |
# kiloword | |
epo = mne.read_epochs( | |
mne.datasets.kiloword.data_path() + '/kword_metadata-epo.fif') | |
epo.pick_types(meg=False, eeg=True) | |
# XXX this / 1000. is a bug with kiloword, should be in meters! | |
# Also the info['dig'] should not be empty... | |
montage = mne.channels.make_dig_montage( | |
ch_pos={ch['ch_name']: ch['loc'][:3] / 1000. for ch in epo.info['chs']}) | |
epo.set_montage(montage) | |
raws.append(epo) | |
# LIMO | |
epo = mne.datasets.limo.load_data(subject=1) | |
raws.append(epo) | |
# EEGBCI | |
subject = 1 | |
runs = [6, 10, 14] | |
raw_fnames = mne.datasets.eegbci.load_data(subject, runs) | |
raw = mne.concatenate_raws([ | |
mne.io.read_raw_edf(f, preload=True) for f in raw_fnames]) | |
mne.datasets.eegbci.standardize(raw) # set channel names | |
montage = mne.channels.make_standard_montage('standard_1005') | |
raw.set_montage(montage) | |
raws.append(raw) | |
del raw | |
# Sample | |
raw = mne.io.read_raw_fif( | |
mne.datasets.sample.data_path() + '/MEG/sample/sample_audvis_raw.fif') | |
raws.append(raw) | |
# Run it | |
misses = [('spline', (None, 1e-5))] + \ | |
[('MNE', ('tsvd', x)) for x in [0.5, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5]] + \ | |
[('MNE', ('tikhonov', x)) for x in [1e-0, 0.5, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5]] # noqa: E501 | |
corr_meds = np.empty((len(raws), len(misses))) | |
n = 6 | |
print('| ' + ' | '.join(f'{miss[1][1]:0.0e}'.center(n) for miss in misses) + ' |') # noqa: E501 | |
print('|' + '|'.join(['-' * (n + 2)] * len(misses)) + '|') | |
for ri, raw in enumerate(raws): | |
raw.load_data().pick_types(meg=False, eeg=True) | |
raw.set_eeg_reference(projection=True).apply_proj() | |
corrs = np.empty((len(misses), len(raw.ch_names))) | |
for mi, miss in enumerate(misses): | |
for ii in range(len(raw.ch_names)): | |
this_raw = raw.copy() | |
this_raw.info['bads'] = [this_raw.ch_names[ii]] | |
old = this_raw.get_data(picks=[ii])[..., 0, :].ravel() | |
this_raw.interpolate_bads( | |
verbose=False, method=dict(eeg=miss[0]), miss=miss[1]) | |
new = this_raw.get_data(picks=[ii])[..., 0, :].ravel() | |
corrs[mi, ii] = np.abs(np.dot(old, new)) / ( | |
np.linalg.norm(old) * np.linalg.norm(new)) | |
corr_meds[ri] = np.median(corrs, axis=1) | |
print('| ' + ' | '.join(f'{c:0.4f}' for c in corr_meds[ri]) + ' |') | |
print('|' + '|'.join(['-' * (n + 2)] * len(misses)) + '|') | |
print('| ' + ' | '.join(f'{c:0.4f}' for c in np.mean(corr_meds, axis=0)) + ' |') # noqa: E501 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment