Last active
July 13, 2021 01:58
-
-
Save jnothman/ba46247a36d375136a6662d1b1ef4c6d to your computer and use it in GitHub Desktop.
A wrapper for functions so that they can be parametrized with get_params and set_params in scikit-learn: proof of concept
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import defaultdict | |
import pandas as pd | |
class parametrized_function: | |
def __init__(self, _func, **kwargs): | |
self._func = _func | |
self.__doc__ = self._func.__doc__ | |
self.__name__ = self._func.__name__ | |
# TODO use inspect to automatically find parameters with defaults | |
self._params = kwargs | |
def __call__(self, *args, **kwargs): | |
kw = self._params | |
kw.update(kwargs) | |
return self._func(*args, **kw) | |
def get_params(self, deep=False): | |
out = self._params.copy() | |
out['_func'] = self._func | |
for key, value in out.items(): | |
if deep and hasattr(value, 'get_parms'): | |
deep_items = value.get_params().items() | |
out.update((key + '__' + k, val) for k, val in deep_items) | |
return out | |
def set_params(self, **params): | |
if not params: | |
# Simple optimization to gain speed (inspect is slow) | |
return self | |
valid_params = self.get_params(deep=True) | |
nested_params = defaultdict(dict) # grouped by prefix | |
for key, value in params.items(): | |
key, delim, sub_key = key.partition('__') | |
if key not in valid_params: | |
raise ValueError('Invalid parameter %s for estimator %s. ' | |
'Check the list of available parameters ' | |
'with `estimator.get_params().keys()`.' % | |
(key, self)) | |
if delim: | |
nested_params[key][sub_key] = value | |
else: | |
self._params[key] = value | |
valid_params[key] = value | |
for key, sub_params in nested_params.items(): | |
valid_params[key].set_params(**sub_params) | |
return self | |
if __name__ == '__main__': | |
from sklearn.feature_selection import mutual_info_regression, SelectKBest | |
from sklearn.model_selection import GridSearchCV | |
from sklearn.pipeline import make_pipeline | |
from sklearn.linear_model import LinearRegression | |
from sklearn.datasets import make_regression | |
mutual_info_regression = parametrized_function(mutual_info_regression, | |
n_neighbors=3) | |
X, y = make_regression() | |
gs = GridSearchCV(make_pipeline(SelectKBest(mutual_info_regression, k=1), | |
LinearRegression()), | |
{'selectkbest__score_func__n_neighbors': [3, 4]}, | |
cv=5, return_train_score=False).fit(X, y) | |
print(pd.DataFrame(gs.cv_results_)) |
Reading through the comments, I do not find it too magically. I see this as an extension of how Pipeline or ColumnTransformer exposes their estimators with get_params
and set_params
.
As for refactoring into a BaseParametrized
, if we make BaseParametrized
easy to use for third parties, then they can easily extend any object to have the get_params
and set_params
interface.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Note that if BaseEstimator's
get_params
andset_params
were refactored to someBaseParametrized
where_get_param
was used instead ofgetattr
and_set_param
was used instead ofsetattr
, we wouldn't need to duplicate so much code here