Skip to content

Instantly share code, notes, and snippets.

@enochkan
Created June 12, 2020 05:36
Show Gist options
  • Save enochkan/43066d0b30adf97876c6bea3bbb467eb to your computer and use it in GitHub Desktop.
Save enochkan/43066d0b30adf97876c6bea3bbb467eb to your computer and use it in GitHub Desktop.
Generalized Dice Loss
class diceloss(torch.nn.Module):
def init(self):
super(diceLoss, self).init()
def forward(self,pred, target):
smooth = 1.
iflat = pred.contiguous().view(-1)
tflat = target.contiguous().view(-1)
intersection = (iflat * tflat).sum()
A_sum = torch.sum(iflat * iflat)
B_sum = torch.sum(tflat * tflat)
return 1 - ((2. * intersection + smooth) / (A_sum + B_sum + smooth) )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment