Skip to content

Instantly share code, notes, and snippets.

@erap129
Last active September 13, 2024 09:57
Show Gist options
  • Save erap129/7766d911dad632d5bf0954027cada13e to your computer and use it in GitHub Desktop.
Save erap129/7766d911dad632d5bf0954027cada13e to your computer and use it in GitHub Desktop.
EEG similarity exercise
from copy import deepcopy
import seaborn as sns
from functools import reduce
from tqdm import tqdm
import pandas as pd
from mne.datasets import eegbci
from argparse import ArgumentParser
import logging
from mne.io import concatenate_raws, read_raw_edf
from mne.channels import make_standard_montage
from mne import Epochs, pick_types, concatenate_epochs
import numpy as np
import matplotlib.pyplot as plt
import umap
logging.basicConfig(level=logging.INFO)
def get_epochs(n_subjects,
sfreq,
runs,
tmin,
tmax):
montage = make_standard_montage("standard_1005")
epochs_list = []
event_mapping = dict(hands=2, feet=3)
event_mapping_inverse = {v: k for k, v in event_mapping.items()}
for subject in tqdm(range(1, n_subjects), desc='collecting subject data'):
raw_fnames = eegbci.load_data(subject, runs)
raw = concatenate_raws([read_raw_edf(f, preload=True) for f in raw_fnames])
if raw.info['sfreq'] != sfreq:
logging.info(f'Skipping subject {subject} because of incorrect sampling frequency')
continue
eegbci.standardize(raw) # set channel names
raw.set_montage(montage)
raw.annotations.rename(dict(T1="hands", T2="feet"))
raw.set_eeg_reference(projection=True)
raw.filter(7.0, 30.0, fir_design="firwin", skip_by_annotation="edge")
picks = pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, exclude="bads")
# Read epochs
epochs = Epochs(
raw,
event_id=event_mapping,
tmin=tmin,
tmax=tmax,
proj=True,
picks=picks,
baseline=None,
preload=True,
)
if subject == 1:
(epochs
.pick_channels(['C3', 'C4'])
.plot(events=True,
show_scrollbars=False,
n_epochs=5,
title="Raw EEG data for subject 1")
)
plt.savefig('eeg_similarity_outputs/epochs.png')
# Create metadata with subject information
num_epochs = len(epochs)
logging.info(f'Loaded {num_epochs} epochs for subject {subject}')
epochs.metadata = pd.DataFrame({'subject': [subject] * num_epochs,
'label': [event_mapping_inverse[e] for e in epochs.events[:, -1]]})
epochs_list.append(epochs)
# Concatenate all epochs
epochs_combined = concatenate_epochs(epochs_list)
return epochs_combined
def get_hands_feet_ratio(df):
return pd.Series({'hands_feed_ratio':
len(df[df['label'] == 'hands']) / len(df[df['label'] == 'feet'])})
def analyze_hands_feet_ratio(epochs):
hands_feet_ratio = (epochs.metadata
.groupby('subject')
.apply(lambda group_df: get_hands_feet_ratio(group_df))
)
print(f'Hands/Feet ratio mean: {hands_feet_ratio["hands_feed_ratio"].mean()}')
print(f'Hands/Feet ratio min: {hands_feet_ratio["hands_feed_ratio"].min()}')
print(f'Hands/Feet ratio max: {hands_feet_ratio["hands_feed_ratio"].max()}')
return hands_feet_ratio
def get_naive_correlation_map(epochs, channel, label):
average_subject_dfs = []
for subject in tqdm(epochs.metadata['subject'].unique(), desc=f'collecting subject data for channel: {channel}, label: {label}'):
average_subject_dfs.append(epochs
[f'subject == {subject} and label == "{label}"']
.pick_channels([channel])
.average()
.to_data_frame()
.rename(columns={channel: f'{subject}_{channel}'})
)
df_merged = (reduce(lambda left,right: pd.merge(left,right,on=['time'],
how='outer'), average_subject_dfs)
.drop(columns=['time'])
)
correlation_matrix = df_merged.corr()
logging.info(f'Top 5 correlations for channel: {channel}, label: {label}')
correlation_matrix_copy = deepcopy(correlation_matrix)
stacked_correlation_matrix = correlation_matrix_copy.stack()
stacked_correlation_values = np.abs(np.triu(correlation_matrix_copy.values, k=1).flatten())
for max_idx in np.argsort(stacked_correlation_values)[::-1][:5]:
subject_pair = stacked_correlation_matrix.index[max_idx]
logging.info(f'Subject pair: {subject_pair}, correlation: {stacked_correlation_matrix.values[max_idx]}')
plt.figure(figsize=(10, 10))
plt.title(f'Correlation Matrix - based on average raw values from channel {channel} for label {label}')
sns.heatmap(correlation_matrix, cmap='coolwarm', annot=False, fmt=".2f", linewidths=0.5)
plt.savefig(f'eeg_similarity_outputs/correlation_heatmap_{channel}_{label}.png')
plt.close()
return correlation_matrix
def get_correlation_map_of_correlation_maps(epochs, label):
subject_correlation_maps = []
for subject in tqdm(epochs.metadata['subject'].unique(), desc=f'collecting subject data for label: {label}'):
subject_correlation_maps.append(epochs
[f'subject == {subject} and label == "{label}"']
.average()
.to_data_frame()
.corr()
.values
.flatten()
)
correlation_matrix = np.corrcoef(subject_correlation_maps)
sorted_indices_x, sorted_indices_y = np.unravel_index(
np.argsort(correlation_matrix, axis=None)[::-1], correlation_matrix.shape)
logging.info(f'Top 5 correlations for correlation of correlations, label: {label}')
for i in range(5):
logging.info(f'Subject pairs: {sorted_indices_x[i] + 1}, {sorted_indices_y[i] + 1}, '
f'correlation: {correlation_matrix[sorted_indices_x[i], sorted_indices_y[i]]}')
plt.figure(figsize=(10, 10))
plt.title(f'Correlation Matrix - based on correlation matrices for label {label}')
sns.heatmap(correlation_matrix, cmap='coolwarm', annot=False, fmt=".2f", linewidths=0.5)
ax = plt.gca()
ax.set_xticklabels([int(label.get_text()) + 1 for label in ax.get_xticklabels()])
ax.set_yticklabels([int(label.get_text()) + 1 for label in ax.get_yticklabels()])
plt.savefig(f'eeg_similarity_outputs/correlation_heatmap_{label}.png')
plt.close()
return correlation_matrix, subject_correlation_maps
def visualize_umap_embeddings(embeddings, labels, title, filename):
umap_embeddings = umap.UMAP(n_neighbors=5,
min_dist=0.3,
metric='correlation').fit_transform(embeddings)
plt.figure(figsize=(10, 10))
plt.scatter(umap_embeddings[:, 0], umap_embeddings[:, 1])
for i, label in enumerate(labels):
plt.text(umap_embeddings[i, 0] + 0.01, umap_embeddings[i, 1] + 0.01, label, fontsize=12)
plt.gca().set_aspect('equal', 'datalim')
plt.title(title)
plt.savefig(f'eeg_similarity_outputs/{filename}.png')
plt.close()
def get_psd_correlation_map(epochs, label, correlations_of_correlations=False):
subject_spectrums = []
for subject in tqdm(epochs.metadata['subject'].unique(), desc='collecting subject data for psd'):
if subject == 1:
spectrum = epochs[f'subject == {subject} and label == "{label}"'].compute_psd(fmin=7, fmax=30)
spectrum.plot(picks="data", exclude="bads", amplitude=False)
plt.savefig(f'eeg_similarity_outputs/psd_{label}_subject_{subject}.png')
psds = (epochs
[f'subject == {subject} and label == "{label}"']
.compute_psd(fmin=7, fmax=30)
.get_data(return_freqs=True)[0]
)
if correlations_of_correlations:
psds = np.corrcoef(psds.mean(axis=0)).flatten()
else:
psds = psds.mean(axis=0).flatten()
subject_spectrums.append(psds)
correlation_matrix = np.corrcoef(subject_spectrums)
sorted_indices_x, sorted_indices_y = np.unravel_index(
np.argsort(correlation_matrix, axis=None)[::-1], correlation_matrix.shape)
logging.info('Top 5 correlations for psd')
for i in range(5):
logging.info(f'Subject pairs: {sorted_indices_x[i] + 1}, {sorted_indices_y[i] + 1}, '
f'correlation: {correlation_matrix[sorted_indices_x[i], sorted_indices_y[i]]}')
plt.figure(figsize=(10, 10))
plt.title('Correlation Matrix - based on PSD')
sns.heatmap(correlation_matrix, cmap='coolwarm', annot=False, fmt=".2f", linewidths=0.5)
ax = plt.gca()
ax.set_xticklabels([int(label.get_text()) + 1 for label in ax.get_xticklabels()])
ax.set_yticklabels([int(label.get_text()) + 1 for label in ax.get_yticklabels()])
if correlations_of_correlations:
plt.savefig(f'eeg_similarity_outputs/correlation_heatmap_psd_correlation_{label}.png')
else:
plt.savefig(f'eeg_similarity_outputs/correlation_heatmap_psd_{label}.png')
plt.close()
return correlation_matrix, subject_spectrums
def main(n_subjects,
sfreq,
runs,
tmin,
tmax):
epochs = get_epochs(n_subjects, sfreq, runs, tmin, tmax)
print(f'Number of epochs per subject: {epochs.metadata.groupby("subject").size().unique()}')
analyze_hands_feet_ratio(epochs)
epochs.plot_sensors(show_names=True)
plt.savefig('eeg_similarity_outputs/sensors.png')
get_naive_correlation_map(epochs, 'C4', 'hands')
get_naive_correlation_map(epochs, 'C3', 'hands')
_, subject_correlation_maps = get_correlation_map_of_correlation_maps(epochs, 'hands')
visualize_umap_embeddings(subject_correlation_maps,
[str(x) for x in range(1, 20)],
'UMAP embeddings of correlation matrices for label: hands',
'umap_embeddings_hands')
_, subject_spectrums = get_psd_correlation_map(epochs, 'hands')
visualize_umap_embeddings(subject_spectrums,
[str(x) for x in range(1, 20)],
'UMAP embeddings of PSD for label: hands',
'umap_embeddings_psd_hands')
_, subject_spectrum_correlations = get_psd_correlation_map(epochs, 'hands', correlations_of_correlations=True)
visualize_umap_embeddings(subject_spectrum_correlations,
[str(x) for x in range(1, 20)],
'UMAP embeddings of PSD correlation for label: hands',
'umap_embeddings_psd_correlation_hands')
print
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--n_subjects', type=int, default=109)
parser.add_argument('--sfreq', type=int, default=160)
parser.add_argument('--runs', nargs='+', type=int, default=[6, 10, 14])
parser.add_argument('--tmin', type=float, default=-1.0)
parser.add_argument('--tmax', type=float, default=4.0)
args = parser.parse_args()
main(args.n_subjects,
args.sfreq,
args.runs,
args.tmin,
args.tmax)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment