Last active
January 22, 2018 15:33
-
-
Save solaris33/8063d4172959ffda1ce3c48441a3e9cf to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
# tf.get_variable & tf.variable_scope 예제 (변수가 공유된다.) | |
# MNIST 데이터를 다운로드 한다. | |
from tensorflow.examples.tutorials.mnist import input_data | |
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) | |
import tensorflow as tf | |
def softmax_classifier_vs(x, reuse_flag=False): | |
with tf.variable_scope("softmax_classifier", reuse=reuse_flag): | |
W = tf.get_variable("W", [784, 10]) | |
b = tf.get_variable("b", [10]) | |
y = tf.nn.softmax(tf.matmul(x, W) + b) | |
return y | |
x = tf.placeholder(tf.float32, [None, 784]) | |
classifier1_vs = softmax_classifier_vs(x) | |
#classifier2_vs = softmax_classifier_vs(x_vs) # error! | |
classifier2_vs = softmax_classifier_vs(x, True) | |
# cross-entropy 모델을 설정한다. | |
y_ = tf.placeholder(tf.float32, [None, 10]) | |
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(classifier1_vs), reduction_indices=[1])) | |
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) | |
# 경사하강법으로 모델을 학습한다. | |
init = tf.initialize_all_variables() | |
sess = tf.Session() | |
sess.run(init) | |
for i in range(1000): | |
batch_xs, batch_ys = mnist.train.next_batch(100) | |
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) | |
# 학습된 모델이 얼마나 정확한지를 출력한다. | |
# 변수가 공유되므로 classifier2_vs는 명시적으로 학습시키지 않았지만 classifier1_vs와 동일한 정확도를 갖는다. | |
correct_prediction1 = tf.equal(tf.argmax(classifier1_vs,1), tf.argmax(y_,1)) | |
correct_prediction2 = tf.equal(tf.argmax(classifier2_vs,1), tf.argmax(y_,1)) | |
accuracy1 = tf.reduce_mean(tf.cast(correct_prediction1, tf.float32)) | |
accuracy2 = tf.reduce_mean(tf.cast(correct_prediction2, tf.float32)) | |
print(sess.run(accuracy1, feed_dict={x: mnist.test.images, y_: mnist.test.labels})) | |
print(sess.run(accuracy2, feed_dict={x: mnist.test.images, y_: mnist.test.labels})) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment