Created
May 27, 2016 08:25
-
-
Save solaris33/a88140b6ed3a5da17b40249856fb5c95 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
# MNIST 데이터를 다운로드 한다. | |
from tensorflow.examples.tutorials.mnist import input_data | |
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) | |
# TensorFlow 라이브러리를 추가한다. | |
import tensorflow as tf | |
# 변수들을 설정한다. | |
x = tf.placeholder(tf.float32, [None, 784]) | |
W = tf.Variable(tf.zeros([784, 10])) | |
b = tf.Variable(tf.zeros([10])) | |
y = tf.nn.softmax(tf.matmul(x, W) + b) | |
# cross-entropy 모델을 설정한다. | |
y_ = tf.placeholder(tf.float32, [None, 10]) | |
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1])) | |
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) | |
# 경사하강법으로 모델을 학습한다. | |
init = tf.initialize_all_variables() | |
sess = tf.Session() | |
sess.run(init) | |
for i in range(1000): | |
batch_xs, batch_ys = mnist.train.next_batch(100) | |
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) | |
# 학습된 모델이 얼마나 정확한지를 출력한다. | |
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) | |
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) | |
print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment