Created
March 6, 2018 23:00
-
-
Save jnothman/f7415eaf8a5e2f7715f51204dcb3ba70 to your computer and use it in GitHub Desktop.
Scikit-learn: Cache an estimator's fit with a mixin
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.externals import joblib | |
from sklearn.linear_model import LogisticRegression | |
from sklearn.feature_selection import RFE | |
from sklearn.model_selection import GridSearchCV | |
from sklearn.datasets import make_classification | |
memory = joblib.Memory('/tmp') | |
class MemoryFit: | |
def fit(self, *args, **kwargs): | |
fit = memory.cache(super(MemoryFit, self).fit) | |
cached_self = fit(*args, **kwargs) | |
vars(self).update(vars(cached_self)) | |
class CachedLogisticRegression(MemoryFit, LogisticRegression): | |
pass | |
gs = GridSearchCV(RFE(CachedLogisticRegression()), | |
{'n_features_to_select': [1, 2, 3]}, verbose=10) | |
gs.fit(*make_classification()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi - is there a way to be able to pass in the memory as a kwarg during instantiation? I tried writing this:
then in my script code:
The first CachedLinearSVC object instantiation works fine but then scikit-learn instantiates CachedLinearSVC again (possibly during cloning?) and it passes in **kwargs not from the original dict in the script but the **kwargs missing memory since I popped it off before sending **kwargs to super: