Skip to content

Instantly share code, notes, and snippets.

@suvojit-0x55aa
Created September 30, 2019 07:55
Show Gist options
  • Save suvojit-0x55aa/0afb3eefbb26d33f54e1fb9f94d6b609 to your computer and use it in GitHub Desktop.
Save suvojit-0x55aa/0afb3eefbb26d33f54e1fb9f94d6b609 to your computer and use it in GitHub Desktop.
Label Smoothing in Pytorch
import torch
import torch.nn as nn
class LabelSmoothing(nn.Module):
"""
NLL loss with label smoothing.
"""
def __init__(self, smoothing=0.0):
"""
Constructor for the LabelSmoothing module.
:param smoothing: label smoothing factor
"""
super(LabelSmoothing, self).__init__()
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
def forward(self, x, target):
logprobs = torch.nn.functional.log_softmax(x, dim=-1)
nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
nll_loss = nll_loss.squeeze(1)
smooth_loss = -logprobs.mean(dim=-1)
loss = self.confidence * nll_loss + self.smoothing * smooth_loss
return loss.mean()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment