Skip to content

Instantly share code, notes, and snippets.

@aijournal
Created December 23, 2017 14:27
Show Gist options
  • Save aijournal/9b4c2ba8ac92abe1e61ebecac45b59d9 to your computer and use it in GitHub Desktop.
Save aijournal/9b4c2ba8ac92abe1e61ebecac45b59d9 to your computer and use it in GitHub Desktop.
Code for Generative adversarial networks
from __future__ import print_function, division
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
get_ipython().run_line_magic('matplotlib', 'inline')
plt.rcParams['figure.figsize'] = (10.0, 8.0)
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
def show_images(images):
images = np.reshape(images, [images.shape[0], -1])
sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))
fig = plt.figure(figsize=(sqrtn, sqrtn))
gs = gridspec.GridSpec(sqrtn, sqrtn)
gs.update(wspace=0.05, hspace=0.05)
for i, img in enumerate(images):
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
plt.imshow(img.reshape([sqrtimg,sqrtimg]))
return
def preprocess_img(x):
return 2 * x - 1.0
def deprocess_img(x):
return (x + 1.0) / 2.0
def rel_error(x,y):
return np.max(np.abs(x - y) / (np.maximum(1e-8, np.abs(x) + np.abs(y))))
def count_params():
"""Count the number of parameters in the current TensorFlow graph """
param_count = np.sum([np.prod(x.get_shape().as_list()) for x in tf.global_variables()])
return param_count
def get_session():
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.Session(config=config)
return session
answers = np.load('gan-checks-tf.npz')
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('./cs231n/datasets/MNIST_data', one_hot=False)
show_images(mnist.train.next_batch(16)[0])
def leaky_relu(x, alpha=0.01):
activation = tf.maximum(x,alpha*x)
return activation
def test_leaky_relu(x, y_true):
tf.reset_default_graph()
with get_session() as sess:
y_tf = leaky_relu(tf.constant(x))
y = sess.run(y_tf)
print('Maximum error: %g'%rel_error(y_true, y))
test_leaky_relu(answers['lrelu_x'], answers['lrelu_y'])
def sample_noise(batch_size, dim):
random_noise = tf.random_uniform(maxval=1,minval=-1,shape=[batch_size, dim])
return random_noise
def test_sample_noise():
batch_size = 3
dim = 4
tf.reset_default_graph()
with get_session() as sess:
z = sample_noise(batch_size, dim)
assert z.get_shape().as_list() == [batch_size, dim]
assert isinstance(z, tf.Tensor)
z1 = sess.run(z)
z2 = sess.run(z)
assert not np.array_equal(z1, z2)
assert np.all(z1 >= -1.0) and np.all(z1 <= 1.0)
print("All tests passed!")
test_sample_noise()
def discriminator(x):
with tf.variable_scope("discriminator"):
fc1 = tf.layers.dense(inputs=x, units=256, activation=leaky_relu)
fc2 = tf.layers.dense(inputs=fc1, units=256, activation=leaky_relu)
logits = tf.layers.dense(inputs=fc2, units=1)
return logits
def test_discriminator(true_count=267009):
tf.reset_default_graph()
with get_session() as sess:
y = discriminator(tf.ones((2, 784)))
cur_count = count_params()
if cur_count != true_count:
print('Incorrect number of parameters in discriminator. {0} instead of {1}. Check your achitecture.'.format(cur_count,true_count))
else:
print('Correct number of parameters in discriminator.')
test_discriminator()
def generator(z):
with tf.variable_scope("generator"):
fc1 = tf.layers.dense(inputs=z, units=1024, activation=tf.nn.relu)
fc2 = tf.layers.dense(inputs=fc1, units=1024, activation=tf.nn.relu)
img = tf.layers.dense(inputs=fc2, units=784, activation=tf.nn.tanh)
return img
def test_generator(true_count=1858320):
tf.reset_default_graph()
with get_session() as sess:
y = generator(tf.ones((1, 4)))
cur_count = count_params()
if cur_count != true_count:
print('Incorrect number of parameters in generator. {0} instead of {1}. Check your achitecture.'.format(cur_count,true_count))
else:
print('Correct number of parameters in generator.')
test_generator()
def gan_loss(logits_real, logits_fake):
true_labels = tf.ones_like(logits_fake)
real_image_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_real, labels=true_labels)
fake_image_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_fake, labels=1-true_labels)
D_loss = real_image_loss + fake_image_loss
D_loss = tf.reduce_mean(D_loss)
G_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_fake, labels=true_labels)
G_loss = tf.reduce_mean(G_loss)
return D_loss, G_loss
def test_gan_loss(logits_real, logits_fake, d_loss_true, g_loss_true):
tf.reset_default_graph()
with get_session() as sess:
d_loss, g_loss = sess.run(gan_loss(tf.constant(logits_real), tf.constant(logits_fake)))
print("Maximum error in d_loss: %g"%rel_error(d_loss_true, d_loss))
print("Maximum error in g_loss: %g"%rel_error(g_loss_true, g_loss))
test_gan_loss(answers['logits_real'], answers['logits_fake'],
answers['d_loss_true'], answers['g_loss_true'])
def get_solvers(learning_rate=1e-3, beta1=0.5):
D_solver = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=beta1)
G_solver = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=beta1)
return D_solver, G_solver
tf.reset_default_graph()
batch_size = 128
noise_dim = 96
x = tf.placeholder(tf.float32, [None, 784])
z = sample_noise(batch_size, noise_dim)
G_sample = generator(z)
with tf.variable_scope("") as scope:
logits_real = discriminator(preprocess_img(x))
scope.reuse_variables()
logits_fake = discriminator(G_sample)
D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'discriminator')
G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'generator')
D_solver, G_solver = get_solvers()
D_loss, G_loss = gan_loss(logits_real, logits_fake)
D_train_step = D_solver.minimize(D_loss, var_list=D_vars)
G_train_step = G_solver.minimize(G_loss, var_list=G_vars)
D_extra_step = tf.get_collection(tf.GraphKeys.UPDATE_OPS, 'discriminator')
G_extra_step = tf.get_collection(tf.GraphKeys.UPDATE_OPS, 'generator')
def train(sess, G_train_step, G_loss, D_train_step, D_loss, G_extra_step, D_extra_step,show_every=250, print_every=50, batch_size=128, num_epoch=10):
max_iter = int(mnist.train.num_examples*num_epoch/batch_size)
for it in range(max_iter):
if it % show_every == 0:
samples = sess.run(G_sample)
fig = show_images(samples[:16])
plt.show()
print()
minibatch,minbatch_y = mnist.train.next_batch(batch_size)
_, D_loss_curr = sess.run([D_train_step, D_loss], feed_dict={x: minibatch})
_, G_loss_curr = sess.run([G_train_step, G_loss])
if it % print_every == 0:
print('Iter: {}, D: {:.4}, G:{:.4}'.format(it,D_loss_curr,G_loss_curr))
print('Final images')
samples = sess.run(G_sample)
fig = show_images(samples[:16])
plt.show()
with get_session() as sess:
sess.run(tf.global_variables_initializer())
train(sess,G_train_step,G_loss,D_train_step,D_loss,G_extra_step,D_extra_step)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment