Last active
March 6, 2016 02:36
-
-
Save naoyashiga/712bd44c33710eeef818 to your computer and use it in GitHub Desktop.
MNISTをTensorboardに表示する
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
import input_data | |
import tensorflow as tf | |
# prameters | |
learning_rate = 0.01 | |
training_epochs = 1000 | |
def main(): | |
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) | |
# 重みと閾値 | |
W = tf.Variable(tf.zeros([784, 10]), name="weights") | |
b = tf.Variable(tf.zeros([10]), name="bias") | |
# mnist画像データ 28*28=784 | |
x = tf.placeholder("float", [None, 784]) | |
# 0-9なので10のクラス | |
y = tf.placeholder("float", [None, 10]) | |
# softmax | |
activation = tf.nn.softmax(tf.matmul(x, W) + b) | |
# 損失関数 | |
# 正解とのズレ | |
# Σ(y * log(y)) | |
cost = -tf.reduce_sum(y * tf.log(activation)) | |
# 学習の仕方を定義 | |
# Gradient Descent = 勾配降下法 学習率が引数 | |
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) | |
tf.scalar_summary("lossを表示しますよ", cost) | |
# セッションを準備 | |
sess = tf.Session() | |
# 変数を初期化 | |
init = tf.initialize_all_variables() | |
sess.run(init) | |
# ログを残す | |
summary_op = tf.merge_all_summaries() | |
summary_writer = tf.train.SummaryWriter("/tmp/tensorflow_logs", graph_def=sess.graph_def) | |
for epoch in range(training_epochs): | |
batch_xs, batch_ys = mnist.train.next_batch(100) | |
# 勾配を用いた更新 | |
sess.run(train_step, feed_dict={x :batch_xs, y: batch_ys}) | |
summary_str = sess.run(summary_op, feed_dict={x :batch_xs, y: batch_ys}) | |
summary_writer.add_summary(summary_str, epoch) | |
print("Optimization Finished!") | |
# 正答率を返す関数を定義 | |
correct_prediction = tf.equal(tf.argmax(activation, 1), tf.argmax(y, 1)) | |
# accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) | |
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) | |
# 結果 | |
print(sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})) | |
if __name__ == "__main__": | |
main() | |
# Run the command line: tensorboard --logdir=/tmp/tensorflow_logs | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment