Created
April 11, 2016 22:43
-
-
Save nlintz/732c182b072763dd0e9c33bcff09a1e4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tensorflow.models.rnn import rnn_cell | |
from tensorflow.python.framework import dtypes | |
from tensorflow.python.ops import variable_scope as vs | |
from tensorflow.python.ops.math_ops import sigmoid, tanh | |
import numpy as np | |
def softmax(x): | |
e_x = np.exp(x - x.max(axis=1)[:, None]) | |
return e_x / e_x.sum(axis=1)[:, None] | |
def weightedChoice(weights): | |
cs = np.cumsum(weights) | |
idx = sum(cs < np.random.rand()) | |
return idx | |
def normal_init_fn(shape): | |
return tf.Variable(tf.truncated_normal(shape, stddev=0.02)) | |
def proj_init_fn(shape): | |
return tf.Variable(tf.random_uniform(shape, -0.05, 0.05)) | |
def constant_init_fn(shape): | |
return tf.Variable(tf.constant(0.0, shape=shape, dtype=tf.float32)) | |
def identity_initializer(shape, dtype="float32"): | |
if len(shape) == 1: | |
return tf.constant(0., dtype=dtype, shape=shape) | |
elif len(shape) == 2 and shape[0] == shape[1]: | |
return tf.constant(np.identity(shape[0], dtype)) | |
elif len(shape) == 4 and shape[2] == shape[3]: | |
array = np.zeros(shape, dtype=float) | |
cx, cy = shape[0]/2, shape[1]/2 | |
for i in range(shape[2]): | |
array[cx, cy, i, i] = 1 | |
return tf.constant_op.constant(array, dtype=dtype) | |
def merge_sequences(sequences, output_dim): | |
res = [] | |
with tf.variable_scope("merge_sequences"): | |
for i in range(len(sequences[0])): | |
if i > 0: | |
tf.get_variable_scope().reuse_variables() | |
sources = [s[i] for s in sequences] | |
res.append(rnn_cell.linear(sources, output_dim, False)) | |
return res | |
class AttentionComputer(object): | |
def __init__(self, hidden_dim, attended_dim, match_dim=None): | |
self.hidden_dim = hidden_dim | |
if match_dim is None: | |
match_dim = self.hidden_dim | |
self.attended_dim = attended_dim | |
self.match_dim = match_dim | |
self.h_mdim = proj_init_fn([self.hidden_dim, self.match_dim]) | |
self.a_mdim = proj_init_fn([self.attended_dim, self.match_dim]) | |
self.a_wavg = proj_init_fn([self.match_dim, 1]) | |
def compute_attention(self, h_tm1, attended): | |
attended_state = tf.matmul(h_tm1, self.h_mdim) | |
matched_attended = [(tf.matmul(s, self.a_mdim) + attended_state) for | |
s in attended] | |
preweights = [tf.matmul(tf.nn.tanh(s), self.a_wavg) for s in matched_attended] | |
preweights = tf.squeeze(tf.pack(preweights)) | |
weights = tf.transpose(tf.nn.softmax(tf.transpose(preweights))) | |
prewavgs = tf.mul(tf.pack(attended), tf.expand_dims(weights, 2)) | |
wavg = tf.reduce_sum(prewavgs, reduction_indices=[0]) | |
return wavg, tf.transpose(weights) | |
class GRUCellWithContext(rnn_cell.RNNCell): | |
"""Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078).""" | |
def __init__(self, num_units, input_size=None): | |
self._num_units = num_units | |
self._input_size = num_units if input_size is None else input_size | |
@property | |
def input_size(self): | |
return self._input_size | |
@property | |
def output_size(self): | |
return self._num_units | |
@property | |
def state_size(self): | |
return self._num_units | |
def __call__(self, inputs, state, context, scope=None): | |
"""Gated recurrent unit (GRU) with nunits cells.""" | |
with vs.variable_scope(scope or type(self).__name__): # "GRUCell" | |
with vs.variable_scope("Gates"): # Reset gate and update gate. | |
# We start with bias of 1.0 to not reset and not update. | |
r, u = tf.split(1, 2, rnn_cell.linear([inputs, state, context], | |
2 * self._num_units, | |
True, 1.0)) | |
r, u = sigmoid(r), sigmoid(u) | |
with vs.variable_scope("Candidate"): | |
c = tanh(rnn_cell.linear([inputs, r * state, context], | |
self._num_units, True)) | |
new_h = u * state + (1 - u) * c | |
return new_h, new_h | |
class Seq2SeqAttention(object): | |
def __init__(self, vocab_size, hidden_size): | |
self.vocab_size = vocab_size | |
self.hidden_size = hidden_size | |
self.enc_cell = rnn_cell.GRUCell(hidden_size) | |
self.attn_computer = AttentionComputer(hidden_size, hidden_size * 2) | |
with tf.variable_scope("encoder_lookup"): | |
self.lookup = tf.Variable(tf.random_uniform([vocab_size, | |
hidden_size], | |
-1.0, 1.0)) | |
with tf.variable_scope("lookup_feedback"): | |
self.lookup_feedback = tf.Variable(tf.random_uniform([vocab_size, | |
hidden_size], | |
-1.0, 1.0)) | |
with tf.variable_scope("decoder"): | |
self.dec_cell = GRUCellWithContext(hidden_size) | |
def _encode(self, encoder_inputs, dtype=dtypes.float32): | |
""" | |
Args: | |
encoder_inputs: A list of 1D Tensors [batch_size] | |
Returns: tuple (encoder_state, attended) | |
encoder_state: the final state of the encoder rnn | |
2D tensor [batch_size, 2*hidden_size] | |
attended: the attended sequence from the encoder | |
list of 2D Tensors [batch_size, 2*hidden_size] | |
""" | |
encoder_inputs = [tf.nn.embedding_lookup(self.lookup, inp) for inp in encoder_inputs] | |
enc_outputs, fw_state, bw_state = tf.nn.bidirectional_rnn(self.enc_cell, | |
self.enc_cell, | |
encoder_inputs, | |
dtype=dtype) | |
encoder_state = tf.concat(1, [fw_state, bw_state]) | |
return encoder_state, enc_outputs | |
def _get_readouts(self, sources): | |
readout_inp = merge_sequences(sources, self.vocab_size) | |
return readout_inp | |
def _compute_cost_matrix(self, logits, targets): | |
return [tf.nn.sparse_softmax_cross_entropy_with_logits(l, t) for | |
(l, t) in zip(logits, targets)] | |
def cost(self, chars, chars_mask, targets, targets_mask): | |
batch_size = targets[0].get_shape()[0].value | |
targets_mask = tf.slice(targets_mask, [1, 0], [len(targets) - 1, batch_size]) | |
cost_per_char = self.cost_matrix(chars, chars_mask, targets, targets_mask) | |
return tf.reduce_sum(cost_per_char) / tf.reduce_sum(targets_mask) | |
def cost_matrix(self, chars, chars_mask, targets, targets_mask): | |
batch_size = targets[0].get_shape()[0].value | |
with tf.variable_scope("encoder"): | |
encoder_state, attended = self._encode(chars) | |
feedback = [tf.nn.embedding_lookup(self.lookup_feedback, inp) | |
for inp in targets[:-1]] | |
targets = targets[1:] | |
states = [tf.zeros([batch_size, self.hidden_size], "float32")] | |
weights = [] | |
weighted_averages = [] | |
with tf.variable_scope("merge_sequences"): | |
for i, inp in enumerate(feedback): | |
if i > 0: | |
tf.get_variable_scope().reuse_variables() | |
wavg, w = self.attn_computer.compute_attention(states[-1], attended) | |
with tf.variable_scope("decoder"): | |
next_state, _ = self.dec_cell(inp, states[-1], wavg) | |
states.append(next_state) | |
weights.append(w) | |
weighted_averages.append(wavg) | |
tf.add_to_collection("attention_weights", tf.pack(weights)) | |
with tf.variable_scope("readout") as readout_scope: | |
self.readout_scope = readout_scope | |
readouts = self._get_readouts([states[1:], weighted_averages]) | |
costs = self._compute_cost_matrix(readouts, targets) | |
costs = tf.pack(costs) | |
costs = tf.mul(costs, targets_mask) | |
return costs | |
def generate_step(self, attended, previous_output, previous_state): | |
""" | |
Args: | |
chars: A list of int32 1D Tensors [batch_size] | |
previous_output: An int32 1D Tensor [batch_size] | |
previous_state: A float32 2D Tensor (batch_size, hidden_size) | |
Returns: tuple (next_state, next_readout) | |
next_state: A float32 2D Tensor (batch_size, hidden_size) | |
next_readout: A float32 2D Tensor (batch_size, vocab_size) | |
of unnormalized logits | |
""" | |
# with tf.variable_scope("encoder", reuse=True): | |
# encoder_state, attended = self._encode(chars) | |
with tf.variable_scope("merge_sequences"): | |
inp = tf.nn.embedding_lookup(self.lookup_feedback, previous_output) | |
wavg, w = self.attn_computer.compute_attention(previous_state, attended) | |
with tf.variable_scope("decoder", reuse=True): | |
next_state, _ = self.dec_cell(inp, previous_state, wavg) | |
with tf.variable_scope(self.readout_scope, reuse=True): | |
next_readout = self._get_readouts([[next_state], [wavg]]) | |
return next_state, next_readout[0] | |
def _emit_softmax(self, readout): | |
batch_size = len(readout) | |
probs = softmax(readout) | |
generated = [] | |
for i in range(len(probs)): | |
choice = weightedChoice(probs[i]) | |
generated.append(np.eye(self.vocab_size)[choice]) | |
generated = np.array(generated) | |
emit = np.argmax(generated, axis=1) | |
cost = -np.log(probs[np.arange(batch_size), emit]) | |
return emit, cost | |
def sample(self, sess, chars_pl, chars_split_pl, prev_output_pl, prev_state_pl, | |
chars_input, generation_length): | |
batch_size = chars_pl.get_shape()[1].value | |
with tf.variable_scope("encoder", reuse=True): | |
encoder_state, attended = self._encode(chars_split_pl) | |
gen_step_op = self.generate_step(attended, prev_output_pl, prev_state_pl) | |
states = [np.zeros((batch_size, self.hidden_size)).astype("float32")] | |
outputs = [np.zeros((batch_size)).astype("int32")] | |
costs = [] | |
for i in range(generation_length): | |
feed_dict = {chars_pl: chars_input, prev_output_pl: outputs[-1], | |
prev_state_pl: states[-1]} | |
next_state, next_readout = sess.run(gen_step_op, feed_dict=feed_dict) | |
next_output, next_cost = self._emit_softmax(next_readout) | |
states.append(next_state) | |
outputs.append(next_output) | |
costs.append(next_cost) | |
outputs.pop(0) | |
return np.array(outputs).T, np.array(costs).sum(axis=0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment