MlFlow is awesome. We use it all the time to track our ML models and their artifacts. Logging model training runs are super easy, and if you need any custom logic, the API is also pretty easy to use.
The one small caveat to this is that the model should already be supported by MlFlow. See the following list for all the models that MlFlow natively supports. However, what happens when you have to log a model that isn't natively supported by MlFlow?
Thankfully, MlFlow exposes a base class that. once you've implemented all the needed functions, allows MlFlow to treat your custom model (almost) like a native one.
Here, we have a PyCaret model that we need to log into the model registry, and that we would want to load later on. For our use case, we wanted to log the model from a Jupyter Notebook and later load it from a Kubeflow component.
While PyCaret supports MlFlow, it does this for a fresh model. In other words, you have to train the model from scratch. However, in our case, we had already a trained model that had to be saved into the model registry.
Here, you can see that we've implemented a PyCaretModel
that inherits from the mlflow.pyfunc.PythonModel
base class. This requires us to implement to methods: load_context
and predict
.
undefined
load_context
is called when the model is loaded. context
would contain all the information needed for the model to be loaded. But, how does this context
get populated in the first place?
Remember that in our case, we already have an existing PyCaret model that we wanted to log. Let's call this segmentation_model.pkl
. Now, we want to save this into the model registry. Since mlflow.pycaret.log_model()
exist, we have to use mlflow.pyfunc.log_model()
instead:
undefined
So, pass in a dictionary containing to the path to where we are storing the model to, with the key pycaret_model_path
. Also, note that we are using PyCaretModel
as the model class.
So, now you know how this works:
undefined
The only small thing to handle here is that load_model
, a PyCaret function, tries to be overly smart and appends .pkl
to the model_path
. This means that we have to remove the extension first this is why we have to apply os.path.splitext
here.
The prediction is more straightforward. We don't need the context
here, just data
, which is the input data to the model. This is a wrapper to the PyCaret predict_model
:
undefined
So once you call mlflow.pyfunc.log_model
, the PyCaret model would have been logged as per normal. You would just have to supply the right MlFlow URI. For example;
undefined
Finally, for predictions:
undefined
That's pretty much it!
This approach is nice because you are not bound by what MlFlow natively supports and can easily extend this with any models you might come across.