Skip to content

Instantly share code, notes, and snippets.

@dengemann
Forked from kingjr/riemann_taj.py
Last active May 15, 2017 20:37
Show Gist options
  • Save dengemann/72ada879df279e16250f to your computer and use it in GitHub Desktop.
Save dengemann/72ada879df279e16250f to your computer and use it in GitHub Desktop.
import numpy as np
import mne
from mne.decoding import GeneralizationAcrossTime as GAT
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.pipeline import make_pipeline
from sklearn.cross_validation import StratifiedKFold
from meeg_preprocessing.utils import setup_provenance
import matplotlib.pyplot as plt
from pyriemann.estimation import ERPCovariances
from pyriemann.tangentspace import TangentSpace
report, run_id, results_dir, logger = setup_provenance(
__file__, results_dir='results')
conditions = ['LSGS', 'LSGD', 'LDGS', 'LDGD']
local_cond, global_cond = [
[['LSGS', 'LSGD'], ['LDGS', 'LDGD']],
[['LSGS', 'LDGS'], ['LSGD', 'LDGD']]]
class reshape_X(object):
def __init__(self, window=1):
self.window = window # in time sample
def fit(self, X, y=None):
return self
def fit_transform(self, X, y=None):
return self.transform(X)
def transform(self, X):
if X.shape[1] / self.window != X.shape[1] / float(self.window):
raise ValueError('Wrong window size')
return X.reshape([X.shape[0], X.shape[1] / self.window, self.window])
class force_predict(object):
def __init__(self, clf, mode='predict_proba', axis=0):
self._mode = mode
self._axis = axis
self._clf = clf
def fit(self, X, y, **kwargs):
self._clf.fit(X, y, **kwargs)
self._copyattr()
def predict(self, X):
if self._mode == 'predict_proba':
return self._clf.predict_proba(X)[:, self._axis]
elif self._mode == 'decision_function':
distances = self._clf.decision_function(X)
if len(distances.shape) > 1:
return distances[:, self._axis]
else:
return distances
else:
return self._clf.predict(X)
def get_params(self, deep=True):
return dict(clf=self._clf, mode=self._mode, axis=self._axis)
def _copyattr(self):
for key, value in self._clf.__dict__.iteritems():
self.__setattr__(key, value)
class force_weight(object):
def __init__(self, clf, weights=None):
self._clf = clf
def fit(self, X, y):
return self._clf.fit(X, np.array(y[:, 0], dtype=int),
sample_weight=np.array(y[:, 1]))
def predict(self, X):
return self._clf.predict(X)
def get_params(self, deep=True):
return dict(clf=self._clf)
# weighted probablistic linear classifier
clf = make_pipeline(StandardScaler(),
force_predict(force_weight(SVC(
kernel='linear', probability=True)), axis=1))
results = list()
for subject_name in ['TAJ20081223']:
# Preproc
epochs = mne.read_epochs('TAJ-epo.fif')
# this_epochs = epochs[conditions].crop(0.6, None)
this_epochs = epochs[conditions].crop(0.750, None)
# Contrast definitions
event_id = {v: k for k, v in this_epochs.event_id.items()}
y_raw = [event_id[k] for k in this_epochs.events[:, 2]]
sample_weight = [1. / y_raw.count(k) for k in y_raw]
y_local = [int(v in local_cond[1]) for v in y_raw]
y_global = [int(v in global_cond[1]) for v in y_raw]
iter_contrast = [[y_local, y_global], ['local', 'global']]
# GAT
cv = StratifiedKFold(y=y_raw, n_folds=5) # ensure full stratification
for y_fit, names_fit in zip(*iter_contrast):
window = 10
window_s = window / epochs.info['sfreq'] # in seconds
step = 5. / epochs.info['sfreq']
step = window_s
# test_times = 'diagonal'
# train_times = dict(length=window_s, step=step)
window = len(this_epochs.times)
train_times = dict(slices=[range(window)], start=this_epochs.times[0],
stop=this_epochs.times[-1],
times=[this_epochs.times[0]])
test_times = dict(slices=[[range(window)]], start=this_epochs.times[0],
stop=this_epochs.times[-1],
times=[this_epochs.times[0]])
kwargs = dict(test_times=test_times, scorer=roc_auc_score, cv=cv,
train_times=train_times)
# # Fit & Score on a single validation
svc = force_predict(force_weight(SVC(kernel='linear', probability=True)), axis=1)
# clf_classic = make_pipeline(StandardScaler(), svc)
# gat = GAT(n_jobs=-1, clf=clf_classic, **kwargs)
# gat.fit(this_epochs, y=np.c_[y_fit, sample_weight])
# gat.score(this_epochs)
svc = force_predict(SVC(kernel='linear', probability=True), axis=1)
clf_rieman = make_pipeline(reshape_X(window=window),
ERPCovariances(estimator='lwf', svd=4),
TangentSpace(metric='logeuclid'), svc)
gat_rieman = GAT(n_jobs=1, clf=clf_rieman, **kwargs)
gat_rieman.fit(this_epochs, y=y_fit)
score = gat_rieman.score(this_epochs, y=y_fit)
print score
# Plot
# fig, axes = plt.subplots(3)
# # fig = gat.plot_diagonal(label='SVC', show=False, chance=.5, color='b',
# # ax=axes[0])
# gat_rieman.plot_diagonal(label='Riemann+SVC', ax=axes[0], color='r',
# show=False, chance=False)
# # gat.plot(show=False, ax=axes[1], title='SVC')
# gat_rieman.plot(show=False, ax=axes[2], title='Riemann')
# report.add_figs_to_section(fig, names_fit, subject_name)
report.save()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment