Last active
June 17, 2018 22:16
-
-
Save stsievert/10cce35465ebe8c9d6f9b560698820b6 to your computer and use it in GitHub Desktop.
Adaptive parameter search
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from sklearn.model_selection import cross_validate | |
from sklearn.linear_model import SGDClassifier | |
from sklearn.datasets import make_classification | |
import scipy.stats as stats | |
from sklearn.base import clone | |
from pprint import pprint | |
import sklearn.model_selection | |
class BaseSearchCV: | |
def __init__(self, estimator, params, **kwargs): | |
self.estimator = estimator | |
self.params = params | |
self.history = [] | |
# rest of BaseSearchCV.__init__ | |
def fit(self, X, y): | |
# warnings/etc | |
history = self.perform_search(X, y) | |
# formatting history, setting best/etc | |
return self | |
def perform_search(self, X, y): | |
results = candidate_params = [] | |
history = [] | |
while True: | |
candidate_params = self.nominate_params(results, candidate_params) | |
if not candidate_params: | |
break | |
results = self.evaluate_params(candidate_params) | |
history += results | |
return history | |
def evaluate_params(self, params): | |
candidates = [clone(self.estimator).set_params(**param) | |
for param in params] | |
return [cross_validate(est, X, y) for est in candidates] | |
def nominate_params(self, results): | |
raise NotImplementedError | |
class SimpleSearchCV(BaseSearchCV): | |
""" | |
SimpleSearchCV searches one paramter. It initially evaluates 3 parameters, | |
then nominates a parameter based on historical evaluations. It returns | |
best_param / 2 if last_best_param == min(historical_params) | |
best_param * 2 if last_best_param == max(historical_params) | |
The search stops when the best scoring parameter is surronded by worse | |
scoring parameters. | |
This is the hyperparameter optimization used to tune step size in [1]_. | |
References | |
---------- | |
1. A. C. Wilson, R. Roelofs, M. Stern, N. Srebro, B. Rech. | |
"The Marginal Value of Adaptive Gradient Methods in Machine Learning", | |
Section 4.1. https://arxiv.org/abs/1705.08292 | |
""" | |
def __init__(self, *args, **kwargs): | |
self._counter = 1 | |
self._last_score = None | |
self._alg_history = {} # param: score | |
super().__init__(*args, **kwargs) | |
def nominate_params(self, results, params): | |
self._counter += 1 | |
if not results and self._last_score is None: | |
return [{'alpha': a} for a in [1, 2, 4]] | |
scores = {a['alpha']: r['test_score'].mean() | |
for a, r in zip(params, results)} | |
self._alg_history.update(scores) | |
best_param = max(self._alg_history, key=self._alg_history.get) | |
edge_params = (min(self._alg_history), max(self._alg_history)) | |
if best_param not in edge_params: | |
return [] | |
if best_param == edge_params[0]: | |
return [{'alpha': best_param / 2}] | |
if best_param == edge_params[1]: | |
return [{'alpha': best_param * 2}] | |
if __name__ == "__main__": | |
np.random.seed(42) | |
X, y = make_classification(random_state=0, n_features=20, n_samples=1000) | |
est = SGDClassifier(max_iter=5) | |
params = {'alpha': stats.uniform(0, 1)} | |
search = SimpleSearchCV(est, params) | |
search.fit(X, y) | |
pprint(search._alg_history) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment