Created
March 11, 2019 06:24
-
-
Save astariul/44e5d3eeac577815b0cf6f3f7ecdc61c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def _beam_search(self, batch, beam_width, max_len): | |
""" Beam search for predicting a sentence.""" | |
batch_size = batch['input_t'].size(1) | |
with torch.no_grad(): | |
encoder_hidden, encoder_final = self.model.encode( | |
batch['input_t'].transpose(0, 1), | |
batch['input_mask'].transpose(0, 1), | |
batch['input_len']) | |
prev_y = torch.ones(batch_size, 1).fill_(START_TOKEN_ID).type_as( | |
batch['input_t'].transpose(0, 1)) | |
trg_mask = torch.ones_like(prev_y) | |
candidate = { | |
'prev_y': prev_y, | |
'output': [prev_y], | |
'attention': [], | |
'hidden': None, | |
'score': torch.tensor([1.0] * batch_size, device=DEVICE) | |
} | |
candidates = [candidate] # Start beam search with only 1 candidate | |
for i in range(max_len): | |
next_candidates = [] | |
for candidate in candidates: | |
with torch.no_grad(): | |
out, hidden, pre_output = self.model.decode( | |
encoder_hidden, | |
encoder_final, | |
batch['input_mask'].transpose(0, 1), | |
candidate['prev_y'], | |
trg_mask, | |
candidate['hidden']) | |
# we predict from the pre-output layer, which is | |
# a combination of Decoder state, prev emb, and context | |
prob = self.model.generator(pre_output[:, -1]) | |
topb, next_word = prob.topk(beam_width) # [batch_size, beam_width] | |
candidate['attention'].append(self.model.decoder.attention.alphas.cpu().numpy()) | |
for i in range(beam_width): | |
next_word_i = next_word[:, i].unsqueeze(-1) | |
topb_i = topb[:, i] | |
next_candidates.append({ | |
'prev_y': next_word_i, | |
'output': candidate['output'] + [next_word_i], | |
'attention': candidate['attention'], | |
'hidden': hidden, | |
'score': candidate['score'] * topb_i | |
}) | |
# For each current candidates, we add beam_width candidates | |
# So we have beam_width candidates after the first iter | |
# and beam_width^2 for every following iter | |
# Sort all candidates based on score | |
next_candidates.sort(key=lambda k: k['score'].sum(), reverse=True) | |
# Take only the beam_width best | |
candidates = next_candidates[:beam_width] | |
# Take the output / attention of the best candidate | |
output = candidates[0]['output'] | |
attention_scores = candidates[0]['attention'] | |
# Reorganize output | |
output = torch.cat(output, dim=-1).tolist() | |
# output = np.array(output) | |
first_stop = np.where(output==STOP_TOKEN_ID)[0] | |
if len(first_stop) > 0: | |
output = output[:first_stop[0]] | |
return output, np.concatenate(attention_scores, axis=1) | |
outputs, _ = self._beam_search(batch, self.bsw, max_len) | |
# Batch is : | |
# { | |
# 'input_t': batched_article, #[padded_seq_len, batch_size] | |
# 'target_t': batched_abstract, | |
# 'input_mask': article_mask, | |
# 'target_mask': abstract_mask, | |
# 'input_len': article_len, | |
# 'target_len': abstract_len | |
# } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment