Last active
March 12, 2018 02:52
-
-
Save khanhnamle1994/f17cda305e3c3cae6f7cb65e8810547a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Setup | |
| from __future__ import print_function, division | |
| import tensorflow as tf | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| %matplotlib inline | |
| # Load Dataset | |
| from tensorflow.examples.tutorials.mnist import input_data | |
| mnist = input_data.read_data_sets('MNIST_data', one_hot=False) | |
| # Reshape Data | |
| batch_image = mnist.train.next_batch(1)[0] | |
| batch_image.reshape([28, 28]) | |
| # Implement LeakyReLU | |
| def leaky_relu(x, alpha=0.01): | |
| # If x is below 0 returns alpha*x else it will return x. | |
| activation = tf.maximum(x,alpha*x) | |
| return activation | |
| # Random Noise | |
| def sample_noise(batch_size, dim): | |
| random_noise = tf.random_uniform(maxval=1,minval=-1,shape=[batch_size, dim]) | |
| return random_noise | |
| # Discriminator | |
| def discriminator(x): | |
| with tf.variable_scope("discriminator"): | |
| fc1 = tf.layers.dense(inputs=x, units=256, activation=leaky_relu) | |
| fc2 = tf.layers.dense(inputs=fc1, units=256, activation=leaky_relu) | |
| logits = tf.layers.dense(inputs=fc2, units=1) | |
| return logits | |
| # Generator | |
| def generator(z): | |
| with tf.variable_scope("generator"): | |
| fc1 = tf.layers.dense(inputs=z, units=1024, activation=tf.nn.relu) | |
| fc2 = tf.layers.dense(inputs=fc1, units=1024, activation=tf.nn.relu) | |
| img = tf.layers.dense(inputs=fc2, units=784, activation=tf.nn.tanh) | |
| return img | |
| # Compute GAN Loss | |
| def gan_loss(logits_real, logits_fake): | |
| # Target label vector for generator loss and used in discriminator loss. | |
| true_labels = tf.ones_like(logits_fake) | |
| # DISCRIMINATOR loss has 2 parts: how well it classifies real images and how well it | |
| # classifies fake images. | |
| real_image_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_real, labels=true_labels) | |
| fake_image_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_fake, labels=1-true_labels) | |
| # Combine and average losses over the batch | |
| D_loss = real_image_loss + fake_image_loss | |
| D_loss = tf.reduce_mean(D_loss) | |
| # GENERATOR is trying to make the discriminator output 1 for all its images. | |
| # So we use our target label vector of ones for computing generator loss. | |
| G_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_fake, labels=true_labels) | |
| # Average generator loss over the batch. | |
| G_loss = tf.reduce_mean(G_loss) | |
| return D_loss, G_loss | |
| # Optimizing GAN Loss | |
| def get_solvers(learning_rate=1e-3, beta1=0.5): | |
| # Create solvers for GAN training | |
| D_solver = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=beta1) | |
| G_solver = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=beta1) | |
| return D_solver, G_solver | |
| ############################################################################## | |
| # Final Model | |
| # Number of Images for Each Batch | |
| batch_size = 128 | |
| # Noise Dimension | |
| noise_dim = 96 | |
| # Placeholder for Images from the Training Dataset | |
| x = tf.placeholder(tf.float32, [None, 784]) | |
| # Random Noise for the Generator | |
| z = sample_noise(batch_size, noise_dim) | |
| # Generated Images | |
| G_sample = generator(z) | |
| with tf.variable_scope("") as scope: | |
| # Scale Images to be -1 to 1 | |
| logits_real = discriminator(preprocess_img(x)) | |
| # Re-use Discriminator Weights on New Inputs | |
| scope.reuse_variables() | |
| logits_fake = discriminator(G_sample) | |
| # Get the List of Variables for the Discriminator and Generator | |
| D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'discriminator') | |
| G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'generator') | |
| # Get the Solver | |
| D_solver, G_solver = get_solvers() | |
| # Get the Loss | |
| D_loss, G_loss = gan_loss(logits_real, logits_fake) | |
| # Setup Training Steps | |
| D_train_step = D_solver.minimize(D_loss, var_list=D_vars) | |
| G_train_step = G_solver.minimize(G_loss, var_list=G_vars) | |
| D_extra_step = tf.get_collection(tf.GraphKeys.UPDATE_OPS, 'discriminator') | |
| G_extra_step = tf.get_collection(tf.GraphKeys.UPDATE_OPS, 'generator') | |
| ############################################################################## | |
| # Training a GAN | |
| def training_gan(sess, G_train_step, G_loss, D_train_step, D_loss, G_extra_step, D_extra_step,\ | |
| show_every=250, print_every=50, batch_size=128, num_epoch=10): | |
| # Compute the Number of Iterations Needed | |
| max_iter = int(mnist.train.num_examples*num_epoch/batch_size) | |
| for it in range(max_iter): | |
| # For Every 250 Images, Show A Sample Result | |
| if it % show_every == 0: | |
| samples = sess.run(G_sample) | |
| fig = show_images(samples[:16]) | |
| plt.show() | |
| print() | |
| # Run a Batch of Data | |
| minibatch,minbatch_y = mnist.train.next_batch(batch_size) | |
| _, D_loss_curr = sess.run([D_train_step, D_loss], feed_dict={x: minibatch}) | |
| _, G_loss_curr = sess.run([G_train_step, G_loss]) | |
| # For Every 50 Iterations, Print Loss | |
| if it % print_every == 0: | |
| print('Iter: {}, D: {:.4}, G:{:.4}'.format(it,D_loss_curr,G_loss_curr)) | |
| print('Final images') | |
| samples = sess.run(G_sample) | |
| fig = show_images(samples[:16]) | |
| plt.show() | |
| # Run the helper function | |
| with get_session() as sess: | |
| sess.run(tf.global_variables_initializer()) | |
| training_gan(sess,G_train_step,G_loss,D_train_step,D_loss,G_extra_step,D_extra_step) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment