-
-
Save igormq/000add00702f09029ea4c30eba976e0a to your computer and use it in GitHub Desktop.
Tensorflow Beam Search
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
def beam_decoder(decoder_inputs, initial_state, cell, loop_function, scope=None, | |
beam_size=7, done_token=0 | |
): | |
""" | |
Beam search decoder | |
Args: | |
decoder_inputs: A list of 2D Tensors [batch_size x input_size]. | |
initial_state: 2D Tensor with shape [batch_size x cell.state_size]. | |
cell: rnn_cell.RNNCell defining the cell function and size. | |
loop_function: This function will be applied to the i-th output | |
in order to generate the i+1-st input, and decoder_inputs will be ignored, | |
except for the first element ("GO" symbol). | |
Signature -- loop_function(prev_symbol, i) = next | |
* prev_symbol is a 1D Tensor of shape [batch_size*beam_size] | |
* i is an integer, the step number (when advanced control is needed), | |
* next is a 2D Tensor of shape [batch_size*beam_size, input_size]. | |
scope: Passed to seq2seq.rnn_decoder | |
beam_size: An integer beam size to use for each example | |
done_token: An integer token that specifies the STOP symbol | |
Return: | |
A tensor of dimensions [batch_size, len(decoder_inputs)] that corresponds to | |
the 1-best beam for each batch. | |
Known limitations: | |
* The output sequence consisting of only a STOP symbol is not considered | |
(zero-length sequences are not very useful, so this wasn't implemented) | |
* The computation graph this creates is messy and not very well-optimized | |
""" | |
decoder = BeamDecoder(decoder_inputs, initial_state, beam_size=beam_size, done_token=done_token) | |
_ = seq2seq.rnn_decoder( | |
decoder.decoder_inputs, | |
decoder.initial_state, | |
cell=cell, | |
loop_function = lambda prev, i: loop_function(decoder.take_step(prev, i), i), | |
scope=scope | |
) | |
return decoder.finished_beams | |
class BeamDecoder(): | |
""" | |
Main class for implementing beam decoder. | |
""" | |
def __init__(self, decoder_inputs, initial_state, beam_size=7, done_token=0, | |
batch_size=None, num_classes=None): | |
self.beam_size = beam_size | |
self.batch_size = batch_size | |
if self.batch_size is None: | |
self.batch_size = tf.shape(decoder_inputs[0])[0] | |
self.max_len = len(decoder_inputs) | |
self.num_classes = num_classes | |
self.done_token = done_token | |
self.past_logprobs = None | |
self.past_symbols = None | |
self.finished_beams = tf.zeros((self.batch_size, self.max_len), dtype=tf.int32) | |
self.logprobs_finished_beams = tf.ones((self.batch_size,), dtype=tf.float32) * -float('inf') | |
self.decoder_inputs = [None] * len(decoder_inputs) | |
self.decoder_inputs[0] = self.tile_along_beam(initial_input) | |
# Convert the state input to the decoder | |
if isinstance(initial_state, tf.nn.rnn_cell.LSTMStateTuple): | |
self.initial_state = tf.nn.rnn_cell.LSTMStateTuple( | |
c=self.tile_along_beam(initial_state.c), | |
h=self.tile_along_beam(initial_state.h) | |
) | |
else: | |
self.initial_state = self.tile_along_beam(initial_state) | |
def tile_along_beam(self, tensor): | |
""" | |
Helps tile tensors for each beam. | |
Args: | |
tensor: a 2-D tensor, [batch_size x T] | |
Return: | |
An [batch_size*beam_size x T] tensor, where each row of the input | |
tensor is copied beam_size times in a row in the output | |
""" | |
res = tf.expand_dims(tensor, 1) | |
res = tf.tile(res, [1, self.beam_size, 1]) | |
res = tf.reshape(res, [-1, tf.shape(tensor)[1]]) | |
try: | |
new_first_dim = tensor.get_shape()[0] * self.beam_size | |
except: | |
new_first_dim = None | |
res.set_shape((new_first_dim, tensor.get_shape()[1])) | |
return res | |
def take_step(self, prev, i): | |
logprobs = tf.nn.log_softmax(prev) | |
if self.num_classes is None: | |
self.num_classes = tf.shape(logprobs)[1] | |
logprobs_batched = tf.reshape(logprobs, [-1, self.beam_size, self.num_classes]) | |
logprobs_batched.set_shape((None, self.beam_size, None)) | |
# Note: masking out entries to -inf plays poorly with top_k, so just subtract out | |
# a large number. | |
nondone_mask = tf.reshape( | |
tf.cast(tf.equal(tf.range(self.num_classes), self.done_token), tf.float32) * -1e18, | |
[1, 1, self.num_classes] | |
) | |
if self.past_logprobs is not None: | |
logprobs_batched = logprobs_batched + tf.expand_dims(self.past_logprobs, 2) | |
self.past_logprobs, indices = tf.nn.top_k( | |
tf.reshape(logprobs_batched + nondone_mask, [-1, self.beam_size * self.num_classes]), | |
self.beam_size | |
) | |
else: | |
self.past_logprobs, indices = tf.nn.top_k( | |
(logprobs_batched + nondone_mask)[:,0,:], | |
self.beam_size | |
) | |
# For continuing to the next symbols | |
symbols = indices % self.num_classes | |
parent_refs = indices // self.num_classes | |
if self.past_symbols is not None: | |
parent_refs_offsets = tf.reshape( | |
(tf.range(self.batch_size * self.beam_size) // self.beam_size) * self.beam_size, | |
[self.batch_size, self.beam_size] | |
) | |
past_symbols_batch_major = tf.reshape(self.past_symbols, [-1, i-1]) | |
beam_past_symbols = tf.gather(past_symbols_batch_major, #batch-major | |
parent_refs + parent_refs_offsets) | |
self.past_symbols = tf.concat(2, [beam_past_symbols, tf.expand_dims(symbols, 2)]) | |
# For finishing the beam here | |
logprobs_done = logprobs_batched[:,:,self.done_token] | |
done_parent_refs = tf.cast(tf.argmax(logprobs_done, 1), tf.int32) | |
done_parent_refs_offsets = tf.range(self.batch_size) * self.beam_size | |
done_past_symbols = tf.gather(past_symbols_batch_major, | |
done_parent_refs + done_parent_refs_offsets | |
) | |
symbols_done = tf.concat(1, [done_past_symbols, | |
tf.ones_like(done_past_symbols[:,0:1]) * self.done_token, | |
tf.tile(tf.zeros_like(done_past_symbols[:,0:1]), | |
[1, self.max_len - i]) | |
]) | |
logprobs_done_max = tf.reduce_max(logprobs_done, 1) | |
self.finished_beams = tf.select(logprobs_done_max > self.logprobs_finished_beams, | |
symbols_done, | |
self.finished_beams) | |
self.logprobs_finished_beams = tf.maximum(logprobs_done_max, self.logprobs_finished_beams) | |
else: | |
self.past_symbols = tf.expand_dims(symbols, 2) | |
# NOTE: outputing a zero-length sequence is not supported for simplicity reasons | |
symbols_flat = tf.reshape(symbols, [-1]) | |
return symbols_flat |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment