import tensorflow as tf
import random
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
nb_classes = 10
# MNIST data image of shape 28 * 28 = 784
X = tf.placeholder(tf.float32, shape = [None, 784])
# 0 - 9 digits recognition = 10 classes
Y = tf.placeholder(tf.float32, shape = [None, nb_classes])
W1 = tf.get_variable("W1", shape=[784,50],
initializer = tf.contrib.layers.xavier_initializer())
b1 = tf.Variable(tf.random_normal([50]))
L1 = tf.nn.relu(tf.matmul(X,W1) + b1)
W2 = tf.get_variable("W2", shape=[50,100],
initializer = tf.contrib.layers.xavier_initializer())
b2 = tf.Variable(tf.random_normal([100]))
L2 = tf.nn.relu(tf.matmul(L1,W2) + b2)
W3 = tf.get_variable("W3", shape=[100,nb_classes],
initializer = tf.contrib.layers.xavier_initializer())
b3 = tf.Variable(tf.random_normal([nb_classes]))
hypothesis = tf.nn.softmax(tf.matmul(L2, W3) + b3)
cost = tf.reduce_mean(-tf.reduce_sum(Y*tf.log(hypothesis), axis=1))
optimizer = tf.train.AdamOptimizer(learning_rate = 0.001).minimize(cost)
# Test model
is_correct = tf.equal(tf.arg_max(hypothesis, 1), tf.arg_max(Y, 1))
# Calculate accuracy
accuracy = tf.reduce_mean(tf.cast(is_correct, tf.float32))
# parameters
training_epochs = 15
# 전체 data set을 1번 학습을 하는 것이 1epoch. 즉 여기서 100개씩 15번 한다는 것이다
batch_size = 100 # 한번에 몇 개씩 training?
sess = tf.Session()
for epoch in range(training_epochs):
avg_cost = 0
total_batch = int(mnist.train.num_examples / batch_size)
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
c, _=[cost,optimizer], feed_dict={X:batch_xs,
avg_cost += c / total_batch
print('Epoch', '%04d' % (epoch+1), 'cost =', '{:.9f}'.format(avg_cost))
# eval은 sess.run과 같은 역할 ~
print("Accuracy: ", accuracy.eval(session=sess,
feed_dict={X: mnist.test.images,
Y: mnist.test.labels}))
# Get one predict
r = random.randint(0, mnist.test.num_examples - 1)
print("Label:",[r:r+1], 1)))
print("Prediction:",, 1),
feed_dict={X: mnist.test.images[r:r+1]}))
plt.imshow(mnist.test.images[r:r+1].reshape(28, 28), cmap = 'Greys',
