Last active
July 18, 2018 08:57
-
-
Save joshlk/51bf180b433b2c91595bb6a207fe1e39 to your computer and use it in GitHub Desktop.
Scikit-learn cross-validation that returns Keras results
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import warnings | |
import numbers | |
import time | |
import numpy as np | |
from sklearn.base import is_classifier, clone | |
from sklearn.utils import indexable | |
from sklearn.utils.validation import _num_samples | |
from sklearn.utils.metaestimators import _safe_split | |
from sklearn.externals.joblib import Parallel, delayed, logger | |
from sklearn.externals.six.moves import zip | |
from sklearn.metrics.scorer import _check_multimetric_scoring | |
from sklearn.exceptions import FitFailedWarning | |
from sklearn.model_selection._split import check_cv | |
from sklearn.model_selection._validation import _aggregate_score_dicts, _score, _index_param_value | |
def cross_validate_keras(estimator, X, y=None, groups=None, scoring=None, cv=None, | |
n_jobs=1, verbose=0, fit_params=None, | |
pre_dispatch='2*n_jobs'): | |
X, y, groups = indexable(X, y, groups) | |
cv = check_cv(cv, y, classifier=is_classifier(estimator)) | |
scorers, _ = _check_multimetric_scoring(estimator, scoring=scoring) | |
# We clone the estimator to make sure that all the folds are | |
# independent, and that it is pickle-able. | |
parallel = Parallel(n_jobs=n_jobs, verbose=verbose, | |
pre_dispatch=pre_dispatch) | |
scores = parallel( | |
delayed(_fit_and_score_keras)( | |
clone(estimator), X, y, scorers, train, test, verbose, None, | |
fit_params, return_times=True) | |
for train, test in cv.split(X, y, groups)) | |
test_scores, fit_times, score_times, estimators, keras_results = zip(*scores) | |
test_scores = _aggregate_score_dicts(test_scores) | |
ret = {} | |
ret['fit_time'] = np.array(fit_times) | |
ret['score_time'] = np.array(score_times) | |
for name in scorers: | |
ret['test_%s' % name] = np.array(test_scores[name]) | |
ret['estimator'] = estimators | |
ret['keras_results'] = keras_results | |
return ret | |
def _fit_and_score_keras(estimator, X, y, scorer, train, test, verbose, | |
parameters, fit_params, return_train_score=False, | |
return_parameters=False, return_n_test_samples=False, | |
return_times=False, error_score='raise'): | |
if verbose > 1: | |
if parameters is None: | |
msg = '' | |
else: | |
msg = '%s' % (', '.join('%s=%s' % (k, v) | |
for k, v in parameters.items())) | |
print("[CV] %s %s" % (msg, (64 - len(msg)) * '.')) | |
# Adjust length of sample weights | |
fit_params = fit_params if fit_params is not None else {} | |
fit_params = dict([(k, _index_param_value(X, v, train)) | |
for k, v in fit_params.items()]) | |
train_scores = {} | |
if parameters is not None: | |
estimator.set_params(**parameters) | |
start_time = time.time() | |
X_train, y_train = _safe_split(estimator, X, y, train) | |
X_test, y_test = _safe_split(estimator, X, y, test, train) | |
is_multimetric = not callable(scorer) | |
n_scorers = len(scorer.keys()) if is_multimetric else 1 | |
try: | |
if y_train is None: | |
keras_results = estimator.fit(X_train, validation_data=(X_test, y_test), **fit_params) | |
else: | |
keras_results = estimator.fit(X_train, y_train, validation_data=(X_test, y_test), **fit_params) | |
except Exception as e: | |
# Note fit time as time until error | |
fit_time = time.time() - start_time | |
score_time = 0.0 | |
if error_score == 'raise': | |
raise | |
elif isinstance(error_score, numbers.Number): | |
if is_multimetric: | |
test_scores = dict(zip(scorer.keys(), | |
[error_score, ] * n_scorers)) | |
if return_train_score: | |
train_scores = dict(zip(scorer.keys(), | |
[error_score, ] * n_scorers)) | |
else: | |
test_scores = error_score | |
if return_train_score: | |
train_scores = error_score | |
warnings.warn("Classifier fit failed. The score on this train-test" | |
" partition for these parameters will be set to %f. " | |
"Details: \n%r" % (error_score, e), FitFailedWarning) | |
else: | |
raise ValueError("error_score must be the string 'raise' or a" | |
" numeric value. (Hint: if using 'raise', please" | |
" make sure that it has been spelled correctly.)") | |
else: | |
fit_time = time.time() - start_time | |
# _score will return dict if is_multimetric is True | |
test_scores = _score(estimator, X_test, y_test, scorer, is_multimetric) | |
score_time = time.time() - start_time - fit_time | |
if return_train_score: | |
train_scores = _score(estimator, X_train, y_train, scorer, | |
is_multimetric) | |
if verbose > 2: | |
if is_multimetric: | |
for scorer_name, score in test_scores.items(): | |
msg += ", %s=%s" % (scorer_name, score) | |
else: | |
msg += ", score=%s" % test_scores | |
if verbose > 1: | |
total_time = score_time + fit_time | |
end_msg = "%s, total=%s" % (msg, logger.short_format_time(total_time)) | |
print("[CV] %s %s" % ((64 - len(end_msg)) * '.', end_msg)) | |
ret = [train_scores, test_scores] if return_train_score else [test_scores] | |
if return_n_test_samples: | |
ret.append(_num_samples(X_test)) | |
if return_times: | |
ret.extend([fit_time, score_time]) | |
if return_parameters: | |
ret.append(parameters) | |
ret.append(estimator) | |
ret.append(keras_results) | |
return ret |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment