Skip to content

Instantly share code, notes, and snippets.

@benjamintanweihao
Created March 14, 2022 07:31
Show Gist options
  • Save benjamintanweihao/a517652cf770958c387e9554175101d3 to your computer and use it in GitHub Desktop.
Save benjamintanweihao/a517652cf770958c387e9554175101d3 to your computer and use it in GitHub Desktop.

Supporting Arbitary ML Models with MlFlow

Using PyCaret as an Example

MlFlow is awesome. We use it all the time to track our ML models and their artifacts. Logging model training runs are super easy, and if you need any custom logic, the API is also pretty easy to use.

The one small caveat to this is that the model should already be supported by MlFlow. See the following list for all the models that MlFlow natively supports. However, what happens when you have to log a model that isn't natively supported by MlFlow?

Enter mlflow.pyfunc.PythonModel

Thankfully, MlFlow exposes a base class that. once you've implemented all the needed functions, allows MlFlow to treat your custom model (almost) like a native one.

Here, we have a PyCaret model that we need to log into the model registry, and that we would want to load later on. For our use case, we wanted to log the model from a Jupyter Notebook and later load it from a Kubeflow component.

While PyCaret supports MlFlow, it does this for a fresh model. In other words, you have to train the model from scratch. However, in our case, we had already a trained model that had to be saved into the model registry.

Create a mlflow.pyfunc.PythonModel class

Here, you can see that we've implemented a PyCaretModel that inherits from the mlflow.pyfunc.PythonModel base class. This requires us to implement to methods: load_context and predict.

undefined

load_context

load_context is called when the model is loaded. context would contain all the information needed for the model to be loaded. But, how does this context get populated in the first place?

Remember that in our case, we already have an existing PyCaret model that we wanted to log. Let's call this segmentation_model.pkl. Now, we want to save this into the model registry. Since mlflow.pycaret.log_model() exist, we have to use mlflow.pyfunc.log_model() instead:

undefined

So, pass in a dictionary containing to the path to where we are storing the model to, with the key pycaret_model_path. Also, note that we are using PyCaretModel as the model class.

So, now you know how this works:

undefined

The only small thing to handle here is that load_model, a PyCaret function, tries to be overly smart and appends .pkl to the model_path. This means that we have to remove the extension first this is why we have to apply os.path.splitext here.

predict

The prediction is more straightforward. We don't need the context here, just data, which is the input data to the model. This is a wrapper to the PyCaret predict_model:

undefined

Loading the Model

So once you call mlflow.pyfunc.log_model, the PyCaret model would have been logged as per normal. You would just have to supply the right MlFlow URI. For example;

undefined

Finally, for predictions:

undefined

That's pretty much it!

This approach is nice because you are not bound by what MlFlow natively supports and can easily extend this with any models you might come across.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment