Last active
January 22, 2024 19:20
-
-
Save jihunchoi/f1434a77df9db1bb337417854b398df1 to your computer and use it in GitHub Desktop.
PyTorch workaround for masking cross entropy loss
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def _sequence_mask(sequence_length, max_len=None): | |
if max_len is None: | |
max_len = sequence_length.data.max() | |
batch_size = sequence_length.size(0) | |
seq_range = torch.range(0, max_len - 1).long() | |
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) | |
seq_range_expand = Variable(seq_range_expand) | |
if sequence_length.is_cuda: | |
seq_range_expand = seq_range_expand.cuda() | |
seq_length_expand = (sequence_length.unsqueeze(1) | |
.expand_as(seq_range_expand)) | |
return seq_range_expand < seq_length_expand | |
def compute_loss(logits, target, length): | |
""" | |
Args: | |
logits: A Variable containing a FloatTensor of size | |
(batch, max_len, num_classes) which contains the | |
unnormalized probability for each class. | |
target: A Variable containing a LongTensor of size | |
(batch, max_len) which contains the index of the true | |
class for each corresponding step. | |
length: A Variable containing a LongTensor of size (batch,) | |
which contains the length of each data in a batch. | |
Returns: | |
loss: An average loss value masked by the length. | |
""" | |
# logits_flat: (batch * max_len, num_classes) | |
logits_flat = logits.view(-1, logits.size(-1)) | |
# log_probs_flat: (batch * max_len, num_classes) | |
log_probs_flat = functional.log_softmax(logits_flat) | |
# target_flat: (batch * max_len, 1) | |
target_flat = target.view(-1, 1) | |
# losses_flat: (batch * max_len, 1) | |
losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat) | |
# losses: (batch, max_len) | |
losses = losses_flat.view(*target.size()) | |
# mask: (batch, max_len) | |
mask = _sequence_mask(sequence_length=length, max_len=target.size(1)) | |
losses = losses * mask.float() | |
loss = losses.sum() / length.float().sum() | |
return loss |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
masking in the proposed gist is per-sample, i.e. telling you if each sample should have a loss or not, while your weight here is per-class.