Skip to content

Instantly share code, notes, and snippets.

@nlintz
Created April 11, 2016 22:43
Show Gist options
  • Save nlintz/732c182b072763dd0e9c33bcff09a1e4 to your computer and use it in GitHub Desktop.
Save nlintz/732c182b072763dd0e9c33bcff09a1e4 to your computer and use it in GitHub Desktop.
import tensorflow as tf
from tensorflow.models.rnn import rnn_cell
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops.math_ops import sigmoid, tanh
import numpy as np
def softmax(x):
e_x = np.exp(x - x.max(axis=1)[:, None])
return e_x / e_x.sum(axis=1)[:, None]
def weightedChoice(weights):
cs = np.cumsum(weights)
idx = sum(cs < np.random.rand())
return idx
def normal_init_fn(shape):
return tf.Variable(tf.truncated_normal(shape, stddev=0.02))
def proj_init_fn(shape):
return tf.Variable(tf.random_uniform(shape, -0.05, 0.05))
def constant_init_fn(shape):
return tf.Variable(tf.constant(0.0, shape=shape, dtype=tf.float32))
def identity_initializer(shape, dtype="float32"):
if len(shape) == 1:
return tf.constant(0., dtype=dtype, shape=shape)
elif len(shape) == 2 and shape[0] == shape[1]:
return tf.constant(np.identity(shape[0], dtype))
elif len(shape) == 4 and shape[2] == shape[3]:
array = np.zeros(shape, dtype=float)
cx, cy = shape[0]/2, shape[1]/2
for i in range(shape[2]):
array[cx, cy, i, i] = 1
return tf.constant_op.constant(array, dtype=dtype)
def merge_sequences(sequences, output_dim):
res = []
with tf.variable_scope("merge_sequences"):
for i in range(len(sequences[0])):
if i > 0:
tf.get_variable_scope().reuse_variables()
sources = [s[i] for s in sequences]
res.append(rnn_cell.linear(sources, output_dim, False))
return res
class AttentionComputer(object):
def __init__(self, hidden_dim, attended_dim, match_dim=None):
self.hidden_dim = hidden_dim
if match_dim is None:
match_dim = self.hidden_dim
self.attended_dim = attended_dim
self.match_dim = match_dim
self.h_mdim = proj_init_fn([self.hidden_dim, self.match_dim])
self.a_mdim = proj_init_fn([self.attended_dim, self.match_dim])
self.a_wavg = proj_init_fn([self.match_dim, 1])
def compute_attention(self, h_tm1, attended):
attended_state = tf.matmul(h_tm1, self.h_mdim)
matched_attended = [(tf.matmul(s, self.a_mdim) + attended_state) for
s in attended]
preweights = [tf.matmul(tf.nn.tanh(s), self.a_wavg) for s in matched_attended]
preweights = tf.squeeze(tf.pack(preweights))
weights = tf.transpose(tf.nn.softmax(tf.transpose(preweights)))
prewavgs = tf.mul(tf.pack(attended), tf.expand_dims(weights, 2))
wavg = tf.reduce_sum(prewavgs, reduction_indices=[0])
return wavg, tf.transpose(weights)
class GRUCellWithContext(rnn_cell.RNNCell):
"""Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078)."""
def __init__(self, num_units, input_size=None):
self._num_units = num_units
self._input_size = num_units if input_size is None else input_size
@property
def input_size(self):
return self._input_size
@property
def output_size(self):
return self._num_units
@property
def state_size(self):
return self._num_units
def __call__(self, inputs, state, context, scope=None):
"""Gated recurrent unit (GRU) with nunits cells."""
with vs.variable_scope(scope or type(self).__name__): # "GRUCell"
with vs.variable_scope("Gates"): # Reset gate and update gate.
# We start with bias of 1.0 to not reset and not update.
r, u = tf.split(1, 2, rnn_cell.linear([inputs, state, context],
2 * self._num_units,
True, 1.0))
r, u = sigmoid(r), sigmoid(u)
with vs.variable_scope("Candidate"):
c = tanh(rnn_cell.linear([inputs, r * state, context],
self._num_units, True))
new_h = u * state + (1 - u) * c
return new_h, new_h
class Seq2SeqAttention(object):
def __init__(self, vocab_size, hidden_size):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.enc_cell = rnn_cell.GRUCell(hidden_size)
self.attn_computer = AttentionComputer(hidden_size, hidden_size * 2)
with tf.variable_scope("encoder_lookup"):
self.lookup = tf.Variable(tf.random_uniform([vocab_size,
hidden_size],
-1.0, 1.0))
with tf.variable_scope("lookup_feedback"):
self.lookup_feedback = tf.Variable(tf.random_uniform([vocab_size,
hidden_size],
-1.0, 1.0))
with tf.variable_scope("decoder"):
self.dec_cell = GRUCellWithContext(hidden_size)
def _encode(self, encoder_inputs, dtype=dtypes.float32):
"""
Args:
encoder_inputs: A list of 1D Tensors [batch_size]
Returns: tuple (encoder_state, attended)
encoder_state: the final state of the encoder rnn
2D tensor [batch_size, 2*hidden_size]
attended: the attended sequence from the encoder
list of 2D Tensors [batch_size, 2*hidden_size]
"""
encoder_inputs = [tf.nn.embedding_lookup(self.lookup, inp) for inp in encoder_inputs]
enc_outputs, fw_state, bw_state = tf.nn.bidirectional_rnn(self.enc_cell,
self.enc_cell,
encoder_inputs,
dtype=dtype)
encoder_state = tf.concat(1, [fw_state, bw_state])
return encoder_state, enc_outputs
def _get_readouts(self, sources):
readout_inp = merge_sequences(sources, self.vocab_size)
return readout_inp
def _compute_cost_matrix(self, logits, targets):
return [tf.nn.sparse_softmax_cross_entropy_with_logits(l, t) for
(l, t) in zip(logits, targets)]
def cost(self, chars, chars_mask, targets, targets_mask):
batch_size = targets[0].get_shape()[0].value
targets_mask = tf.slice(targets_mask, [1, 0], [len(targets) - 1, batch_size])
cost_per_char = self.cost_matrix(chars, chars_mask, targets, targets_mask)
return tf.reduce_sum(cost_per_char) / tf.reduce_sum(targets_mask)
def cost_matrix(self, chars, chars_mask, targets, targets_mask):
batch_size = targets[0].get_shape()[0].value
with tf.variable_scope("encoder"):
encoder_state, attended = self._encode(chars)
feedback = [tf.nn.embedding_lookup(self.lookup_feedback, inp)
for inp in targets[:-1]]
targets = targets[1:]
states = [tf.zeros([batch_size, self.hidden_size], "float32")]
weights = []
weighted_averages = []
with tf.variable_scope("merge_sequences"):
for i, inp in enumerate(feedback):
if i > 0:
tf.get_variable_scope().reuse_variables()
wavg, w = self.attn_computer.compute_attention(states[-1], attended)
with tf.variable_scope("decoder"):
next_state, _ = self.dec_cell(inp, states[-1], wavg)
states.append(next_state)
weights.append(w)
weighted_averages.append(wavg)
tf.add_to_collection("attention_weights", tf.pack(weights))
with tf.variable_scope("readout") as readout_scope:
self.readout_scope = readout_scope
readouts = self._get_readouts([states[1:], weighted_averages])
costs = self._compute_cost_matrix(readouts, targets)
costs = tf.pack(costs)
costs = tf.mul(costs, targets_mask)
return costs
def generate_step(self, attended, previous_output, previous_state):
"""
Args:
chars: A list of int32 1D Tensors [batch_size]
previous_output: An int32 1D Tensor [batch_size]
previous_state: A float32 2D Tensor (batch_size, hidden_size)
Returns: tuple (next_state, next_readout)
next_state: A float32 2D Tensor (batch_size, hidden_size)
next_readout: A float32 2D Tensor (batch_size, vocab_size)
of unnormalized logits
"""
# with tf.variable_scope("encoder", reuse=True):
# encoder_state, attended = self._encode(chars)
with tf.variable_scope("merge_sequences"):
inp = tf.nn.embedding_lookup(self.lookup_feedback, previous_output)
wavg, w = self.attn_computer.compute_attention(previous_state, attended)
with tf.variable_scope("decoder", reuse=True):
next_state, _ = self.dec_cell(inp, previous_state, wavg)
with tf.variable_scope(self.readout_scope, reuse=True):
next_readout = self._get_readouts([[next_state], [wavg]])
return next_state, next_readout[0]
def _emit_softmax(self, readout):
batch_size = len(readout)
probs = softmax(readout)
generated = []
for i in range(len(probs)):
choice = weightedChoice(probs[i])
generated.append(np.eye(self.vocab_size)[choice])
generated = np.array(generated)
emit = np.argmax(generated, axis=1)
cost = -np.log(probs[np.arange(batch_size), emit])
return emit, cost
def sample(self, sess, chars_pl, chars_split_pl, prev_output_pl, prev_state_pl,
chars_input, generation_length):
batch_size = chars_pl.get_shape()[1].value
with tf.variable_scope("encoder", reuse=True):
encoder_state, attended = self._encode(chars_split_pl)
gen_step_op = self.generate_step(attended, prev_output_pl, prev_state_pl)
states = [np.zeros((batch_size, self.hidden_size)).astype("float32")]
outputs = [np.zeros((batch_size)).astype("int32")]
costs = []
for i in range(generation_length):
feed_dict = {chars_pl: chars_input, prev_output_pl: outputs[-1],
prev_state_pl: states[-1]}
next_state, next_readout = sess.run(gen_step_op, feed_dict=feed_dict)
next_output, next_cost = self._emit_softmax(next_readout)
states.append(next_state)
outputs.append(next_output)
costs.append(next_cost)
outputs.pop(0)
return np.array(outputs).T, np.array(costs).sum(axis=0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment