Skip to content

Instantly share code, notes, and snippets.

@sadimanna
Last active June 20, 2020 20:39
Show Gist options
  • Save sadimanna/13a5ea581eb01d410bf87ee6daf18ab2 to your computer and use it in GitHub Desktop.
Save sadimanna/13a5ea581eb01d410bf87ee6daf18ab2 to your computer and use it in GitHub Desktop.
def roc_curve(y,pred):
TPR = [0.0]
FPR = [0.0]
thresholds = np.arange(0.01,1.00+0.01,0.01)
for th in thresholds:
TPR.append(recall(y,pred,th))
FPR.append(fpr(y,pred,th))
return TPR,FPR,thresholds
def auc_score(y,pred):
TPR,FPR,thresholds = roc_curve(y,pred)
TPR = TPR/TPR[-1]
FPR = FPR/FPR[-1]
tpr_diff = np.r_[np.diff(TPR),0.0]
fpr_diff = np.r_[np.diff(FPR),0.0]
auc = np.dot(TPR,fpr_diff) + np.dot(tpr_diff,fpr_diff)/2
def plot_roc_curve(y,pred):
TPR,FPR,Thresholds = roc_curve(y,pred)
roc_auc = auc_score(y,pred)
plt.figure(figsize=(10,10))
plt.plot(FPR,TPR,linewidth=2,label='{} (AUC={:.3f})'.format(label, roc_auc))
plt.plot([0.0,1.0], [0.0,1.0], linestyle='dashed', color='red', linewidth=2, label='random')
plt.xlim(0.0, 1.0)
plt.ylim(0.0, 1.0)
plt.xlabel("FPR",fontsize=12)
plt.ylabel("TPR",fontsize=12)
plt.legend(fontsize=10, loc='best')
plt.title("ROC Curve",fontsize=12)
plt.savefig('ROC_Curve.png')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment