Skip to content

Instantly share code, notes, and snippets.

@bkj
Last active February 14, 2019 17:44
Show Gist options
  • Save bkj/2c4ae03439935e52ad3dca6478a005d4 to your computer and use it in GitHub Desktop.
Save bkj/2c4ae03439935e52ad3dca6478a005d4 to your computer and use it in GitHub Desktop.
formulas for BCE loss in pytorch
import pandas as pd
import torch.nn.functional as F
def sparse_bce_with_logits(x, i, j):
t1 = x.clamp(min=0).mean()
t2 = - x[(i, j)].sum() / x.numel()
t3 = torch.log(1 + torch.exp(-torch.abs(x))).mean()
return t1 + t2 + t3
loss = torch.nn.BCEWithLogitsLoss()
sloss = torch.nn.BCELoss()
all_res = []
for scale in np.arange(0, 100, 2):
x = torch.randn((100, 10)) * scale
sx = torch.sigmoid(x)
y = (torch.rand((100, 10)) < 0.1).float()
i, j = np.where(y.numpy())
i, j = torch.LongTensor(i), torch.LongTensor(j)
bce_logits = loss(x, y)
bce_sigmoid = sloss(sx, y)
bce_sigmoid_manual = - (y * sx.log() + (1 - y) * (1 - sx).log()).mean()
bce_logit_manual = (x.clamp(min=0) - x * y + torch.log(1 + torch.exp(-torch.abs(x)))).mean()
bce_logit_sparse = sparse_bce_with_logits(x, i, j)
res = {
"sigmoid" : (bce_logits - bce_sigmoid).item(),
"sigmoid_manual" : (bce_logits - bce_sigmoid_manual).item(),
"logit_manual" : (bce_logits - bce_logit_manual).item(),
"logit_sparse" : (bce_logits - bce_logit_sparse).item(),
}
all_res.append(res)
pd.DataFrame(all_res)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment