Created
July 12, 2017 06:36
-
-
Save yujuwon/d9e380c43f37bab319dcb6f3e9d51e8d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import tensorflow as tf | |
| import numpy as np | |
| from cell.lstm import BN_LSTMCell | |
| import sys | |
| class Seq2Seq(object): | |
| def __init__(self, xseq_len, yseq_len, | |
| xvocab_size, yvocab_size, | |
| emb_dim, num_layers, ckpt_path, | |
| lr=0.001, | |
| epochs=100000, model_name='seq2seq_model'): | |
| # attach these arguments to self | |
| self.xseq_len = xseq_len | |
| self.yseq_len = yseq_len | |
| self.ckpt_path = ckpt_path | |
| self.epochs = epochs | |
| self.model_name = model_name | |
| # build thy graph | |
| # attach any part of the graph that needs to be exposed, to the self | |
| def __graph__(): | |
| # placeholders | |
| tf.reset_default_graph() | |
| self.training = tf.placeholder(tf.bool) | |
| self.val_ip = tf.placeholder(tf.float32) | |
| self.enc_ip = [ tf.placeholder(shape=[None,], | |
| dtype=tf.int64, | |
| name='ei_{}'.format(t)) for t in range(xseq_len) ] | |
| # labels that represent the real outputs | |
| # traininig label data(answer) | |
| self.labels = [ tf.placeholder(shape=[None,], | |
| dtype=tf.int64, | |
| name='ei_{}'.format(t)) for t in range(yseq_len) ] | |
| # decoder inputs : 'GO' + [ y1, y2, ... y_t-1 ] | |
| self.dec_ip = [ tf.zeros_like(self.enc_ip[0], dtype=tf.int64, name='GO') ] + self.labels[:-1] | |
| # Basic LSTM cell wrapped in Dropout Wrapper | |
| self.keep_prob = tf.placeholder(tf.float32) | |
| basic_cell = tf.contrib.rnn.core_rnn_cell.DropoutWrapper(BN_LSTMCell(emb_dim, self.training), output_keep_prob=0.5) | |
| # stack cells together : n layered model | |
| stacked_lstm = tf.contrib.rnn.core_rnn_cell.MultiRNNCell([basic_cell]*num_layers, state_is_tuple=True) | |
| # for parameter sharing between training model | |
| # and testing model | |
| with tf.variable_scope('decoder') as scope: | |
| # build the seq2seq model | |
| self.decode_outputs, self.decode_states = tf.contrib.legacy_seq2seq.embedding_rnn_seq2seq(self.enc_ip,self.dec_ip, stacked_lstm, | |
| xvocab_size, yvocab_size, emb_dim) | |
| # share parameters | |
| scope.reuse_variables() | |
| # testing model, where output of previous timestep is fed as input | |
| # to the next timestep | |
| self.decode_outputs_test, self.decode_states_test = tf.contrib.legacy_seq2seq.embedding_rnn_seq2seq( | |
| self.enc_ip, self.dec_ip, stacked_lstm, xvocab_size, yvocab_size,emb_dim, | |
| feed_previous=True) | |
| # now, for training, | |
| # build loss function | |
| # weighted loss | |
| loss_weights = [ tf.ones_like(label, dtype=tf.float32) for label in self.labels ] | |
| self.loss = tf.contrib.legacy_seq2seq.sequence_loss(self.decode_outputs, self.labels, loss_weights, yvocab_size) | |
| self.val_loss = tf.reduce_mean(self.val_ip) | |
| # train op to minimize the loss | |
| self.train_op = tf.train.AdamOptimizer(learning_rate=lr).minimize(self.loss) | |
| sys.stdout.write('<log> Building Graph ') | |
| # build comput graph | |
| __graph__() | |
| sys.stdout.write('</log>') | |
| ''' | |
| Training and Evaluation | |
| ''' | |
| # get the feed dictionary | |
| def get_feed(self, X, Y, training): | |
| feed_dict = {self.enc_ip[t]: X[t] for t in range(self.xseq_len)} | |
| feed_dict.update({self.labels[t]: Y[t] for t in range(self.yseq_len)}) | |
| feed_dict[self.training] = training # dropout prob | |
| return feed_dict | |
| # run one batch for training | |
| def train_batch(self, sess, train_batch_gen, merged): | |
| # get batches | |
| batchX, batchY = train_batch_gen.__next__() | |
| # build feed | |
| feed_dict = self.get_feed(batchX, batchY, training=True) | |
| result = sess.run([merged, self.train_op, self.loss], feed_dict) | |
| return result[0] | |
| def eval_step(self, sess, eval_batch_gen): | |
| # get batches | |
| batchX, batchY = eval_batch_gen.__next__() | |
| # build feed | |
| feed_dict = self.get_feed(batchX, batchY, training=False) | |
| loss_v, dec_op_v = sess.run([self.loss, self.decode_outputs_test], feed_dict) | |
| # dec_op_v is a list; also need to transpose 0,1 indices | |
| # (interchange batch_size and timesteps dimensions | |
| dec_op_v = np.array(dec_op_v).transpose([1,0,2]) | |
| return loss_v, dec_op_v, batchX, batchY | |
| # evaluate 'num_batches' batches | |
| def eval_batches(self, sess, eval_batch_gen, num_batches): | |
| losses = [] | |
| for i in range(num_batches): | |
| loss_v, dec_op_v, batchX, batchY = self.eval_step(sess, eval_batch_gen) | |
| losses.append(loss_v) | |
| return losses | |
| # finally the train function that | |
| # runs the train_op in a session | |
| # evaluates on valid set periodically | |
| # prints statistics | |
| # train과 validation을 실행한다. | |
| def train(self, train_set, valid_set, sess=None ): | |
| # we need to save the model periodically | |
| saver = tf.train.Saver() | |
| # if no session is given | |
| if not sess: | |
| # create a session | |
| sess = tf.Session() | |
| # init all variables | |
| sess.run(tf.global_variables_initializer()) | |
| train_loss_summary = tf.summary.scalar("train_loss", self.loss) | |
| val_loss_summary = tf.summary.scalar("val_loss", self.val_loss) | |
| merged = tf.summary.merge_all() | |
| sys.stdout.write('\n<log> Training started </log>\n') | |
| writer = tf.summary.FileWriter("./logs") | |
| # run M epochs | |
| for i in range(self.epochs): | |
| try: | |
| summary_str = self.train_batch(sess, train_set, train_loss_summary) | |
| writer.add_summary(summary_str, i) | |
| if i and i% (self.epochs//100) == 0: # TODO : make this tunable by the user | |
| # save model to disk | |
| saver.save(sess, self.ckpt_path + self.model_name + '.ckpt', global_step=i) | |
| # evaluate to get validation loss | |
| val_feed = self.eval_batches(sess, valid_set, 16) # TODO : and this | |
| summary_str, loss = sess.run([val_loss_summary, self.val_loss], feed_dict={self.val_ip : val_feed}) | |
| writer.add_summary(summary_str, i) | |
| # print stats | |
| print('\nModel saved to disk at iteration #{}'.format(i)) | |
| print('val loss : {0:.6f}'.format(loss)) | |
| sys.stdout.flush() | |
| except KeyboardInterrupt: # this will most definitely happen, so handle it | |
| print('Interrupted by user at iteration {}'.format(i)) | |
| self.session = sess | |
| return sess | |
| def restore_last_session(self): | |
| saver = tf.train.Saver() | |
| # create a session | |
| sess = tf.Session() | |
| # get checkpoint state | |
| ckpt = tf.train.get_checkpoint_state(self.ckpt_path) | |
| # restore session | |
| if ckpt and ckpt.model_checkpoint_path: | |
| saver.restore(sess, ckpt.model_checkpoint_path) | |
| # return to user | |
| return sess | |
| # prediction | |
| def predict(self, sess, X): | |
| feed_dict = {self.enc_ip[t]: X[t] for t in range(self.xseq_len)} | |
| feed_dict[self.keep_prob] = 1. | |
| dec_op_v = sess.run(self.decode_outputs_test, feed_dict) | |
| # dec_op_v is a list; also need to transpose 0,1 indices | |
| # (interchange batch_size and timesteps dimensions | |
| dec_op_v = np.array(dec_op_v).transpose([1,0,2]) | |
| # return the index of item with highest probability | |
| return np.argmax(dec_op_v, axis=2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment