Skip to content

Instantly share code, notes, and snippets.

@sailfish009
Created April 5, 2020 05:53
Show Gist options
  • Save sailfish009/2bbdd5d61879365a911dc2d2cc849967 to your computer and use it in GitHub Desktop.
Save sailfish009/2bbdd5d61879365a911dc2d2cc849967 to your computer and use it in GitHub Desktop.
# https://forums.fast.ai/t/focalloss-with-multi-class/35588/2
def one_hot_embedding(labels, num_classes):
return torch.eye(num_classes)[labels.data.cpu()]
class myCCELoss(nn.Module):
def __init__(self):
super(myCCELoss, self).__init__()
def forward(self, input, target):
y = one_hot_embedding(target, input.size(-1))
logit = F.softmax(input)
loss = -1 * V(y) * torch.log(logit) # cross entropy loss
return loss.sum(dim=1).mean()
class FocalLoss(nn.Module):
def __init__(self, gamma=2, eps=1e-7):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.eps = eps
def forward(self, input, target):
y = one_hot_embedding(target, input.size(-1))
y = y.cuda()
y = y.long()
logit = F.softmax(input)
logit = logit.clamp(self.eps, 1. - self.eps)
loss = -1 * Variable(y) * torch.log(logit) # cross entropy
loss = loss * (1 - logit) ** self.gamma # focal loss
return loss.sum(dim=1).mean()
learn.crit = FocalLoss(gamma=2.0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment