Created
May 8, 2018 16:45
-
-
Save benoitdescamps/4272f8f8110f65473a98211ca3000ae2 to your computer and use it in GitHub Desktop.
code snippet for Tuning Hyperparameters (part I): SuccessiveHalving
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class SHBaseEstimator(ABC): | |
def __init__(self,model): | |
self.model = model | |
self.env = None | |
def fit(self,X,y): | |
self.model.fit(X,y) | |
def predict(self,X): | |
return self.model.predict(X) | |
@abstractmethod | |
def save(self,name=None): | |
return NotImplementedError | |
@abstractmethod | |
def load(self,model_name): | |
return NotImplementedError | |
@abstractmethod | |
def remove(self,model_name): | |
return NotImplementedError | |
def get_params(self): | |
return self.model.get_params() | |
def set_params(self,*args,**kwargs): | |
self.model.set_params(*args,**kwargs) | |
def n_iteration(self,ressource_name): | |
return self.model.get_params()[ressource_name] | |
@abstractmethod | |
def update(self,Xtrain,ytrain,Xval,yval,scoring,n_iterations): | |
''' | |
Further train the model, after a reload! This is definition can vary, | |
depending on which library you are wrapping around! | |
''' | |
return NotImplementedError |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment