Skip to content

Instantly share code, notes, and snippets.

@GzuPark
Created March 13, 2017 02:48
Show Gist options
  • Save GzuPark/f943450f494ef5cb2b1c017254634b43 to your computer and use it in GitHub Desktop.
Save GzuPark/f943450f494ef5cb2b1c017254634b43 to your computer and use it in GitHub Desktop.
sequence to sequence model with TensorFlow
'''
# Environment
# 1. HW: Intel i7-4770 RAM 32G, NVIDIA GeForce GTX 960 RAM 4G
# 2. SW: Windows10 x64, Pytion 3.5.2, TensorFlow-gpu 1.0.1
#
# Reference
# 1. TensorFlow tutorials : https://www.tensorflow.org/tutorials/seq2seq
# 2. RNN Encoder–Decoder : https://arxiv.org/pdf/1406.1078.pdf
# 3. Seq2Seq Learning : https://arxiv.org/pdf/1409.3215.pdf
# 4. Practical seq2seq : http://suriyadeepan.github.io/2016-12-31-practical-seq2seq
# 5. RNN in TensorFlow : http://r2rt.com/recurrent-neural-networks-in-tensorflow-ii.html
# 6. Stanford CS20si : http://web.stanford.edu/class/cs20si/lectures/slides_11.pdf
# 7. Stanford CS231n : http://cs231n.github.io/
# 8. codes : https://github.com/hunkim/DeepLearningZeroToAll
# : https://github.com/farizrahman4u/seq2seq
# : https://github.com/ematvey/tensorflow-seq2seq-tutorials
# : https://github.com/tensorflow/models/tree/master/tutorials/rnn
# : https://github.com/nicolas-ivanov/tf_seq2seq_chatbot
# : https://github.com/sherjilozair/char-rnn-tensorflow
#
# TensorFlow, Sequence to Sequence, RNN, Encoder, Decoder, seq2seq, s2s
'''
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
tf.set_random_seed(777) # reporducibility
np.random.seed(77) # reproducibility
# dictionary
alpha = "abcdefghij"
char_set = sorted(list(set(alpha)))
char_dic = {w: i + 2 for i, w in enumerate(char_set)}
# parameters
eos = 1
num_class = len(char_dic)
input_embedded_size = 32
hidden_size = 32
dtype = tf.float32
learning_rate = 0.001
epoches = 3001
test_epoch = 1
batch_size = 16
# global parameter on feed_gen
batch_cnt = 0
# cumulating loss
cum_loss = []
# create data
data = []
for _ in range((epoches + test_epoch) * batch_size):
section = []
for b in range(batch_size):
w = np.random.randint(2, num_class + 2, np.random.randint(4, 9)).tolist()
section.append(w)
data.append(section)
# to transform batch data
def batch_transform(batch_data):
seq_len = [len(seq) for seq in batch_data]
b_size = len(batch_data)
max_seq = max(seq_len)
inp_batch = np.zeros(shape=[b_size, max_seq], dtype=np.int32)
for i, seq in enumerate(batch_data):
for j, val in enumerate(seq):
inp_batch[i, j] = val
inp_batch_pivot = inp_batch.swapaxes(0, 1)
return inp_batch_pivot
# to generate feed_dict each batches
def feed_gen():
global batch_cnt
batch_data = data[batch_cnt]
batch_cnt += 1
enc_input = batch_transform(batch_data)
dec_target = batch_transform([seq + [eos] for seq in batch_data])
dec_input = batch_transform([[eos] + seq for seq in batch_data])
return {encoder_input: enc_input, decoder_input: dec_input, decoder_target: dec_target}
# to trim padding and end of sentence
def to_alpha(num2char):
def trim_zero_eos(trim_data):
data_len = len(trim_data)
list.reverse(trim_data)
for idx, num in enumerate(trim_data):
if num > 1:
end_num = data_len - idx
break
else:
end_num = data_len
revised = []
list.reverse(trim_data)
for j in range(end_num):
revised.append(trim_data[j])
return revised
rev_data = trim_zero_eos(num2char)
rev_len = len(rev_data)
decoded = str()
for i in range(rev_len):
decoded += list(char_dic.keys())[list(char_dic.values()).index(rev_data[i])]
return decoded
encoder_input = tf.placeholder(shape=(None, None), dtype=tf.int32, name='encoder_input')
decoder_target = tf.placeholder(shape=(None, None), dtype=tf.int32, name='decoder_target')
decoder_input = tf.placeholder(shape=(None, None), dtype=tf.int32, name='decoder_input')
embedded = tf.Variable(tf.random_uniform([num_class, input_embedded_size], -1., 1.), dtype=dtype)
encoder_input_embedded = tf.nn.embedding_lookup(embedded, encoder_input)
decoder_input_embedded = tf.nn.embedding_lookup(embedded, decoder_input)
# encoder part
encoder_cell = tf.contrib.rnn.LSTMCell(hidden_size)
encoder_output, encoder_state = tf.nn.dynamic_rnn(encoder_cell, encoder_input_embedded, dtype=dtype, time_major=True)
# decoder part
decoder_cell = tf.contrib.rnn.LSTMCell(hidden_size)
decoder_output, decoder_state = tf.nn.dynamic_rnn(decoder_cell, decoder_input_embedded, initial_state=encoder_state,
dtype=dtype, time_major=True, scope="decoder_session")
decoder_logit = tf.contrib.layers.linear(decoder_output, num_class)
# loss
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
labels=tf.one_hot(decoder_target, depth=num_class, dtype=dtype),
logits=decoder_logit)
loss = tf.reduce_mean(cross_entropy)
# optimizer
train = tf.train.AdamOptimizer(learning_rate).minimize(loss)
prediction = tf.argmax(decoder_logit, axis=2)
# session start
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
# training
for epoch in range(epoches):
feed = feed_gen()
t, l = sess.run([train, loss], feed_dict=feed)
cum_loss.append(l)
if epoch == 0 or epoch % 100 == 0:
print('number of batch:', epoch, '\t,', 'loss:', l)
# predicting
for epoch in range(test_epoch):
feed = feed_gen()
results = sess.run(prediction, feed_dict=feed)
for p, (test, result) in enumerate(zip(feed[encoder_input].T, results.T)):
print(p, ':', to_alpha(test.tolist()), '\t->\t', to_alpha(result.tolist()))
# plotting of cumulative loss
plt.plot(cum_loss)
plt.xlabel("Epoches")
plt.ylabel("Loss")
fname = "plot_loss_adam_1e-3_epoch" + str(epoches) + ".png"
plt.savefig(fname)
plt.show()
'''
number of batch: 0 , loss: 2.01557
number of batch: 100 , loss: 1.09725
number of batch: 200 , loss: 0.962277
number of batch: 300 , loss: 0.573369
number of batch: 400 , loss: 0.700255
number of batch: 500 , loss: 0.589177
number of batch: 2600 , loss: 0.67308
number of batch: 2700 , loss: 0.721901
number of batch: 2800 , loss: 0.606156
number of batch: 2900 , loss: 0.646664
number of batch: 3000 , loss: 0.639117
0 : hcheffc -> hchfff
1 : agii -> agcc
2 : gfcgg -> ggggg
3 : degf -> degf
4 : hjfbebff -> hcbbcfff
5 : hfcg -> ffcg
6 : ebcje -> eccee
7 : bgah -> bgah
8 : cgeeadaa -> eeeaaaaa
9 : fhgh -> fhgh
10 : hdbfbaah -> hdbaaaa
11 : fjdadf -> fddaff
12 : adjiah -> acccah
13 : abhgdjig -> ahhgccgg
14 : cabb -> bbbb
15 : jhhhc -> hhhhc
'''
@GzuPark
Copy link
Author

GzuPark commented Mar 13, 2017

plot_loss_adam_1e-3_epoch3001

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment