-
-
Save nikitakit/6ab61a73b86c50ad88d409bac3c3d09f to your computer and use it in GitHub Desktop.
""" | |
Beam decoder for tensorflow | |
Sample usage: | |
``` | |
from tf_beam_decoder import beam_decoder | |
decoded_sparse, decoded_logprobs = beam_decoder( | |
cell=cell, | |
beam_size=7, | |
stop_token=2, | |
initial_state=initial_state, | |
initial_input=initial_input, | |
tokens_to_inputs_fn=lambda tokens: tf.nn.embedding_lookup(my_embedding, tokens), | |
) | |
``` | |
See the `beam_decoder` function for complete documentation. (Only the | |
`beam_decoder` function is part of the public API here.) | |
""" | |
import tensorflow as tf | |
import numpy as np | |
from tensorflow.python.util import nest | |
# %% | |
def nest_map(func, nested): | |
if not nest.is_sequence(nested): | |
return func(nested) | |
flat = nest.flatten(nested) | |
return nest.pack_sequence_as(nested, list(map(func, flat))) | |
# %% | |
def sparse_boolean_mask(tensor, mask): | |
""" | |
Creates a sparse tensor from masked elements of `tensor` | |
Inputs: | |
tensor: a 2-D tensor, [batch_size, T] | |
mask: a 2-D mask, [batch_size, T] | |
Output: a 2-D sparse tensor | |
""" | |
mask_lens = tf.reduce_sum(tf.cast(mask, tf.int32), -1, keep_dims=True) | |
mask_shape = tf.shape(mask) | |
left_shifted_mask = tf.tile( | |
tf.expand_dims(tf.range(mask_shape[1]), 0), | |
[mask_shape[0], 1] | |
) < mask_lens | |
return tf.SparseTensor( | |
indices=tf.where(left_shifted_mask), | |
values=tf.boolean_mask(tensor, mask), | |
shape=tf.cast(tf.pack([mask_shape[0], tf.reduce_max(mask_lens)]), tf.int64) # For 2D only | |
) | |
# %% | |
def flat_batch_gather(flat_params, indices, validate_indices=None, | |
batch_size=None, | |
options_size=None): | |
""" | |
Gather slices from `flat_params` according to `indices`, separately for each | |
example in a batch. | |
output[(b * indices_size + i), :, ..., :] = flat_params[(b * options_size + indices[b, i]), :, ..., :] | |
The arguments `batch_size` and `options_size`, if provided, are used instead | |
of looking up the shape from the inputs. This may help avoid redundant | |
computation (TODO: figure out if tensorflow's optimizer can do this automatically) | |
Args: | |
flat_params: A `Tensor`, [batch_size * options_size, ...] | |
indices: A `Tensor`, [batch_size, indices_size] | |
validate_indices: An optional `bool`. Defaults to `True` | |
batch_size: (optional) an integer or scalar tensor representing the batch size | |
options_size: (optional) an integer or scalar Tensor representing the number of options to choose from | |
""" | |
if batch_size is None: | |
batch_size = indices.get_shape()[0].value | |
if batch_size is None: | |
batch_size = tf.shape(indices)[0] | |
if options_size is None: | |
options_size = flat_params.get_shape()[0].value | |
if options_size is None: | |
options_size = tf.shape(flat_params)[0] // batch_size | |
else: | |
options_size = options_size // batch_size | |
indices_offsets = tf.reshape(tf.range(batch_size) * options_size, [-1] + [1] * (len(indices.get_shape())-1)) | |
indices_into_flat = indices + tf.cast(indices_offsets, indices.dtype) | |
flat_indices_into_flat = tf.reshape(indices_into_flat, [-1]) | |
return tf.gather(flat_params, flat_indices_into_flat, validate_indices=validate_indices) | |
def batch_gather(params, indices, validate_indices=None, | |
batch_size=None, | |
options_size=None): | |
""" | |
Gather slices from `params` according to `indices`, separately for each | |
example in a batch. | |
output[b, i, ..., j, :, ..., :] = params[b, indices[b, i, ..., j], :, ..., :] | |
The arguments `batch_size` and `options_size`, if provided, are used instead | |
of looking up the shape from the inputs. This may help avoid redundant | |
computation (TODO: figure out if tensorflow's optimizer can do this automatically) | |
Args: | |
params: A `Tensor`, [batch_size, options_size, ...] | |
indices: A `Tensor`, [batch_size, ...] | |
validate_indices: An optional `bool`. Defaults to `True` | |
batch_size: (optional) an integer or scalar tensor representing the batch size | |
options_size: (optional) an integer or scalar Tensor representing the number of options to choose from | |
""" | |
if batch_size is None: | |
batch_size = params.get_shape()[0].merge_with(indices.get_shape()[0]).value | |
if batch_size is None: | |
batch_size = tf.shape(indices)[0] | |
if options_size is None: | |
options_size = params.get_shape()[1].value | |
if options_size is None: | |
options_size = tf.shape(params)[1] | |
batch_size_times_options_size = batch_size * options_size | |
# TODO(nikita): consider using gather_nd. However as of 1/9/2017 gather_nd | |
# has no gradients implemented. | |
flat_params = tf.reshape(params, tf.concat(0,[[batch_size_times_options_size], tf.shape(params)[2:]])) | |
indices_offsets = tf.reshape(tf.range(batch_size) * options_size, [-1] + [1] * (len(indices.get_shape())-1)) | |
indices_into_flat = indices + tf.cast(indices_offsets, indices.dtype) | |
return tf.gather(flat_params, indices_into_flat, validate_indices=validate_indices) | |
# %% | |
class BeamFlattenWrapper(tf.nn.rnn_cell.RNNCell): | |
def __init__(self, cell, beam_size): | |
self.cell = cell | |
self.beam_size = beam_size | |
def merge_batch_beam(self, tensor): | |
remaining_shape = tf.shape(tensor)[2:] | |
res = tf.reshape(tensor, tf.concat(0, [[-1], remaining_shape])) | |
res.set_shape(tf.TensorShape((None,)).concatenate(tensor.get_shape()[2:])) | |
return res | |
def unmerge_batch_beam(self, tensor): | |
remaining_shape = tf.shape(tensor)[1:] | |
res = tf.reshape(tensor, tf.concat(0, [[-1, self.beam_size], remaining_shape])) | |
res.set_shape(tf.TensorShape((None,self.beam_size)).concatenate(tensor.get_shape()[1:])) | |
return res | |
def prepend_beam_size(self, element): | |
return tf.TensorShape(self.beam_size).concatenate(element) | |
def tile_along_beam(self, state): | |
if nest.is_sequence(state): | |
return nest_map( | |
lambda val: self.tile_along_beam(val), | |
state | |
) | |
if not isinstance(state, tf.Tensor): | |
raise ValueError("State should be a sequence or tensor") | |
tensor = state | |
tensor_shape = tensor.get_shape().with_rank_at_least(1) | |
new_tensor_shape = tensor_shape[:1].concatenate(self.beam_size).concatenate(tensor_shape[1:]) | |
dynamic_tensor_shape = tf.unpack(tf.shape(tensor)) | |
res = tf.expand_dims(tensor, 1) | |
res = tf.tile(res, [1, self.beam_size] + [1] * (tensor_shape.ndims-1)) | |
res = tf.reshape(res, [-1, self.beam_size] + list(dynamic_tensor_shape[1:])) | |
res.set_shape(new_tensor_shape) | |
return res | |
def __call__(self, inputs, state, scope=None): | |
flat_inputs = nest_map(self.merge_batch_beam, inputs) | |
flat_state = nest_map(self.merge_batch_beam, state) | |
flat_output, flat_next_state = self.cell(flat_inputs, flat_state, scope=scope) | |
output = nest_map(self.unmerge_batch_beam, flat_output) | |
next_state = nest_map(self.unmerge_batch_beam, flat_next_state) | |
return output, next_state | |
@property | |
def state_size(self): | |
return nest_map(self.prepend_beam_size, self.cell.state_size) | |
@property | |
def output_size(self): | |
return nest_map(self.prepend_beam_size, self.cell.output_size) | |
# %% | |
class BeamReplicateWrapper(tf.nn.rnn_cell.RNNCell): | |
def __init__(self, cell, beam_size): | |
self.cell = cell | |
self.beam_size = beam_size | |
def prepend_beam_size(self, element): | |
return tf.TensorShape(self.beam_size).concatenate(element) | |
def tile_along_beam(self, state): | |
if nest.is_sequence(state): | |
return nest_map( | |
lambda val: self.tile_along_beam(val), | |
state | |
) | |
if not isinstance(state, tf.Tensor): | |
raise ValueError("State should be a sequence or tensor") | |
tensor = state | |
tensor_shape = tensor.get_shape().with_rank_at_least(1) | |
new_tensor_shape = tensor_shape[:1].concatenate(self.beam_size).concatenate(tensor_shape[1:]) | |
dynamic_tensor_shape = tf.unpack(tf.shape(tensor)) | |
res = tf.expand_dims(tensor, 1) | |
res = tf.tile(res, [1, self.beam_size] + [1] * (tensor_shape.ndims-1)) | |
res = tf.reshape(res, [-1, self.beam_size] + list(dynamic_tensor_shape[1:])) | |
res.set_shape(new_tensor_shape) | |
return res | |
def __call__(self, inputs, state, scope=None): | |
varscope = scope or tf.get_variable_scope() | |
flat_inputs = nest.flatten(inputs) | |
flat_state = nest.flatten(state) | |
flat_inputs_unstacked = list(zip(*[tf.unstack(tensor, num=self.beam_size, axis=1) for tensor in flat_inputs])) | |
flat_state_unstacked = list(zip(*[tf.unstack(tensor, num=self.beam_size, axis=1) for tensor in flat_state])) | |
flat_output_unstacked = [] | |
flat_next_state_unstacked = [] | |
output_sample = None | |
next_state_sample = None | |
for i, (inputs_k, state_k) in enumerate(zip(flat_inputs_unstacked, flat_state_unstacked)): | |
inputs_k = nest.pack_sequence_as(inputs, inputs_k) | |
state_k = nest.pack_sequence_as(state, state_k) | |
# TODO(nikita): is this scope stuff correct? | |
if i == 0: | |
output_k, next_state_k = self.cell(inputs_k, state_k, scope=scope) | |
else: | |
with tf.variable_scope(varscope, reuse=True): | |
output_k, next_state_k = self.cell(inputs_k, state_k, scope=varscope if scope is not None else None) | |
flat_output_unstacked.append(nest.flatten(output_k)) | |
flat_next_state_unstacked.append(nest.flatten(next_state_k)) | |
output_sample = output_k | |
next_state_sample = next_state_k | |
flat_output = [tf.stack(tensors, axis=1) for tensors in zip(*flat_output_unstacked)] | |
flat_next_state = [tf.stack(tensors, axis=1) for tensors in zip(*flat_next_state_unstacked)] | |
output = nest.pack_sequence_as(output_sample, flat_output) | |
next_state = nest.pack_sequence_as(next_state_sample, flat_next_state) | |
return output, next_state | |
@property | |
def state_size(self): | |
return nest_map(self.prepend_beam_size, self.cell.state_size) | |
@property | |
def output_size(self): | |
return nest_map(self.prepend_beam_size, self.cell.output_size) | |
# %% | |
class BeamSearchHelper(object): | |
# Our beam scores are stored in a fixed-size tensor, but sometimes the | |
# tensor size is greater than the number of elements actually on the beam. | |
# The invalid elements are assigned a highly negative score. | |
# However, top_k errors if any of the inputs have a score of -inf, so we use | |
# a large negative constant instead | |
INVALID_SCORE = -1e18 | |
def __init__(self, cell, beam_size, stop_token, initial_state, initial_input, | |
score_upper_bound=None, | |
max_len=100, | |
outputs_to_score_fn=None, | |
tokens_to_inputs_fn=None, | |
cell_transform='default', | |
scope=None | |
): | |
self.beam_size = beam_size | |
self.stop_token = stop_token | |
self.max_len = max_len | |
self.scope = scope | |
if score_upper_bound is None and outputs_to_score_fn is None: | |
self.score_upper_bound = 0.0 | |
elif score_upper_bound is None or score_upper_bound > 3e38: | |
# Note: 3e38 is just a little smaller than the largest float32 | |
# Second condition allows for Infinity as a synonym for None | |
self.score_upper_bound = None | |
else: | |
self.score_upper_bound = float(score_upper_bound) | |
if self.max_len is None and self.score_upper_bound is None: | |
raise ValueError("Beam search needs a stopping criterion. Please provide max_len or score_upper_bound.") | |
if cell_transform == 'default': | |
if type(cell) in [tf.nn.rnn_cell.LSTMCell, | |
tf.nn.rnn_cell.GRUCell, | |
tf.nn.rnn_cell.BasicLSTMCell, | |
tf.nn.rnn_cell.BasicRNNCell]: | |
cell_transform = 'flatten' | |
else: | |
cell_transform = 'replicate' | |
if cell_transform == 'none': | |
self.cell = cell | |
self.initial_state = initial_state | |
self.initial_input = initial_input | |
elif cell_transform == 'flatten': | |
self.cell = BeamFlattenWrapper(cell, self.beam_size) | |
self.initial_state = self.cell.tile_along_beam(initial_state) | |
self.initial_input = self.cell.tile_along_beam(initial_input) | |
elif cell_transform == 'replicate': | |
self.cell = BeamReplicateWrapper(cell, self.beam_size) | |
self.initial_state = self.cell.tile_along_beam(initial_state) | |
self.initial_input = self.cell.tile_along_beam(initial_input) | |
else: | |
raise ValueError("cell_transform must be one of: 'default', 'flatten', 'replicate', 'none'") | |
self._cell_transform_used = cell_transform | |
if outputs_to_score_fn is not None: | |
self.outputs_to_score_fn = outputs_to_score_fn | |
if tokens_to_inputs_fn is not None: | |
self.tokens_to_inputs_fn = tokens_to_inputs_fn | |
batch_size = tf.Dimension(None) | |
if not nest.is_sequence(self.initial_state): | |
batch_size = batch_size.merge_with(self.initial_state.get_shape()[0]) | |
else: | |
for tensor in nest.flatten(self.initial_state): | |
batch_size = batch_size.merge_with(tensor.get_shape()[0]) | |
if not nest.is_sequence(self.initial_input): | |
batch_size = batch_size.merge_with(self.initial_input.get_shape()[0]) | |
else: | |
for tensor in nest.flatten(self.initial_input): | |
batch_size = batch_size.merge_with(tensor.get_shape()[0]) | |
self.inferred_batch_size = batch_size.value | |
if self.inferred_batch_size is not None: | |
self.batch_size = self.inferred_batch_size | |
else: | |
if not nest.is_sequence(self.initial_state): | |
self.batch_size = tf.shape(self.initial_state)[0] | |
else: | |
self.batch_size = tf.shape(list(nest.flatten(self.initial_state))[0])[0] | |
self.inferred_batch_size_times_beam_size = None | |
if self.inferred_batch_size is not None: | |
self.inferred_batch_size_times_beam_size = self.inferred_batch_size * self.beam_size | |
self.batch_size_times_beam_size = self.batch_size * self.beam_size | |
def outputs_to_score_fn(self, cell_output): | |
return tf.nn.log_softmax(cell_output) | |
def tokens_to_inputs_fn(self, symbols): | |
return tf.expand_dims(symbols, -1) | |
def beam_setup(self, time): | |
emit_output = None | |
next_cell_state = self.initial_state | |
next_input = self.initial_input | |
# Set up the beam search tracking state | |
cand_symbols = tf.fill([self.batch_size, 0], tf.constant(self.stop_token, dtype=tf.int32)) | |
cand_logprobs = tf.ones((self.batch_size,), dtype=tf.float32) * -float('inf') | |
first_in_beam_mask = tf.equal(tf.range(self.batch_size_times_beam_size) % self.beam_size, 0) | |
beam_symbols = tf.fill([self.batch_size_times_beam_size, 0], tf.constant(self.stop_token, dtype=tf.int32)) | |
beam_logprobs = tf.select( | |
first_in_beam_mask, | |
tf.fill([self.batch_size_times_beam_size], 0.0), | |
tf.fill([self.batch_size_times_beam_size], self.INVALID_SCORE) | |
) | |
# Set up correct dimensions for maintaining loop invariants. | |
# Note that the last dimension (initialized to zero) is not a loop invariant, | |
# so we need to clear it. TODO(nikita): is there a public API for clearing shape | |
# inference so that _shape is not necessary? | |
cand_symbols._shape = tf.TensorShape((self.inferred_batch_size, None)) | |
cand_logprobs._shape = tf.TensorShape((self.inferred_batch_size,)) | |
beam_symbols._shape = tf.TensorShape((self.inferred_batch_size_times_beam_size, None)) | |
beam_logprobs._shape = tf.TensorShape((self.inferred_batch_size_times_beam_size,)) | |
next_loop_state = ( | |
cand_symbols, | |
cand_logprobs, | |
beam_symbols, | |
beam_logprobs, | |
) | |
emit_output = tf.zeros(self.cell.output_size) | |
elements_finished = tf.zeros([self.batch_size], dtype=tf.bool) | |
return (elements_finished, next_input, next_cell_state, | |
emit_output, next_loop_state) | |
def beam_loop(self, time, cell_output, cell_state, loop_state): | |
( | |
past_cand_symbols, # [batch_size, time-1] | |
past_cand_logprobs,# [batch_size] | |
past_beam_symbols, # [batch_size*beam_size, time-1], right-aligned | |
past_beam_logprobs,# [batch_size*beam_size] | |
) = loop_state | |
# We don't actually use this, but emit_output is required to match the | |
# cell output size specfication. Otherwise we would leave this as None. | |
emit_output = cell_output | |
# 1. Get scores for all candidate sequences | |
logprobs = self.outputs_to_score_fn(cell_output) | |
try: | |
num_classes = int(logprobs.get_shape()[-1]) | |
except: | |
# Shape inference failed | |
num_classes = tf.shape(logprobs)[-1] | |
logprobs_batched = tf.reshape(logprobs + tf.expand_dims(tf.reshape(past_beam_logprobs, [self.batch_size, self.beam_size]), 2), | |
[self.batch_size, self.beam_size * num_classes]) | |
# 2. Determine which states to pass to next iteration | |
# TODO(nikita): consider using slice+fill+concat instead of adding a mask | |
nondone_mask = tf.reshape( | |
tf.cast(tf.equal(tf.range(num_classes), self.stop_token), tf.float32) * self.INVALID_SCORE, | |
[1, 1, num_classes]) | |
nondone_mask = tf.reshape(tf.tile(nondone_mask, [1, self.beam_size, 1]), | |
[-1, self.beam_size*num_classes]) | |
beam_logprobs, indices = tf.nn.top_k(logprobs_batched + nondone_mask, self.beam_size) | |
beam_logprobs = tf.reshape(beam_logprobs, [-1]) | |
# For continuing to the next symbols | |
symbols = indices % num_classes # [batch_size, self.beam_size] | |
parent_refs = indices // num_classes # [batch_size, self.beam_size] | |
symbols_history = flat_batch_gather(past_beam_symbols, parent_refs, batch_size=self.batch_size, options_size=self.beam_size) | |
beam_symbols = tf.concat(1, [symbols_history, tf.reshape(symbols, [-1, 1])]) | |
# Handle the output and the cell state shuffling | |
next_cell_state = nest_map( | |
lambda element: batch_gather(element, parent_refs, batch_size=self.batch_size, options_size=self.beam_size), | |
cell_state | |
) | |
next_input = self.tokens_to_inputs_fn(tf.reshape(symbols, [-1, self.beam_size])) | |
# 3. Update the candidate pool to include entries that just ended with a stop token | |
logprobs_done = tf.reshape(logprobs_batched, [-1, self.beam_size, num_classes])[:,:,self.stop_token] | |
done_parent_refs = tf.argmax(logprobs_done, 1) | |
done_symbols = flat_batch_gather(past_beam_symbols, done_parent_refs, batch_size=self.batch_size, options_size=self.beam_size) | |
logprobs_done_max = tf.reduce_max(logprobs_done, 1) | |
cand_symbols_unpadded = tf.select(logprobs_done_max > past_cand_logprobs, | |
done_symbols, | |
past_cand_symbols) | |
cand_logprobs = tf.maximum(logprobs_done_max, past_cand_logprobs) | |
cand_symbols = tf.concat(1, [cand_symbols_unpadded, tf.fill([self.batch_size, 1], self.stop_token)]) | |
# 4. Check the stopping criteria | |
if self.max_len is not None: | |
elements_finished_clip = (time >= self.max_len) | |
if self.score_upper_bound is not None: | |
elements_finished_bound = tf.reduce_max(tf.reshape(beam_logprobs, [-1, self.beam_size]), 1) < (cand_logprobs - self.score_upper_bound) | |
if self.max_len is not None and self.score_upper_bound is not None: | |
elements_finished = elements_finished_clip | elements_finished_bound | |
elif self.score_upper_bound is not None: | |
elements_finished = elements_finished_bound | |
elif self.max_len is not None: | |
# this broadcasts elements_finished_clip to the correct shape | |
elements_finished = tf.zeros([self.batch_size], dtype=tf.bool) | elements_finished_clip | |
else: | |
assert False, "Lack of stopping criterion should have been caught in constructor" | |
# 5. Prepare return values | |
# While loops require strict shape invariants, so we manually set shapes | |
# in case the automatic shape inference can't calculate these. Even when | |
# this is redundant is has the benefit of helping catch shape bugs. | |
for tensor in list(nest.flatten(next_input)) + list(nest.flatten(next_cell_state)): | |
tensor.set_shape(tf.TensorShape((self.inferred_batch_size, self.beam_size)).concatenate(tensor.get_shape()[2:])) | |
for tensor in [cand_symbols, cand_logprobs, elements_finished]: | |
tensor.set_shape(tf.TensorShape((self.inferred_batch_size,)).concatenate(tensor.get_shape()[1:])) | |
for tensor in [beam_symbols, beam_logprobs]: | |
tensor.set_shape(tf.TensorShape((self.inferred_batch_size_times_beam_size,)).concatenate(tensor.get_shape()[1:])) | |
next_loop_state = ( | |
cand_symbols, | |
cand_logprobs, | |
beam_symbols, | |
beam_logprobs, | |
) | |
return (elements_finished, next_input, next_cell_state, | |
emit_output, next_loop_state) | |
def loop_fn(self, time, cell_output, cell_state, loop_state): | |
if cell_output is None: | |
return self.beam_setup(time) | |
else: | |
return self.beam_loop(time, cell_output, cell_state, loop_state) | |
def decode_dense(self): | |
emit_ta, final_state, final_loop_state = tf.nn.raw_rnn(self.cell, self.loop_fn, scope=self.scope) | |
cand_symbols, cand_logprobs, beam_symbols, beam_logprobs = final_loop_state | |
return cand_symbols, cand_logprobs | |
def decode_sparse(self, include_stop_tokens=True): | |
dense_symbols, logprobs = self.decode_dense() | |
mask = tf.not_equal(dense_symbols, self.stop_token) | |
if include_stop_tokens: | |
mask = tf.concat(1, [tf.ones_like(mask[:,:1]), mask[:,:-1]]) | |
return sparse_boolean_mask(dense_symbols, mask), logprobs | |
# %% | |
def beam_decoder( | |
cell, | |
beam_size, | |
stop_token, | |
initial_state, | |
initial_input, | |
tokens_to_inputs_fn, | |
outputs_to_score_fn=None, | |
score_upper_bound=None, | |
max_len=None, | |
cell_transform='default', | |
output_dense=False, | |
scope=None | |
): | |
"""Beam search decoder | |
Args: | |
cell: tf.nn.rnn_cell.RNNCell defining the cell to use | |
beam_size: the beam size for this search | |
stop_token: the index of the symbol used to indicate the end of the | |
output | |
initial_state: initial cell state for the decoder | |
initial_input: initial input into the decoder (typically the embedding | |
of a START token) | |
tokens_to_inputs_fn: function to go from token numbers to cell inputs. | |
A typical implementation would look up the tokens in an embedding | |
matrix. | |
(signature: [batch_size, beam_size, num_classes] int32 -> [batch_size, beam_size, ...]) | |
outputs_to_score_fn: function to go from RNN cell outputs to scores for | |
different tokens. If left unset, log-softmax is used (i.e. the cell | |
outputs are treated as unnormalized logits). | |
Inputs to the function are cell outputs, i.e. a possibly nested | |
structure of items with shape [batch_size, beam_size, ...]. | |
Must return a single Tensor with shape [batch_size, beam_size, num_classes] | |
score_upper_bound: (float or None). An upper bound on sequence scores. | |
Used to determine a stopping criterion for beam search: the search | |
stops if the highest-scoring complete sequence found so far beats | |
anything on the beam by at least score_upper_bound. For typical | |
sequence decoder models, outputs_to_score_fn returns normalized | |
logits and this upper bound should be set to 0. Defaults to 0 if | |
outputs_to_score_fn is not provided, otherwise defaults to None. | |
max_len: (default None) maximum length after which to abort beam search. | |
This provides an alternative stopping criterion. | |
cell_transform: 'flatten', 'replicate', 'none', or 'default'. Most RNN | |
primitives require inputs/outputs/states to have a shape that starts | |
with [batch_size]. Beam search instead relies on shapes that start | |
with [batch_size, beam_size]. This parameter controls how the arguments | |
cell/initial_state/initial_input are transformed to comply with this. | |
* 'flatten' creates a virtual batch of size batch_size*beam_size, and | |
uses the cell with such a batch size. This transformation is only | |
valid for cells that do not rely on the batch ordering in any way. | |
(This is true of most RNNCells, but notably excludes cells that | |
use attention.) | |
The values of initial_state and initial_input are expanded and | |
tiled along the beam_size dimension. | |
* 'replicate' creates beam_size virtual replicas of the cell, each | |
one of which is applied to batch_size elements. This should yield | |
correct results (even for models with attention), but may not have | |
ideal performance. | |
The values of initial_state and initial_input are expanded and | |
tiled along the beam_size dimension. | |
* 'none' passes along cell/initial_state/initial_input as-is. | |
Note that this requires initial_state and initial_input to already | |
have a shape [batch_size, beam_size, ...] and a custom cell type | |
that can handle this | |
* 'default' selects 'flatten' for LSTMCell, GRUCell, BasicLSTMCell, | |
and BasicRNNCell. For all other cell types, it selects 'replicate' | |
output_dense: (default False) toggles returning the decoded sequence as | |
dense tensor. | |
scope: VariableScope for the created subgraph; defaults to "RNN". | |
Returns: | |
A tuple of the form (decoded, log_probabilities) where: | |
decoded: A SparseTensor (or dense Tensor if output_dense=True), of | |
underlying shape [batch_size, ?] that contains the decoded sequence | |
for each batch element | |
log_probability: a [batch_size] tensor containing sequence | |
log-probabilities | |
""" | |
with tf.variable_scope(scope or "RNN") as varscope: | |
helper = BeamSearchHelper( | |
cell=cell, | |
beam_size=beam_size, | |
stop_token=stop_token, | |
initial_state=initial_state, | |
initial_input=initial_input, | |
tokens_to_inputs_fn=tokens_to_inputs_fn, | |
outputs_to_score_fn=outputs_to_score_fn, | |
score_upper_bound=score_upper_bound, | |
max_len=max_len, | |
cell_transform=cell_transform, | |
scope=varscope | |
) | |
if output_dense: | |
return helper.decode_dense() | |
else: | |
return helper.decode_sparse() |
import tensorflow as tf | |
from tensorflow.python.platform import test | |
import numpy as np | |
from tf_beam_decoder import beam_decoder, BeamSearchHelper | |
# %% | |
sess = tf.InteractiveSession() | |
# %% | |
class MarkovChainCell(tf.nn.rnn_cell.RNNCell): | |
""" | |
This cell type is only used for testing the beam decoder. | |
It represents a Markov chain characterized by a probability table p(x_t|x_{t-1},x_{t-2}). | |
""" | |
def __init__(self, table): | |
""" | |
table[a,b,c] = p(x_t=c|x_{t-1}=b,x_{t-2}=a) | |
""" | |
assert len(table.shape) == 3 and table.shape[0] == table.shape[1] == table.shape[2] | |
with np.errstate(divide='ignore'): # ignore warning for log(0) | |
self.log_table = np.log(np.asarray(table, dtype=np.float32)) | |
self.log_table_var = None | |
self._output_size = table.shape[0] | |
def __call__(self, inputs, state, scope=None): | |
""" | |
inputs: [batch_size, 1] int tensor | |
state: [batch_size, 1] int tensor | |
""" | |
# Simulate variable creation, to ensure scoping works correctly | |
log_table = tf.get_variable('log_table', | |
shape=(3,3,3), | |
dtype=tf.float32, | |
initializer=tf.constant_initializer(self.log_table)) | |
if self.log_table_var is None: | |
self.log_table_var = log_table | |
else: | |
assert self.log_table_var == log_table | |
logits = tf.reshape(log_table, [-1, self.output_size]) | |
indices = state[0] * self.output_size + inputs | |
return tf.gather(logits, tf.reshape(indices, [-1])), (inputs,) | |
@property | |
def state_size(self): | |
return (1,) | |
@property | |
def output_size(self): | |
return self._output_size | |
class BeamSearchTest(test.TestCase): | |
def test1(self): | |
""" | |
test correct decode in sequence | |
""" | |
with self.test_session() as sess: | |
table = np.array([[[0.0, 0.6, 0.4], | |
[0.0, 0.4, 0.6], | |
[0.0, 0.0, 1.0]]] * 3) | |
for cell_transform in ['default', 'flatten', 'replicate']: | |
cell = MarkovChainCell(table) | |
initial_state = cell.zero_state(1, tf.int32) | |
initial_input = initial_state[0] | |
with tf.variable_scope('test1_{}'.format(cell_transform)): | |
best_sparse, best_logprobs = beam_decoder( | |
cell=cell, | |
beam_size=7, | |
stop_token=2, | |
initial_state=initial_state, | |
initial_input=initial_input, | |
tokens_to_inputs_fn=lambda x:tf.expand_dims(x, -1), | |
max_len=5, | |
cell_transform=cell_transform, | |
output_dense=False, | |
) | |
tf.variables_initializer([cell.log_table_var]).run() | |
assert all(best_sparse.eval().values == [2]) | |
assert np.isclose(np.exp(best_logprobs.eval())[0], 0.4) | |
def test2(self): | |
""" | |
test correct intermediate beam states | |
""" | |
with self.test_session() as sess: | |
table = np.array([[[0.9, 0.1, 0], | |
[0, 0.9, 0.1], | |
[0, 0, 1.0]]] * 3) | |
for cell_transform in ['default', 'flatten', 'replicate']: | |
cell = MarkovChainCell(table) | |
initial_state = cell.zero_state(1, tf.int32) | |
initial_input = initial_state[0] | |
with tf.variable_scope('test2_{}'.format(cell_transform)): | |
helper = BeamSearchHelper( | |
cell=cell, | |
beam_size=10, | |
stop_token=2, | |
initial_state=initial_state, | |
initial_input=initial_input, | |
tokens_to_inputs_fn=lambda x:tf.expand_dims(x, -1), | |
max_len=3, | |
cell_transform=cell_transform | |
) | |
_, _, final_loop_state = tf.nn.raw_rnn(helper.cell, helper.loop_fn) | |
_, _, beam_symbols, beam_logprobs = final_loop_state | |
tf.variables_initializer([cell.log_table_var]).run() | |
candidates, candidate_logprobs = sess.run((beam_symbols, beam_logprobs)) | |
assert all(candidates[0,:] == [0,0,0]) | |
assert np.isclose(np.exp(candidate_logprobs[0]), 0.9 * 0.9 * 0.9) | |
# Note that these three candidates all have the same score, and the sort order | |
# may change in the future | |
assert all(candidates[1,:] == [0,0,1]) | |
assert np.isclose(np.exp(candidate_logprobs[1]), 0.9 * 0.9 * 0.1) | |
assert all(candidates[2,:] == [0,1,1]) | |
assert np.isclose(np.exp(candidate_logprobs[2]), 0.9 * 0.1 * 0.9) | |
assert all(candidates[3,:] == [1,1,1]) | |
assert np.isclose(np.exp(candidate_logprobs[3]), 0.1 * 0.9 * 0.9) | |
assert all(np.isclose(np.exp(candidate_logprobs[4:]), 0.0)) | |
def test3(self): | |
""" | |
test that variable reuse works as expected | |
""" | |
with self.test_session() as sess: | |
table = np.array([[[0.0, 0.6, 0.4], | |
[0.0, 0.4, 0.6], | |
[0.0, 0.0, 1.0]]] * 3) | |
for cell_transform in ['default', 'flatten', 'replicate']: | |
cell = MarkovChainCell(table) | |
initial_state = cell.zero_state(1, tf.int32) | |
initial_input = initial_state[0] | |
with tf.variable_scope('test3_{}'.format(cell_transform)) as scope: | |
best_sparse, best_logprobs = beam_decoder( | |
cell=cell, | |
beam_size=7, | |
stop_token=2, | |
initial_state=initial_state, | |
initial_input=initial_input, | |
tokens_to_inputs_fn=lambda x:tf.expand_dims(x, -1), | |
max_len=5, | |
cell_transform=cell_transform, | |
output_dense=False, | |
scope=scope | |
) | |
tf.variables_initializer([cell.log_table_var]).run() | |
with tf.variable_scope(scope, reuse=True) as varscope: | |
best_sparse_2, best_logprobs_2 = beam_decoder( | |
cell=cell, | |
beam_size=7, | |
stop_token=2, | |
initial_state=initial_state, | |
initial_input=initial_input, | |
tokens_to_inputs_fn=lambda x:tf.expand_dims(x, -1), | |
max_len=5, | |
cell_transform=cell_transform, | |
output_dense=False, | |
scope=varscope | |
) | |
assert all(sess.run(tf.equal(best_sparse.values, best_sparse_2.values))) | |
assert np.isclose(*sess.run((best_logprobs, best_logprobs_2))) | |
def test4(self): | |
""" | |
test batching, with statically unknown batch size | |
""" | |
with self.test_session() as sess: | |
table = np.array([[[0.9, 0.1, 0], | |
[0, 0.9, 0.1], | |
[0, 0, 1.0]]] * 3) | |
for cell_transform in ['default', 'flatten', 'replicate']: | |
cell = MarkovChainCell(table) | |
initial_state = (tf.constant([[2],[0]]),) | |
initial_input = initial_state[0] | |
initial_input._shape = tf.TensorShape([None, 1]) | |
with tf.variable_scope('test4_{}'.format(cell_transform)): | |
helper = BeamSearchHelper( | |
cell=cell, | |
beam_size=10, | |
stop_token=2, | |
initial_state=initial_state, | |
initial_input=initial_input, | |
tokens_to_inputs_fn=lambda x:tf.expand_dims(x, -1), | |
max_len=3, | |
cell_transform=cell_transform | |
) | |
_, _, final_loop_state = tf.nn.raw_rnn(helper.cell, helper.loop_fn) | |
_, _, beam_symbols, beam_logprobs = final_loop_state | |
tf.variables_initializer([cell.log_table_var]).run() | |
candidates, candidate_logprobs = sess.run((beam_symbols, beam_logprobs)) | |
assert all(candidates[10,:] == [0,0,0]) | |
assert np.isclose(np.exp(candidate_logprobs[10]), 0.9 * 0.9 * 0.9) | |
# Note that these three candidates all have the same score, and the sort order | |
# may change in the future | |
assert all(candidates[11,:] == [0,0,1]) | |
assert np.isclose(np.exp(candidate_logprobs[11]), 0.9 * 0.9 * 0.1) | |
assert all(candidates[12,:] == [0,1,1]) | |
assert np.isclose(np.exp(candidate_logprobs[12]), 0.9 * 0.1 * 0.9) | |
assert all(candidates[13,:] == [1,1,1]) | |
assert np.isclose(np.exp(candidate_logprobs[13]), 0.1 * 0.9 * 0.9) | |
assert all(np.isclose(np.exp(candidate_logprobs[14:]), 0.0)) | |
if __name__ == '__main__': | |
test.main() |
@nikitakit I tried this with version 1.1.0. After modifying some api changes, I'm stuck with the error:
alueError: Attempt to reuse RNNCell <tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl.BasicLSTMCell object at 0x7f16c0ca3588> with a different variable scope than its first use. First use of cell was with scope 'seq2seq_att/social_bot/Decoder/decoder/decoder/basic_lstm_cell', this attempt is with scope 'seq2seq_att/social_bot/Decoder/decoder/rnn/basic_lstm_cell'
at this line: output_k, next_state_k = self.cell(inputs_k, state_k,scope=scope)
@nikitakit I have some questions about outputs_to_score_fn and the return value.
- what's the expect input shape and output shape of this function?
- the output should be the log_softmax results instead of the logits before the softmax, right?
- if I want to get n best list, what should I do ?
Thanks!
@avostryakov did you solve your problem? I think I have exactly same problem with you. the cand_symbols is just sequence of stop tokens...
@nikitakit hey! does this work with TF 1.0? thanks!