Last active
January 13, 2020 19:49
-
-
Save ZaxR/ad6eccfaeba3a98d2273200d9d9b5359 to your computer and use it in GitHub Desktop.
Attempt at a general ML model architecture
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" Demo Model class that can be used generally for ML projects. | |
Attempts to solve the following problems: | |
1. Ensure data preprocessing is consistent between data for model training and prediction | |
2. Have a common architecture for any data types / ML model types/libraries (as long as a consistent API is ued) | |
3. Allow for easy swapping of preprocessing, modeling, and/or postprocessing | |
4. Enforce an input (and potentially output) data schema(s) | |
Bonus: | |
- Follow's sklearn's estimator/transformer/predictor APIs, allowing use of sklearn Pipelines and GridSearchCV. | |
""" | |
from pyspins.io import load_file, output_file | |
from schema import Schema | |
from sklearn.base import BaseEstimator | |
# The below imports are for the demo only | |
import pandas as pd | |
from sklearn.ensemble import RandomForestClassifier | |
from sklearn.model_selection import train_test_split | |
from sklearn.pipeline import Pipeline | |
from sklearn.preprocessing import MinMaxScaler, StandardScaler | |
class NoModelError(Exception): | |
pass | |
class MissingFilePathError(Exception): | |
pass | |
class MLModelError(Exception): | |
""" Exception type used to raise exceptions within MLModel derived classes """ | |
def __init__(self, *args): | |
Exception.__init__(self, *args) | |
class MLModelSchemaValidationError(MLModelError): | |
""" Exception type used to raise schema validation exceptions within MLModel derived classes """ | |
def __init__(self, *args): | |
MLModelError.__init__(self, *args) | |
class Model(BaseEstimator): # maybe don't inherit from BaseEstimator to remove sklearn dependency | |
"""Build, run predictions on, and store metadata for a ML model. | |
Args: | |
model: ML model with a `predict` method. | |
input_schema: See https://pypi.org/project/schema/ | |
preprocessor: One or more preprocessing steps for fitting/predicting data. | |
Must have a fit_transform method. | |
Works with sklearn Pipelines and Transformers. | |
E.g. `preprocessor=Pipeline([('minmax', MinMaxScaler()), | |
('std', StandardScaler())])` | |
trainer: One or more model ML estimators. | |
Must have a fit method. | |
Works with sklearn Pipelines and Estimators. | |
E.g. `trainer=RandomForestClassifier()` | |
postprocessor: One or more postprocessing steps for transforming predictions. | |
Must have a fit_transform method. | |
Works with sklearn Pipelines and Transformers. | |
model_path: | |
""" | |
def __init__(self, model=None, input_schema=Schema(None), | |
preprocessor=None, trainer=None, postprocessor=None, model_path=None): | |
self.model = model | |
self.input_schema = input_schema | |
# preprocessing may depend on self.training; postprocess will also depend on self.training | |
self.preprocessor = preprocessor | |
self.trainer = trainer | |
self.postprocessor = postprocessor | |
self.model_path = model_path | |
def fit(self, X, y, **train_kwargs): | |
""" | |
Note: | |
Can be used either to train a single model or with GridSearchCV. | |
""" | |
X = self.preprocessor.fit_transform(X) | |
model = self.trainer.fit(X, y, **train_kwargs) | |
self.model = model | |
# return self? | |
def predict(self, prediction_inputs): | |
try: | |
self.input_schema.validate(prediction_inputs) | |
except Exception as e: | |
raise MLModelSchemaValidationError("Failed to validate input data: {}".format(str(e))) | |
if self.preprocessor is not None: | |
prediction_inputs = self.preprocessor.fit_transform(prediction_inputs) | |
predictions = self.model.predict(prediction_inputs) | |
if self.postprocessor is not None: | |
predictions = self.postprocessor.transform(predictions) | |
return predictions | |
def save(self): | |
if self.model is None: | |
raise NoModelError | |
if self.model_path is None: | |
raise MissingFilePathError | |
output_file(self, self.model_path) | |
@classmethod | |
def load(cls, model_path): | |
model_instance = load_file(model_path) | |
if isinstance(model_instance, cls): | |
return model_instance | |
raise TypeError("The file at model_path isn't an instance of Model.") | |
# Demo | |
if __name__ == '__main__': | |
df = pd.DataFrame({"feature1": [1, 2, 3, 4, 5, 6, 7, 8], | |
"target": [1, 1, 1, 1, 2, 2, 2, 2]}) | |
target = ["target"] | |
model_path = "modeltestv1.pkl" | |
# Training + Predicting | |
X_train, X_test, y_train, y_test = train_test_split(df.drop(target, axis="columns"), | |
df[target], | |
test_size=0.2) | |
preprocessor = Pipeline([('minmax', MinMaxScaler()), ('std', StandardScaler())]) | |
trainer = RandomForestClassifier() | |
# postprocessor = PredCleaner | |
model = Model(model=None, | |
input_schema=Schema(pd.DataFrame), | |
preprocessor=preprocessor, | |
trainer=trainer, | |
postprocessor=None, | |
model_path=model_path) | |
model.fit(X_train, y_train) | |
model.save() | |
preds = model.predict(X_test) | |
actuals = y_test | |
# Predicting from existing model | |
to_predict = pd.DataFrame([100, 0]) | |
model = Model.load(model_path) | |
preds = model.predict(to_predict) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
2 votes for setting path in the method, not the class insitantiation