Last active
December 20, 2016 17:13
-
-
Save alrojo/d66240bdb8fdb2658081c3918ed8efb2 to your computer and use it in GitHub Desktop.
attention_rnn_decoder
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
from tensorflow.contrib import layers | |
from tensorflow.python.framework import constant_op | |
from tensorflow.python.framework import dtypes | |
from tensorflow.python.framework import ops | |
from tensorflow.python.ops import array_ops | |
from tensorflow.python.ops import control_flow_ops | |
from tensorflow.python.ops import math_ops | |
from tensorflow.python.ops import nn | |
from tensorflow.python.ops import nn_ops | |
from tensorflow.python.ops import rnn | |
from tensorflow.python.ops import rnn_cell | |
from tensorflow.python.ops import tensor_array_ops | |
from tensorflow.python.ops import variable_scope as vs | |
from tensorflow.python.util import nest | |
def mask(sequence_lengths): | |
""" Mask function used by recurrent decoder with attention. | |
Given a vector with varying lengths, produces an explicit matrices with | |
True/False values. E.g. | |
> mask([1,2,3]) | |
[[True, False, False], | |
[True, True, False], | |
[True, True, True]] | |
Args: | |
sequence_lengths: An int32/int64 vector of size n. | |
Return: | |
true_false: A [n, max_len(sequence_lengths)] sized Tensor. | |
""" | |
# based on this SO answer: http://stackoverflow.com/a/34138336/118173 | |
batch_size = array_ops.shape(sequence_lengths)[0] | |
max_len = math_ops.reduce_max(sequence_lengths) | |
lengths_transposed = array_ops.expand_dims(sequence_lengths, 1) | |
rng = math_ops.range(max_len) | |
rng_row = array_ops.expand_dims(rng, 0) | |
true_false = math_ops.less(rng_row, lengths_transposed) | |
return true_false | |
def simple_attention_fn(attention_inputs, number_attention_units, | |
attention_length, scope, normalizer=nn.softmax): | |
attention_input_depth = int(attention_inputs.get_shape()[2]) | |
hidden = array_ops.expand_dims(attention_inputs, axis=2) | |
with vs.variable_scope("part1") as varscope: | |
part1 = layers.conv2d(hidden, number_attention_units, (1, 1), scope=varscope) | |
part1 = array_ops.squeeze(part1, [2]) | |
def context_fn(state, inp, reuse): | |
with vs.variable_scope("part21", reuse=reuse) as varscope: | |
part2 = layers.fully_connected(state, number_attention_units, | |
activation_fn=None, | |
biases_initializer=None, | |
scope=varscope) | |
part2 = array_ops.expand_dims(part2, 1) | |
with vs.variable_scope("part22", reuse=reuse) as varscope: | |
cmb_attn = layers.fully_connected(math_ops.tanh(part1 + part2), 1, | |
activation_fn=None, | |
biases_initializer=None, | |
scope=varscope) | |
print("cmb_attn", cmb_attn.get_shape()) | |
e = array_ops.squeeze(cmb_attn, axis=[2]) | |
alpha = normalizer(e) | |
# Mask | |
if attention_length is not None: | |
alpha = math_ops.to_float(mask(attention_length)) * alpha | |
alpha = alpha / math_ops.reduce_sum(alpha, [1], keep_dims=True) | |
context = math_ops.reduce_sum(array_ops.expand_dims(alpha, 2) | |
* attention_inputs, [1]) | |
context.set_shape([None, attention_input_depth]) | |
print(context.get_shape()) | |
print(inp.get_shape()) | |
con = array_ops.concat(1, (inp, context)) | |
print(con.get_shape()) | |
return con, alpha | |
return context_fn | |
def simple_decoder_fn_train(encoder_state, context_fn=None, name=None): | |
with ops.name_scope(name, "simple_decoder_fn_train", [encoder_state, context_fn]): | |
pass | |
def decoder_fn(time, cell_state, cell_input, cell_output, context_state, reuse=False): | |
with ops.name_scope(name, "simple_decoder_fn_train", | |
[time, cell_state, cell_input, cell_output, | |
context_state]): | |
if cell_state is None: # first call, return encoder_state | |
if context_fn is not None: | |
cell_input, _ = context_fn(encoder_state, cell_input, reuse=reuse) | |
return (None, encoder_state, cell_input, cell_output, context_state) | |
else: | |
if context_fn is not None: | |
cell_input, _ = context_fn(cell_state, cell_input, reuse=reuse) | |
return (None, cell_state, cell_input, cell_output, context_state) | |
return decoder_fn | |
def simple_decoder_fn_inference(output_fn, encoder_state, embeddings, | |
start_of_sequence_id, end_of_sequence_id, | |
maximum_length, num_decoder_symbols, | |
context_fn=None, | |
dtype=dtypes.int32, name=None): | |
with ops.name_scope(name, "simple_decoder_fn_inference", | |
[output_fn, encoder_state, embeddings, | |
start_of_sequence_id, end_of_sequence_id, | |
maximum_length, num_decoder_symbols, dtype]): | |
start_of_sequence_id = ops.convert_to_tensor(start_of_sequence_id, dtype) | |
end_of_sequence_id = ops.convert_to_tensor(end_of_sequence_id, dtype) | |
maximum_length = ops.convert_to_tensor(maximum_length, dtype) | |
num_decoder_symbols = ops.convert_to_tensor(num_decoder_symbols, dtype) | |
encoder_info = nest.flatten(encoder_state)[0] | |
batch_size = encoder_info.get_shape()[0].value | |
if output_fn is None: | |
output_fn = lambda x: x | |
if batch_size is None: | |
batch_size = array_ops.shape(encoder_info)[0] | |
def decoder_fn(time, cell_state, cell_input, cell_output, context_state, reuse=True): | |
with ops.name_scope(name, "simple_decoder_fn_inference", | |
[time, cell_state, cell_input, cell_output, | |
context_state]): | |
if cell_input is not None: | |
raise ValueError("Expected cell_input to be None, but saw: %s" % | |
cell_input) | |
if cell_output is None: | |
# invariant that this is time == 0 | |
next_input_id = array_ops.ones([batch_size,], dtype=dtype) * ( | |
start_of_sequence_id) | |
done = array_ops.zeros([batch_size,], dtype=dtypes.bool) | |
cell_state = encoder_state | |
cell_output = array_ops.zeros([num_decoder_symbols], | |
dtype=dtypes.float32) | |
else: | |
cell_output = output_fn(cell_output) | |
next_input_id = math_ops.cast( | |
math_ops.argmax(cell_output, 1), dtype=dtype) | |
done = math_ops.equal(next_input_id, end_of_sequence_id) | |
next_input = array_ops.gather(embeddings, next_input_id) | |
if context_fn is not None: | |
next_input, _ = context_fn(cell_state, next_input, reuse=reuse) | |
# if time > maxlen, return all true vector | |
done = control_flow_ops.cond(math_ops.greater(time, maximum_length), | |
lambda: array_ops.ones([batch_size,], dtype=dtypes.bool), | |
lambda: done) | |
return (done, cell_state, next_input, cell_output, context_state) | |
return decoder_fn | |
def dynamic_rnn_decoder(cell, decoder_fn, inputs=None, sequence_length=None, | |
number_attention_units=None, | |
parallel_iterations=None, swap_memory=False, | |
time_major=False, scope=None, name=None): | |
""" Dynamic RNN decoder for a sequence-to-sequence model specified by | |
RNNCell and decoder function. | |
The `dynamic_rnn_decoder` is similar to the `tf.python.ops.rnn.dynamic_rnn` | |
as the decoder does not make any assumptions of sequence length and batch | |
size of the input. | |
The `dynamic_rnn_decoder` has two modes: training or inference and expects | |
the user to create seperate functions for each. | |
Under both training and inference `cell` and `decoder_fn` is expected. Where | |
the `cell` performs computation at every timestep using the `raw_rnn` and | |
the `decoder_fn` allows modelling of early stopping, output, state, and next | |
input and context. | |
When training the user is expected to supply `inputs`. At every time step a | |
slice of the supplied input is fed to the `decoder_fn`, which modifies and | |
returns the input for the next time step. | |
`sequence_length` is needed at training time, i.e., when `inputs` is not | |
None, for dynamic unrolling. At test time, when `inputs` is None, | |
`sequence_length` is not needed. | |
Under inference `inputs` is expected to be `None` and the input is inferred | |
solely from the `decoder_fn`. | |
Args: | |
cell: An instance of RNNCell. | |
decoder_fn: A function that takes time, cell state, cell input, | |
cell output and context state. It returns a early stopping vector, | |
cell state, next input, cell output and context state. | |
Examples of decoder_fn can be found in the decoder_fn.py folder. | |
inputs: The inputs for decoding (embedded format). | |
If `time_major == False` (default), this must be a `Tensor` of shape: | |
`[batch_size, max_time, ...]`. | |
If `time_major == True`, this must be a `Tensor` of shape: | |
`[max_time, batch_size, ...]`. | |
The input to `cell` at each time step will be a `Tensor` with dimensions | |
`[batch_size, ...]`. | |
sequence_length: (optional) An int32/int64 vector sized `[batch_size]`. | |
if `inputs` is not None and `sequence_length` is None it is inferred | |
from the `inputs` as the maximal possible sequence length. | |
parallel_iterations: (Default: 32). The number of iterations to run in | |
parallel. Those operations which do not have any temporal dependency | |
and can be run in parallel, will be. This parameter trades off | |
time for space. Values >> 1 use more memory but take less time, | |
while smaller values use less memory but computations take longer. | |
swap_memory: Transparently swap the tensors produced in forward inference | |
but needed for back prop from GPU to CPU. This allows training RNNs | |
which would typically not fit on a single GPU, with very minimal (or no) | |
performance penalty. | |
time_major: The shape format of the `inputs` and `outputs` Tensors. | |
If true, these `Tensors` must be shaped `[max_time, batch_size, depth]`. | |
If false, these `Tensors` must be shaped `[batch_size, max_time, depth]`. | |
Using `time_major = True` is a bit more efficient because it avoids | |
transposes at the beginning and end of the RNN calculation. However, | |
most TensorFlow data is batch-major, so by default this function | |
accepts input and emits output in batch-major form. | |
scope: VariableScope for the `raw_rnn`; | |
defaults to None. | |
name: NameScope for the decoder; | |
defaults to "dynamic_rnn_decoder" | |
Returns: | |
A pair (outputs, state) where: | |
outputs: the RNN output 'Tensor'. | |
If time_major == False (default), this will be a `Tensor` shaped: | |
`[batch_size, max_time, cell.output_size]`. | |
If time_major == True, this will be a `Tensor` shaped: | |
`[max_time, batch_size, cell.output_size]`. | |
state: The final state and will be shaped | |
`[batch_size, cell.state_size]`. | |
Raises: | |
ValueError: if inputs is not None and has less than three dimensions. | |
""" | |
with ops.name_scope(name, "dynamic_rnn_decoder", | |
[cell, decoder_fn, inputs, sequence_length, | |
parallel_iterations, swap_memory, time_major, scope]): | |
if inputs is not None: | |
# Convert to tensor | |
inputs = ops.convert_to_tensor(inputs) | |
# Test input dimensions | |
if inputs.get_shape().ndims is not None and ( | |
inputs.get_shape().ndims < 2): | |
raise ValueError("Inputs must have at least two dimensions") | |
# Setup of RNN (dimensions, sizes, length, initial state, dtype) | |
if not time_major: | |
# [batch, seq, features] -> [seq, batch, features] | |
inputs = array_ops.transpose(inputs, perm=[1, 0, 2]) | |
dtype = inputs.dtype | |
# Get data input information | |
input_depth = int(inputs.get_shape()[2]) | |
#if number_attention_units is not None: | |
# input_depth += number_attention_units | |
batch_depth = inputs.get_shape()[1].value | |
max_time = inputs.get_shape()[0].value | |
if max_time is None: | |
max_time = array_ops.shape(inputs)[0] | |
# Setup decoder inputs as TensorArray | |
inputs_ta = tensor_array_ops.TensorArray(dtype, size=max_time) | |
inputs_ta = inputs_ta.unpack(inputs) | |
def loop_fn(time, cell_output, cell_state, loop_state): | |
if cell_state is None: # first call, before while loop (in raw_rnn) | |
if cell_output is not None: | |
raise ValueError("Expected cell_output to be None when cell_state " | |
"is None, but saw: %s" % cell_output) | |
if loop_state is not None: | |
raise ValueError("Expected loop_state to be None when cell_state " | |
"is None, but saw: %s" % loop_state) | |
context_state = None | |
else: # subsequent calls, inside while loop, after cell excution | |
if isinstance(loop_state, tuple): | |
(done, context_state) = loop_state | |
else: | |
done = loop_state | |
context_state = None | |
# call decoder function | |
if inputs is not None: # training | |
# get next_cell_input | |
if cell_state is None: | |
read_input = inputs_ta.read(0) | |
print("first read:", read_input) | |
(next_done, next_cell_state, next_cell_input, emit_output, | |
next_context_state) = decoder_fn(time, cell_state, read_input, | |
cell_output, context_state) | |
print("first next_cell_input:", next_cell_input) | |
else: | |
if batch_depth is not None: | |
batch_size = batch_depth | |
else: | |
batch_size = array_ops.shape(done)[0] | |
read_input = control_flow_ops.cond( | |
math_ops.equal(time, max_time), | |
lambda: array_ops.zeros([batch_size, input_depth], dtype=dtype), | |
lambda: inputs_ta.read(time)) | |
print("second read:", read_input) | |
(next_done, next_cell_state, next_cell_input, emit_output, | |
next_context_state) = decoder_fn(time, cell_state, read_input, | |
cell_output, context_state, reuse=True) | |
print("second next_cell_input:", next_cell_input) | |
else: # inference | |
# next_cell_input is obtained through decoder_fn | |
(next_done, next_cell_state, next_cell_input, emit_output, | |
next_context_state) = decoder_fn(time, cell_state, None, cell_output, | |
context_state) | |
# check if we are done | |
if next_done is None: # training | |
next_done = time >= sequence_length | |
# build next_loop_state | |
if next_context_state is None: | |
next_loop_state = next_done | |
else: | |
next_loop_state = (next_done, next_context_state) | |
return (next_done, next_cell_input, next_cell_state, | |
emit_output, next_loop_state) | |
# Run raw_rnn function | |
outputs_ta, state, _ = rnn.raw_rnn( | |
cell, loop_fn, parallel_iterations=parallel_iterations, | |
swap_memory=swap_memory, scope=scope) | |
outputs = outputs_ta.pack() | |
if not time_major: | |
# [seq, batch, features] -> [batch, seq, features] | |
outputs = array_ops.transpose(outputs, perm=[1, 0, 2]) | |
return outputs, state | |
def sequence_loss(logits, targets, weights, | |
average_across_timesteps=True, average_across_batch=True, | |
softmax_loss_function=None, name=None): | |
"""Weighted cross-entropy loss for a sequence of logits (per example). | |
Args: | |
logits: A 3D Tensor of shape | |
[batch_size x sequence_length x num_decoder_symbols] and dtype float. | |
The logits correspond to the prediction across all classes at each | |
timestep. | |
targets: A 2D Tensor of shape [batch_size x sequence_length] and dtype | |
int. The target represents the true class at each timestep. | |
weights: A 2D Tensor of shape [batch_size x sequence_length] and dtype | |
float. Weights constitutes the weighting of each prediction in the | |
sequence. When using weights as masking set all valid timesteps to 1 and | |
all padded timesteps to 0. | |
average_across_timesteps: If set, sum the cost across the sequence | |
dimension and divide by the cost by the total label weight across | |
timesteps. | |
average_across_batch: If set, sum the cost across the batch dimension and | |
divide the returned cost by the batch size. | |
softmax_loss_function: Function (inputs-batch, labels-batch) -> loss-batch | |
to be used instead of the standard softmax (the default if this is None). | |
name: Optional name for this operation, defaults to "sequence_loss". | |
Returns: | |
A scalar float Tensor: The average log-perplexity per symbol (weighted). | |
Raises: | |
ValueError: logits does not have 3 dimensions or targets does not have 2 | |
dimensions or weights does not have 2 dimensions. | |
""" | |
if len(logits.get_shape()) != 3: | |
raise ValueError("Logits must be a " | |
"[batch_size x sequence_length x logits] tensor") | |
if len(targets.get_shape()) != 2: | |
raise ValueError("Targets must be a [batch_size x sequence_length] " | |
"tensor") | |
if len(weights.get_shape()) != 2: | |
raise ValueError("Weights must be a [batch_size x sequence_length] " | |
"tensor") | |
with ops.name_scope(name, "sequence_loss", [logits, targets, weights]): | |
num_classes = array_ops.shape(logits)[2] | |
probs_flat = array_ops.reshape(logits, [-1, num_classes]) | |
targets = array_ops.reshape(targets, [-1]) | |
if softmax_loss_function is None: | |
crossent = nn_ops.sparse_softmax_cross_entropy_with_logits( | |
probs_flat, targets) | |
else: | |
crossent = softmax_loss_function(probs_flat, targets) | |
crossent = crossent * array_ops.reshape(weights, [-1]) | |
if average_across_timesteps and average_across_batch: | |
crossent = math_ops.reduce_sum(crossent) | |
total_size = math_ops.reduce_sum(weights) | |
total_size += 1e-12 # to avoid division by 0 for all-0 weights | |
crossent /= total_size | |
else: | |
batch_size = array_ops.shape(logits)[0] | |
sequence_length = array_ops.shape(logits)[1] | |
crossent = array_ops.reshape(crossent, [batch_size, sequence_length]) | |
if average_across_timesteps and not average_across_batch: | |
crossent = math_ops.reduce_sum(crossent, axis=[1]) | |
total_size = math_ops.reduce_sum(weights, axis=[1]) | |
total_size += 1e-12 # to avoid division by 0 for all-0 weights | |
crossent /= total_size | |
if not average_across_timesteps and average_across_batch: | |
crossent = math_ops.reduce_sum(crossent, axis=[0]) | |
total_size = math_ops.reduce_sum(weights, axis=[0]) | |
total_size += 1e-12 # to avoid division by 0 for all-0 weights | |
crossent /= total_size | |
return crossent | |
if __name__ == '__main__': | |
import tensorflow as tf | |
# Setting up hyperparameters and general configs | |
MAX_DIGITS = 5 | |
MIN_DIGITS = 5 | |
NUM_INPUTS = 27 | |
NUM_OUTPUTS = 11 #(0-9 + '#') | |
BATCH_SIZE = 100 | |
# try various learning rates 1e-2 to 1e-5 | |
LEARNING_RATE = 0.005 | |
X_EMBEDDINGS = 8 | |
t_EMBEDDINGS = 8 | |
NUM_UNITS_ENC = 12 | |
NUM_UNITS_DEC = 12 | |
# Setting up placeholders, these are the tensors that we "feed" to our network | |
Xs = tf.placeholder(tf.int32, shape=[None, None], name='X_input') | |
ts_in = tf.placeholder(tf.int32, shape=[None, None], name='t_input_in') | |
ts_out = tf.placeholder(tf.int32, shape=[None, None], name='t_input_out') | |
X_len = tf.placeholder(tf.int32, shape=[None], name='X_len') | |
t_len = tf.placeholder(tf.int32, shape=[None], name='X_len') | |
t_mask = tf.placeholder(tf.float32, shape=[None, None], name='t_mask') | |
# first we build the embeddings to make our characters into dense, trainable vectors | |
X_embeddings = tf.get_variable('X_embeddings', [NUM_INPUTS, X_EMBEDDINGS], | |
initializer=tf.random_normal_initializer(stddev=0.1)) | |
t_embeddings = tf.get_variable('t_embeddings', [NUM_OUTPUTS, t_EMBEDDINGS], | |
initializer=tf.random_normal_initializer(stddev=0.1)) | |
# setting up weights for computing the final output | |
W_out = tf.get_variable('W_out', [NUM_UNITS_DEC, NUM_OUTPUTS]) | |
b_out = tf.get_variable('b_out', [NUM_OUTPUTS]) | |
X_embedded = tf.gather(X_embeddings, Xs, name='embed_X') | |
t_embedded = tf.gather(t_embeddings, ts_in, name='embed_t') | |
with tf.variable_scope("rnn") as scope: | |
# forward encoding | |
enc_cell = tf.nn.rnn_cell.GRUCell(NUM_UNITS_ENC) | |
encoder_outputs, enc_state = tf.nn.dynamic_rnn(cell=enc_cell, inputs=X_embedded, | |
sequence_length=X_len, dtype=tf.float32) | |
with tf.variable_scope("decoder") as scope: | |
#with tf.variable_scope("attention") as scope: | |
context_fn = simple_attention_fn(encoder_outputs, 20, X_len, scope=scope) | |
output_fn = lambda x: tf.contrib.layers.linear(x, NUM_OUTPUTS, scope=scope) | |
decoder_cell = tf.nn.rnn_cell.GRUCell(NUM_UNITS_DEC) | |
decoder_fn_train = simple_decoder_fn_train( | |
encoder_state=enc_state, context_fn=context_fn) | |
decoder_outputs_train, decoder_state_train = ( | |
dynamic_rnn_decoder( | |
cell=decoder_cell, | |
decoder_fn=decoder_fn_train, | |
inputs=t_embedded, | |
sequence_length=t_len, | |
number_attention_units=20, | |
scope=scope)) | |
decoder_outputs_train = output_fn(decoder_outputs_train) | |
scope.reuse_variables() | |
decoder_fn_inference = ( | |
simple_decoder_fn_inference( | |
output_fn=output_fn, | |
encoder_state=enc_state, | |
embeddings=t_embeddings, | |
start_of_sequence_id=10, | |
end_of_sequence_id=10, | |
#TODO: find out why it goes to +1 | |
maximum_length=MAX_DIGITS, | |
num_decoder_symbols=NUM_OUTPUTS, | |
context_fn=context_fn, | |
dtype=tf.int32)) | |
decoder_outputs_inference, decoder_state_inference = ( | |
dynamic_rnn_decoder( | |
cell=decoder_cell, | |
decoder_fn=decoder_fn_inference, | |
scope=scope)) | |
y = decoder_outputs_train | |
y_valid = decoder_outputs_inference | |
# loss and optimize | |
loss = sequence_loss(y, ts_out, t_mask) | |
global_step = tf.Variable(0, name='global_step', trainable=False) | |
optimizer = tf.train.AdamOptimizer(LEARNING_RATE) | |
# extract gradients for each variable | |
grads_and_vars = optimizer.compute_gradients(loss) | |
train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment