Skip to content

Instantly share code, notes, and snippets.

@Steboss89
Created September 18, 2022 17:50
Show Gist options
  • Save Steboss89/e2705dd3b758a955345aa54013760fb7 to your computer and use it in GitHub Desktop.
Save Steboss89/e2705dd3b758a955345aa54013760fb7 to your computer and use it in GitHub Desktop.
Set up for MLflow for general models
import mlflow
from mlflow.tracking.client import MlflowClient
# CLASS AND PIPELINE
# MAIN
# argparse ...
# set up the tracking and define the input arguments
mlflow_client = MlflowClient(tracking_uri=mlflow_tracking_uri)
run_name = run_name_
experiment_family = exp_name_
try:
print("setting up experiment ")
experiment = mlflow.create_experiment(name = experiment_family)
experiment_id = experiment.experiment_id
except:
experiment = mlflow_client.get_experiment_by_name(experiment_family)
experiment_id = experiment.experiment_id
# connect to tracking
mlflow.set_tracking_uri(mlflow_tracking_uri)
# start the recording
starter = mlflow.start_run(experiment_id=experiment_id,
run_name=run_name,
nested=False)
# set the autolog
mlflow.sklearn.autolog(log_models=True,log_input_examples=True,log_model_signatures=True, )
# fit the pipeline
trained_model = training_process(model_, vectorizer_)
trained_model.fit(X_train, y_train)
# and run predictions
y_pred = trained_model.predict(X_valid)
report = classification_report(
y_valid, y_pred, output_dict=True
)
cm = confusion_matrix(y_valid, y_pred)
# save the final model
joblib.dump(trained_model, "final_model.joblib")
# and port it to the server under model
mlflow.sklearn.log_model(sk_model=trained_model, artifact_path="model")
mlflow.end_run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment