Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save ricardodeazambuja/7b079fc8426d860b73666873e2dafa50 to your computer and use it in GitHub Desktop.
Save ricardodeazambuja/7b079fc8426d860b73666873e2dafa50 to your computer and use it in GitHub Desktop.
Usage Example for PyTorch CrossEntropyLoss and BCELoss (Binary Cross Entropy Loss) with Images (e.g. semantic segmentation)
# https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
# https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html
import torch
from torch.nn import CrossEntropyLoss, BCELoss
# For this example:
# batch size = 1
# number of classes = 2
# size of tensors = 2x2
# outputs are any float values (incluing negative ones!)
output_class_0 = [[.0,1.],
[.0,1.]]
output_class_1 = [[1.,0.],
[1.,0.]]
outputs = torch.Tensor([[output_class_0,output_class_1]]) # [batch, num_classes, H, W]
# targets are class indices (therefore only integers!)
# and 255 is the "ignore" class
target = [[1,255],
[1,0]]
targets = torch.Tensor([target]).long() # [batch, H, W] (must be long for indices!)
# ignore_index is only valid when using targets as class indices!
criterion = CrossEntropyLoss(ignore_index=255)
print("Using indices as targets (and ignore_index=255):")
print(f"outputs.shape: {outputs.shape}, targets.shape:{targets.shape}")
print(f"Loss: {criterion(outputs, targets)}")
# targets are class indices (therefore only integers!)
target = [[1,0],
[1,0]]
targets = torch.Tensor([target]).long() # [batch, H, W] (must be long for indices!)
# ignore_index is only valid when using targets as class indices!
criterion = CrossEntropyLoss(ignore_index=255)
print("Using indices as targets:")
print(f"outputs.shape: {outputs.shape}, targets.shape:{targets.shape}")
print(f"Loss: {criterion(outputs, targets)}")
# or something like intensities for each class (doesn't need to add to 1)
# (it accepts the same values as the outputs, including negative ones!)
target_class_0 = [[.0,1.],
[.0,1.]]
target_class_1 = [[1.,0.],
[1.,0.]]
targets = torch.Tensor([[target_class_0,target_class_1]]) # [batch, num_classes, H, W]
criterion = CrossEntropyLoss()
print("Using probabilities as targets:")
print(f"outputs.shape: {outputs.shape}, targets.shape:{targets.shape}")
print(f"Loss: {criterion(outputs, targets)}")
# Now using Binary Cross Entropy Loss (values MUST be in the interval [0,1])
criterion = BCELoss()
print("Using probabilities as targets with BCELoss:")
print(f"outputs.shape: {outputs.shape}, targets.shape:{targets.shape}")
print(f"Loss: {criterion(outputs, targets)}")
# And now it got totally wrong
target_class_0 = [[1.,0.],
[1.,0.]]
target_class_1 = [[0.,1.],
[0.,1.]]
targets = torch.Tensor([[target_class_0,target_class_1]]) # [batch, num_classes, H, W]
criterion = CrossEntropyLoss()
print("Using probabilities as targets (totally wrong outputs, so loss must be bigger than with correct ones):")
print(f"outputs.shape: {outputs.shape}, targets.shape:{targets.shape}")
print(f"Loss: {criterion(outputs, targets)}")
# Now using Binary Cross Entropy Loss (values MUST be in the interval [0,1])
criterion = BCELoss()
print("Using probabilities as targets with BCELoss (totally wrong outputs, so loss must be bigger than with correct ones):")
print(f"outputs.shape: {outputs.shape}, targets.shape:{targets.shape}")
print(f"Loss: {criterion(outputs, targets)}")
# And now it got half of the outputs wrong
target_class_0 = [[1.,0.],
[1.,0.]]
target_class_1 = [[1.,0.],
[1.,0.]]
targets = torch.Tensor([[target_class_0,target_class_1]]) # [batch, num_classes, H, W]
criterion = CrossEntropyLoss()
print("Using probabilities as targets (half of the outputs wrong, so loss must be bigger than with correct ones):")
print(f"outputs.shape: {outputs.shape}, targets.shape:{targets.shape}")
print(f"Loss: {criterion(outputs, targets)}")
# Now using Binary Cross Entropy Loss (values MUST be in the interval [0,1])
criterion = BCELoss()
print("Using probabilities as targets with BCELoss (half of the outputs wrong, so loss must be bigger than with correct ones):")
print(f"outputs.shape: {outputs.shape}, targets.shape:{targets.shape}")
print(f"Loss: {criterion(outputs, targets)}")
# For this example:
# batch size = 1
# number of classes = 1
# size of tensors = 2x2
# outputs got it totally wrong in the next examples,
# therefore we expect the loss to be much bigger than 0...
output_class_0 = [[.0,1],
[.0,1]]
target_class_0 = [[1,.0],
[1,.0]]
outputs = torch.Tensor([[output_class_0]]) # [batch, num_classes, H, W]
targets = torch.Tensor([[target_class_0]]) # [batch, num_classes, H, W]
criterion = CrossEntropyLoss()
print("Using probabilities as targets (wrong outputs, so loss must be bigger than with correct ones), but num_classes==1:")
print(f"outputs.shape: {outputs.shape}, targets.shape:{targets.shape}")
print(f"Loss: {criterion(outputs, targets)}")
output_class_0 = [[.0,1],
[.0,1]]
target_class_0 = [[0,255],
[0,255]]
outputs = torch.Tensor([[output_class_0]]) # [batch, num_classes, H, W]
targets = torch.Tensor([target_class_0]).long() # [batch, H, W]
criterion = CrossEntropyLoss(ignore_index=255)
print("Using indices as targets (wrong outputs, so loss must be bigger than with correct ones), but num_classes==1:")
print(f"outputs.shape: {outputs.shape}, targets.shape:{targets.shape}")
print(f"Loss: {criterion(outputs, targets)}")
# Now using Binary Cross Entropy Loss (values MUST be in the interval [0,1])
output_class_0 = [[.0,1],
[.0,1]]
target_class_0 = [[1,.0],
[1,.0]]
outputs = torch.Tensor([[output_class_0]]) # [batch, num_classes, H, W]
targets = torch.Tensor([[target_class_0]]) # [batch, num_classes, H, W]
criterion = BCELoss()
print("Using probabilities as targets (wrong outputs, so loss must be bigger than with correct ones), num_classes==1, but with BCELoss:")
print(f"outputs.shape: {outputs.shape}, targets.shape:{targets.shape}")
print(f"Loss: {criterion(outputs, targets)}")
#######################################
# Expected Outputs: #
# torch.__version__ == '1.13.1+cu117' #
#######################################
# Using indices as targets (and ignore_index=255):
# outputs.shape: torch.Size([1, 2, 2, 2]), targets.shape:torch.Size([1, 2, 2])
# Loss: 0.31326165795326233
# Using indices as targets:
# outputs.shape: torch.Size([1, 2, 2, 2]), targets.shape:torch.Size([1, 2, 2])
# Loss: 0.31326165795326233
# Using probabilities as targets:
# outputs.shape: torch.Size([1, 2, 2, 2]), targets.shape:torch.Size([1, 2, 2, 2])
# Loss: 0.31326165795326233
# Using probabilities as targets with BCELoss:
# outputs.shape: torch.Size([1, 2, 2, 2]), targets.shape:torch.Size([1, 2, 2, 2])
# Loss: 0.0
# Using probabilities as targets (totally wrong outputs, so loss must be bigger than with correct ones):
# outputs.shape: torch.Size([1, 2, 2, 2]), targets.shape:torch.Size([1, 2, 2, 2])
# Loss: 1.31326162815094
# Using probabilities as targets with BCELoss (totally wrong outputs, so loss must be bigger than with correct ones):
# outputs.shape: torch.Size([1, 2, 2, 2]), targets.shape:torch.Size([1, 2, 2, 2])
# Loss: 100.0
# Using probabilities as targets (half of the outputs wrong, so loss must be bigger than with correct ones):
# outputs.shape: torch.Size([1, 2, 2, 2]), targets.shape:torch.Size([1, 2, 2, 2])
# Loss: 0.8132616877555847
# Using probabilities as targets with BCELoss (half of the outputs wrong, so loss must be bigger than with correct ones):
# outputs.shape: torch.Size([1, 2, 2, 2]), targets.shape:torch.Size([1, 2, 2, 2])
# Loss: 50.0
# Using probabilities as targets (wrong outputs, so loss must be bigger than with correct ones), but num_classes==1:
# outputs.shape: torch.Size([1, 1, 2, 2]), targets.shape:torch.Size([1, 1, 2, 2])
# Loss: -0.0
# Using indices as targets (wrong outputs, so loss must be bigger than with correct ones), but num_classes==1:
# outputs.shape: torch.Size([1, 1, 2, 2]), targets.shape:torch.Size([1, 2, 2])
# Loss: 0.0
# Using probabilities as targets (wrong outputs, so loss must be bigger than with correct ones), num_classes==1, but with BCELoss:
# outputs.shape: torch.Size([1, 1, 2, 2]), targets.shape:torch.Size([1, 1, 2, 2])
# Loss: 100.0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment