-
-
Save jihunchoi/f1434a77df9db1bb337417854b398df1 to your computer and use it in GitHub Desktop.
def _sequence_mask(sequence_length, max_len=None): | |
if max_len is None: | |
max_len = sequence_length.data.max() | |
batch_size = sequence_length.size(0) | |
seq_range = torch.range(0, max_len - 1).long() | |
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) | |
seq_range_expand = Variable(seq_range_expand) | |
if sequence_length.is_cuda: | |
seq_range_expand = seq_range_expand.cuda() | |
seq_length_expand = (sequence_length.unsqueeze(1) | |
.expand_as(seq_range_expand)) | |
return seq_range_expand < seq_length_expand | |
def compute_loss(logits, target, length): | |
""" | |
Args: | |
logits: A Variable containing a FloatTensor of size | |
(batch, max_len, num_classes) which contains the | |
unnormalized probability for each class. | |
target: A Variable containing a LongTensor of size | |
(batch, max_len) which contains the index of the true | |
class for each corresponding step. | |
length: A Variable containing a LongTensor of size (batch,) | |
which contains the length of each data in a batch. | |
Returns: | |
loss: An average loss value masked by the length. | |
""" | |
# logits_flat: (batch * max_len, num_classes) | |
logits_flat = logits.view(-1, logits.size(-1)) | |
# log_probs_flat: (batch * max_len, num_classes) | |
log_probs_flat = functional.log_softmax(logits_flat) | |
# target_flat: (batch * max_len, 1) | |
target_flat = target.view(-1, 1) | |
# losses_flat: (batch * max_len, 1) | |
losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat) | |
# losses: (batch, max_len) | |
losses = losses_flat.view(*target.size()) | |
# mask: (batch, max_len) | |
mask = _sequence_mask(sequence_length=length, max_len=target.size(1)) | |
losses = losses * mask.float() | |
loss = losses.sum() / length.float().sum() | |
return loss |
Great way thanks 👍
Hi, with the introduction of the reduce=False variable, what changes need to be done to the masked cross entropy to simplify it ?
@viig99
Hi, it seems that functional.cross_entropy
still doesn't support >2D input.
I think log_softmax
+ gather
calls can be merged into one cross_entropy
call with reduce=False
, and I expect there might be
some performance gain.
I will update this gist soon.
I (genuinely) wonder how this is different from using:
weight = torch.ones(vocab_size)
weight[pad_idx] = 0.0
crit = nn.CrossEntropy(weight=weight)
crit(output, targets)
I seem to get the same numbers (assuming that you have padded every sequence with pad_idx up to the maximum sentence length in the batch.
Interesting solution, @emanjavacas!
@emanjavacas, @lzfelix
yeah its a good one, and shouldnt this be even more "correct"?
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)
i hope this was created for masking and no other purpose?
I (genuinely) wonder how this is different from using:
weight = torch.ones(vocab_size) weight[pad_idx] = 0.0 crit = nn.CrossEntropy(weight=weight) crit(output, targets)
I seem to get the same numbers (assuming that you have padded every sequence with pad_idx up to the maximum sentence length in the batch.
masking in the proposed gist is per-sample, i.e. telling you if each sample should have a loss or not, while your weight here is per-class.
Does Variable length has the size:BxHxW for an image in a batch. B- Batch size, Height - H , W- Width
Thanks,