Skip to content

Instantly share code, notes, and snippets.

@ptrblck
Last active February 9, 2018 15:34
Show Gist options
  • Save ptrblck/af4153bb9e8c100477d21ff2800ec5e4 to your computer and use it in GitHub Desktop.
Save ptrblck/af4153bb9e8c100477d21ff2800ec5e4 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
from torch._C import _infer_size
# Setup
x = Variable(torch.randn(4, 2, 2))
y = Variable(torch.Tensor(4, 2, 2).random_(2))
output = F.sigmoid(x)
# Create weight according to doc in BCELoss
weight = torch.randn(4)
criterion_weighted = nn.BCELoss(weight=weight)
loss_weighted = criterion_weighted(output, y) # Error!
# Unsqueeze weight tensor
weight = torch.randn(4, 1, 1)
criterion_weighted = nn.BCELoss(weight=weight)
loss_weighted = criterion_weighted(output, y)
# Create class weights
weight = torch.FloatTensor([0.1, 0.9])
# Internally, weight is expanded as
size = _infer_size(weight.size(), y.size())
weight_expanded = weight.expand(size) # This is not, what we wanted as class weights!
criterion_weighted = nn.BCELoss(weight=weight)
loss_weighted = criterion_weighted(output, y)
criterion_nonreduced = nn.BCELoss(reduce=False)
loss_unreduced = criterion_nonreduced(output, y)
loss_weighted_manual = (Variable(weight_expanded) * loss_unreduced).mean()
if loss_weighted == loss_weighted_manual:
print('Class weighting failed')
# Let's use weight as class weights
weight_ = weight[y.data.view(-1).long()].view_as(y)
criterion = nn.BCELoss(reduce=False)
loss = criterion(output, y)
loss_class_weighted = loss * Variable(weight_)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment