Last active
February 29, 2024 13:30
-
-
Save smurching/366781ae6a3e5d597d716ef30cf26ba8 to your computer and use it in GitHub Desktop.
creating-child-runs-in-mlflow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import mlflow | |
# There are two ways to create parent/child runs in MLflow. | |
# (1) The most common way is to use the fluent | |
# mlflow.start_run API, passing nested=True: | |
with mlflow.start_run(): | |
num_trials = 10 | |
mlflow.log_param("num_trials", num_trials) | |
best_loss = 1e100 | |
for trial_idx in range(num_trials): | |
# Create a child run per tuning trial | |
with mlflow.start_run(nested=True): | |
# Look up params for the current trial, train a model on them, log the | |
# model and params | |
params = tuning_params[trial_idx] | |
model, loss = train_model(params, train_data) | |
if loss < best_loss: | |
best_model = model # Update best model across trial runs etc | |
best_loss = loss | |
# We're in the parent run again and all trials have completed, | |
# let's log the best model and best loss | |
mlflow.pytorch.log_model(best_model) | |
mlflow.log_metric("best_loss", best_loss) | |
# (2) Another way to create parent and child runs is to explicitly set the "mlflow.parentRunId" tag on one run. | |
# This can be useful if creating nested run relationships across processes/distributed workers, or if you | |
# don't want to change the current active run. | |
from mlflow.tracking import MlflowClient | |
client = MlflowClient() | |
# See docs https://www.mlflow.org/docs/latest/python_api/mlflow.tracking.html#mlflow.tracking.MlflowClient.create_run | |
parent_run = client.create_run(...) | |
child_run = client.create_run(..., tags={"mlflow.parentRunId": parent_run.info.run_id}) | |
# We can use the client to statelessly log metrics/params/tags to the runs | |
client.log_param(parent_run.info.run_id, key="learning_rate", value=0.001) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment