Skip to content

Instantly share code, notes, and snippets.

View cattaneod's full-sized avatar

Daniele Cattaneo cattaneod

View GitHub Profile
@tejaskhot
tejaskhot / softmax_cross_entropy_with_logits.py
Created July 14, 2018 22:51
Pytorch softmax cross entropy with logits
# pytorch function to replicate tensorflow's tf.nn.softmax_cross_entropy_with_logits
# works for soft targets or one-hot encodings
import torch
import torch.nn.functional as F
logits = model(input)
loss = torch.sum(- target * F.log_softmax(logits, -1), -1)
mean_loss = loss.mean()