Skip to content

Instantly share code, notes, and snippets.

@alrojo
Last active February 19, 2018 06:25
Show Gist options
  • Save alrojo/a3124ba7a756c95024f0acf3a200cd2b to your computer and use it in GitHub Desktop.
Save alrojo/a3124ba7a756c95024f0acf3a200cd2b to your computer and use it in GitHub Desktop.
rnn_decoder_attention.md
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import rnn
from tensorflow.python.ops import nn
from tensorflow.python.ops import init_ops

from tensorflow.contrib.layers import conv2d, fully_connected
from functools import partial

def mask(sequence_lengths):
  """ Mask function used by recurrent decoder with attention.

  Given a vector with varying lengths, produces an explicit matrices with
    True/False values. E.g.
    > mask([1,2,3])
    [[True, False, False],
     [True, True, False],
     [True, True, True]]
  Args:
    sequence_lengths: An int32/int64 vector of size n.
  Return:
    true_false: A [n, max_len(sequence_lengths)] sized Tensor.
  """
  # based on this SO answer: http://stackoverflow.com/a/34138336/118173
  batch_size = array_ops.shape(sequence_lengths)[0]
  max_len = math_ops.reduce_max(sequence_lengths)

  lengths_transposed = array_ops.expand_dims(sequence_lengths, 1)

  rng = math_ops.range(max_len)
  rng_row = array_ops.expand_dims(rng, 0)
  true_false = math_ops.less(rng_row, lengths_transposed)
  return true_false

def rnn_decoder_attention(cell, num_attention_units,
                          attention_inputs, decoder_inputs, initial_state,
                          decoder_length, decoder_fn, attention_length=None,
                          weight_initializer=None, encoder_projection=None,
                          parallel_iterations=None, swap_memory=False,
                          time_major=False, scope=None):
  """ Dynamic RNN decoder with attention for a sequence-to-sequence model
  specified by RNNCell 'cell'.

  The 'rnn_decoder_attention' is similar to the
  'tf.python.ops.rnn.dynamic_rnn'. As the decoder does not make any
  assumptions of sequence length of the input or how many steps it can decode,
  since 'dynamic_rnn_decoder' uses dynamic unrolling. This allows
  'attention_inputs' and 'decoder_inputs' to have [None] in the sequence
  length of the decoder inputs.

  The parameters attention_inputs and  decoder_inputs are nessesary for both
  training and evaluation. During training all of attention_inputs and a slice
  of decoder_inputs is feed at every timestep. During evaluation
  decoder_inputs it is only feed at time==0, as the decoder needs the
  'start-of-sequence' symbol, known from Bahdanau et al., 2014
  https://arxiv.org/abs/1409.0473, at the beginning of decoding.

  The parameter  initial_state is used to initialize the decoder RNN.
  As default a linear transformation with a tf.nn.tanh linearity is used.
  By a linear transformation we can have different number of units between
  the encoder and decoder.

  The parameter sequence length is nessesary as it determines how many
  timesteps to decode for each sample. TODO: Could make it optional for
  training.

  The parameter attention_length is used for masking the alpha values
  computes over the attention_input. Is set to None (default) no mask is
  computed.

  Extensions of interest:
  - Support time_major=True for attention_input (not using conv2D)
  - Look into rnn.raw_rnn so we don't need to handle zero states
  - Make 'alpha' usable
  - Don't use decoder_inputs for evaluation
  - Make a attention class to allow custom attention functions
  - Multi-layered decoder
  - Beam search

  Args:
    cell: An instance of RNNCell.
    num_attention_units: The number of units used for attention.
    attention_inputs: The encoded inputs.
      The input used to attend over at every timestep, must be of size
      [batch_size, seq_len, features]
    decoder_inputs: The inputs for decoding (embedded format).
      If `time_major == False` (default), this must be a `Tensor` of shape:
        `[batch_size, max_time, ...]`.
      If `time_major == True`, this must be a `Tensor` of shape:
        `[max_time, batch_size, ...]`.
      The input to `cell` at each time step will be a `Tensor` with dimensions
        `[batch_size, ...]`.
    initial_state: An initial state for the decoder's RNN.
      Must be [batch_size, num_features], where num_features does not have to
      match the cell.state_size. As a projection is performed at the beginning
      of the decoding.
    decoder_length: An int32/int64 vector sized `[batch_size]`.
    decoder_fn: A function that takes a state and returns an embedding.
      Here is an example of a `decoder_fn`:
      def decoder_fn(embeddings, weight, bias):
        def dec_fn(state):
          prev = tf.matmul(state, weight) + bias
          return tf.gather(embeddings, tf.argmax(prev, 1))
        return dec_fn
    encoder_projection: (optional) given that the encoder might have a
      different size than the decoder, we project the intial state as
      described in Bahdanau, 2014 (https://arxiv.org/abs/1409.0473).
      The optional `encoder_projection` is a
      `tf.contrib.layers.fully_connected` with
      `activation_fn=tf.python.ops.nn.tanh`.
    weight_initializer: (optional) An initializer used for attention.
    attention_length: (optional) An int32/int64 vector sized `[batch_size]`.
    parallel_iterations: (Default: 32).  The number of iterations to run in
      parallel.  Those operations which do not have any temporal dependency
      and can be run in parallel, will be.  This parameter trades off
      time for space.  Values >> 1 use more memory but take less time,
      while smaller values use less memory but computations take longer.
    swap_memory: Transparently swap the tensors produced in forward inference
      but needed for back prop from GPU to CPU.  This allows training RNNs
      which would typically not fit on a single GPU, with very minimal (or no)
      performance penalty.
    time_major: The shape format of the `inputs` and `outputs` Tensors.
      If true, these `Tensors` must be shaped `[max_time, batch_size, depth]`.
      If false, these `Tensors` must be shaped `[batch_size, max_time, depth]`.
      Using `time_major = True` is a bit more efficient because it avoids
      transposes at the beginning and end of the RNN calculation.  However,
      most TensorFlow data is batch-major, so by default this function
      accepts input and emits output in batch-major form.
    scope: VariableScope for the created subgraph;
      defaults to "decoder_attention".
  Returns:
    A pair (outputs_train, outputs_eval) where:
      outputs_train/eval: the RNN output 'Tensor'
        If time_major == False (default), this will be a `Tensor` shaped:
          `[batch_size, max_time, cell.output_size]`.
        If time_major == True, this will be a `Tensor` shaped:
          `[max_time, batch_size, cell.output_size]`.
      NOTICE: output_train is commonly used for calculating loss.
  Raises:
    #TODO Put up some raises
  """

  with vs.variable_scope(scope or "decoder") as varscope:
    # Project initial_state as described in Bahdanau et al. 2014
    # https://arxiv.org/abs/1409.0473
    if encoder_projection is None:
      encoder_projection = partial(fully_connected, activation_fn=math_ops.tanh)
    state = encoder_projection(initial_state, cell.output_size)
    # Setup of RNN (dimensions, sizes, length, initial state, dtype)
    # Setup dtype
    dtype = state.dtype
    if not time_major:
      # [batch, seq, features] -> [seq, batch, features]
      decoder_inputs = array_ops.transpose(decoder_inputs, perm=[1, 0, 2])
    # Get data input information
    batch_size = array_ops.shape(decoder_inputs)[1]
    attention_input_depth = int(attention_inputs.get_shape()[2])
    decoder_input_depth = int(decoder_inputs.get_shape()[2])
    attention_max_length = array_ops.shape(attention_inputs)[1]
    # Setup decoder inputs as TensorArray
    decoder_inputs_ta = tensor_array_ops.TensorArray(dtype, size=0,
                                                     dynamic_size=True)
    decoder_inputs_ta = decoder_inputs_ta.unpack(decoder_inputs)

    print "attention_input_depth,", attention_input_depth
    print "decoder_input_depth,", decoder_input_depth
    # Setup attention weight
    if weight_initializer is None:
      weight_initializer = init_ops.truncated_normal_initializer(stddev=0.1) 
    with vs.variable_scope("attention") as attnscope:
      v_a = vs.get_variable('v_a',
                            shape=[num_attention_units],
                            initializer=weight_initializer)
      W_a = vs.get_variable('W_a',
                            shape=[cell.output_size, num_attention_units],
                            initializer=weight_initializer)

    # Encode attention_inputs for attention
    hidden = array_ops.reshape(attention_inputs, [-1, attention_max_length,
                                                   1, attention_input_depth])
    part1 = conv2d(hidden, num_attention_units, (1, 1))
    part1 = array_ops.squeeze(part1, [2])  # Squeeze out the third dimension

    def context_fn(state, inp):
      with vs.variable_scope("attention") as attnscope:
        part2 = math_ops.matmul(state, W_a) # [batch, attn_units]
        part2 = array_ops.expand_dims(part2, 1) # [batch, 1, attn_units]
        cmb_attn = part1 + part2 # [batch, seq, attn_units]
        e = math_ops.reduce_sum(v_a * math_ops.tanh(cmb_attn), [2]) # [batch, seq]
        alpha = nn.softmax(e)
        # Mask
        if attention_length is not None:
          alpha = math_ops.to_float(mask(attention_length)) * alpha
          alpha = alpha / math_ops.reduce_sum(alpha, [1], keep_dims=True)
        # [batch, features]
        context = math_ops.reduce_sum(array_ops.expand_dims(alpha, 2)
                                      * attention_inputs, [1])
        context.set_shape([None, attention_input_depth])
        con = array_ops.concat(1, (inp, context))
        print "con,", con.get_shape()
        return con, alpha

    # loop function train
    def loop_fn_train(time, cell_output, cell_state, loop_state):
      print "@@@TRAIN@@@"
      emit_output = cell_output
      if cell_output is None:
        next_cell_state = state # Use projection of prev encoder state
      else:
        next_cell_state = cell_state
      elements_finished = (time >= decoder_length) #TODO handle seq_len=None
      finished = math_ops.reduce_all(elements_finished)

      next_input, _ = control_flow_ops.cond(
          finished,
          # Handle zero states
          lambda: (array_ops.zeros([batch_size,
              decoder_input_depth+attention_input_depth], dtype=dtype),
                   array_ops.zeros([batch_size, attention_max_length],
              dtype=dtype)),
          # Read data and calculate attention
          lambda: context_fn(next_cell_state, decoder_inputs_ta.read(time)))
      next_input.set_shape([None, decoder_input_depth+attention_input_depth]) # it loses its shape at some point
      next_loop_state = None
      return (elements_finished, next_input, next_cell_state,
              emit_output, next_loop_state)

    # loop function eval
    def loop_fn_eval(time, cell_output, cell_state, loop_state):
      print "@@@EVAL@@@"
      emit_output = cell_output
      if cell_output is None:
        next_cell_state = state
      else:
        next_cell_state = cell_state
      elements_finished = (time >= decoder_length) #TODO handle seq_len=None
      finished = math_ops.reduce_all(elements_finished)
      varscope.reuse_variables()
      next_input, _ = control_flow_ops.cond(
          finished,
          # Handle zero states
          lambda: (array_ops.zeros([batch_size,
              decoder_input_depth+attention_input_depth], dtype=dtype),
                   array_ops.zeros([batch_size, attention_max_length],
              dtype=dtype)),
          # Read data and calculate attention
          lambda: control_flow_ops.cond(math_ops.greater(time, 0),
              lambda: context_fn(next_cell_state,
                  decoder_fn(next_cell_state)),
              lambda: context_fn(next_cell_state,
                  decoder_inputs_ta.read(0))))
      # next_input loses its shape at some point
      next_input.set_shape([None, decoder_input_depth+attention_input_depth]) 
      next_loop_state = None
      print "next_input,", next_input.get_shape()
      return (elements_finished, next_input, next_cell_state,
              emit_output, next_loop_state)

    # Run raw_rnn function
    outputs_ta_train, _, _ = rnn.raw_rnn(cell, loop_fn_train)
    varscope.reuse_variables()
    outputs_ta_eval, _, _ = rnn.raw_rnn(cell, loop_fn_eval)
    outputs_train = outputs_ta_train.pack()
    outputs_eval = outputs_ta_eval.pack()
    if not time_major:
      outputs_train = array_ops.transpose(outputs_train, perm=[1, 0, 2])
      outputs_eval = array_ops.transpose(outputs_eval, perm=[1, 0, 2])
    return outputs_train, outputs_eval
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment