Skip to content

Instantly share code, notes, and snippets.

@alrojo
Last active December 20, 2016 17:13
Show Gist options
  • Save alrojo/d66240bdb8fdb2658081c3918ed8efb2 to your computer and use it in GitHub Desktop.
Save alrojo/d66240bdb8fdb2658081c3918ed8efb2 to your computer and use it in GitHub Desktop.
attention_rnn_decoder
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib import layers
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.util import nest
def mask(sequence_lengths):
""" Mask function used by recurrent decoder with attention.
Given a vector with varying lengths, produces an explicit matrices with
True/False values. E.g.
> mask([1,2,3])
[[True, False, False],
[True, True, False],
[True, True, True]]
Args:
sequence_lengths: An int32/int64 vector of size n.
Return:
true_false: A [n, max_len(sequence_lengths)] sized Tensor.
"""
# based on this SO answer: http://stackoverflow.com/a/34138336/118173
batch_size = array_ops.shape(sequence_lengths)[0]
max_len = math_ops.reduce_max(sequence_lengths)
lengths_transposed = array_ops.expand_dims(sequence_lengths, 1)
rng = math_ops.range(max_len)
rng_row = array_ops.expand_dims(rng, 0)
true_false = math_ops.less(rng_row, lengths_transposed)
return true_false
def simple_attention_fn(attention_inputs, number_attention_units,
attention_length, scope, normalizer=nn.softmax):
attention_input_depth = int(attention_inputs.get_shape()[2])
hidden = array_ops.expand_dims(attention_inputs, axis=2)
with vs.variable_scope("part1") as varscope:
part1 = layers.conv2d(hidden, number_attention_units, (1, 1), scope=varscope)
part1 = array_ops.squeeze(part1, [2])
def context_fn(state, inp, reuse):
with vs.variable_scope("part21", reuse=reuse) as varscope:
part2 = layers.fully_connected(state, number_attention_units,
activation_fn=None,
biases_initializer=None,
scope=varscope)
part2 = array_ops.expand_dims(part2, 1)
with vs.variable_scope("part22", reuse=reuse) as varscope:
cmb_attn = layers.fully_connected(math_ops.tanh(part1 + part2), 1,
activation_fn=None,
biases_initializer=None,
scope=varscope)
print("cmb_attn", cmb_attn.get_shape())
e = array_ops.squeeze(cmb_attn, axis=[2])
alpha = normalizer(e)
# Mask
if attention_length is not None:
alpha = math_ops.to_float(mask(attention_length)) * alpha
alpha = alpha / math_ops.reduce_sum(alpha, [1], keep_dims=True)
context = math_ops.reduce_sum(array_ops.expand_dims(alpha, 2)
* attention_inputs, [1])
context.set_shape([None, attention_input_depth])
print(context.get_shape())
print(inp.get_shape())
con = array_ops.concat(1, (inp, context))
print(con.get_shape())
return con, alpha
return context_fn
def simple_decoder_fn_train(encoder_state, context_fn=None, name=None):
with ops.name_scope(name, "simple_decoder_fn_train", [encoder_state, context_fn]):
pass
def decoder_fn(time, cell_state, cell_input, cell_output, context_state, reuse=False):
with ops.name_scope(name, "simple_decoder_fn_train",
[time, cell_state, cell_input, cell_output,
context_state]):
if cell_state is None: # first call, return encoder_state
if context_fn is not None:
cell_input, _ = context_fn(encoder_state, cell_input, reuse=reuse)
return (None, encoder_state, cell_input, cell_output, context_state)
else:
if context_fn is not None:
cell_input, _ = context_fn(cell_state, cell_input, reuse=reuse)
return (None, cell_state, cell_input, cell_output, context_state)
return decoder_fn
def simple_decoder_fn_inference(output_fn, encoder_state, embeddings,
start_of_sequence_id, end_of_sequence_id,
maximum_length, num_decoder_symbols,
context_fn=None,
dtype=dtypes.int32, name=None):
with ops.name_scope(name, "simple_decoder_fn_inference",
[output_fn, encoder_state, embeddings,
start_of_sequence_id, end_of_sequence_id,
maximum_length, num_decoder_symbols, dtype]):
start_of_sequence_id = ops.convert_to_tensor(start_of_sequence_id, dtype)
end_of_sequence_id = ops.convert_to_tensor(end_of_sequence_id, dtype)
maximum_length = ops.convert_to_tensor(maximum_length, dtype)
num_decoder_symbols = ops.convert_to_tensor(num_decoder_symbols, dtype)
encoder_info = nest.flatten(encoder_state)[0]
batch_size = encoder_info.get_shape()[0].value
if output_fn is None:
output_fn = lambda x: x
if batch_size is None:
batch_size = array_ops.shape(encoder_info)[0]
def decoder_fn(time, cell_state, cell_input, cell_output, context_state, reuse=True):
with ops.name_scope(name, "simple_decoder_fn_inference",
[time, cell_state, cell_input, cell_output,
context_state]):
if cell_input is not None:
raise ValueError("Expected cell_input to be None, but saw: %s" %
cell_input)
if cell_output is None:
# invariant that this is time == 0
next_input_id = array_ops.ones([batch_size,], dtype=dtype) * (
start_of_sequence_id)
done = array_ops.zeros([batch_size,], dtype=dtypes.bool)
cell_state = encoder_state
cell_output = array_ops.zeros([num_decoder_symbols],
dtype=dtypes.float32)
else:
cell_output = output_fn(cell_output)
next_input_id = math_ops.cast(
math_ops.argmax(cell_output, 1), dtype=dtype)
done = math_ops.equal(next_input_id, end_of_sequence_id)
next_input = array_ops.gather(embeddings, next_input_id)
if context_fn is not None:
next_input, _ = context_fn(cell_state, next_input, reuse=reuse)
# if time > maxlen, return all true vector
done = control_flow_ops.cond(math_ops.greater(time, maximum_length),
lambda: array_ops.ones([batch_size,], dtype=dtypes.bool),
lambda: done)
return (done, cell_state, next_input, cell_output, context_state)
return decoder_fn
def dynamic_rnn_decoder(cell, decoder_fn, inputs=None, sequence_length=None,
number_attention_units=None,
parallel_iterations=None, swap_memory=False,
time_major=False, scope=None, name=None):
""" Dynamic RNN decoder for a sequence-to-sequence model specified by
RNNCell and decoder function.
The `dynamic_rnn_decoder` is similar to the `tf.python.ops.rnn.dynamic_rnn`
as the decoder does not make any assumptions of sequence length and batch
size of the input.
The `dynamic_rnn_decoder` has two modes: training or inference and expects
the user to create seperate functions for each.
Under both training and inference `cell` and `decoder_fn` is expected. Where
the `cell` performs computation at every timestep using the `raw_rnn` and
the `decoder_fn` allows modelling of early stopping, output, state, and next
input and context.
When training the user is expected to supply `inputs`. At every time step a
slice of the supplied input is fed to the `decoder_fn`, which modifies and
returns the input for the next time step.
`sequence_length` is needed at training time, i.e., when `inputs` is not
None, for dynamic unrolling. At test time, when `inputs` is None,
`sequence_length` is not needed.
Under inference `inputs` is expected to be `None` and the input is inferred
solely from the `decoder_fn`.
Args:
cell: An instance of RNNCell.
decoder_fn: A function that takes time, cell state, cell input,
cell output and context state. It returns a early stopping vector,
cell state, next input, cell output and context state.
Examples of decoder_fn can be found in the decoder_fn.py folder.
inputs: The inputs for decoding (embedded format).
If `time_major == False` (default), this must be a `Tensor` of shape:
`[batch_size, max_time, ...]`.
If `time_major == True`, this must be a `Tensor` of shape:
`[max_time, batch_size, ...]`.
The input to `cell` at each time step will be a `Tensor` with dimensions
`[batch_size, ...]`.
sequence_length: (optional) An int32/int64 vector sized `[batch_size]`.
if `inputs` is not None and `sequence_length` is None it is inferred
from the `inputs` as the maximal possible sequence length.
parallel_iterations: (Default: 32). The number of iterations to run in
parallel. Those operations which do not have any temporal dependency
and can be run in parallel, will be. This parameter trades off
time for space. Values >> 1 use more memory but take less time,
while smaller values use less memory but computations take longer.
swap_memory: Transparently swap the tensors produced in forward inference
but needed for back prop from GPU to CPU. This allows training RNNs
which would typically not fit on a single GPU, with very minimal (or no)
performance penalty.
time_major: The shape format of the `inputs` and `outputs` Tensors.
If true, these `Tensors` must be shaped `[max_time, batch_size, depth]`.
If false, these `Tensors` must be shaped `[batch_size, max_time, depth]`.
Using `time_major = True` is a bit more efficient because it avoids
transposes at the beginning and end of the RNN calculation. However,
most TensorFlow data is batch-major, so by default this function
accepts input and emits output in batch-major form.
scope: VariableScope for the `raw_rnn`;
defaults to None.
name: NameScope for the decoder;
defaults to "dynamic_rnn_decoder"
Returns:
A pair (outputs, state) where:
outputs: the RNN output 'Tensor'.
If time_major == False (default), this will be a `Tensor` shaped:
`[batch_size, max_time, cell.output_size]`.
If time_major == True, this will be a `Tensor` shaped:
`[max_time, batch_size, cell.output_size]`.
state: The final state and will be shaped
`[batch_size, cell.state_size]`.
Raises:
ValueError: if inputs is not None and has less than three dimensions.
"""
with ops.name_scope(name, "dynamic_rnn_decoder",
[cell, decoder_fn, inputs, sequence_length,
parallel_iterations, swap_memory, time_major, scope]):
if inputs is not None:
# Convert to tensor
inputs = ops.convert_to_tensor(inputs)
# Test input dimensions
if inputs.get_shape().ndims is not None and (
inputs.get_shape().ndims < 2):
raise ValueError("Inputs must have at least two dimensions")
# Setup of RNN (dimensions, sizes, length, initial state, dtype)
if not time_major:
# [batch, seq, features] -> [seq, batch, features]
inputs = array_ops.transpose(inputs, perm=[1, 0, 2])
dtype = inputs.dtype
# Get data input information
input_depth = int(inputs.get_shape()[2])
#if number_attention_units is not None:
# input_depth += number_attention_units
batch_depth = inputs.get_shape()[1].value
max_time = inputs.get_shape()[0].value
if max_time is None:
max_time = array_ops.shape(inputs)[0]
# Setup decoder inputs as TensorArray
inputs_ta = tensor_array_ops.TensorArray(dtype, size=max_time)
inputs_ta = inputs_ta.unpack(inputs)
def loop_fn(time, cell_output, cell_state, loop_state):
if cell_state is None: # first call, before while loop (in raw_rnn)
if cell_output is not None:
raise ValueError("Expected cell_output to be None when cell_state "
"is None, but saw: %s" % cell_output)
if loop_state is not None:
raise ValueError("Expected loop_state to be None when cell_state "
"is None, but saw: %s" % loop_state)
context_state = None
else: # subsequent calls, inside while loop, after cell excution
if isinstance(loop_state, tuple):
(done, context_state) = loop_state
else:
done = loop_state
context_state = None
# call decoder function
if inputs is not None: # training
# get next_cell_input
if cell_state is None:
read_input = inputs_ta.read(0)
print("first read:", read_input)
(next_done, next_cell_state, next_cell_input, emit_output,
next_context_state) = decoder_fn(time, cell_state, read_input,
cell_output, context_state)
print("first next_cell_input:", next_cell_input)
else:
if batch_depth is not None:
batch_size = batch_depth
else:
batch_size = array_ops.shape(done)[0]
read_input = control_flow_ops.cond(
math_ops.equal(time, max_time),
lambda: array_ops.zeros([batch_size, input_depth], dtype=dtype),
lambda: inputs_ta.read(time))
print("second read:", read_input)
(next_done, next_cell_state, next_cell_input, emit_output,
next_context_state) = decoder_fn(time, cell_state, read_input,
cell_output, context_state, reuse=True)
print("second next_cell_input:", next_cell_input)
else: # inference
# next_cell_input is obtained through decoder_fn
(next_done, next_cell_state, next_cell_input, emit_output,
next_context_state) = decoder_fn(time, cell_state, None, cell_output,
context_state)
# check if we are done
if next_done is None: # training
next_done = time >= sequence_length
# build next_loop_state
if next_context_state is None:
next_loop_state = next_done
else:
next_loop_state = (next_done, next_context_state)
return (next_done, next_cell_input, next_cell_state,
emit_output, next_loop_state)
# Run raw_rnn function
outputs_ta, state, _ = rnn.raw_rnn(
cell, loop_fn, parallel_iterations=parallel_iterations,
swap_memory=swap_memory, scope=scope)
outputs = outputs_ta.pack()
if not time_major:
# [seq, batch, features] -> [batch, seq, features]
outputs = array_ops.transpose(outputs, perm=[1, 0, 2])
return outputs, state
def sequence_loss(logits, targets, weights,
average_across_timesteps=True, average_across_batch=True,
softmax_loss_function=None, name=None):
"""Weighted cross-entropy loss for a sequence of logits (per example).
Args:
logits: A 3D Tensor of shape
[batch_size x sequence_length x num_decoder_symbols] and dtype float.
The logits correspond to the prediction across all classes at each
timestep.
targets: A 2D Tensor of shape [batch_size x sequence_length] and dtype
int. The target represents the true class at each timestep.
weights: A 2D Tensor of shape [batch_size x sequence_length] and dtype
float. Weights constitutes the weighting of each prediction in the
sequence. When using weights as masking set all valid timesteps to 1 and
all padded timesteps to 0.
average_across_timesteps: If set, sum the cost across the sequence
dimension and divide by the cost by the total label weight across
timesteps.
average_across_batch: If set, sum the cost across the batch dimension and
divide the returned cost by the batch size.
softmax_loss_function: Function (inputs-batch, labels-batch) -> loss-batch
to be used instead of the standard softmax (the default if this is None).
name: Optional name for this operation, defaults to "sequence_loss".
Returns:
A scalar float Tensor: The average log-perplexity per symbol (weighted).
Raises:
ValueError: logits does not have 3 dimensions or targets does not have 2
dimensions or weights does not have 2 dimensions.
"""
if len(logits.get_shape()) != 3:
raise ValueError("Logits must be a "
"[batch_size x sequence_length x logits] tensor")
if len(targets.get_shape()) != 2:
raise ValueError("Targets must be a [batch_size x sequence_length] "
"tensor")
if len(weights.get_shape()) != 2:
raise ValueError("Weights must be a [batch_size x sequence_length] "
"tensor")
with ops.name_scope(name, "sequence_loss", [logits, targets, weights]):
num_classes = array_ops.shape(logits)[2]
probs_flat = array_ops.reshape(logits, [-1, num_classes])
targets = array_ops.reshape(targets, [-1])
if softmax_loss_function is None:
crossent = nn_ops.sparse_softmax_cross_entropy_with_logits(
probs_flat, targets)
else:
crossent = softmax_loss_function(probs_flat, targets)
crossent = crossent * array_ops.reshape(weights, [-1])
if average_across_timesteps and average_across_batch:
crossent = math_ops.reduce_sum(crossent)
total_size = math_ops.reduce_sum(weights)
total_size += 1e-12 # to avoid division by 0 for all-0 weights
crossent /= total_size
else:
batch_size = array_ops.shape(logits)[0]
sequence_length = array_ops.shape(logits)[1]
crossent = array_ops.reshape(crossent, [batch_size, sequence_length])
if average_across_timesteps and not average_across_batch:
crossent = math_ops.reduce_sum(crossent, axis=[1])
total_size = math_ops.reduce_sum(weights, axis=[1])
total_size += 1e-12 # to avoid division by 0 for all-0 weights
crossent /= total_size
if not average_across_timesteps and average_across_batch:
crossent = math_ops.reduce_sum(crossent, axis=[0])
total_size = math_ops.reduce_sum(weights, axis=[0])
total_size += 1e-12 # to avoid division by 0 for all-0 weights
crossent /= total_size
return crossent
if __name__ == '__main__':
import tensorflow as tf
# Setting up hyperparameters and general configs
MAX_DIGITS = 5
MIN_DIGITS = 5
NUM_INPUTS = 27
NUM_OUTPUTS = 11 #(0-9 + '#')
BATCH_SIZE = 100
# try various learning rates 1e-2 to 1e-5
LEARNING_RATE = 0.005
X_EMBEDDINGS = 8
t_EMBEDDINGS = 8
NUM_UNITS_ENC = 12
NUM_UNITS_DEC = 12
# Setting up placeholders, these are the tensors that we "feed" to our network
Xs = tf.placeholder(tf.int32, shape=[None, None], name='X_input')
ts_in = tf.placeholder(tf.int32, shape=[None, None], name='t_input_in')
ts_out = tf.placeholder(tf.int32, shape=[None, None], name='t_input_out')
X_len = tf.placeholder(tf.int32, shape=[None], name='X_len')
t_len = tf.placeholder(tf.int32, shape=[None], name='X_len')
t_mask = tf.placeholder(tf.float32, shape=[None, None], name='t_mask')
# first we build the embeddings to make our characters into dense, trainable vectors
X_embeddings = tf.get_variable('X_embeddings', [NUM_INPUTS, X_EMBEDDINGS],
initializer=tf.random_normal_initializer(stddev=0.1))
t_embeddings = tf.get_variable('t_embeddings', [NUM_OUTPUTS, t_EMBEDDINGS],
initializer=tf.random_normal_initializer(stddev=0.1))
# setting up weights for computing the final output
W_out = tf.get_variable('W_out', [NUM_UNITS_DEC, NUM_OUTPUTS])
b_out = tf.get_variable('b_out', [NUM_OUTPUTS])
X_embedded = tf.gather(X_embeddings, Xs, name='embed_X')
t_embedded = tf.gather(t_embeddings, ts_in, name='embed_t')
with tf.variable_scope("rnn") as scope:
# forward encoding
enc_cell = tf.nn.rnn_cell.GRUCell(NUM_UNITS_ENC)
encoder_outputs, enc_state = tf.nn.dynamic_rnn(cell=enc_cell, inputs=X_embedded,
sequence_length=X_len, dtype=tf.float32)
with tf.variable_scope("decoder") as scope:
#with tf.variable_scope("attention") as scope:
context_fn = simple_attention_fn(encoder_outputs, 20, X_len, scope=scope)
output_fn = lambda x: tf.contrib.layers.linear(x, NUM_OUTPUTS, scope=scope)
decoder_cell = tf.nn.rnn_cell.GRUCell(NUM_UNITS_DEC)
decoder_fn_train = simple_decoder_fn_train(
encoder_state=enc_state, context_fn=context_fn)
decoder_outputs_train, decoder_state_train = (
dynamic_rnn_decoder(
cell=decoder_cell,
decoder_fn=decoder_fn_train,
inputs=t_embedded,
sequence_length=t_len,
number_attention_units=20,
scope=scope))
decoder_outputs_train = output_fn(decoder_outputs_train)
scope.reuse_variables()
decoder_fn_inference = (
simple_decoder_fn_inference(
output_fn=output_fn,
encoder_state=enc_state,
embeddings=t_embeddings,
start_of_sequence_id=10,
end_of_sequence_id=10,
#TODO: find out why it goes to +1
maximum_length=MAX_DIGITS,
num_decoder_symbols=NUM_OUTPUTS,
context_fn=context_fn,
dtype=tf.int32))
decoder_outputs_inference, decoder_state_inference = (
dynamic_rnn_decoder(
cell=decoder_cell,
decoder_fn=decoder_fn_inference,
scope=scope))
y = decoder_outputs_train
y_valid = decoder_outputs_inference
# loss and optimize
loss = sequence_loss(y, ts_out, t_mask)
global_step = tf.Variable(0, name='global_step', trainable=False)
optimizer = tf.train.AdamOptimizer(LEARNING_RATE)
# extract gradients for each variable
grads_and_vars = optimizer.compute_gradients(loss)
train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment