-
-
Save jihunchoi/f1434a77df9db1bb337417854b398df1 to your computer and use it in GitHub Desktop.
| def _sequence_mask(sequence_length, max_len=None): | |
| if max_len is None: | |
| max_len = sequence_length.data.max() | |
| batch_size = sequence_length.size(0) | |
| seq_range = torch.range(0, max_len - 1).long() | |
| seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) | |
| seq_range_expand = Variable(seq_range_expand) | |
| if sequence_length.is_cuda: | |
| seq_range_expand = seq_range_expand.cuda() | |
| seq_length_expand = (sequence_length.unsqueeze(1) | |
| .expand_as(seq_range_expand)) | |
| return seq_range_expand < seq_length_expand | |
| def compute_loss(logits, target, length): | |
| """ | |
| Args: | |
| logits: A Variable containing a FloatTensor of size | |
| (batch, max_len, num_classes) which contains the | |
| unnormalized probability for each class. | |
| target: A Variable containing a LongTensor of size | |
| (batch, max_len) which contains the index of the true | |
| class for each corresponding step. | |
| length: A Variable containing a LongTensor of size (batch,) | |
| which contains the length of each data in a batch. | |
| Returns: | |
| loss: An average loss value masked by the length. | |
| """ | |
| # logits_flat: (batch * max_len, num_classes) | |
| logits_flat = logits.view(-1, logits.size(-1)) | |
| # log_probs_flat: (batch * max_len, num_classes) | |
| log_probs_flat = functional.log_softmax(logits_flat) | |
| # target_flat: (batch * max_len, 1) | |
| target_flat = target.view(-1, 1) | |
| # losses_flat: (batch * max_len, 1) | |
| losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat) | |
| # losses: (batch, max_len) | |
| losses = losses_flat.view(*target.size()) | |
| # mask: (batch, max_len) | |
| mask = _sequence_mask(sequence_length=length, max_len=target.size(1)) | |
| losses = losses * mask.float() | |
| loss = losses.sum() / length.float().sum() | |
| return loss |
I (genuinely) wonder how this is different from using:
weight = torch.ones(vocab_size)
weight[pad_idx] = 0.0
crit = nn.CrossEntropy(weight=weight)
crit(output, targets)
I seem to get the same numbers (assuming that you have padded every sequence with pad_idx up to the maximum sentence length in the batch.
Interesting solution, @emanjavacas!
@emanjavacas, @lzfelix
yeah its a good one, and shouldnt this be even more "correct"?
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)
i hope this was created for masking and no other purpose?
I (genuinely) wonder how this is different from using:
weight = torch.ones(vocab_size) weight[pad_idx] = 0.0 crit = nn.CrossEntropy(weight=weight) crit(output, targets)I seem to get the same numbers (assuming that you have padded every sequence with pad_idx up to the maximum sentence length in the batch.
masking in the proposed gist is per-sample, i.e. telling you if each sample should have a loss or not, while your weight here is per-class.
@viig99
Hi, it seems that
functional.cross_entropystill doesn't support >2D input.I think
log_softmax+gathercalls can be merged into onecross_entropycall withreduce=False, and I expect there might besome performance gain.
I will update this gist soon.