Skip to content

Instantly share code, notes, and snippets.

Created January 23, 2023 11:10
Show Gist options
  • Save sachinruk/ea39d11c3e339b9f1ee55980e27cd1f0 to your computer and use it in GitHub Desktop.
Save sachinruk/ea39d11c3e339b9f1ee55980e27cd1f0 to your computer and use it in GitHub Desktop.
Asymmetric loss for mult-label classification
import torch
import torch.nn as nn
import torch.nn.functional as F
class AsymmetricLoss(nn.Module):
def __init__(
gamma_neg: float = 4.0,
gamma_pos: float = 1.0,
clip: float = 0.05,
"""Asymmetric Loss for Multi-label Classification.
Loss function where negative classes are weighted less than the positive classes.
Note: the inputs are logits and targets, not sigmoids.
inputs = torch.randn(5, 3)
targets = torch.randint(0, 1, (5, 3)) # must be binary
loss_fn = AsymmetricLoss()
loss = loss_fn(inputs, targets)
gamma_neg: loss attenuation factor for negative classes
gamma_pos: loss attenuation factor for positive classes
clip: shifts the negative class probability and zeros loss if probability > clip
reduction: how to reduced final loss. Must be one of mean[default], sum, none
if clip < 0.0 or clip > 1.0:
raise ValueError("Clipping value must be non-negative and less than one")
if gamma_neg < gamma_pos:
raise ValueError(
"Need to ensure that loss for hard positive is penalised less than hard negative"
self.gamma_neg = gamma_neg
self.gamma_pos = gamma_pos
self.clip = clip
def _get_binary_cross_entropy_loss_and_pt_with_logits(
self, inputs: torch.FloatTensor, targets: torch.LongTensor
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets.float(), reduction="none")
pt = torch.exp(-ce_loss) # probability at y_i=1
return ce_loss, pt
def forward(self, inputs: torch.FloatTensor, targets: torch.LongTensor) -> torch.FloatTensor:
ce_loss, pt = self._get_binary_cross_entropy_loss_and_pt_with_logits(inputs, targets)
# shift and clamp (therefore zero gradient) high confidence negative cases
pt_neg = (pt + self.clip).clamp(max=1.0)
ce_loss_neg = -torch.log(pt_neg)
loss_neg = (1 - pt_neg) ** self.gamma_neg * ce_loss_neg
loss_pos = (1 - pt) ** self.gamma_pos * ce_loss
loss = targets * loss_pos + (1 - targets) * loss_neg
return loss.mean()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment