Skip to content

Instantly share code, notes, and snippets.

@stsievert
Last active June 17, 2018 22:16
Show Gist options
  • Save stsievert/10cce35465ebe8c9d6f9b560698820b6 to your computer and use it in GitHub Desktop.
Save stsievert/10cce35465ebe8c9d6f9b560698820b6 to your computer and use it in GitHub Desktop.
Adaptive parameter search
import numpy as np
from sklearn.model_selection import cross_validate
from sklearn.linear_model import SGDClassifier
from sklearn.datasets import make_classification
import scipy.stats as stats
from sklearn.base import clone
from pprint import pprint
import sklearn.model_selection
class BaseSearchCV:
def __init__(self, estimator, params, **kwargs):
self.estimator = estimator
self.params = params
self.history = []
# rest of BaseSearchCV.__init__
def fit(self, X, y):
# warnings/etc
history = self.perform_search(X, y)
# formatting history, setting best/etc
return self
def perform_search(self, X, y):
results = candidate_params = []
history = []
while True:
candidate_params = self.nominate_params(results, candidate_params)
if not candidate_params:
break
results = self.evaluate_params(candidate_params)
history += results
return history
def evaluate_params(self, params):
candidates = [clone(self.estimator).set_params(**param)
for param in params]
return [cross_validate(est, X, y) for est in candidates]
def nominate_params(self, results):
raise NotImplementedError
class SimpleSearchCV(BaseSearchCV):
"""
SimpleSearchCV searches one paramter. It initially evaluates 3 parameters,
then nominates a parameter based on historical evaluations. It returns
best_param / 2 if last_best_param == min(historical_params)
best_param * 2 if last_best_param == max(historical_params)
The search stops when the best scoring parameter is surronded by worse
scoring parameters.
This is the hyperparameter optimization used to tune step size in [1]_.
References
----------
1. A. C. Wilson, R. Roelofs, M. Stern, N. Srebro, B. Rech.
"The Marginal Value of Adaptive Gradient Methods in Machine Learning",
Section 4.1. https://arxiv.org/abs/1705.08292
"""
def __init__(self, *args, **kwargs):
self._counter = 1
self._last_score = None
self._alg_history = {} # param: score
super().__init__(*args, **kwargs)
def nominate_params(self, results, params):
self._counter += 1
if not results and self._last_score is None:
return [{'alpha': a} for a in [1, 2, 4]]
scores = {a['alpha']: r['test_score'].mean()
for a, r in zip(params, results)}
self._alg_history.update(scores)
best_param = max(self._alg_history, key=self._alg_history.get)
edge_params = (min(self._alg_history), max(self._alg_history))
if best_param not in edge_params:
return []
if best_param == edge_params[0]:
return [{'alpha': best_param / 2}]
if best_param == edge_params[1]:
return [{'alpha': best_param * 2}]
if __name__ == "__main__":
np.random.seed(42)
X, y = make_classification(random_state=0, n_features=20, n_samples=1000)
est = SGDClassifier(max_iter=5)
params = {'alpha': stats.uniform(0, 1)}
search = SimpleSearchCV(est, params)
search.fit(X, y)
pprint(search._alg_history)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment