-
-
Save dengemann/72ada879df279e16250f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import mne | |
from mne.decoding import GeneralizationAcrossTime as GAT | |
from sklearn.metrics import roc_auc_score | |
from sklearn.preprocessing import StandardScaler | |
from sklearn.svm import SVC | |
from sklearn.pipeline import make_pipeline | |
from sklearn.cross_validation import StratifiedKFold | |
from meeg_preprocessing.utils import setup_provenance | |
import matplotlib.pyplot as plt | |
from pyriemann.estimation import ERPCovariances | |
from pyriemann.tangentspace import TangentSpace | |
report, run_id, results_dir, logger = setup_provenance( | |
__file__, results_dir='results') | |
conditions = ['LSGS', 'LSGD', 'LDGS', 'LDGD'] | |
local_cond, global_cond = [ | |
[['LSGS', 'LSGD'], ['LDGS', 'LDGD']], | |
[['LSGS', 'LDGS'], ['LSGD', 'LDGD']]] | |
class reshape_X(object): | |
def __init__(self, window=1): | |
self.window = window # in time sample | |
def fit(self, X, y=None): | |
return self | |
def fit_transform(self, X, y=None): | |
return self.transform(X) | |
def transform(self, X): | |
if X.shape[1] / self.window != X.shape[1] / float(self.window): | |
raise ValueError('Wrong window size') | |
return X.reshape([X.shape[0], X.shape[1] / self.window, self.window]) | |
class force_predict(object): | |
def __init__(self, clf, mode='predict_proba', axis=0): | |
self._mode = mode | |
self._axis = axis | |
self._clf = clf | |
def fit(self, X, y, **kwargs): | |
self._clf.fit(X, y, **kwargs) | |
self._copyattr() | |
def predict(self, X): | |
if self._mode == 'predict_proba': | |
return self._clf.predict_proba(X)[:, self._axis] | |
elif self._mode == 'decision_function': | |
distances = self._clf.decision_function(X) | |
if len(distances.shape) > 1: | |
return distances[:, self._axis] | |
else: | |
return distances | |
else: | |
return self._clf.predict(X) | |
def get_params(self, deep=True): | |
return dict(clf=self._clf, mode=self._mode, axis=self._axis) | |
def _copyattr(self): | |
for key, value in self._clf.__dict__.iteritems(): | |
self.__setattr__(key, value) | |
class force_weight(object): | |
def __init__(self, clf, weights=None): | |
self._clf = clf | |
def fit(self, X, y): | |
return self._clf.fit(X, np.array(y[:, 0], dtype=int), | |
sample_weight=np.array(y[:, 1])) | |
def predict(self, X): | |
return self._clf.predict(X) | |
def get_params(self, deep=True): | |
return dict(clf=self._clf) | |
# weighted probablistic linear classifier | |
clf = make_pipeline(StandardScaler(), | |
force_predict(force_weight(SVC( | |
kernel='linear', probability=True)), axis=1)) | |
results = list() | |
for subject_name in ['TAJ20081223']: | |
# Preproc | |
epochs = mne.read_epochs('TAJ-epo.fif') | |
# this_epochs = epochs[conditions].crop(0.6, None) | |
this_epochs = epochs[conditions].crop(0.750, None) | |
# Contrast definitions | |
event_id = {v: k for k, v in this_epochs.event_id.items()} | |
y_raw = [event_id[k] for k in this_epochs.events[:, 2]] | |
sample_weight = [1. / y_raw.count(k) for k in y_raw] | |
y_local = [int(v in local_cond[1]) for v in y_raw] | |
y_global = [int(v in global_cond[1]) for v in y_raw] | |
iter_contrast = [[y_local, y_global], ['local', 'global']] | |
# GAT | |
cv = StratifiedKFold(y=y_raw, n_folds=5) # ensure full stratification | |
for y_fit, names_fit in zip(*iter_contrast): | |
window = 10 | |
window_s = window / epochs.info['sfreq'] # in seconds | |
step = 5. / epochs.info['sfreq'] | |
step = window_s | |
# test_times = 'diagonal' | |
# train_times = dict(length=window_s, step=step) | |
window = len(this_epochs.times) | |
train_times = dict(slices=[range(window)], start=this_epochs.times[0], | |
stop=this_epochs.times[-1], | |
times=[this_epochs.times[0]]) | |
test_times = dict(slices=[[range(window)]], start=this_epochs.times[0], | |
stop=this_epochs.times[-1], | |
times=[this_epochs.times[0]]) | |
kwargs = dict(test_times=test_times, scorer=roc_auc_score, cv=cv, | |
train_times=train_times) | |
# # Fit & Score on a single validation | |
svc = force_predict(force_weight(SVC(kernel='linear', probability=True)), axis=1) | |
# clf_classic = make_pipeline(StandardScaler(), svc) | |
# gat = GAT(n_jobs=-1, clf=clf_classic, **kwargs) | |
# gat.fit(this_epochs, y=np.c_[y_fit, sample_weight]) | |
# gat.score(this_epochs) | |
svc = force_predict(SVC(kernel='linear', probability=True), axis=1) | |
clf_rieman = make_pipeline(reshape_X(window=window), | |
ERPCovariances(estimator='lwf', svd=4), | |
TangentSpace(metric='logeuclid'), svc) | |
gat_rieman = GAT(n_jobs=1, clf=clf_rieman, **kwargs) | |
gat_rieman.fit(this_epochs, y=y_fit) | |
score = gat_rieman.score(this_epochs, y=y_fit) | |
print score | |
# Plot | |
# fig, axes = plt.subplots(3) | |
# # fig = gat.plot_diagonal(label='SVC', show=False, chance=.5, color='b', | |
# # ax=axes[0]) | |
# gat_rieman.plot_diagonal(label='Riemann+SVC', ax=axes[0], color='r', | |
# show=False, chance=False) | |
# # gat.plot(show=False, ax=axes[1], title='SVC') | |
# gat_rieman.plot(show=False, ax=axes[2], title='Riemann') | |
# report.add_figs_to_section(fig, names_fit, subject_name) | |
report.save() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment