Last active
April 2, 2024 11:56
-
-
Save samson-wang/e5cee676f2ae97795356d9c340d1ec7f to your computer and use it in GitHub Desktop.
A really simple pytorch implementation of focal loss for both sigmoid and softmax predictions.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.nn.functional import log_softmax | |
def sigmoid_focal_loss(logits, target, gamma=2., alpha=0.25): | |
num_classes = logits.shape[1] | |
dtype = target.dtype | |
device = target.device | |
class_range = torch.arange(0, num_classes, dtype=dtype, device=device).unsqueeze(0) | |
t = target.unsqueeze(1) | |
p = torch.sigmoid(logits) | |
term1 = (1 - p) ** gamma * torch.log(p) | |
term2 = p ** gamma * torch.log(1 - p) | |
return torch.mean(-(t == class_range).float() * term1 * alpha - ((t != class_range) * (t >= 0)).float() * term2 * (1 - alpha)) | |
def softmax_focal_loss(x, target, gamma=2., alpha=0.25): | |
n = x.shape[0] | |
device = target.device | |
range_n = torch.arange(0, n, dtype=torch.int64, device=device) | |
pos_num = float(x.shape[1]) | |
p = torch.softmax(x, dim=1) | |
p = p[range_n, target] | |
loss = -(1-p)**gamma*alpha*torch.log(p) | |
return torch.sum(loss) / pos_num |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
You have swapped your alpha and gamma values.