Created
March 13, 2017 02:48
-
-
Save GzuPark/f943450f494ef5cb2b1c017254634b43 to your computer and use it in GitHub Desktop.
sequence to sequence model with TensorFlow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
# Environment | |
# 1. HW: Intel i7-4770 RAM 32G, NVIDIA GeForce GTX 960 RAM 4G | |
# 2. SW: Windows10 x64, Pytion 3.5.2, TensorFlow-gpu 1.0.1 | |
# | |
# Reference | |
# 1. TensorFlow tutorials : https://www.tensorflow.org/tutorials/seq2seq | |
# 2. RNN Encoder–Decoder : https://arxiv.org/pdf/1406.1078.pdf | |
# 3. Seq2Seq Learning : https://arxiv.org/pdf/1409.3215.pdf | |
# 4. Practical seq2seq : http://suriyadeepan.github.io/2016-12-31-practical-seq2seq | |
# 5. RNN in TensorFlow : http://r2rt.com/recurrent-neural-networks-in-tensorflow-ii.html | |
# 6. Stanford CS20si : http://web.stanford.edu/class/cs20si/lectures/slides_11.pdf | |
# 7. Stanford CS231n : http://cs231n.github.io/ | |
# 8. codes : https://github.com/hunkim/DeepLearningZeroToAll | |
# : https://github.com/farizrahman4u/seq2seq | |
# : https://github.com/ematvey/tensorflow-seq2seq-tutorials | |
# : https://github.com/tensorflow/models/tree/master/tutorials/rnn | |
# : https://github.com/nicolas-ivanov/tf_seq2seq_chatbot | |
# : https://github.com/sherjilozair/char-rnn-tensorflow | |
# | |
# TensorFlow, Sequence to Sequence, RNN, Encoder, Decoder, seq2seq, s2s | |
''' | |
import tensorflow as tf | |
import numpy as np | |
import matplotlib.pyplot as plt | |
tf.set_random_seed(777) # reporducibility | |
np.random.seed(77) # reproducibility | |
# dictionary | |
alpha = "abcdefghij" | |
char_set = sorted(list(set(alpha))) | |
char_dic = {w: i + 2 for i, w in enumerate(char_set)} | |
# parameters | |
eos = 1 | |
num_class = len(char_dic) | |
input_embedded_size = 32 | |
hidden_size = 32 | |
dtype = tf.float32 | |
learning_rate = 0.001 | |
epoches = 3001 | |
test_epoch = 1 | |
batch_size = 16 | |
# global parameter on feed_gen | |
batch_cnt = 0 | |
# cumulating loss | |
cum_loss = [] | |
# create data | |
data = [] | |
for _ in range((epoches + test_epoch) * batch_size): | |
section = [] | |
for b in range(batch_size): | |
w = np.random.randint(2, num_class + 2, np.random.randint(4, 9)).tolist() | |
section.append(w) | |
data.append(section) | |
# to transform batch data | |
def batch_transform(batch_data): | |
seq_len = [len(seq) for seq in batch_data] | |
b_size = len(batch_data) | |
max_seq = max(seq_len) | |
inp_batch = np.zeros(shape=[b_size, max_seq], dtype=np.int32) | |
for i, seq in enumerate(batch_data): | |
for j, val in enumerate(seq): | |
inp_batch[i, j] = val | |
inp_batch_pivot = inp_batch.swapaxes(0, 1) | |
return inp_batch_pivot | |
# to generate feed_dict each batches | |
def feed_gen(): | |
global batch_cnt | |
batch_data = data[batch_cnt] | |
batch_cnt += 1 | |
enc_input = batch_transform(batch_data) | |
dec_target = batch_transform([seq + [eos] for seq in batch_data]) | |
dec_input = batch_transform([[eos] + seq for seq in batch_data]) | |
return {encoder_input: enc_input, decoder_input: dec_input, decoder_target: dec_target} | |
# to trim padding and end of sentence | |
def to_alpha(num2char): | |
def trim_zero_eos(trim_data): | |
data_len = len(trim_data) | |
list.reverse(trim_data) | |
for idx, num in enumerate(trim_data): | |
if num > 1: | |
end_num = data_len - idx | |
break | |
else: | |
end_num = data_len | |
revised = [] | |
list.reverse(trim_data) | |
for j in range(end_num): | |
revised.append(trim_data[j]) | |
return revised | |
rev_data = trim_zero_eos(num2char) | |
rev_len = len(rev_data) | |
decoded = str() | |
for i in range(rev_len): | |
decoded += list(char_dic.keys())[list(char_dic.values()).index(rev_data[i])] | |
return decoded | |
encoder_input = tf.placeholder(shape=(None, None), dtype=tf.int32, name='encoder_input') | |
decoder_target = tf.placeholder(shape=(None, None), dtype=tf.int32, name='decoder_target') | |
decoder_input = tf.placeholder(shape=(None, None), dtype=tf.int32, name='decoder_input') | |
embedded = tf.Variable(tf.random_uniform([num_class, input_embedded_size], -1., 1.), dtype=dtype) | |
encoder_input_embedded = tf.nn.embedding_lookup(embedded, encoder_input) | |
decoder_input_embedded = tf.nn.embedding_lookup(embedded, decoder_input) | |
# encoder part | |
encoder_cell = tf.contrib.rnn.LSTMCell(hidden_size) | |
encoder_output, encoder_state = tf.nn.dynamic_rnn(encoder_cell, encoder_input_embedded, dtype=dtype, time_major=True) | |
# decoder part | |
decoder_cell = tf.contrib.rnn.LSTMCell(hidden_size) | |
decoder_output, decoder_state = tf.nn.dynamic_rnn(decoder_cell, decoder_input_embedded, initial_state=encoder_state, | |
dtype=dtype, time_major=True, scope="decoder_session") | |
decoder_logit = tf.contrib.layers.linear(decoder_output, num_class) | |
# loss | |
cross_entropy = tf.nn.softmax_cross_entropy_with_logits( | |
labels=tf.one_hot(decoder_target, depth=num_class, dtype=dtype), | |
logits=decoder_logit) | |
loss = tf.reduce_mean(cross_entropy) | |
# optimizer | |
train = tf.train.AdamOptimizer(learning_rate).minimize(loss) | |
prediction = tf.argmax(decoder_logit, axis=2) | |
# session start | |
sess = tf.InteractiveSession() | |
sess.run(tf.global_variables_initializer()) | |
# training | |
for epoch in range(epoches): | |
feed = feed_gen() | |
t, l = sess.run([train, loss], feed_dict=feed) | |
cum_loss.append(l) | |
if epoch == 0 or epoch % 100 == 0: | |
print('number of batch:', epoch, '\t,', 'loss:', l) | |
# predicting | |
for epoch in range(test_epoch): | |
feed = feed_gen() | |
results = sess.run(prediction, feed_dict=feed) | |
for p, (test, result) in enumerate(zip(feed[encoder_input].T, results.T)): | |
print(p, ':', to_alpha(test.tolist()), '\t->\t', to_alpha(result.tolist())) | |
# plotting of cumulative loss | |
plt.plot(cum_loss) | |
plt.xlabel("Epoches") | |
plt.ylabel("Loss") | |
fname = "plot_loss_adam_1e-3_epoch" + str(epoches) + ".png" | |
plt.savefig(fname) | |
plt.show() | |
''' | |
number of batch: 0 , loss: 2.01557 | |
number of batch: 100 , loss: 1.09725 | |
number of batch: 200 , loss: 0.962277 | |
number of batch: 300 , loss: 0.573369 | |
number of batch: 400 , loss: 0.700255 | |
number of batch: 500 , loss: 0.589177 | |
number of batch: 2600 , loss: 0.67308 | |
number of batch: 2700 , loss: 0.721901 | |
number of batch: 2800 , loss: 0.606156 | |
number of batch: 2900 , loss: 0.646664 | |
number of batch: 3000 , loss: 0.639117 | |
0 : hcheffc -> hchfff | |
1 : agii -> agcc | |
2 : gfcgg -> ggggg | |
3 : degf -> degf | |
4 : hjfbebff -> hcbbcfff | |
5 : hfcg -> ffcg | |
6 : ebcje -> eccee | |
7 : bgah -> bgah | |
8 : cgeeadaa -> eeeaaaaa | |
9 : fhgh -> fhgh | |
10 : hdbfbaah -> hdbaaaa | |
11 : fjdadf -> fddaff | |
12 : adjiah -> acccah | |
13 : abhgdjig -> ahhgccgg | |
14 : cabb -> bbbb | |
15 : jhhhc -> hhhhc | |
''' |
Author
GzuPark
commented
Mar 13, 2017
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment