Last active
March 17, 2021 10:52
-
-
Save adhadse/7db6cffe576399052527463c03ceb8fc to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def plot_precision_vs_recall(precisions, recalls, metric_name=None, metric_perc=None): | |
plt.figure(figsize=(15, 10)) | |
plt.plot(recalls, precisions, 'b-', linewidth=2) | |
plt.xlabel("Recall", fontsize=15) | |
plt.ylabel("Precision", fontsize=15) | |
plt.axis([0, 1, 0, 1]) | |
if metric_name=='precision': | |
# tradedoff recall & threshold at percentage precision we want. | |
recall_atperc_precision = recalls[np.argmax(precisions >= metric_perc)] | |
threshold_atperc_precision = thresholds[np.argmax(precisions >= metric_perc)] | |
plt.plot([recall_atperc_precision, recall_atperc_precision], [0., metric_perc], 'r:') | |
plt.plot([0., recall_atperc_precision], [metric_perc, metric_perc], "r:") | |
plt.plot([recall_atperc_precision], [metric_perc], 'ro') | |
plt.title("Precision/Recall plot with Threshold Set to {}% of {}\n trading off Recall at {:.3f}%".format( | |
metric_perc*100, | |
metric_name.capitalize(), | |
tradedoff_atperc_metric*100), fontsize=20) | |
plt.show() | |
return threshold_atperc_precision | |
elif metric_name == 'recall': | |
# tradedoff precision & threshold at percentage recall we want. | |
precision_atperc_recall = precisions[np.argmax(recalls <= metric_perc)] | |
threshold_atperc_recall = thresholds[np.argmax(recalls <= metric_perc)] | |
plt.plot([0., metric_perc],[precision_atperc_recall, precision_atperc_recall], 'r:') | |
plt.plot([metric_perc, metric_perc],[0., precision_atperc_recall], "r:") | |
plt.plot([metric_perc],[precision_atperc_recall], 'ro') | |
plt.title("Precision/Recall plot with Threshold Set to {}% of {}\n trading off Precision at {:.3f}%".format( | |
metric_perc*100, | |
metric_name.capitalize(), | |
precision_atperc_recall*100), fontsize=20) | |
plt.show() | |
return threshold_atperc_recall | |
else: | |
return | |
threshold = plot_precision_vs_recall(precisions, recalls, 'recall', 0.9) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment