Created
June 2, 2014 09:00
-
-
Save kingjr/7defd2b5c0841398cb68 to your computer and use it in GitHub Desktop.
Why are probabilistic outputs better than non-probabilistic ones?
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
========================== | |
Better with probabilities? | |
========================== | |
Comparing classification performance of SVC versus SVC+Platt | |
using an MEG example from MNE-python. | |
""" | |
# Authors: Jean-Remi King <[email protected]> | |
# | |
# License: BSD (3-clause) | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import mne | |
from mne.datasets import sample | |
from mne.decoding import time_generalization | |
from sklearn.svm import SVC | |
from sklearn.pipeline import Pipeline | |
from sklearn.preprocessing import StandardScaler | |
from sklearn.metrics import make_scorer, roc_auc_score | |
# Load and preprocess data | |
data_path = sample.data_path() | |
raw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif' | |
events_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw-eve.fif' | |
raw = mne.io.Raw(raw_fname, preload=True) | |
picks = mne.pick_types(raw.info, meg=True, exclude='bads') | |
raw.filter(1, 30, method='iir') | |
events = mne.read_events(events_fname) | |
event_id = {'AudL': 1, 'VisL': 3, 'AudR': 2, 'VisR': 4} | |
epochs = mne.Epochs(raw, events, event_id, -0.1, 0.5, proj=True, | |
picks=picks, baseline=None, preload=True, | |
reject=dict(mag=5e-12), decim=4) | |
epochs_list = [epochs[k] for k in ['AudL', 'VisL']] | |
mne.epochs.equalize_epoch_counts(epochs_list) | |
# Decoding parameters | |
scaler = StandardScaler() | |
def decod(svc, scorer): | |
clf = Pipeline([('scaler', scaler), ('svc', svc)]) | |
results = time_generalization(epochs_list, clf=clf, scoring=scorer, | |
n_jobs=1) | |
return results['scores'], 1e3 * results['train_times'] | |
# Scores on decision_function | |
svc = SVC(C=1, kernel='linear') # normal SVC | |
scorer = make_scorer(roc_auc_score) | |
scores_distance, times = decod(svc, scorer) | |
# Scores on probabilities | |
svc = SVC(C=1, kernel='linear', probability=True) # SVC + Platt | |
roc_auc_scorer = lambda y_true, y_pred: roc_auc_score(y_true, y_pred[:, 1]) | |
scorer = make_scorer(roc_auc_scorer, needs_proba=True) | |
scores_proba, times = decod(svc, scorer) | |
# Vizualize | |
fig, ax = plt.subplots(1, 2, figsize=(12, 4)) | |
ax1, ax2 = ax.T.flatten() | |
def plot_time_gen(ax, scores, title): | |
im = ax.imshow(scores, interpolation='nearest', origin='lower', | |
extent=[times[0], times[-1], | |
times[0], times[-1]], | |
vmin=0., vmax=1.) | |
ax.set_xlabel('Times Test (ms)') | |
ax.set_ylabel('Times Train (ms)') | |
ax.set_title(title) | |
ax.axvline(0, color='k') | |
ax.axhline(0, color='k') | |
plt.colorbar(im, ax=ax) | |
plot_time_gen(ax1, scores_distance, 'Distance') | |
plot_time_gen(ax2, scores_proba, 'Probabilities') | |
mne.viz.tight_layout(fig=fig) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment