Last active
September 13, 2024 09:57
-
-
Save erap129/7766d911dad632d5bf0954027cada13e to your computer and use it in GitHub Desktop.
EEG similarity exercise
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from copy import deepcopy | |
import seaborn as sns | |
from functools import reduce | |
from tqdm import tqdm | |
import pandas as pd | |
from mne.datasets import eegbci | |
from argparse import ArgumentParser | |
import logging | |
from mne.io import concatenate_raws, read_raw_edf | |
from mne.channels import make_standard_montage | |
from mne import Epochs, pick_types, concatenate_epochs | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import umap | |
logging.basicConfig(level=logging.INFO) | |
def get_epochs(n_subjects, | |
sfreq, | |
runs, | |
tmin, | |
tmax): | |
montage = make_standard_montage("standard_1005") | |
epochs_list = [] | |
event_mapping = dict(hands=2, feet=3) | |
event_mapping_inverse = {v: k for k, v in event_mapping.items()} | |
for subject in tqdm(range(1, n_subjects), desc='collecting subject data'): | |
raw_fnames = eegbci.load_data(subject, runs) | |
raw = concatenate_raws([read_raw_edf(f, preload=True) for f in raw_fnames]) | |
if raw.info['sfreq'] != sfreq: | |
logging.info(f'Skipping subject {subject} because of incorrect sampling frequency') | |
continue | |
eegbci.standardize(raw) # set channel names | |
raw.set_montage(montage) | |
raw.annotations.rename(dict(T1="hands", T2="feet")) | |
raw.set_eeg_reference(projection=True) | |
raw.filter(7.0, 30.0, fir_design="firwin", skip_by_annotation="edge") | |
picks = pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, exclude="bads") | |
# Read epochs | |
epochs = Epochs( | |
raw, | |
event_id=event_mapping, | |
tmin=tmin, | |
tmax=tmax, | |
proj=True, | |
picks=picks, | |
baseline=None, | |
preload=True, | |
) | |
if subject == 1: | |
(epochs | |
.pick_channels(['C3', 'C4']) | |
.plot(events=True, | |
show_scrollbars=False, | |
n_epochs=5, | |
title="Raw EEG data for subject 1") | |
) | |
plt.savefig('eeg_similarity_outputs/epochs.png') | |
# Create metadata with subject information | |
num_epochs = len(epochs) | |
logging.info(f'Loaded {num_epochs} epochs for subject {subject}') | |
epochs.metadata = pd.DataFrame({'subject': [subject] * num_epochs, | |
'label': [event_mapping_inverse[e] for e in epochs.events[:, -1]]}) | |
epochs_list.append(epochs) | |
# Concatenate all epochs | |
epochs_combined = concatenate_epochs(epochs_list) | |
return epochs_combined | |
def get_hands_feet_ratio(df): | |
return pd.Series({'hands_feed_ratio': | |
len(df[df['label'] == 'hands']) / len(df[df['label'] == 'feet'])}) | |
def analyze_hands_feet_ratio(epochs): | |
hands_feet_ratio = (epochs.metadata | |
.groupby('subject') | |
.apply(lambda group_df: get_hands_feet_ratio(group_df)) | |
) | |
print(f'Hands/Feet ratio mean: {hands_feet_ratio["hands_feed_ratio"].mean()}') | |
print(f'Hands/Feet ratio min: {hands_feet_ratio["hands_feed_ratio"].min()}') | |
print(f'Hands/Feet ratio max: {hands_feet_ratio["hands_feed_ratio"].max()}') | |
return hands_feet_ratio | |
def get_naive_correlation_map(epochs, channel, label): | |
average_subject_dfs = [] | |
for subject in tqdm(epochs.metadata['subject'].unique(), desc=f'collecting subject data for channel: {channel}, label: {label}'): | |
average_subject_dfs.append(epochs | |
[f'subject == {subject} and label == "{label}"'] | |
.pick_channels([channel]) | |
.average() | |
.to_data_frame() | |
.rename(columns={channel: f'{subject}_{channel}'}) | |
) | |
df_merged = (reduce(lambda left,right: pd.merge(left,right,on=['time'], | |
how='outer'), average_subject_dfs) | |
.drop(columns=['time']) | |
) | |
correlation_matrix = df_merged.corr() | |
logging.info(f'Top 5 correlations for channel: {channel}, label: {label}') | |
correlation_matrix_copy = deepcopy(correlation_matrix) | |
stacked_correlation_matrix = correlation_matrix_copy.stack() | |
stacked_correlation_values = np.abs(np.triu(correlation_matrix_copy.values, k=1).flatten()) | |
for max_idx in np.argsort(stacked_correlation_values)[::-1][:5]: | |
subject_pair = stacked_correlation_matrix.index[max_idx] | |
logging.info(f'Subject pair: {subject_pair}, correlation: {stacked_correlation_matrix.values[max_idx]}') | |
plt.figure(figsize=(10, 10)) | |
plt.title(f'Correlation Matrix - based on average raw values from channel {channel} for label {label}') | |
sns.heatmap(correlation_matrix, cmap='coolwarm', annot=False, fmt=".2f", linewidths=0.5) | |
plt.savefig(f'eeg_similarity_outputs/correlation_heatmap_{channel}_{label}.png') | |
plt.close() | |
return correlation_matrix | |
def get_correlation_map_of_correlation_maps(epochs, label): | |
subject_correlation_maps = [] | |
for subject in tqdm(epochs.metadata['subject'].unique(), desc=f'collecting subject data for label: {label}'): | |
subject_correlation_maps.append(epochs | |
[f'subject == {subject} and label == "{label}"'] | |
.average() | |
.to_data_frame() | |
.corr() | |
.values | |
.flatten() | |
) | |
correlation_matrix = np.corrcoef(subject_correlation_maps) | |
sorted_indices_x, sorted_indices_y = np.unravel_index( | |
np.argsort(correlation_matrix, axis=None)[::-1], correlation_matrix.shape) | |
logging.info(f'Top 5 correlations for correlation of correlations, label: {label}') | |
for i in range(5): | |
logging.info(f'Subject pairs: {sorted_indices_x[i] + 1}, {sorted_indices_y[i] + 1}, ' | |
f'correlation: {correlation_matrix[sorted_indices_x[i], sorted_indices_y[i]]}') | |
plt.figure(figsize=(10, 10)) | |
plt.title(f'Correlation Matrix - based on correlation matrices for label {label}') | |
sns.heatmap(correlation_matrix, cmap='coolwarm', annot=False, fmt=".2f", linewidths=0.5) | |
ax = plt.gca() | |
ax.set_xticklabels([int(label.get_text()) + 1 for label in ax.get_xticklabels()]) | |
ax.set_yticklabels([int(label.get_text()) + 1 for label in ax.get_yticklabels()]) | |
plt.savefig(f'eeg_similarity_outputs/correlation_heatmap_{label}.png') | |
plt.close() | |
return correlation_matrix, subject_correlation_maps | |
def visualize_umap_embeddings(embeddings, labels, title, filename): | |
umap_embeddings = umap.UMAP(n_neighbors=5, | |
min_dist=0.3, | |
metric='correlation').fit_transform(embeddings) | |
plt.figure(figsize=(10, 10)) | |
plt.scatter(umap_embeddings[:, 0], umap_embeddings[:, 1]) | |
for i, label in enumerate(labels): | |
plt.text(umap_embeddings[i, 0] + 0.01, umap_embeddings[i, 1] + 0.01, label, fontsize=12) | |
plt.gca().set_aspect('equal', 'datalim') | |
plt.title(title) | |
plt.savefig(f'eeg_similarity_outputs/{filename}.png') | |
plt.close() | |
def get_psd_correlation_map(epochs, label, correlations_of_correlations=False): | |
subject_spectrums = [] | |
for subject in tqdm(epochs.metadata['subject'].unique(), desc='collecting subject data for psd'): | |
if subject == 1: | |
spectrum = epochs[f'subject == {subject} and label == "{label}"'].compute_psd(fmin=7, fmax=30) | |
spectrum.plot(picks="data", exclude="bads", amplitude=False) | |
plt.savefig(f'eeg_similarity_outputs/psd_{label}_subject_{subject}.png') | |
psds = (epochs | |
[f'subject == {subject} and label == "{label}"'] | |
.compute_psd(fmin=7, fmax=30) | |
.get_data(return_freqs=True)[0] | |
) | |
if correlations_of_correlations: | |
psds = np.corrcoef(psds.mean(axis=0)).flatten() | |
else: | |
psds = psds.mean(axis=0).flatten() | |
subject_spectrums.append(psds) | |
correlation_matrix = np.corrcoef(subject_spectrums) | |
sorted_indices_x, sorted_indices_y = np.unravel_index( | |
np.argsort(correlation_matrix, axis=None)[::-1], correlation_matrix.shape) | |
logging.info('Top 5 correlations for psd') | |
for i in range(5): | |
logging.info(f'Subject pairs: {sorted_indices_x[i] + 1}, {sorted_indices_y[i] + 1}, ' | |
f'correlation: {correlation_matrix[sorted_indices_x[i], sorted_indices_y[i]]}') | |
plt.figure(figsize=(10, 10)) | |
plt.title('Correlation Matrix - based on PSD') | |
sns.heatmap(correlation_matrix, cmap='coolwarm', annot=False, fmt=".2f", linewidths=0.5) | |
ax = plt.gca() | |
ax.set_xticklabels([int(label.get_text()) + 1 for label in ax.get_xticklabels()]) | |
ax.set_yticklabels([int(label.get_text()) + 1 for label in ax.get_yticklabels()]) | |
if correlations_of_correlations: | |
plt.savefig(f'eeg_similarity_outputs/correlation_heatmap_psd_correlation_{label}.png') | |
else: | |
plt.savefig(f'eeg_similarity_outputs/correlation_heatmap_psd_{label}.png') | |
plt.close() | |
return correlation_matrix, subject_spectrums | |
def main(n_subjects, | |
sfreq, | |
runs, | |
tmin, | |
tmax): | |
epochs = get_epochs(n_subjects, sfreq, runs, tmin, tmax) | |
print(f'Number of epochs per subject: {epochs.metadata.groupby("subject").size().unique()}') | |
analyze_hands_feet_ratio(epochs) | |
epochs.plot_sensors(show_names=True) | |
plt.savefig('eeg_similarity_outputs/sensors.png') | |
get_naive_correlation_map(epochs, 'C4', 'hands') | |
get_naive_correlation_map(epochs, 'C3', 'hands') | |
_, subject_correlation_maps = get_correlation_map_of_correlation_maps(epochs, 'hands') | |
visualize_umap_embeddings(subject_correlation_maps, | |
[str(x) for x in range(1, 20)], | |
'UMAP embeddings of correlation matrices for label: hands', | |
'umap_embeddings_hands') | |
_, subject_spectrums = get_psd_correlation_map(epochs, 'hands') | |
visualize_umap_embeddings(subject_spectrums, | |
[str(x) for x in range(1, 20)], | |
'UMAP embeddings of PSD for label: hands', | |
'umap_embeddings_psd_hands') | |
_, subject_spectrum_correlations = get_psd_correlation_map(epochs, 'hands', correlations_of_correlations=True) | |
visualize_umap_embeddings(subject_spectrum_correlations, | |
[str(x) for x in range(1, 20)], | |
'UMAP embeddings of PSD correlation for label: hands', | |
'umap_embeddings_psd_correlation_hands') | |
if __name__ == '__main__': | |
parser = ArgumentParser() | |
parser.add_argument('--n_subjects', type=int, default=109) | |
parser.add_argument('--sfreq', type=int, default=160) | |
parser.add_argument('--runs', nargs='+', type=int, default=[6, 10, 14]) | |
parser.add_argument('--tmin', type=float, default=-1.0) | |
parser.add_argument('--tmax', type=float, default=4.0) | |
args = parser.parse_args() | |
main(args.n_subjects, | |
args.sfreq, | |
args.runs, | |
args.tmin, | |
args.tmax) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment