Skip to content

Instantly share code, notes, and snippets.

@level14taken
Last active January 3, 2021 06:50
Show Gist options
  • Save level14taken/25c631d51048141cc3a439959deadc2c to your computer and use it in GitHub Desktop.
Save level14taken/25c631d51048141cc3a439959deadc2c to your computer and use it in GitHub Desktop.
#https://www.kaggle.com/c/tgs-salt-identification-challenge/discussion/67693
def iou_metric(outputs, labels,logits=True):
outputs,labels= size_correct(outputs,labels)
outputs= (outputs>0).detach().cpu().numpy() if logits else (outputs>.5).detach().cpu().numpy()
labels= labels.detach().cpu().numpy()
batch_size = outputs.shape[0]
metric = 0.0
for batch in range(batch_size):
t, p = labels[batch], outputs[batch]
true = np.sum(t)
pred = np.sum(p)
# deal with empty mask first
if true == 0:
metric += (pred == 0)
continue
# non empty mask case. Union is never empty
# hence it is safe to divide by its number of pixels
intersection = np.sum(t * p)
union = true + pred - intersection
iou = intersection / union
# iou metrric is a stepwise approximation of the real iou over 0.5
iou = np.floor(max(0, (iou - 0.45)*20)) / 10
iou= np.clip(iou,0,1.0)
metric += iou
# teake the average over all images in batch
metric /= batch_size
return metric
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment