Skip to content

Instantly share code, notes, and snippets.

@jtrive84
Created April 2, 2024 20:58
Show Gist options
  • Save jtrive84/aefc33594e2f518869a32052f62c1397 to your computer and use it in GitHub Desktop.
Save jtrive84/aefc33594e2f518869a32052f62c1397 to your computer and use it in GitHub Desktop.
Classifier evaluation
import matplotlib.pyplot as plt
from sklearn.metrics import (
PrecisionRecallDisplay, precision_recall_curve, RocCurveDisplay, roc_curve,
ConfusionMatrixDisplay, confusion_matrix
)
beta = 2
prior_thresh = .50
p1, r1, thresh1 = precision_recall_curve(yactual, ypred1)
p2, r2, thresh2 = precision_recall_curve(yactual, ypred2)
# Determine threshold that maximizes f1 score.
f1_1 = 2 * (p1 * r1) / (p1 + r1)
f1_2 = 2 * (p2 * r2) / (p2 + r2)
# Determine threshold that maximizes f_beta score.
fb_1 = (1 + beta**2) * (p1 * r1) / ((beta**2 * p1) + r1)
fb_2 = (1 + beta**2) * (p2 * r2) / ((beta**2 * p2) + r2)
best_f1_1 = thresh1[np.argmax(f1_1)]
best_f1_2 = thresh2[np.argmax(f1_2)]
best_fb_1 = thresh1[np.argmax(fb_1)]
best_fb_2 = thresh2[np.argmax(fb_2)]
print(f"best f1 thresh1: {best_f1_1:.5f}")
print(f"best f1 thresh2: {best_f1_2:.5f}")
print(f"best fb thresh1: {best_fb_1:.5f}")
print(f"best fb thresh2: {best_fb_2:.5f}")
yhat1 = np.where(ypred1 >= prior_thresh, 1, 0)
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(10, 3.5), tight_layout=True)
pr_disp = PrecisionRecallDisplay.from_predictions(
yactual, ypred1, name="mm", plot_chance_level=False, ax=ax1, color="#CD0066"
)
pr_disp.ax_.set_title("mm Precision-Recall curve", fontsize=9)
ax1.grid(True)
roc_disp = RocCurveDisplay.from_predictions(
yactual, ypred1, name="mm", plot_chance_level=True, ax=ax2, color="#191964"
)
ax2.grid(True)
roc_disp.ax_.set_title("mm ROC curve", fontsize=9)
cm_disp = ConfusionMatrixDisplay.from_predictions(yactual, yhat1, ax=ax3, colorbar=False)
cm_disp.ax_.set_title(f"mm confusion matrix (thesh={prior_thresh:.2})", fontsize=9)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment