Created
December 23, 2017 14:27
-
-
Save aijournal/9b4c2ba8ac92abe1e61ebecac45b59d9 to your computer and use it in GitHub Desktop.
Code for Generative adversarial networks
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function, division | |
import tensorflow as tf | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import matplotlib.gridspec as gridspec | |
get_ipython().run_line_magic('matplotlib', 'inline') | |
plt.rcParams['figure.figsize'] = (10.0, 8.0) | |
plt.rcParams['image.interpolation'] = 'nearest' | |
plt.rcParams['image.cmap'] = 'gray' | |
def show_images(images): | |
images = np.reshape(images, [images.shape[0], -1]) | |
sqrtn = int(np.ceil(np.sqrt(images.shape[0]))) | |
sqrtimg = int(np.ceil(np.sqrt(images.shape[1]))) | |
fig = plt.figure(figsize=(sqrtn, sqrtn)) | |
gs = gridspec.GridSpec(sqrtn, sqrtn) | |
gs.update(wspace=0.05, hspace=0.05) | |
for i, img in enumerate(images): | |
ax = plt.subplot(gs[i]) | |
plt.axis('off') | |
ax.set_xticklabels([]) | |
ax.set_yticklabels([]) | |
ax.set_aspect('equal') | |
plt.imshow(img.reshape([sqrtimg,sqrtimg])) | |
return | |
def preprocess_img(x): | |
return 2 * x - 1.0 | |
def deprocess_img(x): | |
return (x + 1.0) / 2.0 | |
def rel_error(x,y): | |
return np.max(np.abs(x - y) / (np.maximum(1e-8, np.abs(x) + np.abs(y)))) | |
def count_params(): | |
"""Count the number of parameters in the current TensorFlow graph """ | |
param_count = np.sum([np.prod(x.get_shape().as_list()) for x in tf.global_variables()]) | |
return param_count | |
def get_session(): | |
config = tf.ConfigProto() | |
config.gpu_options.allow_growth = True | |
session = tf.Session(config=config) | |
return session | |
answers = np.load('gan-checks-tf.npz') | |
from tensorflow.examples.tutorials.mnist import input_data | |
mnist = input_data.read_data_sets('./cs231n/datasets/MNIST_data', one_hot=False) | |
show_images(mnist.train.next_batch(16)[0]) | |
def leaky_relu(x, alpha=0.01): | |
activation = tf.maximum(x,alpha*x) | |
return activation | |
def test_leaky_relu(x, y_true): | |
tf.reset_default_graph() | |
with get_session() as sess: | |
y_tf = leaky_relu(tf.constant(x)) | |
y = sess.run(y_tf) | |
print('Maximum error: %g'%rel_error(y_true, y)) | |
test_leaky_relu(answers['lrelu_x'], answers['lrelu_y']) | |
def sample_noise(batch_size, dim): | |
random_noise = tf.random_uniform(maxval=1,minval=-1,shape=[batch_size, dim]) | |
return random_noise | |
def test_sample_noise(): | |
batch_size = 3 | |
dim = 4 | |
tf.reset_default_graph() | |
with get_session() as sess: | |
z = sample_noise(batch_size, dim) | |
assert z.get_shape().as_list() == [batch_size, dim] | |
assert isinstance(z, tf.Tensor) | |
z1 = sess.run(z) | |
z2 = sess.run(z) | |
assert not np.array_equal(z1, z2) | |
assert np.all(z1 >= -1.0) and np.all(z1 <= 1.0) | |
print("All tests passed!") | |
test_sample_noise() | |
def discriminator(x): | |
with tf.variable_scope("discriminator"): | |
fc1 = tf.layers.dense(inputs=x, units=256, activation=leaky_relu) | |
fc2 = tf.layers.dense(inputs=fc1, units=256, activation=leaky_relu) | |
logits = tf.layers.dense(inputs=fc2, units=1) | |
return logits | |
def test_discriminator(true_count=267009): | |
tf.reset_default_graph() | |
with get_session() as sess: | |
y = discriminator(tf.ones((2, 784))) | |
cur_count = count_params() | |
if cur_count != true_count: | |
print('Incorrect number of parameters in discriminator. {0} instead of {1}. Check your achitecture.'.format(cur_count,true_count)) | |
else: | |
print('Correct number of parameters in discriminator.') | |
test_discriminator() | |
def generator(z): | |
with tf.variable_scope("generator"): | |
fc1 = tf.layers.dense(inputs=z, units=1024, activation=tf.nn.relu) | |
fc2 = tf.layers.dense(inputs=fc1, units=1024, activation=tf.nn.relu) | |
img = tf.layers.dense(inputs=fc2, units=784, activation=tf.nn.tanh) | |
return img | |
def test_generator(true_count=1858320): | |
tf.reset_default_graph() | |
with get_session() as sess: | |
y = generator(tf.ones((1, 4))) | |
cur_count = count_params() | |
if cur_count != true_count: | |
print('Incorrect number of parameters in generator. {0} instead of {1}. Check your achitecture.'.format(cur_count,true_count)) | |
else: | |
print('Correct number of parameters in generator.') | |
test_generator() | |
def gan_loss(logits_real, logits_fake): | |
true_labels = tf.ones_like(logits_fake) | |
real_image_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_real, labels=true_labels) | |
fake_image_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_fake, labels=1-true_labels) | |
D_loss = real_image_loss + fake_image_loss | |
D_loss = tf.reduce_mean(D_loss) | |
G_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_fake, labels=true_labels) | |
G_loss = tf.reduce_mean(G_loss) | |
return D_loss, G_loss | |
def test_gan_loss(logits_real, logits_fake, d_loss_true, g_loss_true): | |
tf.reset_default_graph() | |
with get_session() as sess: | |
d_loss, g_loss = sess.run(gan_loss(tf.constant(logits_real), tf.constant(logits_fake))) | |
print("Maximum error in d_loss: %g"%rel_error(d_loss_true, d_loss)) | |
print("Maximum error in g_loss: %g"%rel_error(g_loss_true, g_loss)) | |
test_gan_loss(answers['logits_real'], answers['logits_fake'], | |
answers['d_loss_true'], answers['g_loss_true']) | |
def get_solvers(learning_rate=1e-3, beta1=0.5): | |
D_solver = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=beta1) | |
G_solver = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=beta1) | |
return D_solver, G_solver | |
tf.reset_default_graph() | |
batch_size = 128 | |
noise_dim = 96 | |
x = tf.placeholder(tf.float32, [None, 784]) | |
z = sample_noise(batch_size, noise_dim) | |
G_sample = generator(z) | |
with tf.variable_scope("") as scope: | |
logits_real = discriminator(preprocess_img(x)) | |
scope.reuse_variables() | |
logits_fake = discriminator(G_sample) | |
D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'discriminator') | |
G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'generator') | |
D_solver, G_solver = get_solvers() | |
D_loss, G_loss = gan_loss(logits_real, logits_fake) | |
D_train_step = D_solver.minimize(D_loss, var_list=D_vars) | |
G_train_step = G_solver.minimize(G_loss, var_list=G_vars) | |
D_extra_step = tf.get_collection(tf.GraphKeys.UPDATE_OPS, 'discriminator') | |
G_extra_step = tf.get_collection(tf.GraphKeys.UPDATE_OPS, 'generator') | |
def train(sess, G_train_step, G_loss, D_train_step, D_loss, G_extra_step, D_extra_step,show_every=250, print_every=50, batch_size=128, num_epoch=10): | |
max_iter = int(mnist.train.num_examples*num_epoch/batch_size) | |
for it in range(max_iter): | |
if it % show_every == 0: | |
samples = sess.run(G_sample) | |
fig = show_images(samples[:16]) | |
plt.show() | |
print() | |
minibatch,minbatch_y = mnist.train.next_batch(batch_size) | |
_, D_loss_curr = sess.run([D_train_step, D_loss], feed_dict={x: minibatch}) | |
_, G_loss_curr = sess.run([G_train_step, G_loss]) | |
if it % print_every == 0: | |
print('Iter: {}, D: {:.4}, G:{:.4}'.format(it,D_loss_curr,G_loss_curr)) | |
print('Final images') | |
samples = sess.run(G_sample) | |
fig = show_images(samples[:16]) | |
plt.show() | |
with get_session() as sess: | |
sess.run(tf.global_variables_initializer()) | |
train(sess,G_train_step,G_loss,D_train_step,D_loss,G_extra_step,D_extra_step) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment