Skip to content

Instantly share code, notes, and snippets.

@leo-pfeiffer
Last active May 10, 2021 13:17
Show Gist options
  • Save leo-pfeiffer/0405c66b1610db133d732daf38b0e251 to your computer and use it in GitHub Desktop.
Save leo-pfeiffer/0405c66b1610db133d732daf38b0e251 to your computer and use it in GitHub Desktop.
ROC curve calculation and plot
import argparse
from decimal import Decimal
import matplotlib.pyplot as plt
from numpy import trapz
def create_roc_values(thresholds, scores, true_values):
"""
Calculate the points on a roc curve
"""
assert len(scores) == len(true_values)
x_values = []
y_values = []
if not len(thresholds):
thresholds = [i * 0.01 for i in range(101)]
for threshold in thresholds:
# predict based on threshold
predictions = [0 if x < threshold else 1 for x in scores]
tpr = calc_tpr(predictions, true_values)
fpr = calc_fpr(predictions, true_values)
# print(f"{round(threshold, 2)} - {round(fpr, 2)} - {round(tpr, 2)}")
x_values.append(fpr)
y_values.append(tpr)
return x_values, y_values
def calc_tpr(predictions, true_values) -> float:
"""
Calculate true positive rate.
tpr = tp / (tp + fn)
"""
assert len(predictions) == len(true_values)
tp = sum([p == 1 for p, t in zip(predictions, true_values) if t == 1])
fn = sum([p == 0 for p, t in zip(predictions, true_values) if t == 1])
return tp / (tp + fn)
def calc_fpr(predictions, true_values) -> float:
"""
Calculate false positive rate.
fpr = fp / (fp + tn)
"""
assert len(predictions) == len(true_values)
fp = sum([p == 1 for p, t in zip(predictions, true_values) if t == 0])
tn = sum([p == 0 for p, t in zip(predictions, true_values) if t == 0])
return fp / (fp + tn)
def create_plot(x_values, y_values):
"""
Plot a ROC curve.
"""
auc = -1 * trapz(y_values, x_values)
plt.style.use('ggplot')
plt.plot(x_values, y_values, linestyle='--', marker='o', lw=3, color='red')
plt.xlim([0, 1])
plt.ylim([0, 1])
plt.xlabel('FPR')
plt.ylabel('TPR')
plt.title('ROC curve, AUC = %.4f'%auc)
plt.show()
if __name__ == '__main__':
# example usage:
# python roc.py -s 0.9 0.8 0.7 0.7 0.6 0.5 0.4 0.3 0.2 0.1 -t 1 1 0 1 1 0 1 1 0 0
cli=argparse.ArgumentParser()
cli.add_argument("-s", "--scores", nargs="*", type=Decimal, default=[])
cli.add_argument("-t", "--true_values", nargs="*", type=int, default=[])
cli.add_argument("-ts", "--thresholds", nargs="*", type=Decimal, default=[])
args = cli.parse_args()
scores = args.scores
true_values = args.true_values
thresholds = args.thresholds
roc = create_roc_values(thresholds, scores, true_values)
create_plot(roc[0], roc[1])
exit(0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment