Skip to content

Instantly share code, notes, and snippets.

@iver56
Last active October 29, 2019 17:10
Show Gist options
  • Save iver56/7aad92f1a22912d40eedca56666aa668 to your computer and use it in GitHub Desktop.
Save iver56/7aad92f1a22912d40eedca56666aa668 to your computer and use it in GitHub Desktop.
Calculate and plot a confusion matrix, and then log it as an artifact in MLflow
import matplotlib.pyplot as plt
import mlflow
import numpy as np
from sklearn import metrics
def plot_confusion_matrix(
cm, class_names, title="Confusion matrix", cmap=plt.cm.Blues, normalize=False
):
"""
This function plots the confusion matrix.
"""
if normalize:
cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]
fig, ax = plt.subplots()
im = ax.imshow(cm, interpolation="nearest", cmap=cmap)
ax.figure.colorbar(im, ax=ax)
ax.set(
xticks=np.arange(cm.shape[1]),
yticks=np.arange(cm.shape[0]),
ylim=(cm.shape[0] - 0.5, -0.5),
xticklabels=class_names,
yticklabels=class_names,
title=title,
ylabel="Ground truth label",
xlabel="Predicted label",
)
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=30, ha="right", rotation_mode="anchor")
# Loop over data dimensions and create text annotations.
fmt = ".2f"
thresh = cm.max() / 2.0
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
ax.text(
j,
i,
format(cm[i, j], fmt),
ha="center",
va="center",
color="white" if cm[i, j] > thresh else "black",
)
fig.tight_layout()
return ax, fig
# Update the following three lines with _your_ data
y_true = np.array([0, 0, 1, 1, 2, 2])
y_pred = np.array([0, 1, 1, 1, 2, 2])
class_names = ["Cats", "Dogs", "Rabbits"]
confusion_matrix = metrics.confusion_matrix(y_true, y_pred)
ax, fig = plot_confusion_matrix(confusion_matrix, class_names, normalize=True)
plt.savefig("confusion_matrix.png")
plt.close(fig)
# Uncomment the following line to log the confusion matrix to mlflow
# mlflow.log_artifact("confusion_matrix.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment