Last active
November 20, 2020 12:37
-
-
Save nkthiebaut/ed02a65791ce248b4366bb59db66c5b2 to your computer and use it in GitHub Desktop.
Plot a good-looking roc curve with thresholds
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.metrics import roc_curve, auc | |
def plot_roc_curve(true_labels, scores): | |
""" Plot ROC curve with associated score thresholds """ | |
# compute fpr, tpr, thresholds and roc_auc | |
fpr, tpr, thresholds = roc_curve(true_labels, scores) | |
roc_auc = auc(fpr, tpr) # compute area under the curve | |
fig, ax = plt.subplots() | |
ax.plot(fpr, tpr, label='ROC curve (area = %0.2f)' % roc_auc) | |
ax.plot([0, 1], [0, 1], 'k--', ) | |
ax.set_xlim([0.0, 1.0]) | |
ax.set_ylim([0.0, 1.05]) | |
ax.set_xlabel('False Positive Rate') | |
ax.set_ylabel('True Positive Rate') | |
ax.set_title('Receiver operating characteristic') | |
ax.legend(loc="lower right") | |
# create the axis of thresholds (scores) | |
ax2 = ax.twinx() | |
ax2.plot(fpr, thresholds, markeredgecolor='r',linestyle='dashed', color='r') | |
ax2.set_ylabel('Threshold',color='r') | |
ax2.set_ylim([thresholds[-1], 1.05]) | |
ax2.set_xlim([fpr[0],fpr[-1]]) | |
return fig |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment