Skip to content

Instantly share code, notes, and snippets.

@MLWhiz
Created September 7, 2020 15:16
Show Gist options
  • Save MLWhiz/86b6eb635f7168fca3e771644109ea7d to your computer and use it in GitHub Desktop.
Save MLWhiz/86b6eb635f7168fca3e771644109ea7d to your computer and use it in GitHub Desktop.
class CustomNLLLoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
# x should be output from LogSoftmax Layer
log_prob = -1.0 * x
# Get log_prob based on y class_index as loss=-mean(ylogp)
loss = log_prob.gather(1, y.unsqueeze(1))
loss = loss.mean()
return loss
criterion = CustomNLLLoss()
loss = criterion(preds,y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment