Created
March 9, 2016 14:21
-
-
Save wmvanvliet/dd17c9d8a9b75435217d to your computer and use it in GitHub Desktop.
Evaluation of bad channel detection method
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import mne | |
import numpy as np | |
from scipy.stats import norm | |
from matplotlib import pyplot as plt | |
# Load required sample data | |
data_path = mne.datasets.sample.data_path() | |
subjects_dir = data_path + '/subjects' | |
evoked = mne.read_evokeds(data_path + '/MEG/sample/sample_audvis-ave.fif')[0] | |
fwd = mne.read_forward_solution(data_path + '/MEG/sample/sample_audvis-meg-eeg-oct-6-fwd.fif', surf_ori=False, force_fixed=True) | |
eog_proj = mne.read_proj('/net/psyko/home/marijn/data/mne-python/examples/MNE-sample-data/MEG/sample/sample_audvis_eog_proj.fif') | |
# Set the noise covariance to unit | |
noise_cov = mne.read_cov(data_path + '/MEG/sample/sample_audvis-shrunk-cov.fif') | |
noise_cov['data'] = np.eye(noise_cov['data'].shape[0]) | |
# Limit channels to EEG only | |
evoked_eeg = evoked.pick_types(meg=False, eeg=True, copy=True) | |
fwd_eeg = mne.pick_types_forward(fwd, meg=False, eeg=False, include=evoked_eeg.ch_names) | |
# Make inverse operator for EEG only | |
inv_eeg = mne.minimum_norm.make_inverse_operator( | |
info = evoked.info, | |
forward = fwd_eeg, | |
noise_cov = noise_cov, | |
loose = None, #0.2, | |
depth = None, #0.0, | |
fixed = True, | |
) | |
def rereference(evoked): | |
"""Helper function to re-apply average reference and re-baseline""" | |
evoked.data -= evoked.data.mean(axis=0)[np.newaxis, :] | |
evoked.data -= evoked_eeg.data[:, :120].mean(axis=1)[:, np.newaxis] | |
return evoked | |
# Add different types of noise | |
noise_amp = 5e-6 # 5 uV | |
noisy_evokeds = dict() | |
# Clean signal | |
rereference(evoked_eeg) | |
noisy_evokeds['clean'] = evoked_eeg | |
# 50 Hz sine | |
evoked_sine = evoked_eeg.copy() | |
evoked_sine.data[12] += np.sin(2 * 50 * np.pi * evoked.times) * noise_amp | |
rereference(evoked_sine) | |
noisy_evokeds['line_noise'] = evoked_sine | |
# White noise | |
evoked_white_noise = evoked_eeg.copy() | |
evoked_white_noise.data[12] += np.random.randn(len(evoked.times)) * noise_amp | |
rereference(evoked_white_noise) | |
noisy_evokeds['white_noise'] = evoked_white_noise | |
# Filtered noise | |
evoked_filtered_noise = evoked_eeg.copy() | |
noise = np.random.randn(len(evoked.times)) | |
noise = mne.filter.low_pass_filter(noise, evoked.info['sfreq'], 40, method='iir') | |
noise /= np.std(noise) # make unit standard deviation, so later on scaling by noise_amp is meaninful | |
evoked_filtered_noise.data[12] += noise * noise_amp | |
rereference(evoked_filtered_noise) | |
noisy_evokeds['filtered_noise'] = evoked_filtered_noise | |
# Simulated blink | |
evoked_blink = evoked_eeg.copy() | |
blink_shape = norm(0.3, 0.05).pdf # blink at 300 ms | |
blink = blink_shape(evoked_eeg.times) | |
blink = blink[:, np.newaxis].dot(eog_proj[-2]['data']['data']).T | |
blink /= blink.max() # normalize amplitude | |
evoked_blink.data += blink * noise_amp | |
rereference(evoked_blink) | |
noisy_evokeds['blink'] = evoked_blink | |
# Start evaluation of the signal quality | |
errors = dict() | |
reconstructed = dict() | |
for name, evoked in noisy_evokeds.items(): | |
# Project to source space | |
snr = 3. | |
lambda2 = 1 / (snr ** 2) | |
stc_eeg = mne.minimum_norm.apply_inverse(evoked, inv_eeg, lambda2, 'MNE') | |
# Project to sensor space | |
evoked_reconstructed = mne.apply_forward(fwd_eeg, stc_eeg, evoked.info) | |
rereference(evoked_reconstructed) | |
# Compute reconstruction error | |
error = np.linalg.norm(evoked.data - evoked_reconstructed.data, axis=1) * 1e6 | |
print 'Max reconstruction error for', name, ':', error.max(), 'uV' | |
errors[name] = error.max() | |
reconstructed[name] = evoked_reconstructed | |
# Show the results | |
plt.figure() | |
plt.bar(range(len(errors)), errors.values(), align='center') | |
plt.xticks(range(len(errors)), errors.keys(), rotation=45) | |
plt.ylabel('Max reconstruction error (uV)') | |
plt.title('Evaluation of signal quality (lower is better)') | |
plt.tight_layout() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment