Last active
January 22, 2018 15:25
-
-
Save solaris33/ed312182e3ad94e671dfd76b9bf39229 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
#tf.Variable 예제 (변수 공유 안됨) | |
# MNIST 데이터를 다운로드 한다. | |
from tensorflow.examples.tutorials.mnist import input_data | |
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) | |
import tensorflow as tf | |
def softmax_classifier(x): | |
W = tf.Variable(tf.zeros([784, 10])) | |
b = tf.Variable(tf.zeros([10])) | |
y = tf.nn.softmax(tf.matmul(x, W) + b) | |
return y | |
x = tf.placeholder(tf.float32, [None, 784]) | |
# tf.Variable을 이용할시 서로 다른 변수를 가진 2개의 classifier가 선언됨 | |
classifier1 = softmax_classifier(x) | |
classifier2 = softmax_classifier(x) | |
# cross-entropy 모델을 설정한다. | |
y_ = tf.placeholder(tf.float32, [None, 10]) | |
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(classifier1), reduction_indices=[1])) | |
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) | |
# 경사하강법으로 모델을 학습한다. | |
init = tf.initialize_all_variables() | |
sess = tf.Session() | |
sess.run(init) | |
for i in range(1000): | |
batch_xs, batch_ys = mnist.train.next_batch(100) | |
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) | |
# 학습된 모델이 얼마나 정확한지를 출력한다. | |
# 두개의 classifier가 변수를 공유하지 않으므로 변수(파라미터)를 최적화한 classifier1만 정확한 값을 출력한다. | |
correct_prediction1 = tf.equal(tf.argmax(classifier1,1), tf.argmax(y_,1)) | |
correct_prediction2 = tf.equal(tf.argmax(classifier2,1), tf.argmax(y_,1)) | |
accuracy1 = tf.reduce_mean(tf.cast(correct_prediction1, tf.float32)) | |
accuracy2 = tf.reduce_mean(tf.cast(correct_prediction2, tf.float32)) | |
print(sess.run(accuracy1, feed_dict={x: mnist.test.images, y_: mnist.test.labels})) | |
print(sess.run(accuracy2, feed_dict={x: mnist.test.images, y_: mnist.test.labels})) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment