Skip to content

Instantly share code, notes, and snippets.

@DavidWalz
Last active March 22, 2020 01:23
Show Gist options
  • Select an option

  • Save DavidWalz/0f5e84ffc4c3929f3af5a93268eaeab9 to your computer and use it in GitHub Desktop.

Select an option

Save DavidWalz/0f5e84ffc4c3929f3af5a93268eaeab9 to your computer and use it in GitHub Desktop.
MLflow - Plot mean and standard deviation for each metric grouped by parameter settings
import mlflow
import numpy as np
import matplotlib.pyplot as plt
# get runs
exp_name = "experiment_name"
exp_id = mlflow.get_experiment_by_name(exp_name).experiment_id
runs = mlflow.search_runs(exp_id)
# group runs by parameter
params = [p for p in runs.columns if p.startswith("params")]
params_common = [p for p in params if len(runs[p].unique()) == 1]
params_varied = [p for p in params if len(runs[p].unique()) > 1]
groups = list(runs.groupby(params_varied).run_id)
# get history for each metric and run and plot mean + std across identical runs
client = mlflow.tracking.MlflowClient()
metrics = [c[8:] for c in runs.columns if c.startswith("metrics")]
fig, axs = plt.subplots(len(metrics), figsize=(6, 3 * len(metrics)), sharex=True)
fig.subplots_adjust(hspace=0.05)
for ax, metric in zip(axs, metrics):
for g, run_ids in groups:
try:
values = [
[s.value for s in client.get_metric_history(r, metric)] for r in run_ids
]
except mlflow.exceptions.MlflowException:
continue
values = np.atleast_2d(values)
mean = values.mean(axis=0)
std = values.std(axis=0)
step = np.arange(len(mean)) + 1
ax.plot(step, mean, label=g)
ax.fill_between(step, mean - std, mean + std, alpha=0.2)
ax.set_ylabel(metric)
ax.legend(fontsize="x-small")
ax.set_xlabel("step")
fig.suptitle(
", ".join([f"{p[7:]}={runs.loc[0, p]}" for p in params_common]),
wrap=True,
fontsize="small",
)
fig.savefig(f"{exp_name}.png")
@DavidWalz
Copy link
Author

MLflow allows to log metrics per step, which can then be viewed with the MLFflow UI. However, the UI provides no way to summarize these curves over multiple identical runs in order to average out stochastic effects. This snippet plots for each metric the mean and std across all identical runs in an experiment, as defined by the logged parameters.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment