Skip to content

Instantly share code, notes, and snippets.

@arestifo
Created May 19, 2020 01:37
Show Gist options
  • Select an option

  • Save arestifo/8ac20f2abd00be917d18eab7b76dde96 to your computer and use it in GitHub Desktop.

Select an option

Save arestifo/8ac20f2abd00be917d18eab7b76dde96 to your computer and use it in GitHub Desktop.
Working (up to 1024x1024) version of my progressively growing GAN implementation in TensorFlow + Keras
import tensorflow as tf
from tensorflow.keras import layers, Model, optimizers, initializers, constraints
from tensorflow.keras import backend as k
from tensorflow.keras.layers import Dense, Conv2D, Input, LeakyReLU, Reshape, Flatten
from tensorflow.keras.layers import AveragePooling2D, UpSampling2D
from tensorflow.keras.constraints import max_norm
import matplotlib.pyplot as plt
import numpy as np
import glob
import os
import progressbar as pb
import pathlib
tf.config.experimental.set_memory_growth(tf.config.list_physical_devices('GPU')[0], True)
filters = [512, 512, 512, 512, 256, 128, 64, 32, 16]
batch_sizes = [64, 32, 32, 16, 16, 16, 16, 8, 8, 4, 4, 2, 2, 2, 2, 1, 1]
epoch_images = [8e5] * 17
init = initializers.he_normal()
latent_dim = 128
# Used in fade-in layers
# 'Fades' input 1 to input 2 using scaling factor alpha
class WeightedSum(layers.Add):
def __init__(self, alpha=0.0, **kwargs):
super(WeightedSum, self).__init__(**kwargs)
self.alpha = k.variable(alpha, name='ws_alpha')
def _merge_function(self, inputs):
assert len(inputs) == 2
return ((1.0 - self.alpha) * inputs[0]) + (self.alpha * inputs[1])
# ProGAN paper: pixelwise feature vector normalization layer
class PixelNorm(layers.Layer):
def __init__(self, **kwargs):
super(PixelNorm, self).__init__(**kwargs)
def build(self, input_shape):
super(PixelNorm, self).build(input_shape)
def call(self, inputs, **kwargs):
inputs *= tf.math.rsqrt(tf.reduce_mean(tf.square(inputs), axis=-1, keepdims=True) + 1e-8)
return inputs
def compute_output_shape(self, input_shape):
return input_shape
class MinibatchStd(layers.Layer):
# initialize the layer
def __init__(self, **kwargs):
super(MinibatchStd, self).__init__(**kwargs)
# perform the operation
def call(self, inputs, **kwargs):
# calculate the mean value for each pixel across channels
mean = tf.reduce_mean(inputs, axis=0, keepdims=True)
# calculate the average of the squared differences (variance) and add a small constant for numerical stability
variance = tf.reduce_mean(tf.square(inputs - mean), axis=0, keepdims=True) + 1e-8
# calculate the mean standard deviation across each pixel coord. stddev = sqrt of variance
average_stddev = tf.reduce_mean(tf.sqrt(variance), keepdims=True)
# Scale this up to be the size of one input feature map for each sample
shape = tf.shape(inputs)
minibatch_stddev = tf.tile(average_stddev, (shape[0], shape[1], shape[2], 1))
# concatenate minibatch std feature map with the input feature maps (axis=-1 if data_format=NHWC)
return tf.concat([inputs, minibatch_stddev], axis=-1)
# define the output shape of the layer
def compute_output_shape(self, input_shape):
input_shape = list(input_shape)
input_shape[-1] += 1 # batch-wide std adds one additional channel
return tuple(input_shape)
class EqualizedLearningRate(constraints.Constraint):
def __init__(self, gain=np.sqrt(2.0)):
self.gain = gain
def __call__(self, w):
# TODO: Finish this
return w
def plot_model(model, save_name):
tf.keras.utils.plot_model(model, to_file='progan_models/' + save_name + '.png',
expand_nested=True, show_shapes=True)
# StyleGAN paper: use learning rate 2 orders of magnitude less for the mapping network
# l_mapnet = 0.01 * l_gen
def view_image(image, title=None):
if image.ndim == 4:
image = image[0, :, :, :]
image = (image - np.min(image)) / np.ptp(image)
plt.plot()
if title:
plt.title(title)
plt.imshow(image)
plt.show()
def random_image(generator, dseen=None, save=False, fade_epoch=None):
if save:
assert dseen is not None, 'Random image saving requires providing discriminator seen stats'
assert fade_epoch is not None, 'Random image saving requires providing fade_epoch'
save_img_panel(generator, dseen, fade_epoch)
else:
r_img = generator(tf.random.normal((1, latent_dim)), training=False)
view_image(r_img)
def save_img_panel(generator, dseen, fade_epoch):
plt.figure(figsize=(3, 3))
predictions = generator(seed, training=False)
for i in range(predictions.shape[0]):
plt.subplot(3, 3, i + 1)
pred_image = predictions[i]
pred_image = (pred_image - np.min(pred_image)) / np.ptp(pred_image)
plt.imshow(pred_image)
plt.axis('off')
predictions = generator(tf.random.normal((seed.shape[0], latent_dim)), training=False)
res = generator.output_shape[1]
# Make save directory if it doesn't exist
pathlib.Path('celeba_output/{}x{}_run1'.format(res, res)).mkdir(parents=True, exist_ok=True)
plt.savefig('celeba_output/{}x{}_run1/{}_image_at_epoch_{:04d}.png'.format(
res, res, 'fade' if fade_epoch else 'straight', dseen
))
plt.show()
def parse_image(image_path):
image_string = tf.io.read_file(image_path)
image_decoded = tf.image.decode_jpeg(image_string, channels=3)
image = tf.cast(image_decoded, tf.float32)
# Scale the image to [-1, 1] (scale of the tanh activation function)
image = (image - 127.5) / 127.5
return image
def load_dataset(data_dir):
print('Loading dataset %s' % data_dir)
images = glob.glob(os.path.join(data_dir, '*.jpg'))
n_images = len(images)
dataset = tf.data.Dataset.from_tensor_slices(images)
dataset = dataset.shuffle(n_images).repeat()
dataset = dataset.map(parse_image, num_parallel_calls=12)
return dataset
def get_weighted_sum_layer(model):
sum_layer = -1
for i, layer in enumerate(model.layers):
if isinstance(layer, WeightedSum):
sum_layer = i
break
return sum_layer
def get_disc_train_fn():
@tf.function
def disc_train_on_batch(disc_model, gen_model, disc_opt, real_batch, batch_size, penalty_coeff=10,
drift_epsilon=0.001, reg_factor=1, margin_factor=10):
# reg_factor: regularization factor from WGAN-TV paper
# margin_factor:
"""
The margin factor δ is capable of controlling the trade-off between generative diversity and visual quality.
Higher values of δ lead to higher visual quality because it helps distinguish real data and fake data,
so that the generator has to output vivid images with more details. Lower values of δ lead to higher
image diversity.
"""
# Batch-sized sample from latent space to seed the generator
# keep in mind all the following math operations operate on *batches*
# epsilon = tf.random.uniform(shape=(batch_size, 1, 1, 1), maxval=1)
# latent_sample = tf.random.normal(shape=(batch_size, latent_dim))
#
# fake_batch = gen_model(latent_sample, training=True)
# mixed_batch = (epsilon * real_batch) + ((1.0 - epsilon) * fake_batch)
#
# mixed_preds = disc_model(mixed_batch, training=True)
# fake_preds = disc_model(fake_batch, training=True)
# real_preds = disc_model(real_batch, training=True)
# loss = tf.reduce_mean(fake_preds) - tf.reduce_mean(real_preds)
#
# # WGAN-GP loss
# mixed_grads = tf.gradients(mixed_preds, mixed_batch)[0]
# mixed_norms = tf.sqrt(tf.reduce_sum(tf.square(mixed_grads), axis=[1, 2, 3]))
# gradient_penalty = tf.reduce_mean((mixed_norms - 1.0) ** 2)
# loss += penalty_coeff * gradient_penalty
#
# # Penalize discriminator loss drifting from 0
# loss += drift_epsilon * tf.reduce_mean(tf.square(real_preds))
#
# disc_gradients = disc_opt.get_gradients(loss, disc_model.trainable_variables)
# disc_opt.apply_gradients(zip(disc_gradients, disc_model.trainable_variables))
latent_sample = tf.random.normal(shape=(batch_size, latent_dim))
fake_batch = gen_model(latent_sample, training=True)
fake_preds = disc_model(fake_batch, training=True)
real_preds = disc_model(real_batch, training=True)
loss = tf.reduce_mean(fake_preds) - tf.reduce_mean(real_preds)
# WGAN with total variational regularization
# Avoids calculating mixed-batch gradients
# from https://arxiv.org/pdf/1812.00810.pdf
loss += reg_factor * tf.reduce_mean(tf.abs(real_preds - fake_preds - margin_factor))
# Penalize discriminator loss drifting from 0 (from ProGAN paper)
loss += drift_epsilon * tf.reduce_mean(tf.square(real_preds))
disc_gradients = disc_opt.get_gradients(loss, disc_model.trainable_variables)
disc_opt.apply_gradients(zip(disc_gradients, disc_model.trainable_variables))
return loss
return disc_train_on_batch
def get_gen_train_fn():
@tf.function
def gen_train_on_batch(disc_model, gen_model, gen_opt, batch_size):
latent_sample = tf.random.normal(shape=(batch_size, latent_dim))
fake_batch = gen_model(latent_sample, training=True)
fake_preds = disc_model(fake_batch, training=True)
loss = -tf.reduce_mean(fake_preds)
# Get loss-scaled gradients and apply them
gen_gradients = gen_opt.get_gradients(loss, gen_model.trainable_variables)
gen_opt.apply_gradients(zip(gen_gradients, gen_model.trainable_variables))
return loss
return gen_train_on_batch
def get_pb_format():
return [
'[', pb.Variable('cur_res', format='{formatted_value}', width=1), '/',
pb.Variable('target_res', format='{formatted_value}', width=1), ']', ' ',
pb.Variable('epoch_type', format='{formatted_value}', width=1), ' ',
pb.Variable('g_loss'), ', ', pb.Variable('d_loss'), ' [', pb.SimpleProgress(), '] ',
pb.FileTransferSpeed(unit='img', prefixes=['', 'k', 'm']),
'\t', pb.ETA()
]
def train(gen_model, disc_model, target_res=16, update_batches=1, fade_epoch=False, model_restore=None):
# train schedule:
# 4x4_straight, 8x8_fade, 8x8_straight, 16x16_fade, 16x16_straight, ..., <res>x<res>_fade, <res>x<res>_straight
print('Starting training...')
# Get the initial training functions
disc_train_on_batch = get_disc_train_fn()
gen_train_on_batch = get_gen_train_fn()
# Define generator and discriminator optimizers
gen_opt = optimizers.Adam(learning_rate=0.00001, beta_1=0, beta_2=0.99, epsilon=1e-8)
disc_opt = optimizers.Adam(learning_rate=0.00001, beta_1=0, beta_2=0.99, epsilon=1e-8)
# Enable FP16 training
gen_opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(gen_opt)
disc_opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(disc_opt)
# get starting resolution from generator model
cur_res = gen_model.output_shape[1] # [N<H>WC], output_shape[2] works equally well
print_stats_freq = 25 # every print_stats_freq batches shown to discriminator
epochs = int(np.log2(target_res / cur_res)) * 2 + 1
print('Goal resolution %dx%d, starting at %dx%d' % (target_res, target_res, cur_res, cur_res))
for epoch in range(epochs):
cur_batch_size = batch_sizes[epoch]
cur_dataset = load_dataset('celeba/data%dx%d' % (cur_res, cur_res)).batch(cur_batch_size).as_numpy_iterator()
discriminator_seen = 0
gen_sum_layer_i = get_weighted_sum_layer(gen_model)
disc_sum_layer_i = get_weighted_sum_layer(disc_model)
# Save model weights at beginning of each epoch
gen_model.save_weights('checkpoints/%dx%d_%s_gen' % (cur_res, cur_res, 'fade' if fade_epoch else 'straight'))
disc_model.save_weights('checkpoints/%dx%d_%s_disc' % (cur_res, cur_res, 'fade' if fade_epoch else 'straight'))
# Main training loop
with pb.ProgressBar(widgets=get_pb_format(), max_value=epoch_images[epoch]) as progress:
while discriminator_seen < epoch_images[epoch]:
disc_losses = []
for disc_update_batch in range(update_batches):
disc_batch = cur_dataset.next()
# If fade-in epoch, then fade real images in progressively from cur_res / 2 to cur_res
# also update fade-in layer alphas
if fade_epoch:
# 2x downscale -> 2x upscale gives cur_res/2 detail at cur_res resolution
disc_batch_downscaled = tf.image.resize(disc_batch, [cur_res // 2, cur_res // 2])
disc_batch_downscaled = tf.image.resize(disc_batch_downscaled, [cur_res, cur_res])
alpha = min(1.0, discriminator_seen / epoch_images[epoch])
# Update the alphas
k.set_value(gen_model.get_layer(index=gen_sum_layer_i).alpha, alpha)
k.set_value(disc_model.get_layer(index=disc_sum_layer_i).alpha, alpha)
# Linearly interpolate between old and new resolutions
disc_batch = (disc_batch * alpha) + (disc_batch_downscaled * (1 - alpha))
disc_losses.append(disc_train_on_batch(disc_model, gen_model, disc_opt, disc_batch, cur_batch_size))
discriminator_seen += cur_batch_size
disc_loss = np.mean(disc_losses)
gen_loss = gen_train_on_batch(disc_model, gen_model, gen_opt, cur_batch_size)
if discriminator_seen % (cur_batch_size * print_stats_freq) == 0:
progress.update(discriminator_seen, cur_res=cur_res, target_res=target_res, g_loss=gen_loss,
d_loss=disc_loss, epoch_type='f' if fade_epoch else 's')
# print a grid of random images from the start seed
if discriminator_seen % (cur_batch_size * print_stats_freq * 12) == 0:
random_image(gen_model, dseen=discriminator_seen, save=True)
# Show random sample
random_image(gen_model)
# Double the resolution and generate new models for the new resolution
if not fade_epoch and cur_res != target_res:
cur_res *= 2
print('Rebuilding models at res=%d' % cur_res)
new_gen = build_generator(cur_res)
new_disc = build_discriminator(cur_res)
# Copy weights from the old models to new
copy_generator_weights(gen_model, new_gen)
copy_discriminator_weights(disc_model, new_disc)
gen_model = new_gen
disc_model = new_disc
# Reset optimizer internal states
reset_optimizer(gen_opt)
reset_optimizer(disc_opt)
# Generate new training functions (basically rebuilding the graphs to fit the new models/data sizes)
# TODO: Do we really need to do this with eager execution enabled? Isn't the point supposed to be we
# TODO: don't have to do this?
disc_train_on_batch = get_disc_train_fn()
gen_train_on_batch = get_gen_train_fn()
# Clean up for the next epoch
fade_epoch = not fade_epoch
# End main training loop
return gen_model, disc_model
def reset_optimizer(opt):
for var in opt.variables():
var.assign(tf.zeros_like(var))
# TODO: Have a function that returns custom layers so we don't have k_i and k_c boilerplate everywhere...
def build_generator(target_res):
input_layer = Input(shape=(latent_dim,))
g = PixelNorm()(input_layer)
g = Dense(4 * 4 * filters[0], kernel_initializer=init, kernel_constraint=max_norm(3.0))(g)
g = Reshape((4, 4, filters[0]))(g)
g = LeakyReLU(alpha=0.2)(g)
g = PixelNorm()(g)
g = Conv2D(filters[0], (4, 4), padding='same', kernel_initializer=init, kernel_constraint=max_norm(3.0))(g)
g = LeakyReLU(alpha=0.2)(g)
g = PixelNorm()(g)
g = Conv2D(filters[0], (3, 3), padding='same', kernel_initializer=init, kernel_constraint=max_norm(3.0))(g)
g = LeakyReLU(alpha=0.2)(g)
g = PixelNorm()(g)
blocks = int(np.log2(target_res / 4)) # 4x4 is our starting resolution
block_i = 0
for block_i in range(1, blocks):
g = UpSampling2D()(g)
g = Conv2D(filters[block_i], (3, 3), padding='same', kernel_initializer=init,
kernel_constraint=max_norm(3.0))(g)
g = LeakyReLU(alpha=0.2)(g)
g = PixelNorm()(g)
g = Conv2D(filters[block_i], (3, 3), padding='same', kernel_initializer=init,
kernel_constraint=max_norm(3.0))(g)
g = LeakyReLU(alpha=0.2)(g)
g = PixelNorm()(g)
# Fade-in block
block_i += 1
if blocks != 0:
new_b = UpSampling2D()(g)
new_b = Conv2D(filters[block_i], (3, 3), padding='same', kernel_initializer=init,
kernel_constraint=max_norm(3.0))(new_b)
new_b = LeakyReLU(alpha=0.2)(new_b)
new_b = PixelNorm()(new_b)
new_b = Conv2D(filters[block_i], (3, 3), padding='same', kernel_initializer=init,
kernel_constraint=max_norm(3.0))(new_b)
new_b = LeakyReLU(alpha=0.2)(new_b)
new_b = PixelNorm()(new_b)
# This is the *new* to-RGB layer so keep its weights for the next model
new_b = Conv2D(3, (1, 1), padding='same', kernel_initializer=init, kernel_constraint=max_norm(3.0),
dtype='float32')(new_b)
g = Conv2D(3, (1, 1), name='nocopy_rgb%d' % block_i, padding='same', kernel_initializer=init,
kernel_constraint=max_norm(3.0), dtype='float32')(g)
g = UpSampling2D(name='nocopy_upsamp%d' % block_i)(g)
g = WeightedSum(name='nocopy_wsum%d' % block_i)([g, new_b])
else:
g = Conv2D(3, (1, 1), name='nocopy_rgb0', padding='same', kernel_initializer=init,
kernel_constraint=max_norm(3.0), dtype='float32')(g)
return Model(inputs=input_layer, outputs=g)
def build_discriminator(target_res):
input_layer = Input(shape=(target_res, target_res, 3))
num_blocks = int(np.log2(target_res / 4)) # 4x4 is our starting resolution
d = input_layer
if num_blocks != 0:
new_b = Conv2D(filters[num_blocks], (1, 1), padding='same', kernel_initializer=init,
kernel_constraint=max_norm(3.0))(input_layer)
new_b = LeakyReLU(alpha=0.2)(new_b)
new_b = Conv2D(filters[num_blocks], (3, 3), padding='same', kernel_initializer=init,
kernel_constraint=max_norm(3.0))(new_b)
new_b = LeakyReLU(alpha=0.2)(new_b)
new_b = Conv2D(filters[num_blocks], (3, 3), padding='same', kernel_initializer=init,
kernel_constraint=max_norm(3.0))(new_b)
new_b = LeakyReLU(alpha=0.2)(new_b)
d = AveragePooling2D()(new_b)
old_b = AveragePooling2D(name='nocopy_avgpool2d%d' % num_blocks)(input_layer)
old_b = Conv2D(filters[num_blocks], (1, 1), name='nocopy_rgb%d' % num_blocks, padding='same',
kernel_initializer=init, kernel_constraint=max_norm(3.0))(old_b)
old_b = LeakyReLU(alpha=0.2, name='nocopy_lrelu%d' % num_blocks)(old_b)
d = WeightedSum(name='nocopy_wsum%d' % num_blocks)([old_b, d])
else:
d = Conv2D(filters[0], (1, 1), name='nocopy_rgb0', padding='same', kernel_initializer=init,
kernel_constraint=max_norm(3.0))(d)
d = LeakyReLU(alpha=0.2)(d)
# Add intermediate discriminator layers
# Do not use blocks+1 here because the output layer differs from the intermediate
for block_i in range(1, num_blocks):
d = Conv2D(filters[num_blocks - block_i], (3, 3), padding='same', kernel_initializer=init,
kernel_constraint=max_norm(3.0))(d)
d = LeakyReLU(alpha=0.2)(d)
d = Conv2D(filters[num_blocks - block_i], (3, 3), padding='same', kernel_initializer=init,
kernel_constraint=max_norm(3.0))(d)
d = LeakyReLU(alpha=0.2)(d)
d = AveragePooling2D()(d)
# Minibatch standard-deviation layer
d = MinibatchStd()(d)
d = Conv2D(filters[0], (3, 3), padding='same', kernel_initializer=init, kernel_constraint=max_norm(3.0))(d)
d = LeakyReLU(alpha=0.2)(d)
d = Conv2D(filters[0], (4, 4), padding='same', kernel_initializer=init, kernel_constraint=max_norm(3.0))(d)
d = LeakyReLU(alpha=0.2)(d)
d = Flatten()(d)
d = Dense(1, dtype='float32')(d)
return Model(inputs=input_layer, outputs=d)
# TODO: Merge these two functions
def copy_generator_weights(old_gen, new_gen):
old_layers = old_gen.layers
new_layers = new_gen.layers
shared_layers = [layer for layer in zip(old_layers, new_layers)
if isinstance(layer[0], type(layer[1])) and layer[0].output_shape == layer[1].output_shape]
for old_layer, new_layer in shared_layers:
new_layer.set_weights(old_layer.get_weights())
# print('copy g old_rgb to new_rgb')
old_rgb_idx = -1 if old_gen.output_shape[1] == 4 else -2
# copy rgb layers if they are compatible
if new_gen.layers[-5].output_shape == old_gen.layers[old_rgb_idx].output_shape:
print('copy g old_rgb to new_rgb')
new_gen.layers[-5].set_weights(old_gen.layers[old_rgb_idx].get_weights())
def copy_discriminator_weights(old_disc, new_disc):
old_layers = reversed(old_disc.layers)
new_layers = reversed(new_disc.layers)
shared_layers = [layer for layer in zip(old_layers, new_layers)
if isinstance(layer[0], type(layer[1])) and layer[0].output_shape == layer[1].output_shape]
for old_layer, new_layer in shared_layers:
new_layer.set_weights(old_layer.get_weights())
# copy rgb layers if they are compatible
# TODO: Potentially don't need to copy RGB layers. they only have ~2000 parameters
if new_disc.layers[7].output_shape == old_disc.layers[1].output_shape:
print('copy d old_rgb to new_rgb')
new_disc.layers[7].set_weights(old_disc.layers[1].get_weights())
def pl(la):
[print(_) for _ in la]
g4, g8, g16 = build_generator(4), build_generator(8), build_generator(16)
d4, d8, d16 = build_discriminator(4), build_discriminator(8), build_discriminator(16)
seed = tf.random.normal((3 * 3, latent_dim))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment