Skip to content

Instantly share code, notes, and snippets.

@sadimanna
Created June 29, 2020 15:44
Show Gist options
  • Select an option

  • Save sadimanna/47611acf6f4d9a1fba259bef1aa1caef to your computer and use it in GitHub Desktop.

Select an option

Save sadimanna/47611acf6f4d9a1fba259bef1aa1caef to your computer and use it in GitHub Desktop.
def get_roc_curve(gt, pred, target_names):
for i in range(len(target_names)):
curve_function = plot_roc_curve
auc_roc = auc_score(gt[:, i], pred[:, i])
label = str(target_names[i]) + " AUC: %.3f " % auc_roc
xlabel = "False positive rate"
ylabel = "True positive rate"
a, b, _ = curve_function(gt[:, i], pred[:, i])
plt.figure(1, figsize=(7, 7))
plt.plot([0, 1], [0, 1], 'k--')
plt.plot(a, b, label=label)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.legend(loc='upper center', bbox_to_anchor=(1.3, 1),
fancybox=True, ncol=1)
plt.savefig('ROC_Curve.png')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment