Created
October 31, 2017 01:34
-
-
Save Yevgnen/967f98d0249ebd30f8ab57d92b161218 to your computer and use it in GitHub Desktop.
test_tf.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
import tensorflow as tf | |
import numpy as np | |
def load_data(size, | |
min_len=5, | |
max_len=15, | |
min_word=3, | |
max_word=100, | |
epoch=10, | |
batch_size=64, | |
pad=0, | |
bos=1, | |
eos=2): | |
src = [ | |
np.random.randint(min_word, max_word - 1, | |
np.random.randint(min_len, max_len)).tolist() | |
for _ in range(size) | |
] | |
tgt_in = [[bos] + [xi + 1 for xi in x] for x in src] | |
tgt_out = [[xi + 1 for xi in x] + [eos] for x in src] | |
def _pad(batch): | |
max_len = max(len(x) for x in batch) | |
return np.asarray( | |
[ | |
np.pad( | |
x, (0, max_len - len(x)), | |
mode='constant', | |
constant_values=pad) for x in batch | |
], | |
dtype=np.int64) | |
def _len(batch): | |
return np.asarray([len(x) for x in batch], dtype=np.int64) | |
for e in range(epoch): | |
batch_start = 0 | |
while batch_start < size: | |
batch_end = batch_start + batch_size | |
s, ti, to = (src[batch_start:batch_end], | |
tgt_in[batch_start:batch_end], | |
tgt_out[batch_start:batch_end]) | |
lens, lent = _len(s), _len(ti) | |
s, ti, to = _pad(s).T, _pad(ti).T, _pad(to).T | |
yield s, ti, to, lens, lent | |
batch_start += batch_size | |
def print_sample(x, y, pred): | |
x = x.T | |
y = y.T | |
pred = pred.T | |
for u, v, w in zip(x, y, pred): | |
print('--------') | |
print('S: ', u) | |
print('T: ', v) | |
print('P: ', w) | |
class Seq2seq(object): | |
def __init__(self, vocab_size, embedding_size, hidden_size): | |
src = tf.placeholder(tf.int32, [None, None], name='src') | |
src_len = tf.placeholder(tf.int32, [None], name='src_len') | |
tgt_len = tf.placeholder(tf.int32, [None], name='tgt_len') | |
tgt_in = tf.placeholder(tf.int32, [None, None], name='tgt_in') | |
tgt_out = tf.placeholder(tf.int32, [None, None], name='tgt_out') | |
# Input src_embedding | |
src_embedding = tf.Variable( | |
tf.random_uniform([vocab_size, embedding_size], -1, 1)) | |
tgt_embedding = tf.Variable( | |
tf.random_uniform([vocab_size, embedding_size], -1, 1)) | |
embedding_inputs = tf.nn.embedding_lookup(src_embedding, src) | |
# Encode | |
with tf.variable_scope('encoder'): | |
encoder_cell = tf.nn.rnn_cell.GRUCell(hidden_size) | |
_, encoder_final_state = tf.nn.dynamic_rnn( | |
cell=encoder_cell, | |
inputs=embedding_inputs, | |
sequence_length=src_len, | |
dtype=tf.float32, | |
time_major=True) | |
# Output projection | |
output_weights = tf.get_variable( | |
'output_weights', | |
shape=[hidden_size, vocab_size], | |
initializer=tf.contrib.layers.xavier_initializer()) | |
output_biases = tf.Variable( | |
tf.constant(0.0, shape=[vocab_size]), name='output_biases') | |
# Decode | |
with tf.variable_scope('decoder'): | |
decoder = tf.nn.rnn_cell.GRUCell(hidden_size) | |
decoder_outputs, _, = tf.nn.dynamic_rnn( | |
cell=decoder, | |
inputs=tf.nn.embedding_lookup(tgt_embedding, tgt_in), | |
sequence_length=tgt_len, | |
dtype=tf.float32, | |
time_major=True) | |
# T x B x H2 | |
(decoder_max_steps, decoder_batch_size, | |
decoder_hidden_size) = tf.unstack(tf.shape(decoder_outputs)) | |
# TB x H2 | |
decoder_outputs_flat = tf.reshape(decoder_outputs, | |
(-1, decoder_hidden_size)) | |
# TB x V | |
decoder_logits_flat = tf.add( | |
tf.matmul(decoder_outputs_flat, output_weights), output_biases) | |
# T x B x V | |
decoder_logits = tf.reshape(decoder_logits_flat, | |
(decoder_max_steps, decoder_batch_size, | |
vocab_size)) | |
# T x B | |
predictions = tf.argmax(decoder_logits, axis=2) | |
loss = tf.nn.sparse_softmax_cross_entropy_with_logits( | |
labels=tgt_out, logits=decoder_logits) | |
tgt_weights = tf.sequence_mask( | |
tgt_len, decoder_max_steps, dtype=decoder_logits.dtype) | |
tgt_weights = tf.transpose(tgt_weights) | |
loss = tf.reduce_sum( | |
loss * tgt_weights) / tf.to_float(decoder_batch_size) | |
self.src = src | |
self.src_len = src_len | |
self.tgt_in = tgt_in | |
self.tgt_out = tgt_out | |
self.tgt_len = tgt_len | |
self.predictions = predictions | |
self.loss = loss | |
self.encoder_final_state = encoder_final_state | |
n_data = 40 | |
min_len = 5 | |
max_len = 10 | |
vocab_size = 101 | |
n_samples = 5 | |
epoch = 10000 | |
batch_size = 32 | |
lr = 1e-2 | |
clip = 3 | |
emb_size = 50 | |
hidden_size = 50 | |
num_layers = 1 | |
max_length = 15 | |
loader = load_data( | |
n_data, | |
min_len=min_len, | |
max_len=max_len, | |
max_word=vocab_size, | |
epoch=epoch, | |
batch_size=batch_size) | |
net = Seq2seq(vocab_size, emb_size, hidden_size) | |
train_op = tf.train.AdamOptimizer(learning_rate=lr).minimize(net.loss) | |
init = tf.global_variables_initializer() | |
sess = tf.Session() | |
sess.run(init) | |
for i, (x, yin, yout, lenx, leny) in enumerate(loader): | |
_, loss_val = sess.run( | |
[train_op, net.loss], | |
feed_dict={ | |
net.src: x, | |
net.tgt_in: yin, | |
net.tgt_out: yout, | |
net.src_len: lenx, | |
net.tgt_len: leny | |
}) | |
if i % 10 == 0: | |
print('step: {}, loss: {}'.format(i, loss_val)) | |
if i % 200 == 0 and i > 0: | |
preds = sess.run( | |
net.predictions, | |
feed_dict={ | |
net.src: x[:, :n_samples], | |
net.tgt_in: yin[:, :n_samples], | |
net.tgt_out: yout[:, :n_samples], | |
net.src_len: lenx[:n_samples], | |
net.tgt_len: leny[:n_samples] | |
}) | |
print_sample(x, yout, preds) | |
if i % 1000 == 0 and i > 0: | |
h = sess.run( | |
net.encoder_final_state, | |
feed_dict={ | |
net.src: x, | |
net.tgt_in: yin, | |
net.tgt_out: yout, | |
net.src_len: lenx, | |
net.tgt_len: leny | |
}) | |
print(h) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment