Skip to content

Instantly share code, notes, and snippets.

@jtrive84
Created April 2, 2024 18:12
Show Gist options
  • Save jtrive84/03e885c915fe86048a43d912819987ab to your computer and use it in GitHub Desktop.
Save jtrive84/03e885c915fe86048a43d912819987ab to your computer and use it in GitHub Desktop.
Classifier metrics
import matplotlib.pyplot as plt
from sklearn.metrics import (
PrecisionRecallDisplay, precision_recall_curve, RocCurveDisplay, roc_curve
)
p1, r1, thresh1 = precision_recall_curve(yactual, ypred1)
p2, r2, thresh2 = precision_recall_curve(yactual, ypred2)
# Determine threshold that maximizes f1 score.
f1_1 = 2 * (p1 * r1) / (p1 + r1)
f1_2 = 2 * (p2 * r2) / (p2 + r2)
best_f1_1 = np.argmax(f1_1)
best_f1_2 = np.argmax(f1_2)
best_thresh1 = thresh1[best_f1_1]
best_thresh2 = thresh1[best_f1_2]
print(f"best_thresh1: {best_thresh1:.5f}")
print(f"best_thresh2: {best_thresh2:.5f}")
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4.5), tight_layout=True)
pr_disp1 = PrecisionRecallDisplay.from_predictions(
yactual, ypred1, name="mm", plot_chance_level=False, ax=ax1, color="#CD0066"
)
ax1.grid(True)
pr_disp1.ax_.set_title("mm Precision-Recall curve", fontsize=9)
roc_disp1 = RocCurveDisplay.from_predictions(
yactual, ypred1, name="mm", plot_chance_level=True, ax=ax2, color="#191964"
)
ax2.grid(True)
roc_disp1.ax_.set_title("mm ROC curve", fontsize=9)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment