Last active
October 29, 2019 17:10
-
-
Save iver56/7aad92f1a22912d40eedca56666aa668 to your computer and use it in GitHub Desktop.
Calculate and plot a confusion matrix, and then log it as an artifact in MLflow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
import mlflow | |
import numpy as np | |
from sklearn import metrics | |
def plot_confusion_matrix( | |
cm, class_names, title="Confusion matrix", cmap=plt.cm.Blues, normalize=False | |
): | |
""" | |
This function plots the confusion matrix. | |
""" | |
if normalize: | |
cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis] | |
fig, ax = plt.subplots() | |
im = ax.imshow(cm, interpolation="nearest", cmap=cmap) | |
ax.figure.colorbar(im, ax=ax) | |
ax.set( | |
xticks=np.arange(cm.shape[1]), | |
yticks=np.arange(cm.shape[0]), | |
ylim=(cm.shape[0] - 0.5, -0.5), | |
xticklabels=class_names, | |
yticklabels=class_names, | |
title=title, | |
ylabel="Ground truth label", | |
xlabel="Predicted label", | |
) | |
# Rotate the tick labels and set their alignment. | |
plt.setp(ax.get_xticklabels(), rotation=30, ha="right", rotation_mode="anchor") | |
# Loop over data dimensions and create text annotations. | |
fmt = ".2f" | |
thresh = cm.max() / 2.0 | |
for i in range(cm.shape[0]): | |
for j in range(cm.shape[1]): | |
ax.text( | |
j, | |
i, | |
format(cm[i, j], fmt), | |
ha="center", | |
va="center", | |
color="white" if cm[i, j] > thresh else "black", | |
) | |
fig.tight_layout() | |
return ax, fig | |
# Update the following three lines with _your_ data | |
y_true = np.array([0, 0, 1, 1, 2, 2]) | |
y_pred = np.array([0, 1, 1, 1, 2, 2]) | |
class_names = ["Cats", "Dogs", "Rabbits"] | |
confusion_matrix = metrics.confusion_matrix(y_true, y_pred) | |
ax, fig = plot_confusion_matrix(confusion_matrix, class_names, normalize=True) | |
plt.savefig("confusion_matrix.png") | |
plt.close(fig) | |
# Uncomment the following line to log the confusion matrix to mlflow | |
# mlflow.log_artifact("confusion_matrix.png") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment