Created
August 7, 2016 09:40
-
-
Save katsugeneration/354d7a45726e715b45fcbbe8ac9957e1 to your computer and use it in GitHub Desktop.
TensorFlow basic RNN sample
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tensorflow as tf | |
import reader | |
import time | |
class PTBModel: | |
@property | |
def optimizer(self): | |
return self._optimizer | |
def __init__(self): | |
# internal setting | |
self._optimizer = tf.train.AdamOptimizer() | |
# config | |
self._batch_size = 20 | |
self._num_steps = 2 | |
self._hidden_size = 2 | |
self._vocab_size = 10000 | |
self._num_layers = 1 | |
self._keep_prob = 0.5 | |
self._max_grad_norm = 1 | |
# input and output variables | |
self._input_data = tf.placeholder(tf.int32, [self._batch_size, self._num_steps]) | |
self._targets = tf.placeholder(tf.int32, [self._batch_size, self._num_steps]) | |
self._initial_state = None | |
self._final_state = None | |
self._cost = None | |
self._train_op = None | |
self._logits = None | |
self._build_graph(True) | |
def _build_graph(self, is_training): | |
# LSTM | |
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(self._hidden_size, forget_bias=0.0, state_is_tuple=True) | |
# add dropout | |
if is_training: | |
lstm_cell = tf.nn.rnn_cell.DropoutWrapper(lstm_cell, output_keep_prob=self._keep_prob) | |
# add multi lyaers | |
cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell] * self._num_layers, state_is_tuple=True) | |
# initial state setup | |
self._initial_state = cell.zero_state(self._batch_size, tf.float32) | |
# Load predefined layer "embedding" | |
with tf.device("/cpu:0"): | |
embedding = tf.get_variable("embedding", [self._vocab_size, self._hidden_size]) | |
inputs = tf.nn.embedding_lookup(embedding, self._input_data) | |
# Add dropout after embedding layer | |
if is_training: | |
inputs = tf.nn.dropout(inputs, self._keep_prob) | |
# Claculate LSTM Layer for in _num_steps | |
outputs = [] | |
state = self._initial_state | |
with tf.variable_scope("RNN"): | |
for time_step in range(self._num_steps): | |
if time_step > 0: tf.get_variable_scope().reuse_variables() | |
(cell_output, state) = cell(inputs[:, time_step, :], state) | |
outputs.append(cell_output) | |
# Final output layer for getting word label | |
output = tf.reshape(tf.concat(1, outputs), [-1, self._hidden_size]) | |
softmax_w = tf.get_variable("softmax_w", [self._hidden_size, self._vocab_size]) | |
softmax_b = tf.get_variable("softmax_b", [self._vocab_size]) | |
self._logits = tf.matmul(output, softmax_w) + softmax_b | |
# loss function | |
loss = tf.nn.seq2seq.sequence_loss_by_example( | |
[self._logits], | |
[tf.reshape(self._targets, [-1])], | |
[tf.ones([self._batch_size * self._num_steps])]) | |
self._cost = cost = tf.reduce_sum(loss) / self._batch_size | |
self._final_state = state | |
if not is_training: | |
return | |
# Gradient calculator | |
tvars = tf.trainable_variables() | |
grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars), self._max_grad_norm) | |
self._train_op = self._optimizer.apply_gradients(zip(grads, tvars)) | |
def _one_loop_setup(self, eval_op, state): | |
fetches = [] | |
fetches.append(self._cost) | |
fetches.append(eval_op) | |
for c, m in self._final_state: # _final_state: ((c1, m1), (c2, m2)) | |
fetches.append(c) | |
fetches.append(m) | |
feed_dict = {} | |
for i, (c, m) in enumerate(self._initial_state): | |
feed_dict[c], feed_dict[m] = state[i] | |
return fetches, feed_dict | |
def _run_epoch(self, session, data, eval_op, verbose=False): | |
epoch_size = ((len(data) // self._batch_size) - 1) // self._num_steps | |
start_time = time.time() | |
costs = 0.0 | |
iters = 0 | |
state = [] | |
# change state to tupple referd by https://github.com/jihunchoi/tensorflow/blob/ptb_use_state_tuple/tensorflow/models/rnn/ptb/ptb_word_lm.py | |
for c, m in self._initial_state: # _initial_state: ((c1, m1), (c2, m2)) | |
state.append((c.eval(), m.eval())) | |
for step, (x, y) in enumerate(reader.ptb_iterator(data, self._batch_size, self._num_steps)): | |
fetches, feed_dict = self._one_loop_setup(eval_op, state) | |
feed_dict[self._input_data] = x | |
feed_dict[self._targets] = y | |
res = session.run(fetches, feed_dict) | |
cost = res[0] | |
state_flat = res[2:] # [c1, m1, c2, m2] | |
state = [state_flat[i:i+2] for i in range(0, len(state_flat), 2)] | |
costs += cost | |
iters += self._num_steps | |
if verbose and step % (epoch_size // 10) == 10: | |
print("%.3f perplexity: %.3f speed: %.0f wps" % | |
(step * 1.0 / epoch_size, np.exp(costs / iters), | |
iters * self._batch_size / (time.time() - start_time))) | |
return np.exp(costs / iters) | |
def train(self, session, data): | |
return self._run_epoch(session, data, self._train_op, verbose=True) | |
def evaluate(self, session, data): | |
return self._run_epoch(session, data, tf.no_op()) | |
def predict(self, session, data, word_to_id): | |
# prediction = tf.argmax(self._logits, 1) | |
state = [] | |
def _get_word_fromid(word_to_id, search_id): | |
for word, wid in word_to_id.items(): | |
if wid == search_id: | |
return word | |
for c, m in self._initial_state: # _initial_state: ((c1, m1), (c2, m2)) | |
state.append((c.eval(), m.eval())) | |
for step, (x, y) in enumerate(reader.ptb_iterator(data, self._batch_size, self._num_steps)): | |
fetches, feed_dict = self._one_loop_setup(self._logits, state) | |
feed_dict[self._input_data] = x | |
feed_dict[self._targets] = y | |
res = session.run(fetches, feed_dict) | |
state_flat = res[2:] # [c1, m1, c2, m2] | |
state = [state_flat[i:i+2] for i in range(0, len(state_flat), 2)] | |
label = res[1] | |
label = np.argmax(label, 1) | |
y = np.reshape(y, (self._batch_size * self._num_steps)) | |
for pre, real in zip(label, y): | |
print("Predict %s : Real %s" % (_get_word_fromid(word_to_id, pre), _get_word_fromid(word_to_id, real))) | |
def main(): | |
print("start ptb") | |
raw_data = reader.ptb_raw_data("") | |
train_data, valid_data, test_data, word_to_id = raw_data | |
with tf.Graph().as_default(), tf.Session() as session: | |
initializer = tf.random_uniform_initializer(-0.04, 0.04) | |
with tf.variable_scope("model", reuse=None, initializer=initializer): | |
model = PTBModel() | |
tf.initialize_all_variables().run() | |
for i in range(20): | |
print("Epoch: %d" % (i + 1)) | |
train_perplexity = model.train(session, train_data) | |
print("Epoch: %d Train Perplexity: %.3f" % (i + 1, train_perplexity)) | |
valid_perplexity = model.evaluate(session, valid_data) | |
print("Epoch: %d Valid Perplexity: %.3f" % (i + 1, valid_perplexity)) | |
model.predict(session, test_data, word_to_id) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment