Created
July 14, 2016 03:54
-
-
Save piyo7/f67c18bc44e5901c37a7d86ca9360096 to your computer and use it in GitHub Desktop.
TensorFlowでAutoEncoderを可視化してみたよ ref: http://qiita.com/piyo7/items/355510a1c7ec061d9aff
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tensorflow.examples.tutorials.mnist import input_data | |
flags = tf.app.flags | |
FLAGS = flags.FLAGS | |
flags.DEFINE_integer('batch_size', 256, '') | |
flags.DEFINE_integer('epoch_num', 10, '') | |
flags.DEFINE_integer('hidden_num', 64, '') | |
flags.DEFINE_float('learning_rate', 0.01, '') | |
flags.DEFINE_float('noise_rate', 0.3, '') | |
flags.DEFINE_float('noise_strength', 0.2, '') | |
flags.DEFINE_string('summary_dir', 'summary', '') | |
def main(): | |
mnist_data = input_data.read_data_sets('MNIST_data', one_hot=True) | |
with tf.name_scope('input'): | |
input_layer = tf.placeholder(tf.float32, shape=[None, 784]) | |
input_summary = tf.image_summary('tag', tf.reshape(input_layer, [-1, 28, 28, 1]), 10) | |
with tf.name_scope('noisy_input'): | |
noise_layer = tf.maximum(tf.random_uniform(shape=tf.shape(input_layer), | |
minval=-FLAGS.noise_strength * (2 / FLAGS.noise_rate - 1), | |
maxval=FLAGS.noise_strength), | |
-FLAGS.noise_strength) | |
noisy_input_layer = tf.minimum(tf.maximum(input_layer + noise_layer, 0.0), 1.0) | |
noisy_input_summary = tf.image_summary('tag', tf.reshape(noisy_input_layer, [-1, 28, 28, 1]), 10) | |
with tf.name_scope('hidden'): | |
w1 = tf.Variable(tf.random_uniform([784, FLAGS.hidden_num], minval=-1, maxval=1)) | |
b1 = tf.Variable(tf.zeros([FLAGS.hidden_num])) | |
hidden_layer = tf.sigmoid(tf.matmul(noisy_input_layer, w1) + b1) | |
with tf.name_scope('output'): | |
w2 = tf.transpose(w1) | |
b2 = tf.Variable(tf.zeros([784])) | |
output_layer = tf.sigmoid(tf.matmul(hidden_layer, w2) + b2) | |
output_summary = tf.image_summary('tag', tf.reshape(output_layer, [-1, 28, 28, 1]), 10) | |
with tf.name_scope('loss'): | |
loss = tf.nn.l2_loss(input_layer - output_layer) | |
loss_summary = tf.scalar_summary('loss', loss) | |
train_step = tf.train.GradientDescentOptimizer(FLAGS.learning_rate).minimize(loss) | |
sess = tf.InteractiveSession() | |
sess.run(tf.initialize_all_variables()) | |
if tf.gfile.Exists(FLAGS.summary_dir): | |
tf.gfile.DeleteRecursively(FLAGS.summary_dir) | |
input_writer = tf.train.SummaryWriter(FLAGS.summary_dir + '/input', sess.graph) | |
noisy_input_writer = tf.train.SummaryWriter(FLAGS.summary_dir + '/noisy_input') | |
output_writer = tf.train.SummaryWriter(FLAGS.summary_dir + '/output') | |
step = 0 | |
while mnist_data.train.epochs_completed < FLAGS.epoch_num: | |
step += 1 | |
images, _ = mnist_data.train.next_batch(FLAGS.batch_size) | |
loss_result, _ = sess.run([loss_summary, train_step], feed_dict={input_layer: images}) | |
input_writer.add_summary(loss_result, step) | |
input_result, noisy_input_result, output_result =\ | |
sess.run([input_summary, noisy_input_summary, output_summary], feed_dict={input_layer: images}) | |
input_writer.add_summary(input_result) | |
noisy_input_writer.add_summary(noisy_input_result) | |
output_writer.add_summary(output_result) | |
if __name__ == '__main__': | |
main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
$ python autoencoder.py | |
Extracting MNIST_data/train-images-idx3-ubyte.gz | |
Extracting MNIST_data/train-labels-idx1-ubyte.gz | |
Extracting MNIST_data/t10k-images-idx3-ubyte.gz | |
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz | |
$ tensorboard --logdir=summary/ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
$ python --version | |
Python 3.5.1 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
numpy==1.11.1 # via tensorflow | |
protobuf==3.0.0b2 # via tensorflow | |
six==1.10.0 # via protobuf, tensorflow | |
tensorflow==0.9.0 | |
wheel==0.29.0 # via tensorflow |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment