Created
December 5, 2017 07:33
-
-
Save Audhil/15e395a547a7c1c9cb91ab00f4c837bd to your computer and use it in GitHub Desktop.
RNN_Basic_Demo - counts number of 1's in the binary input
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Attempt to understand RNN | |
url : http://monik.in/a-noobs-guide-to-implementing-rnn-lstm-using-tensorflow/ | |
""" | |
import numpy as np | |
from random import shuffle | |
import tensorflow as tf | |
# makes 20 digits binary values | |
train_input = ['{0:020b}'.format(i) for i in range(2 ** 20)] | |
shuffle(train_input) | |
# converting each digit - [0],[0],[1],[0] | |
train_input = [map(int, i) for i in train_input] | |
ti = [] | |
for i in train_input: | |
temp_list = [] | |
for j in i: | |
temp_list.append([j]) | |
ti.append(np.array(temp_list)) | |
train_input = ti | |
# making train_output -> | |
# which is one hot encoder(since we have [0 - 20] possibilities -> so 21 possibilities) | |
# [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] -> the seq has 2 ones | |
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1] -> the seq has 20 ones | |
# .... | |
train_output = [] # one hot vector | |
for i in train_input: | |
temp_list = ([0] * 21) | |
count = 0 | |
for j in i: # len(j) ==> 1 (it's always 1) | |
if j[0] == 1: | |
count += 1 | |
temp_list[count] = 1 | |
train_output.append(temp_list) | |
# train/test split | |
NUM_SAMPLES = 10000 # .9% of total data | |
test_input = train_input[NUM_SAMPLES:] | |
test_output = train_output[NUM_SAMPLES:] | |
train_input = train_input[:NUM_SAMPLES] | |
train_output = train_output[:NUM_SAMPLES] | |
# tensorflow | |
data = tf.placeholder(tf.float32, [None, 20, 1]) # [batch_size, seq length, input dimension] | |
target = tf.placeholder(tf.float32, [None, 21]) | |
num_hidden = 24 | |
cell = tf.nn.rnn_cell.LSTMCell(num_hidden, state_is_tuple=True) # LSTM cell | |
val, state = tf.nn.dynamic_rnn(cell, data, dtype=tf.float32) # val -> output ;; state(which is irrelevant) | |
# transpose the output to switch batch size with sequence size. | |
# After that we take the values of outputs only at sequence’s last input, | |
# which means in a string of 20 we’re only interested in the output we got at the 20th | |
# character and the rest of the output for previous characters is irrelevant here | |
val = tf.transpose(val, [1, 0, 2]) | |
last = tf.gather(val, int(val.get_shape()[0]) - 1) | |
# What we want to do is apply the final transformation to the outputs of the LSTM and map it to the 21 output classes. | |
weight = tf.Variable(tf.truncated_normal([num_hidden, int(target.get_shape()[1])])) | |
bias = tf.Variable(tf.constant(.1, shape=[target.get_shape()[1]])) | |
# prediction | |
prediction = tf.nn.softmax(tf.matmul(last, weight) + bias) | |
# loss | |
cross_entropy = -tf.reduce_sum(target * tf.log(tf.clip_by_value(prediction, 1e-10, 1.0))) | |
optimizer = tf.train.AdamOptimizer() | |
minimize = optimizer.minimize(cross_entropy) | |
mistakes = tf.not_equal(tf.argmax(target, 1), tf.argmax(prediction, 1)) | |
error = tf.reduce_mean(tf.cast(mistakes, tf.float32)) | |
init = tf.global_variables_initializer() | |
# hyper parameters | |
batch_size = 1000 | |
no_of_batches = int(len(train_input) / batch_size) | |
epochs = 5000 | |
with tf.Session() as sess: | |
sess.run(init) | |
for i in range(epochs): | |
ptr = 0 | |
for j in range(no_of_batches): | |
inp, out = train_input[ptr:ptr + batch_size], train_output[ptr:ptr + batch_size] | |
ptr += batch_size | |
sess.run(minimize, feed_dict={data: inp, target: out}) | |
print('epoch %d ' % i) | |
incorrect = sess.run(error, {data: test_input, target: test_output}) | |
print('Epoch {:2d} error {:3.1f}%'.format(epochs, 100 * incorrect)) | |
print(sess.run(prediction, { | |
data: [[[1], [0], [0], [1], [1], [0], [1], [1], [1], [0], [1], [0], [0], [1], [1], [0], [1], [1], [1], [0]]]})) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment