Last active
April 8, 2016 19:47
-
-
Save nlintz/e711b01208bd0fde6a45b1eb5178c191 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from tensorflow.examples.tutorials.mnist import input_data | |
""" | |
Utilities | |
""" | |
def orthogonal_initializer(scale=1.1): | |
''' From Lasagne and Keras. Reference: Saxe et al., http://arxiv.org/abs/1312.6120 | |
''' | |
def get_orthogonal(shape): | |
flat_shape = (shape[0], np.prod(shape[1:])) | |
a = np.random.normal(0.0, 1.0, flat_shape) | |
u, _, v = np.linalg.svd(a, full_matrices=False) | |
# pick the one with the correct shape | |
q = u if u.shape == flat_shape else v | |
q = q.reshape(shape) | |
initial_val = scale * q[:shape[0], :shape[1]] | |
return initial_val | |
def _initializer(shape, dtype=tf.float32): | |
initial_val = get_orthogonal(shape) | |
return tf.constant(initial_val, dtype=tf.float32) | |
return _initializer | |
def smooth(x, window_len=11, window='hanning'): | |
""" | |
Ripped From http://scipy-cookbook.readthedocs.org/items/SignalSmooth.html | |
""" | |
s=np.r_[x[window_len-1:0:-1], x, x[-1:-window_len:-1]] | |
if window == 'flat': # moving average | |
w = np.ones(window_len, 'd') | |
else: | |
w = eval('np.' + window + '(window_len)') | |
y=np.convolve(w / w.sum(), s, mode='valid') | |
return y | |
mnist = input_data.read_data_sets('MNIST_data', one_hot=False) | |
def get_batch(batch_size, which_set="train"): | |
if which_set == "train": | |
X, Y = mnist.train.next_batch(batch_size) | |
if which_set == "test": | |
X, Y = mnist.test.next_batch(batch_size) | |
X = X.reshape((batch_size, 28, 28)).astype("float32") | |
X = X.transpose(1, 0, 2) | |
Y = Y.astype("int32") | |
return (X, Y) | |
""" | |
GRU Model | |
""" | |
class GRU(object): | |
def __init__(self, input_dim, n_hidden): | |
self.input_dim = input_dim | |
self.n_hidden = n_hidden | |
with tf.variable_scope("weights", initializer=orthogonal_initializer()): | |
self.W_z = tf.get_variable("W_z", [self.input_dim, self.n_hidden]) | |
self.W_r = tf.get_variable("W_r", [self.input_dim, self.n_hidden]) | |
self.W_h = tf.get_variable("W_h", [self.input_dim, self.n_hidden]) | |
self.U_z = tf.get_variable("U_z", [self.n_hidden, self.n_hidden]) | |
self.U_r = tf.get_variable("U_r", [self.n_hidden, self.n_hidden]) | |
self.U_h = tf.get_variable("U_h", [self.n_hidden, self.n_hidden]) | |
with tf.variable_scope("biases", initializer=tf.constant_initializer(0.0)): | |
self.b_z = tf.get_variable("b_z", [self.n_hidden]) | |
self.b_r = tf.get_variable("b_r", [self.n_hidden]) | |
self.b_h = tf.get_variable("b_h", [self.n_hidden]) | |
def step(self, h_tm1, x): | |
z = tf.nn.sigmoid(tf.nn.xw_plus_b(x, self.W_z, self.b_z) + | |
tf.matmul(h_tm1, self.U_z)) | |
r = tf.nn.sigmoid(tf.nn.xw_plus_b(x, self.W_r, self.b_r) + | |
tf.matmul(h_tm1, self.U_r)) | |
h = tf.nn.tanh(tf.nn.xw_plus_b(x, self.W_h, self.b_h) + | |
tf.matmul(tf.mul(h_tm1, r), self.U_h)) | |
h_t = tf.mul((1. - z), h) + tf.mul(z, h_tm1) | |
return h_t | |
def initial_state(self, batch_size): | |
return tf.zeros([batch_size, self.n_hidden]) | |
def __call__(self, X): | |
batch_size = X.get_shape()[1].value | |
return tf.scan(self.step, X, initializer=self.initial_state(batch_size)) | |
""" | |
Logits and Cost ops | |
""" | |
def get_logits(input_, input_size, n_hidden, num_classes): | |
gru = GRU(input_size, n_hidden) | |
seq_len = input_.get_shape()[0].value | |
batch_size = input_.get_shape()[1].value | |
W_out = tf.get_variable("W_out", [n_hidden, num_classes], initializer=orthogonal_initializer()) | |
b_out = tf.get_variable("b_out", [num_classes], initializer=tf.constant_initializer(0.0)) | |
states = tf.split(0, seq_len, gru(input_)) | |
final_state = states[-1] | |
final_state = tf.squeeze(final_state) | |
final_state.set_shape([batch_size, n_hidden]) | |
return tf.nn.xw_plus_b(final_state, W_out, b_out) | |
def get_cost(logits, targets): | |
return tf.nn.sparse_softmax_cross_entropy_with_logits(logits, | |
targets) | |
def train(): | |
bs = 64 | |
n_iter = 1000 | |
X = tf.placeholder(tf.float32, [28, 64, 28]) | |
targets = tf.placeholder(tf.int32, [bs]) | |
logits = get_logits(X, input_size=28, n_hidden=256, num_classes=10) | |
loss = tf.reduce_mean(get_cost(logits, targets)) | |
train_op = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss) | |
sess = tf.Session() | |
writer = tf.train.SummaryWriter("./gru_logs", sess.graph) | |
sess.run(tf.initialize_all_variables()) | |
tr_losses = [] | |
te_losses = [] | |
tr_accs = [] | |
te_accs = [] | |
for i in range(n_iter): | |
trX, trY = get_batch(batch_size=bs) | |
teX, teY = get_batch(batch_size=bs, which_set="test") | |
sess.run(train_op, feed_dict={X: trX, targets: trY}) | |
tr_loss, tr_logits = sess.run([loss, logits], feed_dict={X: trX, targets: trY}) | |
te_loss, te_logits = sess.run([loss, logits], feed_dict={X: teX, targets: teY}) # super legit way to estimate test error /s | |
tr_acc = (trY == np.argmax(tr_logits, axis=1)).mean() | |
te_acc = (teY == np.argmax(te_logits, axis=1)).mean() | |
tr_losses.append(tr_loss) | |
te_losses.append(te_loss) | |
tr_accs.append(tr_acc) | |
te_accs.append(te_acc) | |
print "iter: %d, train_loss: %f, test_loss: %f, train_acc: %f, test_acc: %f" % (i, tr_loss, te_loss, tr_acc, te_acc) | |
plt.subplot(211) | |
plt.title('cost') | |
plt.plot(tr_losses) | |
plt.plot(te_losses, '--') | |
plt.subplot(212) | |
plt.title('accuracy') | |
plt.plot(smooth(tr_accs)) | |
plt.plot(smooth(te_accs), '--') | |
plt.show() | |
if __name__ == "__main__": | |
train() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment