|
from tensorflow.examples.tutorials.mnist import input_data |
|
|
|
# 如果 one_hot 设为 False , 标签则会输出数字 |
|
mnist = input_data.read_data_sets("../MNIST_data", one_hot=True) |
|
import tensorflow as tf |
|
|
|
batch_size = 50 |
|
display_step = 2 |
|
learning_rate = 0.03 |
|
training_epochs = 30 |
|
|
|
|
|
x = tf.placeholder(tf.float32, shape=[None, 784]) |
|
y = tf.placeholder(tf.float32, shape=[None, 10]) |
|
|
|
# 创建输入层(input layer)和隐藏层(hidden layer)的weight和bias变量 |
|
w = tf.Variable(tf.truncated_normal([784, 512]), name='w') |
|
b = tf.Variable(tf.zeros([512]), name='b') |
|
|
|
# 创建隐藏层(hidden layer)和输出层(output layer)的weight和bias变量 |
|
w1 = tf.Variable(tf.truncated_normal([512, 10]), name='w1') |
|
b1 = tf.Variable(tf.zeros([10]), name='b1') |
|
|
|
# 计算隐藏层(hidden layer),添加 sigmoid 激活函数 |
|
h_z = tf.nn.bias_add(tf.matmul(x, w), b) |
|
h_a = tf.sigmoid(h_z) |
|
|
|
# 计算输出层(output layer)假设值,后面会调用 softmax_cross_entropy_with_logits |
|
# 所以不用添加 softmax 激活函数 |
|
o_z = tf.nn.bias_add(tf.matmul(h_a,w1), b1) |
|
|
|
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=o_z, labels=y)) |
|
|
|
optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost) |
|
|
|
correct_prediction = tf.equal(tf.argmax(o_z,1), tf.argmax(y,1)) |
|
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) |
|
|
|
init = tf.global_variables_initializer() |
|
with tf.Session() as sess: |
|
sess.run(init) |
|
for epoch in range(training_epochs): |
|
avg_cost = 0. |
|
total_batch = mnist.train.num_examples // batch_size |
|
for i in range(total_batch): |
|
batch_images, batch_labels = mnist.train.next_batch(batch_size) |
|
_, c = sess.run([optimizer, cost], feed_dict={x: batch_images, |
|
y: batch_labels}) |
|
avg_cost += c / total_batch |
|
if (epoch +1) % display_step == 0: |
|
acc = sess.run(accuracy, feed_dict={x: mnist.test.images[:3000], |
|
y: mnist.test.labels[:3000]}) |
|
print("Epoch: ", '%04d' % (epoch+1), "cost=", "{:.9f}".format(avg_cost), |
|
"accuracy : {:.3f}".format(acc)) |
|
|
|
print("Optimization Finished!") |
|
saver = tf.train.Saver() |
|
save_path = saver.save(sess, "./sess.p") |