Created
May 14, 2020 18:09
-
-
Save masterdezign/f7941bfb99249da205faf733eeb7bbec to your computer and use it in GitHub Desktop.
BCI 101: MNE workshop by David Haslacher
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#%% | |
import os | |
import numpy as np | |
import mne | |
from mne.preprocessing import compute_current_source_density | |
# %matplotlib qt | |
# First, we need to load the data | |
sample_data_folder = mne.datasets.sample.data_path() | |
sample_data_raw_file = os.path.join(sample_data_folder, 'MEG', 'sample', | |
'sample_audvis_filt-0-40_raw.fif') | |
raw = mne.io.read_raw_fif(sample_data_raw_file,preload=True) | |
#%% | |
#%% | |
# Let's pick the EEG channels | |
raw.pick_types(eeg=True,meg=False,stim=True,eog=True) | |
#%% | |
#%% | |
# Plot the power spectrum | |
raw.plot_psd(fmin=1,fmax=20,n_fft=2**10,spatial_colors=True) | |
#%% | |
#%% | |
# Plot the first five seconds of data | |
raw.plot(duration=5, n_channels=30,show=True) | |
#%% | |
#%% | |
# Train independent component analysis (ICA) and plot components | |
ica = mne.preprocessing.ICA(n_components=10, random_state=97, max_iter=800) | |
ica.fit(raw) | |
ica.plot_components() | |
#%% | |
#%% | |
# Make a copy of the original raw data and apply the ICA to one of the copies (reject artifact components) | |
orig_raw = raw.copy() | |
ica.apply(raw) | |
#%% | |
#%% | |
# Plot the original raw data and the new data with ICA applied | |
orig_raw.plot(start=0, duration=5,block=False) | |
raw.plot(start=0, duration=5,block=True) | |
#%% | |
#%% | |
# Find the events (markers/triggers) in the data | |
events = mne.find_events(raw, stim_channel='STI 014') | |
print(events[:5]) | |
#%% | |
#%% | |
# Give the events names as a dictionary | |
event_dict = {'auditory/left': 1, 'auditory/right': 2, 'visual/left': 3, | |
'visual/right': 4, 'smiley': 5, 'buttonpress': 32} | |
#%% | |
#%% | |
# Use convenience function to plot the events over the time of the entire recording | |
fig = mne.viz.plot_events(events, event_id=event_dict, sfreq=raw.info['sfreq'], | |
first_samp=raw.first_samp) | |
#%% | |
#%% | |
# Define signal amplitudes that are too large to originate from the brain | |
# We want to exclude epochs of data where the signal exceeds these amplitudes | |
reject_criteria = dict(eeg=150e-6, # 150 µV | |
eog=250e-6) # 250 µV | |
#%% | |
#%% | |
# Construct epochs of data, excluding epochs with large signal amplitudes | |
epochs = mne.Epochs(raw, events, event_id=event_dict, tmin=-0.2, tmax=0.5, | |
reject=reject_criteria, preload=True) | |
#%% | |
#%% | |
# Takes an equal number of trials for each condition for a fair comparison | |
epochs.equalize_event_counts(['auditory/left', 'auditory/right','visual/left', 'visual/right']) | |
aud_epochs = epochs['auditory'] | |
vis_epochs = epochs['visual'] | |
#%% | |
#%% | |
# Plot the epochs for auditory and visual stimulation separately | |
aud_epochs.plot_image(picks=['EEG 021']) | |
vis_epochs.plot_image(picks=['EEG 059']) | |
#%% | |
#%% | |
# Create a time-frequency analysis of the auditory epochs via Morlet wavelets | |
frequencies = np.arange(7, 30, 3) | |
power = mne.time_frequency.tfr_morlet(aud_epochs, n_cycles=2, return_itc=False, | |
freqs=frequencies, decim=3) | |
power.plot(picks=['EEG 021']) | |
#%% | |
#%% | |
# Construct evoked responses from the epochs by averaging across trials | |
aud_evoked = aud_epochs.average() | |
vis_evoked = vis_epochs.average() | |
#%% | |
#%% | |
# Plot the evoked responses across all sensors for auditory stimulation | |
aud_evoked.plot_joint(picks='eeg') | |
#%% | |
#%% | |
# Plot the evoked response for auditory stimulation as a topographical map at selected timepoints | |
aud_evoked.plot_topomap(times=[0., 0.08, 0.1, 0.12, 0.2], ch_type='eeg') | |
#%% | |
#%% | |
# Plot the evoked response across all sensors for visual stimulation | |
vis_evoked.plot_joint(picks='eeg') | |
#%% | |
#%% | |
# Plot the evoked response for visual stimulation as a topographical map at selected timepoints | |
vis_evoked.plot_topomap(times=[0., 0.08, 0.1, 0.12, 0.2], ch_type='eeg') | |
#%% | |
#%% | |
# Plot the evoked response at a topographical map at all timepoints | |
vis_evoked.plot_topo(color='r', legend=False) | |
#%% | |
#%% | |
# Try different reference channels to see how they influence visual evoked respones | |
vis_evoked.plot_joint(picks='eeg') | |
vis_evoked.set_eeg_reference(ref_channels='average').plot_joint(picks='eeg') | |
#%% |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#%% | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from sklearn.pipeline import Pipeline | |
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis | |
from sklearn.model_selection import ShuffleSplit, cross_val_score | |
from mne import Epochs, pick_types, events_from_annotations | |
from mne.channels import make_standard_montage | |
from mne.io import concatenate_raws, read_raw_edf | |
from mne.datasets import eegbci | |
from mne.decoding import CSP | |
# %matplotlib qt | |
# Load the data | |
subject = 1 | |
runs = [6, 10, 14] | |
raw_fnames = eegbci.load_data(subject, runs) | |
raw = concatenate_raws([read_raw_edf(f, preload=True) for f in raw_fnames]) | |
#%% | |
#%% | |
# Apply default sensor locations (montage) to data | |
eegbci.standardize(raw) | |
montage = make_standard_montage('standard_1005') | |
raw.set_montage(montage) | |
#%% | |
#%% | |
# Filter the data to include motor-related mu (alpha) and beta rhythms | |
raw.filter(7., 30.) | |
events, _ = events_from_annotations(raw, event_id=dict(T1=0, T2=1)) | |
picks = pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, | |
exclude='bads') | |
#%% | |
#%% | |
# Make epochs around left hand and right hand events | |
tmin, tmax = -1., 4. | |
epochs = Epochs(raw, events, None, tmin, tmax, proj=True, picks=picks, | |
baseline=None, preload=True) | |
epochs_train = epochs.copy().crop(tmin=1., tmax=2.) | |
labels = epochs.events[:, -1] | |
#%% | |
#%% | |
# Make K folds for cross-validation of classifier | |
scores = [] | |
epochs_data = epochs.get_data() | |
epochs_data_train = epochs_train.get_data() | |
cv = ShuffleSplit(10, test_size=0.2, random_state=42) | |
cv_split = cv.split(epochs_data_train) | |
#%% | |
#%% | |
# Assemble a classifier based on the Common Spatial Patterns for feature extraction and Linear Discriminant Analysis for classification of extracted features | |
lda = LinearDiscriminantAnalysis() | |
csp = CSP(n_components=4, reg=None, log=True, norm_trace=False) | |
#%% | |
#%% | |
# Train CSP classifier in order to visualize patterns (inverse of spatial filters) | |
csp.fit_transform(epochs_data, labels) | |
csp.plot_patterns(epochs.info, ch_type='eeg', units='Patterns (AU)', size=1.5) | |
#%% | |
#%% | |
# Prepare to classify the data in a sliding window (starting from imagery onset) | |
sfreq = raw.info['sfreq'] | |
w_length = int(sfreq * 0.5) # running classifier: window length | |
w_step = int(sfreq * 0.1) # running classifier: window step size | |
w_start = np.arange(0, epochs_data.shape[2] - w_length, w_step) | |
scores_windows = [] | |
#%% | |
#%% | |
# Do cross-validated classification for each sliding window | |
for train_idx, test_idx in cv_split: | |
y_train, y_test = labels[train_idx], labels[test_idx] | |
X_train = csp.fit_transform(epochs_data_train[train_idx], y_train) | |
X_test = csp.transform(epochs_data_train[test_idx]) | |
# fit classifier | |
lda.fit(X_train, y_train) | |
# running classifier: test classifier on sliding window | |
score_this_window = [] | |
for n in w_start: | |
X_test = csp.transform(epochs_data[test_idx][:, :, n:(n + w_length)]) | |
score_this_window.append(lda.score(X_test, y_test)) | |
scores_windows.append(score_this_window) | |
#%% | |
#%% | |
# Plot scores (classification accuracy) over time | |
w_times = (w_start + w_length / 2.) / sfreq + epochs.tmin | |
plt.figure() | |
plt.plot(w_times, np.mean(scores_windows, 0), label='Score') | |
plt.axvline(0, linestyle='--', color='k', label='Onset') | |
plt.axhline(0.5, linestyle='-', color='k', label='Chance') | |
plt.xlabel('time (s)') | |
plt.ylabel('classification accuracy') | |
plt.title('Classification score over time') | |
plt.legend(loc='lower right') | |
plt.show() | |
#%% |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment