Created
December 7, 2022 01:24
-
-
Save iydon/5eb0966c79fbb0a870b932d105b36cf8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
__all__ = ['SearchCV'] | |
import copy | |
import pathlib as p | |
import pickle | |
import typing as t | |
import numpy as np | |
import pandas as pd | |
import tqdm | |
from sklearn.model_selection._search import BaseSearchCV | |
from sklearn.base import BaseEstimator | |
from sklearn.metrics import make_scorer | |
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV, cross_val_score | |
if t.TYPE_CHECKING: | |
from typing_extensions import Self | |
Best = t.Dict[str, BaseSearchCV] | |
Parameters = t.Dict[str, t.Any] | |
ParameterSpace = t.Dict[str, t.Iterable[t.Any]] | |
Path = t.Union[str, p.Path] | |
SearchType = t.Literal['grid', 'random'] | |
class SearchCV: | |
'''GridSearchCV and RandomizedSearchCV''' | |
_default_path = 'search_cv.pkl' | |
def __init__(self, X: np.ndarray, y: np.ndarray) -> None: | |
self._X, self._y = X, y | |
self._cv = self._n = self._random = self._scoring = self._type = None | |
self._best = {} | |
@classmethod | |
def from_xy(cls, X: np.ndarray, y: np.ndarray) -> 'Self': | |
return cls(X, y) | |
@classmethod | |
def load(cls, path: Path = _default_path) -> 'Self': | |
obj = pickle.loads(p.Path(path).read_bytes()) | |
self = cls.from_xy(obj['X'], obj['y']) \ | |
.set_cv(obj['cv']) \ | |
.set_n_jobs(obj['n']) \ | |
.set_random_state(obj['random']) \ | |
.set_type(obj['type']) | |
self._scoring = obj['scoring'] | |
self._best = obj['best'] | |
return self | |
@property | |
def best(self) -> Best: | |
return self._best | |
@property | |
def best_estimators(self) -> t.Dict[str, BaseEstimator]: | |
return { | |
key: value.best_estimator_ | |
for key, value in self.best.items() | |
} | |
@property | |
def best_parameters(self) -> t.Dict[str, Parameters]: | |
return { | |
key: value.best_params_ | |
for key, value in self.best.items() | |
} | |
@property | |
def best_parameters_all(self) -> t.Dict[str, Parameters]: | |
return { | |
key: value.best_estimator_.get_params() | |
for key, value in self.best.items() | |
} | |
@property | |
def best_scores(self) -> t.Dict[str, float]: | |
return { | |
key: value.best_score_ | |
for key, value in self.best.items() | |
} | |
def copy(self) -> 'Self': | |
return copy.deepcopy(self) | |
def save(self, path: Path = _default_path) -> 'Self': | |
obj = { | |
'X': self._X, 'y': self._y, | |
'cv': self._cv, 'n': self._n, 'random': self._random, | |
'scoring': self._scoring, 'type': self._type, | |
'best': self._best, | |
} | |
p.Path(path).write_bytes(pickle.dumps(obj)) | |
return self | |
def cross_validation_scores( | |
self, | |
greater_is_better_metrics: t.Optional[t.List[t.Callable]] = None, | |
less_is_better_metrics: t.Optional[t.List[t.Callable]] = None, | |
cv: int = 3, | |
) -> pd.DataFrame: | |
make_scorings = lambda metrics, flag: { | |
metric.__name__: make_scorer(metric, greater_is_better=flag) | |
for metric in metrics | |
} | |
scorings = { | |
**make_scorings(greater_is_better_metrics or [], True), | |
**make_scorings(less_is_better_metrics or [], False), | |
} | |
ans = { | |
name: { | |
name: cross_val_score( | |
estimator, self._X, self._y, cv=cv, scoring=scoring, | |
) for name, estimator in self.best_estimators.items() | |
} for name, scoring in scorings.items() | |
} | |
return pd.DataFrame.from_dict(ans, orient='index') | |
def set_cv(self, cv: int) -> 'Self': | |
self._cv = cv | |
return self | |
def set_n_jobs(self, n_jobs: int) -> 'Self': | |
self._n = n_jobs | |
return self | |
def set_random_state(self, random: int) -> 'Self': | |
self._random = random | |
return self | |
def set_scoring( | |
self, | |
func: t.Callable, greater_is_better: bool = True, | |
**kwargs: t.Any, | |
) -> 'Self': | |
self._scoring = make_scorer(func, greater_is_better=greater_is_better, **kwargs) | |
return self | |
def set_type(self, type: SearchType = 'grid') -> 'Self': | |
self._type = type | |
return self | |
def set_type_grid(self) -> 'Self': | |
return self.set_type('grid') | |
def set_type_random(self) -> 'Self': | |
return self.set_type('random') | |
def add_estimator( | |
self, | |
estimator: BaseEstimator, parameter_space: ParameterSpace, | |
overwrite: bool = False, | |
) -> 'Self': | |
key = estimator.__class__.__name__ | |
if key in self._best and not overwrite: | |
return self | |
# | |
self._check() | |
if self._type == 'grid': | |
clf = GridSearchCV( | |
estimator, parameter_space, | |
scoring=self._scoring, n_jobs=self._n, cv=self._cv, | |
) | |
elif self._type == 'random': | |
clf = RandomizedSearchCV( | |
estimator, parameter_space, | |
scoring=self._scoring, n_jobs=self._n, cv=self._cv, random_state=self._random, | |
) | |
else: | |
raise Exception | |
self._best[key] = clf.fit(self._X, self._y) | |
return self | |
def add_estimators( | |
self, | |
*items: t.Tuple[BaseEstimator, ParameterSpace], | |
**kwargs: t.Any, | |
) -> 'Self': | |
for estimator, parameter_space in tqdm.tqdm(items): | |
self.add_estimator(estimator, parameter_space, **kwargs) | |
return self | |
def add_estimators_via_dict( | |
self, | |
data: t.Dict[BaseEstimator, ParameterSpace], | |
**kwargs: t.Any, | |
) -> 'Self': | |
return self.add_estimators(*data.items(), **kwargs) | |
def _check(self) -> None: | |
assert len(self._X.shape)==2 and len(self._y.shape)==1 | |
assert self._scoring is not None | |
assert self._type is not None | |
if __name__ == '__main__': | |
__import__('warnings').filterwarnings('ignore') | |
from sklearn.datasets import load_iris | |
from sklearn.gaussian_process import GaussianProcessRegressor | |
from sklearn.linear_model import LogisticRegression | |
from sklearn.svm import SVR | |
from sklearn.metrics import r2_score, mean_absolute_percentage_error, mean_squared_error | |
# load data | |
iris = load_iris() | |
# search parameter | |
scv = SearchCV.from_xy(iris.data, iris.target) \ | |
.set_cv(5) \ | |
.set_random_state(20221206) \ | |
.set_scoring(r2_score, greater_is_better=True) \ | |
.set_type('grid' if True else 'random') | |
# estimator parameter space | |
models = { | |
GaussianProcessRegressor(): dict( | |
alpha=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], | |
), | |
LogisticRegression(): dict( | |
solver=['sag', 'saga'], | |
tol=[1e-4, 1e-3, 1e-2], | |
penalty=['l2', 'l1'], | |
), | |
SVR(): dict( | |
kernel=['rbf', 'sigmoid'], | |
gamma=['scale', 'auto'], | |
C=[1.0, 10.0, 100.0], | |
epsilon=[1e-3, 1e-2, 1e-1], | |
), | |
} | |
scv.add_estimators_via_dict(models, overwrite=True) | |
# cross validation scores | |
df = scv.cross_validation_scores( | |
greater_is_better_metrics=[r2_score], | |
less_is_better_metrics=[mean_absolute_percentage_error, mean_squared_error], | |
cv=5, | |
) | |
print(df.applymap(lambda x: x.min())) | |
scv.save() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment