Created
September 10, 2019 18:38
-
-
Save epwalsh/fc5c6e5c0c7102141fe83ad243a5b69c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class CopyNetSeq2Seq(Model): | |
# snip... | |
def _get_ll_contrib(self, | |
generation_scores: torch.Tensor, | |
generation_scores_mask: torch.Tensor, | |
copy_scores: torch.Tensor, | |
target_tokens: torch.Tensor, | |
target_to_source: torch.Tensor, | |
copy_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Get the log-likelihood contribution from a single timestep. | |
Parameters | |
---------- | |
generation_scores : ``torch.Tensor`` | |
Shape: `(batch_size, target_vocab_size)` | |
generation_scores_mask : ``torch.Tensor`` | |
Shape: `(batch_size, target_vocab_size)`. This is just a tensor of 1's. | |
copy_scores : ``torch.Tensor`` | |
Shape: `(batch_size, trimmed_source_length)` | |
target_tokens : ``torch.Tensor`` | |
Shape: `(batch_size,)` | |
target_to_source : ``torch.Tensor`` | |
Shape: `(batch_size, trimmed_source_length)` | |
copy_mask : ``torch.Tensor`` | |
Shape: `(batch_size, trimmed_source_length)` | |
Returns | |
------- | |
Tuple[torch.Tensor, torch.Tensor] | |
Shape: `(batch_size,), (batch_size, max_input_sequence_length)` | |
""" | |
_, target_size = generation_scores.size() | |
# The point of this mask is to just mask out all source token scores | |
# that just represent padding. We apply the mask to the concatenation | |
# of the generation scores and the copy scores to normalize the scores | |
# correctly during the softmax. | |
# shape: (batch_size, target_vocab_size + trimmed_source_length) | |
mask = torch.cat((generation_scores_mask, copy_mask), dim=-1) | |
# shape: (batch_size, target_vocab_size + trimmed_source_length) | |
all_scores = torch.cat((generation_scores, copy_scores), dim=-1) | |
# Normalize generation and copy scores. | |
# shape: (batch_size, target_vocab_size + trimmed_source_length) | |
log_probs = util.masked_log_softmax(all_scores, mask) | |
# Calculate the log probability (`copy_log_probs`) for each token in the source sentence | |
# that matches the current target token. We use the sum of these copy probabilities | |
# for matching tokens in the source sentence to get the total probability | |
# for the target token. We also need to normalize the individual copy probabilities | |
# to create `selective_weights`, which are used in the next timestep to create | |
# a selective read state. | |
# shape: (batch_size, trimmed_source_length) | |
copy_log_probs = log_probs[:, target_size:] + (target_to_source.float() + 1e-45).log() | |
# Since `log_probs[:, target_size]` gives us the raw copy log probabilities, | |
# we use a non-log softmax to get the normalized non-log copy probabilities. | |
selective_weights = util.masked_softmax(log_probs[:, target_size:], target_to_source) | |
# This mask ensures that item in the batch has a non-zero generation probabilities | |
# for this timestep only when the gold target token is not OOV or there are no | |
# matching tokens in the source sentence. | |
# shape: (batch_size, 1) | |
gen_mask = ((target_tokens != self._oov_index) | (target_to_source.sum(-1) == 0)).float() | |
log_gen_mask = (gen_mask + 1e-45).log().unsqueeze(-1) | |
# Now we get the generation score for the gold target token. | |
# shape: (batch_size, 1) | |
generation_log_probs = log_probs.gather(1, target_tokens.unsqueeze(1)) + log_gen_mask | |
# ... and add the copy score to get the step log likelihood. | |
# shape: (batch_size, 1 + trimmed_source_length) | |
combined_gen_and_copy = torch.cat((generation_log_probs, copy_log_probs), dim=-1) | |
# shape: (batch_size,) | |
step_log_likelihood = util.logsumexp(combined_gen_and_copy) | |
return step_log_likelihood, selective_weights |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment