Skip to content

Instantly share code, notes, and snippets.

@goldsborough
Created September 6, 2017 15:29
Show Gist options
  • Save goldsborough/07304d204c4eb125b933ee62eaeba8a5 to your computer and use it in GitHub Desktop.
Save goldsborough/07304d204c4eb125b933ee62eaeba8a5 to your computer and use it in GitHub Desktop.
import os
import keras.backend as K
import matplotlib.pyplot as plot
import numpy as np
import tensorflow as tf
from keras.datasets import mnist
from keras.initializers import TruncatedNormal
from keras.layers import (Activation, BatchNormalization, Conv2D,
Conv2DTranspose, Dense, Input, LeakyReLU,
MaxPooling2D, Reshape, UpSampling2D)
from keras.models import Model, Sequential
from keras.optimizers import SGD, Adam, RMSprop
gpu_options = tf.GPUOptions(allow_growth=True)
session = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
K.set_session(session)
# Supress warnings about wrong compilation of TensorFlow.
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
noise_size = 100
number_of_epochs = 100
batch_size = 64
initializer = TruncatedNormal(mean=0.0, stddev=0.02)
z = Input(shape=[noise_size])
g = Dense(7 * 7 * 128, kernel_initializer=initializer)(z)
g = BatchNormalization()(g)
g = Activation('relu')(g)
g = Reshape([7, 7, 128])(g)
# 7 x 7 x 128
g = Conv2DTranspose(
64, (5, 5), strides=(2, 2), padding='same',
kernel_initializer=initializer)(g)
g = BatchNormalization()(g)
g = Activation('relu')(g)
# 14 x 14 x 64
g = Conv2DTranspose(
1, (5, 5), strides=(2, 2), padding='same',
kernel_initializer=initializer)(g)
g = Activation('sigmoid')(g)
# 28 x 28 x 1
generator = Model(z, g)
x = Input(shape=[28, 28, 1])
# 28 x 28 x 1
d = Conv2D(
128, (5, 5),
strides=(2, 2),
padding='same',
kernel_initializer=initializer)(x)
d = LeakyReLU(alpha=0.2)(d)
# 14 x 14 x 128
d = Conv2D(
64, (5, 5), strides=(2, 2), padding='same',
kernel_initializer=initializer)(d)
d = BatchNormalization()(d)
d = LeakyReLU(alpha=0.2)(d)
# 7 x 7 x 64
d = Conv2D(
32, (5, 5), strides=(2, 2), padding='same',
kernel_initializer=initializer)(d)
d = BatchNormalization()(d)
d = LeakyReLU(alpha=0.2)(d)
# 4 x 4 x 32
d = Conv2D(
1, (4, 4),
padding='valid',
activation='sigmoid',
kernel_initializer=initializer)(d)
p = Reshape([1])(d)
# 1 x 1
discriminator = Model(x, p)
# SGD(lr=5e-5, momentum=0.9, nesterov=True, decay=6e-8)
discriminator.compile(
loss='binary_crossentropy', optimizer=RMSprop(lr=1e-5, decay=5e-10))
gan = Sequential()
gan.add(generator)
discriminator.trainable = False
gan.add(discriminator)
gan.compile(loss='binary_crossentropy', optimizer=RMSprop(lr=5e-4, decay=5e-9))
discriminator.trainable = True
generator.summary()
discriminator.summary()
(x_train, _), (x_test, _) = mnist.load_data()
x_train = np.expand_dims(x_train, axis=-1) / 255
x_test = np.expand_dims(x_train, axis=-1) / 255
def noise(size):
return np.random.uniform(-1, +1, size=(size, noise_size))
try:
for epoch in range(number_of_epochs):
print('Epoch: {0}/{1}'.format(epoch + 1, number_of_epochs))
for batch_start in range(0, len(x_train), batch_size):
generated_images = generator.predict_on_batch(noise(batch_size))
real_images = x_train[batch_start:batch_start + batch_size]
all_images = np.concatenate(
[generated_images, real_images], axis=0)
labels = np.zeros(len(all_images))
labels[batch_size:] = 1
d_loss = discriminator.train_on_batch(all_images, labels)
labels = np.ones(batch_size * 2)
discriminator.trainable = False
g_loss = gan.train_on_batch(noise(batch_size * 2), labels)
discriminator.trainable = True
batch_index = batch_start // batch_size + 1
message = '\rBatch: {0} | D: {1:.7f} | G: {2:.7f}'
print(message.format(batch_index, d_loss, g_loss), end='')
print()
np.random.shuffle(x_train)
except KeyboardInterrupt:
print()
print('Training complete!')
display_images = 4
images = generator.predict_on_batch(noise(display_images))
plot.switch_backend('Agg')
plot.figure(figsize=(10, 4))
for i in range(display_images):
axis = plot.subplot(1, display_images, i + 1)
plot.imshow(images[i].reshape(28, 28), cmap='gray')
axis.get_xaxis().set_visible(False)
axis.get_yaxis().set_visible(False)
print('Saving fig.png')
plot.savefig('fig.png')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment