Skip to content

Instantly share code, notes, and snippets.

@theeluwin
Created July 19, 2017 22:46
Show Gist options
  • Select an option

  • Save theeluwin/1f194068f9c66bed63bde3a77d41dad7 to your computer and use it in GitHub Desktop.

Select an option

Save theeluwin/1f194068f9c66bed63bde3a77d41dad7 to your computer and use it in GitHub Desktop.
Testing fully-connected mnist with yellowfin. See https://github.com/JianGoForIt/YellowFin
# -*- coding: utf-8 -*-
import numpy as np
import tensorflow as tf
from yellowfin import YFOptimizer
from tensorflow.examples.tutorials.mnist import input_data
def linear(X, fanout, scope=None):
with tf.variable_scope(scope or 'linear'):
W = tf.get_variable('W', shape=[X.get_shape().as_list()[-1], fanout], initializer=tf.contrib.layers.xavier_initializer())
b = tf.get_variable('b', shape=[fanout], initializer=tf.contrib.layers.xavier_initializer())
a = tf.nn.xw_plus_b(X, W, b)
return a
def main():
mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)
alpha = 1e-3
beta = 1e-2
num_epochs = 15
batch_size = 100
image_size = 28
num_classes = 10
dropout_rate = tf.placeholder(tf.float32)
X = tf.placeholder(tf.float32, [None, image_size * image_size])
y = tf.placeholder(tf.float32, [None, num_classes])
a1 = linear(X, 200, 'layer1')
z1 = tf.nn.relu(a1)
h1 = tf.nn.dropout(z1, dropout_rate)
a2 = linear(h1, 100, 'layer2')
z2 = tf.nn.relu(a2)
h2 = tf.nn.dropout(z2, dropout_rate)
logits = linear(h2, num_classes)
y_ = tf.nn.softmax(logits)
correct = tf.equal(tf.argmax(y_, -1), tf.argmax(y, -1))
acc = tf.reduce_mean(tf.cast(correct, tf.float32))
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y))
# optim = tf.train.AdamOptimizer(learning_rate=alpha).minimize(loss)
optim = YFOptimizer(learning_rate=alpha, momentum=beta).minimize(loss)
loss_summ = tf.summary.scalar('loss', loss)
summary = tf.summary.merge_all()
writer = tf.summary.FileWriter('./logs/yellowfin')
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
writer.add_graph(sess.graph)
step = 0
for epoch in range(num_epochs):
avg_loss = 0
total_batch = int(mnist.train.num_examples / batch_size)
for i in range(total_batch):
step += 1
X_batch, y_batch = mnist.train.next_batch(batch_size)
_, summarized, lost = sess.run([optim, summary, loss], feed_dict={X: X_batch, y: y_batch, dropout_rate: 0.75})
writer.add_summary(summarized, global_step=step)
avg_loss += lost / total_batch
print("epoch: %04d\tloss: %12.9f" % (epoch + 1, avg_loss))
print("training done")
print("accuracy: %5.2f%%" % (100 * sess.run(acc, feed_dict={X: mnist.test.images, y: mnist.test.labels, dropout_rate: 1})))
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment