Last active
August 6, 2017 17:52
-
-
Save muggin/3097e7ed45a75dd53bd96c0e430a2895 to your computer and use it in GitHub Desktop.
Vanilla Char-RNN class using TensorFlow
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Vanilla Char-RNN class using TensorFlow by Wojciech Kryscinski (@muggin). | |
Adapted from Vinh Khuc's min-char-rnn-tensorflow.py | |
https://gist.github.com/vinhkhuc/7ec5bf797308279dc587 | |
Requires tensorflow>=1.0 | |
BSD License | |
""" | |
class TFBasicRNN(object): | |
def __init__(self, vocab_size, hidden_size, seq_length): | |
# size of vocabulary | |
self.vocab_size = vocab_size | |
# size of hidden state | |
self.hidden_size = hidden_size | |
# length of input sequence | |
self.seq_length = seq_length | |
# tf computation graph | |
self.graph = tf.Graph() | |
# define and setup tf graph nodes | |
self._create_constants() | |
self._create_variables() | |
self._create_placeholders() | |
self._create_loss_and_optimizer() | |
def _create_placeholders(self): | |
with self.graph.as_default(): | |
with self.graph.name_scope('placeholders'): | |
self.inputs = tf.placeholder(shape=[self.seq_length, self.vocab_size], dtype=tf.float32, name='inputs') | |
self.targets = tf.placeholder(shape=[self.seq_length, self.vocab_size], dtype=tf.float32, name='targets') | |
self.init_state = tf.placeholder(shape=[1, self.hidden_size], dtype=tf.float32, name='state') | |
def _create_constants(self): | |
with self.graph.as_default(): | |
with tf.name_scope('constants'): | |
self.grad_limit = tf.constant(5.0, dtype=tf.float32, name='grad_limit') | |
def _create_variables(self): | |
with self.graph.as_default(): | |
with tf.name_scope('weights'): | |
self.Wxh = tf.Variable(tf.random_normal(stddev=0.1, shape=(self.vocab_size, self.hidden_size)), name='Wxh') | |
self.Whh = tf.Variable(tf.random_normal(stddev=0.1, shape=(self.hidden_size, self.hidden_size)), name='Whh') | |
self.Why = tf.Variable(tf.random_normal(stddev=0.1, shape=(self.hidden_size, self.vocab_size)), name='Why') | |
self.bh = tf.Variable(tf.zeros((self.hidden_size)), name='bh') | |
self.by = tf.Variable(tf.zeros((self.vocab_size)), name='by') | |
def _create_loss_and_optimizer(self): | |
with self.graph.as_default(): | |
with tf.name_scope('loss'): | |
hs_t = self.init_state | |
ys = [] | |
# forward pass through the network | |
for t, xs_t in enumerate(tf.split(self.inputs, self.seq_length, axis=0)): | |
hs_t = tf.tanh(tf.matmul(xs_t, self.Wxh) + tf.matmul(hs_t, self.Whh) + self.bh) | |
ys_t = tf.matmul(hs_t, self.Why) + self.by | |
ys.append(ys_t) | |
outputs = tf.concat(ys, axis=0) | |
# update state after forward pass | |
self.update_state = hs_t | |
# transform network output to probabilities | |
self.outputs_softmax = tf.nn.softmax(ys[-1]) | |
# define loss | |
self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.targets, logits=outputs)) | |
# compute and update gradients | |
self.optimizer = tf.train.AdamOptimizer() | |
grads_and_vars = self.optimizer.compute_gradients(self.loss) | |
clipped_grads_and_vars = [] | |
for grad, var in grads_and_vars: | |
clipped_grad = tf.clip_by_value(grad, -self.grad_limit, self.grad_limit) | |
clipped_grads_and_vars.append((clipped_grad, var)) | |
self.updates = self.optimizer.apply_gradients(clipped_grads_and_vars) | |
def _create_ohe(self, vector): | |
return np.eye(self.vocab_size)[vector] | |
def train(self, text, max_iter=100001, sample_dist=True): | |
# pointer to input data | |
p = 0 | |
with tf.Session(graph=self.graph) as sess: | |
# init variables | |
sess.run(tf.global_variables_initializer()) | |
for n in xrange(max_iter): | |
# do data bookkeeping | |
if p + self.seq_length + 1 >= len(text) or n == 0: | |
# reset hidden state and data pointer | |
p = 0 | |
state_prev = np.zeros((1, self.hidden_size)) | |
# prepare data | |
inputs = [char_to_ix[ch] for ch in text[p:p+self.seq_length]] | |
targets = [char_to_ix[ch] for ch in text[p+1:p+self.seq_length+1]] | |
# one hot encode data | |
inputs_ohe = self._create_ohe(inputs) | |
targets_ohe = self._create_ohe(targets) | |
# training step | |
feed_data = {self.inputs: inputs_ohe, self.targets: targets_ohe, self.init_state: state_prev} | |
current_state, loss, _ = sess.run([self.update_state, self.loss, self.updates], feed_dict=feed_data) | |
if n % 1000 == 0: | |
print 'step: %d - p: %d -- loss: %f' % (n, p, loss) | |
if sample_dist and n % 1000 == 0: | |
# sampling | |
sample_length = 200 | |
start_ix = np.random.randint(0, len(text) - seq_length) | |
sample_seq_ix = [char_to_ix[ch] for ch in text[start_ix:start_ix + seq_length]] | |
ixes = [] | |
sample_state_prev = np.copy(current_state) | |
for t in xrange(sample_length): | |
sample_input_vals = self._create_ohe(sample_seq_ix) | |
feed_data = {self.inputs: sample_input_vals, self.init_state: sample_state_prev} | |
sample_output_softmax_val, sample_current_state = \ | |
sess.run([self.outputs_softmax, self.update_state], feed_dict=feed_data) | |
ix = np.random.choice(range(vocab_size), p=sample_output_softmax_val.ravel()) | |
ixes.append(ix) | |
sample_seq_ix = sample_seq_ix[1:] + [ix] | |
txt = ''.join(ix_to_char[ix] for ix in ixes) | |
print('----\n %s \n----\n' % (txt,)) | |
p += seq_length |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment