Skip to content

Instantly share code, notes, and snippets.

@iydon
Created December 7, 2022 01:24
Show Gist options
  • Save iydon/5eb0966c79fbb0a870b932d105b36cf8 to your computer and use it in GitHub Desktop.
Save iydon/5eb0966c79fbb0a870b932d105b36cf8 to your computer and use it in GitHub Desktop.
__all__ = ['SearchCV']
import copy
import pathlib as p
import pickle
import typing as t
import numpy as np
import pandas as pd
import tqdm
from sklearn.model_selection._search import BaseSearchCV
from sklearn.base import BaseEstimator
from sklearn.metrics import make_scorer
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV, cross_val_score
if t.TYPE_CHECKING:
from typing_extensions import Self
Best = t.Dict[str, BaseSearchCV]
Parameters = t.Dict[str, t.Any]
ParameterSpace = t.Dict[str, t.Iterable[t.Any]]
Path = t.Union[str, p.Path]
SearchType = t.Literal['grid', 'random']
class SearchCV:
'''GridSearchCV and RandomizedSearchCV'''
_default_path = 'search_cv.pkl'
def __init__(self, X: np.ndarray, y: np.ndarray) -> None:
self._X, self._y = X, y
self._cv = self._n = self._random = self._scoring = self._type = None
self._best = {}
@classmethod
def from_xy(cls, X: np.ndarray, y: np.ndarray) -> 'Self':
return cls(X, y)
@classmethod
def load(cls, path: Path = _default_path) -> 'Self':
obj = pickle.loads(p.Path(path).read_bytes())
self = cls.from_xy(obj['X'], obj['y']) \
.set_cv(obj['cv']) \
.set_n_jobs(obj['n']) \
.set_random_state(obj['random']) \
.set_type(obj['type'])
self._scoring = obj['scoring']
self._best = obj['best']
return self
@property
def best(self) -> Best:
return self._best
@property
def best_estimators(self) -> t.Dict[str, BaseEstimator]:
return {
key: value.best_estimator_
for key, value in self.best.items()
}
@property
def best_parameters(self) -> t.Dict[str, Parameters]:
return {
key: value.best_params_
for key, value in self.best.items()
}
@property
def best_parameters_all(self) -> t.Dict[str, Parameters]:
return {
key: value.best_estimator_.get_params()
for key, value in self.best.items()
}
@property
def best_scores(self) -> t.Dict[str, float]:
return {
key: value.best_score_
for key, value in self.best.items()
}
def copy(self) -> 'Self':
return copy.deepcopy(self)
def save(self, path: Path = _default_path) -> 'Self':
obj = {
'X': self._X, 'y': self._y,
'cv': self._cv, 'n': self._n, 'random': self._random,
'scoring': self._scoring, 'type': self._type,
'best': self._best,
}
p.Path(path).write_bytes(pickle.dumps(obj))
return self
def cross_validation_scores(
self,
greater_is_better_metrics: t.Optional[t.List[t.Callable]] = None,
less_is_better_metrics: t.Optional[t.List[t.Callable]] = None,
cv: int = 3,
) -> pd.DataFrame:
make_scorings = lambda metrics, flag: {
metric.__name__: make_scorer(metric, greater_is_better=flag)
for metric in metrics
}
scorings = {
**make_scorings(greater_is_better_metrics or [], True),
**make_scorings(less_is_better_metrics or [], False),
}
ans = {
name: {
name: cross_val_score(
estimator, self._X, self._y, cv=cv, scoring=scoring,
) for name, estimator in self.best_estimators.items()
} for name, scoring in scorings.items()
}
return pd.DataFrame.from_dict(ans, orient='index')
def set_cv(self, cv: int) -> 'Self':
self._cv = cv
return self
def set_n_jobs(self, n_jobs: int) -> 'Self':
self._n = n_jobs
return self
def set_random_state(self, random: int) -> 'Self':
self._random = random
return self
def set_scoring(
self,
func: t.Callable, greater_is_better: bool = True,
**kwargs: t.Any,
) -> 'Self':
self._scoring = make_scorer(func, greater_is_better=greater_is_better, **kwargs)
return self
def set_type(self, type: SearchType = 'grid') -> 'Self':
self._type = type
return self
def set_type_grid(self) -> 'Self':
return self.set_type('grid')
def set_type_random(self) -> 'Self':
return self.set_type('random')
def add_estimator(
self,
estimator: BaseEstimator, parameter_space: ParameterSpace,
overwrite: bool = False,
) -> 'Self':
key = estimator.__class__.__name__
if key in self._best and not overwrite:
return self
#
self._check()
if self._type == 'grid':
clf = GridSearchCV(
estimator, parameter_space,
scoring=self._scoring, n_jobs=self._n, cv=self._cv,
)
elif self._type == 'random':
clf = RandomizedSearchCV(
estimator, parameter_space,
scoring=self._scoring, n_jobs=self._n, cv=self._cv, random_state=self._random,
)
else:
raise Exception
self._best[key] = clf.fit(self._X, self._y)
return self
def add_estimators(
self,
*items: t.Tuple[BaseEstimator, ParameterSpace],
**kwargs: t.Any,
) -> 'Self':
for estimator, parameter_space in tqdm.tqdm(items):
self.add_estimator(estimator, parameter_space, **kwargs)
return self
def add_estimators_via_dict(
self,
data: t.Dict[BaseEstimator, ParameterSpace],
**kwargs: t.Any,
) -> 'Self':
return self.add_estimators(*data.items(), **kwargs)
def _check(self) -> None:
assert len(self._X.shape)==2 and len(self._y.shape)==1
assert self._scoring is not None
assert self._type is not None
if __name__ == '__main__':
__import__('warnings').filterwarnings('ignore')
from sklearn.datasets import load_iris
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVR
from sklearn.metrics import r2_score, mean_absolute_percentage_error, mean_squared_error
# load data
iris = load_iris()
# search parameter
scv = SearchCV.from_xy(iris.data, iris.target) \
.set_cv(5) \
.set_random_state(20221206) \
.set_scoring(r2_score, greater_is_better=True) \
.set_type('grid' if True else 'random')
# estimator parameter space
models = {
GaussianProcessRegressor(): dict(
alpha=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
),
LogisticRegression(): dict(
solver=['sag', 'saga'],
tol=[1e-4, 1e-3, 1e-2],
penalty=['l2', 'l1'],
),
SVR(): dict(
kernel=['rbf', 'sigmoid'],
gamma=['scale', 'auto'],
C=[1.0, 10.0, 100.0],
epsilon=[1e-3, 1e-2, 1e-1],
),
}
scv.add_estimators_via_dict(models, overwrite=True)
# cross validation scores
df = scv.cross_validation_scores(
greater_is_better_metrics=[r2_score],
less_is_better_metrics=[mean_absolute_percentage_error, mean_squared_error],
cv=5,
)
print(df.applymap(lambda x: x.min()))
scv.save()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment