Created
July 19, 2017 22:46
-
-
Save theeluwin/1f194068f9c66bed63bde3a77d41dad7 to your computer and use it in GitHub Desktop.
Testing fully-connected mnist with yellowfin. See https://github.com/JianGoForIt/YellowFin
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # -*- coding: utf-8 -*- | |
| import numpy as np | |
| import tensorflow as tf | |
| from yellowfin import YFOptimizer | |
| from tensorflow.examples.tutorials.mnist import input_data | |
| def linear(X, fanout, scope=None): | |
| with tf.variable_scope(scope or 'linear'): | |
| W = tf.get_variable('W', shape=[X.get_shape().as_list()[-1], fanout], initializer=tf.contrib.layers.xavier_initializer()) | |
| b = tf.get_variable('b', shape=[fanout], initializer=tf.contrib.layers.xavier_initializer()) | |
| a = tf.nn.xw_plus_b(X, W, b) | |
| return a | |
| def main(): | |
| mnist = input_data.read_data_sets('MNIST_data/', one_hot=True) | |
| alpha = 1e-3 | |
| beta = 1e-2 | |
| num_epochs = 15 | |
| batch_size = 100 | |
| image_size = 28 | |
| num_classes = 10 | |
| dropout_rate = tf.placeholder(tf.float32) | |
| X = tf.placeholder(tf.float32, [None, image_size * image_size]) | |
| y = tf.placeholder(tf.float32, [None, num_classes]) | |
| a1 = linear(X, 200, 'layer1') | |
| z1 = tf.nn.relu(a1) | |
| h1 = tf.nn.dropout(z1, dropout_rate) | |
| a2 = linear(h1, 100, 'layer2') | |
| z2 = tf.nn.relu(a2) | |
| h2 = tf.nn.dropout(z2, dropout_rate) | |
| logits = linear(h2, num_classes) | |
| y_ = tf.nn.softmax(logits) | |
| correct = tf.equal(tf.argmax(y_, -1), tf.argmax(y, -1)) | |
| acc = tf.reduce_mean(tf.cast(correct, tf.float32)) | |
| loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y)) | |
| # optim = tf.train.AdamOptimizer(learning_rate=alpha).minimize(loss) | |
| optim = YFOptimizer(learning_rate=alpha, momentum=beta).minimize(loss) | |
| loss_summ = tf.summary.scalar('loss', loss) | |
| summary = tf.summary.merge_all() | |
| writer = tf.summary.FileWriter('./logs/yellowfin') | |
| with tf.Session() as sess: | |
| sess.run(tf.global_variables_initializer()) | |
| writer.add_graph(sess.graph) | |
| step = 0 | |
| for epoch in range(num_epochs): | |
| avg_loss = 0 | |
| total_batch = int(mnist.train.num_examples / batch_size) | |
| for i in range(total_batch): | |
| step += 1 | |
| X_batch, y_batch = mnist.train.next_batch(batch_size) | |
| _, summarized, lost = sess.run([optim, summary, loss], feed_dict={X: X_batch, y: y_batch, dropout_rate: 0.75}) | |
| writer.add_summary(summarized, global_step=step) | |
| avg_loss += lost / total_batch | |
| print("epoch: %04d\tloss: %12.9f" % (epoch + 1, avg_loss)) | |
| print("training done") | |
| print("accuracy: %5.2f%%" % (100 * sess.run(acc, feed_dict={X: mnist.test.images, y: mnist.test.labels, dropout_rate: 1}))) | |
| if __name__ == '__main__': | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment