Skip to content

Instantly share code, notes, and snippets.

@khizirsiddiqui
Created April 17, 2021 05:16
Show Gist options
  • Save khizirsiddiqui/559a91dab223944fb83f8480715d2582 to your computer and use it in GitHub Desktop.
Save khizirsiddiqui/559a91dab223944fb83f8480715d2582 to your computer and use it in GitHub Desktop.
AUC ROC Pytorch
def auroc(model, loader_name='val', N_classes=4):
model.eval()
y_test = []
y_score = []
with torch.no_grad():
for i, (inputs, classes) in enumerate(dataloaders[loader_name]):
inputs = inputs.to(device)
y_test.append(F.one_hot(classes, N_classes).numpy())
try:
bs, ncrops, c, h, w = inputs.size()
except:
bs, c, h, w = inputs.size()
ncrops = 1
if ncrops > 1:
outputs = model(inputs.view(-1, c, h, w))
outputs = outputs.view(bs, ncrops, -1).mean(1)
else:
outputs = model(inputs)
y_score.append(outputs.cpu().numpy())
y_test = np.array([t.ravel() for t in y_test])
y_score = np.array([t.ravel() for t in y_score])
# print(y_true)
# print(y_pred)
"""
compute ROC curve and ROC area for each class in each fold
"""
fpr = dict()
tpr = dict()
local_roc_auc = dict()
for i in range(N_classes):
fpr[i], tpr[i], _ = roc_curve(np.array(y_test[:, i]),np.array(y_score[:, i]))
local_roc_auc[i] = auc(fpr[i], tpr[i])
# Compute micro-average ROC curve and ROC area
fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), y_score.ravel())
local_roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
# Compute macro-average ROC curve and ROC area
# First aggregate all false positive rates
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(N_classes)]))
# Then interpolate all ROC curves at this points
mean_tpr = np.zeros_like(all_fpr)
for i in range(N_classes):
mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
# Finally average it and compute AUC
mean_tpr /= N_classes
fpr["macro"] = all_fpr
tpr["macro"] = mean_tpr
local_roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])
# Plot all ROC curves
plt.figure()
plt.plot(fpr["micro"], tpr["micro"],
label='micro-average ROC curve (area = {0:0.2f})'
''.format(local_roc_auc["micro"]),
color='deeppink', linestyle=':', linewidth=4)
plt.plot(fpr["macro"], tpr["macro"],
label='macro-average ROC curve (area = {0:0.2f})'
''.format(local_roc_auc["macro"]),
color='navy', linestyle=':', linewidth=4)
colors = cycle(['aqua', 'darkorange', 'cornflowerblue'])
for i, color in zip(range(N_classes), colors):
plt.plot(fpr[i], tpr[i], color=color, lw=2,
label='ROC curve of class {0} (area = {1:0.2f})'
''.format(i, local_roc_auc[i]))
plt.plot([0, 1], [0, 1], 'k--', lw=2)
plt.xlim([-1e-2, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristics')
plt.legend(loc="lower right")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment