Created
January 23, 2023 11:10
-
-
Save sachinruk/ea39d11c3e339b9f1ee55980e27cd1f0 to your computer and use it in GitHub Desktop.
Asymmetric loss for mult-label classification
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class AsymmetricLoss(nn.Module): | |
def __init__( | |
self, | |
gamma_neg: float = 4.0, | |
gamma_pos: float = 1.0, | |
clip: float = 0.05, | |
): | |
"""Asymmetric Loss for Multi-label Classification. https://arxiv.org/abs/2009.14119 | |
Loss function where negative classes are weighted less than the positive classes. | |
Note: the inputs are logits and targets, not sigmoids. | |
Usage: | |
inputs = torch.randn(5, 3) | |
targets = torch.randint(0, 1, (5, 3)) # must be binary | |
loss_fn = AsymmetricLoss() | |
loss = loss_fn(inputs, targets) | |
Args: | |
gamma_neg: loss attenuation factor for negative classes | |
gamma_pos: loss attenuation factor for positive classes | |
clip: shifts the negative class probability and zeros loss if probability > clip | |
reduction: how to reduced final loss. Must be one of mean[default], sum, none | |
""" | |
super().__init__() | |
if clip < 0.0 or clip > 1.0: | |
raise ValueError("Clipping value must be non-negative and less than one") | |
if gamma_neg < gamma_pos: | |
raise ValueError( | |
"Need to ensure that loss for hard positive is penalised less than hard negative" | |
) | |
self.gamma_neg = gamma_neg | |
self.gamma_pos = gamma_pos | |
self.clip = clip | |
def _get_binary_cross_entropy_loss_and_pt_with_logits( | |
self, inputs: torch.FloatTensor, targets: torch.LongTensor | |
) -> tuple[torch.FloatTensor, torch.FloatTensor]: | |
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets.float(), reduction="none") | |
pt = torch.exp(-ce_loss) # probability at y_i=1 | |
return ce_loss, pt | |
def forward(self, inputs: torch.FloatTensor, targets: torch.LongTensor) -> torch.FloatTensor: | |
ce_loss, pt = self._get_binary_cross_entropy_loss_and_pt_with_logits(inputs, targets) | |
# shift and clamp (therefore zero gradient) high confidence negative cases | |
pt_neg = (pt + self.clip).clamp(max=1.0) | |
ce_loss_neg = -torch.log(pt_neg) | |
loss_neg = (1 - pt_neg) ** self.gamma_neg * ce_loss_neg | |
loss_pos = (1 - pt) ** self.gamma_pos * ce_loss | |
loss = targets * loss_pos + (1 - targets) * loss_neg | |
return loss.mean() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment