Created
December 12, 2017 22:36
-
-
Save RustyNail/4064ae24f864530258447a37eca9915a to your computer and use it in GitHub Desktop.
mnist_tutorial.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | |
import tensorflow as tf # TensorFlowをインポート | |
from tensorflow.examples.tutorials.mnist import input_data # MNISTデータのロード | |
mnist = input_data.read_data_sets("MNIST_data/", one_hot = True) | |
train_image = tf.placeholder(tf.float32, [None, 784]) # 28x28pxの訓練画像を格納する変数 | |
train_label = tf.placeholder(tf.float32, [None, 10]) # 正解データのラベルを格納する変数 | |
W = tf.Variable(tf.zeros([784, 10])) # 重み(初期値0) | |
b = tf.Variable(tf.zeros([10])) # バイアス(初期値0) | |
y = tf.nn.softmax(tf.matmul(train_image, W) + b) # ソフトマックス回帰を実行 | |
learning_rate = 0.01 # 学習率 | |
learning_count = 1000 # 学習回数 | |
cross_entropy = -tf.reduce_sum(train_label * tf.log(y)) # 交差エントロピー誤差 | |
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(cross_entropy) # 勾配降下法で交差エントロピー誤差が最小となるようyを最適化 | |
with tf.Session() as sess: | |
init = tf.global_variables_initializer() | |
sess.run(init) | |
for i in range(learning_count): | |
batch_image, batch_label = mnist.train.next_batch(100) # ランダムに抽出した100個の訓練データ(画像と対応するラベル)を選択 | |
sess.run(train_step, feed_dict = { train_image: batch_image, train_label: batch_label }) # 学習(train_stepを実行) | |
# 正答率を算出する処理 | |
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(train_label, 1)) | |
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) | |
test_images = mnist.test.images | |
test_labels = mnist.test.labels | |
print(sess.run(accuracy, feed_dict = { train_image: test_images, train_label: test_labels })) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment