Skip to content

Instantly share code, notes, and snippets.

@khanhnamle1994
Last active March 12, 2018 02:52
Show Gist options
  • Select an option

  • Save khanhnamle1994/f17cda305e3c3cae6f7cb65e8810547a to your computer and use it in GitHub Desktop.

Select an option

Save khanhnamle1994/f17cda305e3c3cae6f7cb65e8810547a to your computer and use it in GitHub Desktop.
# Setup
from __future__ import print_function, division
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
# Load Dataset
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=False)
# Reshape Data
batch_image = mnist.train.next_batch(1)[0]
batch_image.reshape([28, 28])
# Implement LeakyReLU
def leaky_relu(x, alpha=0.01):
# If x is below 0 returns alpha*x else it will return x.
activation = tf.maximum(x,alpha*x)
return activation
# Random Noise
def sample_noise(batch_size, dim):
random_noise = tf.random_uniform(maxval=1,minval=-1,shape=[batch_size, dim])
return random_noise
# Discriminator
def discriminator(x):
with tf.variable_scope("discriminator"):
fc1 = tf.layers.dense(inputs=x, units=256, activation=leaky_relu)
fc2 = tf.layers.dense(inputs=fc1, units=256, activation=leaky_relu)
logits = tf.layers.dense(inputs=fc2, units=1)
return logits
# Generator
def generator(z):
with tf.variable_scope("generator"):
fc1 = tf.layers.dense(inputs=z, units=1024, activation=tf.nn.relu)
fc2 = tf.layers.dense(inputs=fc1, units=1024, activation=tf.nn.relu)
img = tf.layers.dense(inputs=fc2, units=784, activation=tf.nn.tanh)
return img
# Compute GAN Loss
def gan_loss(logits_real, logits_fake):
# Target label vector for generator loss and used in discriminator loss.
true_labels = tf.ones_like(logits_fake)
# DISCRIMINATOR loss has 2 parts: how well it classifies real images and how well it
# classifies fake images.
real_image_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_real, labels=true_labels)
fake_image_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_fake, labels=1-true_labels)
# Combine and average losses over the batch
D_loss = real_image_loss + fake_image_loss
D_loss = tf.reduce_mean(D_loss)
# GENERATOR is trying to make the discriminator output 1 for all its images.
# So we use our target label vector of ones for computing generator loss.
G_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_fake, labels=true_labels)
# Average generator loss over the batch.
G_loss = tf.reduce_mean(G_loss)
return D_loss, G_loss
# Optimizing GAN Loss
def get_solvers(learning_rate=1e-3, beta1=0.5):
# Create solvers for GAN training
D_solver = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=beta1)
G_solver = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=beta1)
return D_solver, G_solver
##############################################################################
# Final Model
# Number of Images for Each Batch
batch_size = 128
# Noise Dimension
noise_dim = 96
# Placeholder for Images from the Training Dataset
x = tf.placeholder(tf.float32, [None, 784])
# Random Noise for the Generator
z = sample_noise(batch_size, noise_dim)
# Generated Images
G_sample = generator(z)
with tf.variable_scope("") as scope:
# Scale Images to be -1 to 1
logits_real = discriminator(preprocess_img(x))
# Re-use Discriminator Weights on New Inputs
scope.reuse_variables()
logits_fake = discriminator(G_sample)
# Get the List of Variables for the Discriminator and Generator
D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'discriminator')
G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'generator')
# Get the Solver
D_solver, G_solver = get_solvers()
# Get the Loss
D_loss, G_loss = gan_loss(logits_real, logits_fake)
# Setup Training Steps
D_train_step = D_solver.minimize(D_loss, var_list=D_vars)
G_train_step = G_solver.minimize(G_loss, var_list=G_vars)
D_extra_step = tf.get_collection(tf.GraphKeys.UPDATE_OPS, 'discriminator')
G_extra_step = tf.get_collection(tf.GraphKeys.UPDATE_OPS, 'generator')
##############################################################################
# Training a GAN
def training_gan(sess, G_train_step, G_loss, D_train_step, D_loss, G_extra_step, D_extra_step,\
show_every=250, print_every=50, batch_size=128, num_epoch=10):
# Compute the Number of Iterations Needed
max_iter = int(mnist.train.num_examples*num_epoch/batch_size)
for it in range(max_iter):
# For Every 250 Images, Show A Sample Result
if it % show_every == 0:
samples = sess.run(G_sample)
fig = show_images(samples[:16])
plt.show()
print()
# Run a Batch of Data
minibatch,minbatch_y = mnist.train.next_batch(batch_size)
_, D_loss_curr = sess.run([D_train_step, D_loss], feed_dict={x: minibatch})
_, G_loss_curr = sess.run([G_train_step, G_loss])
# For Every 50 Iterations, Print Loss
if it % print_every == 0:
print('Iter: {}, D: {:.4}, G:{:.4}'.format(it,D_loss_curr,G_loss_curr))
print('Final images')
samples = sess.run(G_sample)
fig = show_images(samples[:16])
plt.show()
# Run the helper function
with get_session() as sess:
sess.run(tf.global_variables_initializer())
training_gan(sess,G_train_step,G_loss,D_train_step,D_loss,G_extra_step,D_extra_step)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment