Created
May 19, 2020 01:37
-
-
Save arestifo/8ac20f2abd00be917d18eab7b76dde96 to your computer and use it in GitHub Desktop.
Working (up to 1024x1024) version of my progressively growing GAN implementation in TensorFlow + Keras
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import tensorflow as tf | |
| from tensorflow.keras import layers, Model, optimizers, initializers, constraints | |
| from tensorflow.keras import backend as k | |
| from tensorflow.keras.layers import Dense, Conv2D, Input, LeakyReLU, Reshape, Flatten | |
| from tensorflow.keras.layers import AveragePooling2D, UpSampling2D | |
| from tensorflow.keras.constraints import max_norm | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import glob | |
| import os | |
| import progressbar as pb | |
| import pathlib | |
| tf.config.experimental.set_memory_growth(tf.config.list_physical_devices('GPU')[0], True) | |
| filters = [512, 512, 512, 512, 256, 128, 64, 32, 16] | |
| batch_sizes = [64, 32, 32, 16, 16, 16, 16, 8, 8, 4, 4, 2, 2, 2, 2, 1, 1] | |
| epoch_images = [8e5] * 17 | |
| init = initializers.he_normal() | |
| latent_dim = 128 | |
| # Used in fade-in layers | |
| # 'Fades' input 1 to input 2 using scaling factor alpha | |
| class WeightedSum(layers.Add): | |
| def __init__(self, alpha=0.0, **kwargs): | |
| super(WeightedSum, self).__init__(**kwargs) | |
| self.alpha = k.variable(alpha, name='ws_alpha') | |
| def _merge_function(self, inputs): | |
| assert len(inputs) == 2 | |
| return ((1.0 - self.alpha) * inputs[0]) + (self.alpha * inputs[1]) | |
| # ProGAN paper: pixelwise feature vector normalization layer | |
| class PixelNorm(layers.Layer): | |
| def __init__(self, **kwargs): | |
| super(PixelNorm, self).__init__(**kwargs) | |
| def build(self, input_shape): | |
| super(PixelNorm, self).build(input_shape) | |
| def call(self, inputs, **kwargs): | |
| inputs *= tf.math.rsqrt(tf.reduce_mean(tf.square(inputs), axis=-1, keepdims=True) + 1e-8) | |
| return inputs | |
| def compute_output_shape(self, input_shape): | |
| return input_shape | |
| class MinibatchStd(layers.Layer): | |
| # initialize the layer | |
| def __init__(self, **kwargs): | |
| super(MinibatchStd, self).__init__(**kwargs) | |
| # perform the operation | |
| def call(self, inputs, **kwargs): | |
| # calculate the mean value for each pixel across channels | |
| mean = tf.reduce_mean(inputs, axis=0, keepdims=True) | |
| # calculate the average of the squared differences (variance) and add a small constant for numerical stability | |
| variance = tf.reduce_mean(tf.square(inputs - mean), axis=0, keepdims=True) + 1e-8 | |
| # calculate the mean standard deviation across each pixel coord. stddev = sqrt of variance | |
| average_stddev = tf.reduce_mean(tf.sqrt(variance), keepdims=True) | |
| # Scale this up to be the size of one input feature map for each sample | |
| shape = tf.shape(inputs) | |
| minibatch_stddev = tf.tile(average_stddev, (shape[0], shape[1], shape[2], 1)) | |
| # concatenate minibatch std feature map with the input feature maps (axis=-1 if data_format=NHWC) | |
| return tf.concat([inputs, minibatch_stddev], axis=-1) | |
| # define the output shape of the layer | |
| def compute_output_shape(self, input_shape): | |
| input_shape = list(input_shape) | |
| input_shape[-1] += 1 # batch-wide std adds one additional channel | |
| return tuple(input_shape) | |
| class EqualizedLearningRate(constraints.Constraint): | |
| def __init__(self, gain=np.sqrt(2.0)): | |
| self.gain = gain | |
| def __call__(self, w): | |
| # TODO: Finish this | |
| return w | |
| def plot_model(model, save_name): | |
| tf.keras.utils.plot_model(model, to_file='progan_models/' + save_name + '.png', | |
| expand_nested=True, show_shapes=True) | |
| # StyleGAN paper: use learning rate 2 orders of magnitude less for the mapping network | |
| # l_mapnet = 0.01 * l_gen | |
| def view_image(image, title=None): | |
| if image.ndim == 4: | |
| image = image[0, :, :, :] | |
| image = (image - np.min(image)) / np.ptp(image) | |
| plt.plot() | |
| if title: | |
| plt.title(title) | |
| plt.imshow(image) | |
| plt.show() | |
| def random_image(generator, dseen=None, save=False, fade_epoch=None): | |
| if save: | |
| assert dseen is not None, 'Random image saving requires providing discriminator seen stats' | |
| assert fade_epoch is not None, 'Random image saving requires providing fade_epoch' | |
| save_img_panel(generator, dseen, fade_epoch) | |
| else: | |
| r_img = generator(tf.random.normal((1, latent_dim)), training=False) | |
| view_image(r_img) | |
| def save_img_panel(generator, dseen, fade_epoch): | |
| plt.figure(figsize=(3, 3)) | |
| predictions = generator(seed, training=False) | |
| for i in range(predictions.shape[0]): | |
| plt.subplot(3, 3, i + 1) | |
| pred_image = predictions[i] | |
| pred_image = (pred_image - np.min(pred_image)) / np.ptp(pred_image) | |
| plt.imshow(pred_image) | |
| plt.axis('off') | |
| predictions = generator(tf.random.normal((seed.shape[0], latent_dim)), training=False) | |
| res = generator.output_shape[1] | |
| # Make save directory if it doesn't exist | |
| pathlib.Path('celeba_output/{}x{}_run1'.format(res, res)).mkdir(parents=True, exist_ok=True) | |
| plt.savefig('celeba_output/{}x{}_run1/{}_image_at_epoch_{:04d}.png'.format( | |
| res, res, 'fade' if fade_epoch else 'straight', dseen | |
| )) | |
| plt.show() | |
| def parse_image(image_path): | |
| image_string = tf.io.read_file(image_path) | |
| image_decoded = tf.image.decode_jpeg(image_string, channels=3) | |
| image = tf.cast(image_decoded, tf.float32) | |
| # Scale the image to [-1, 1] (scale of the tanh activation function) | |
| image = (image - 127.5) / 127.5 | |
| return image | |
| def load_dataset(data_dir): | |
| print('Loading dataset %s' % data_dir) | |
| images = glob.glob(os.path.join(data_dir, '*.jpg')) | |
| n_images = len(images) | |
| dataset = tf.data.Dataset.from_tensor_slices(images) | |
| dataset = dataset.shuffle(n_images).repeat() | |
| dataset = dataset.map(parse_image, num_parallel_calls=12) | |
| return dataset | |
| def get_weighted_sum_layer(model): | |
| sum_layer = -1 | |
| for i, layer in enumerate(model.layers): | |
| if isinstance(layer, WeightedSum): | |
| sum_layer = i | |
| break | |
| return sum_layer | |
| def get_disc_train_fn(): | |
| @tf.function | |
| def disc_train_on_batch(disc_model, gen_model, disc_opt, real_batch, batch_size, penalty_coeff=10, | |
| drift_epsilon=0.001, reg_factor=1, margin_factor=10): | |
| # reg_factor: regularization factor from WGAN-TV paper | |
| # margin_factor: | |
| """ | |
| The margin factor δ is capable of controlling the trade-off between generative diversity and visual quality. | |
| Higher values of δ lead to higher visual quality because it helps distinguish real data and fake data, | |
| so that the generator has to output vivid images with more details. Lower values of δ lead to higher | |
| image diversity. | |
| """ | |
| # Batch-sized sample from latent space to seed the generator | |
| # keep in mind all the following math operations operate on *batches* | |
| # epsilon = tf.random.uniform(shape=(batch_size, 1, 1, 1), maxval=1) | |
| # latent_sample = tf.random.normal(shape=(batch_size, latent_dim)) | |
| # | |
| # fake_batch = gen_model(latent_sample, training=True) | |
| # mixed_batch = (epsilon * real_batch) + ((1.0 - epsilon) * fake_batch) | |
| # | |
| # mixed_preds = disc_model(mixed_batch, training=True) | |
| # fake_preds = disc_model(fake_batch, training=True) | |
| # real_preds = disc_model(real_batch, training=True) | |
| # loss = tf.reduce_mean(fake_preds) - tf.reduce_mean(real_preds) | |
| # | |
| # # WGAN-GP loss | |
| # mixed_grads = tf.gradients(mixed_preds, mixed_batch)[0] | |
| # mixed_norms = tf.sqrt(tf.reduce_sum(tf.square(mixed_grads), axis=[1, 2, 3])) | |
| # gradient_penalty = tf.reduce_mean((mixed_norms - 1.0) ** 2) | |
| # loss += penalty_coeff * gradient_penalty | |
| # | |
| # # Penalize discriminator loss drifting from 0 | |
| # loss += drift_epsilon * tf.reduce_mean(tf.square(real_preds)) | |
| # | |
| # disc_gradients = disc_opt.get_gradients(loss, disc_model.trainable_variables) | |
| # disc_opt.apply_gradients(zip(disc_gradients, disc_model.trainable_variables)) | |
| latent_sample = tf.random.normal(shape=(batch_size, latent_dim)) | |
| fake_batch = gen_model(latent_sample, training=True) | |
| fake_preds = disc_model(fake_batch, training=True) | |
| real_preds = disc_model(real_batch, training=True) | |
| loss = tf.reduce_mean(fake_preds) - tf.reduce_mean(real_preds) | |
| # WGAN with total variational regularization | |
| # Avoids calculating mixed-batch gradients | |
| # from https://arxiv.org/pdf/1812.00810.pdf | |
| loss += reg_factor * tf.reduce_mean(tf.abs(real_preds - fake_preds - margin_factor)) | |
| # Penalize discriminator loss drifting from 0 (from ProGAN paper) | |
| loss += drift_epsilon * tf.reduce_mean(tf.square(real_preds)) | |
| disc_gradients = disc_opt.get_gradients(loss, disc_model.trainable_variables) | |
| disc_opt.apply_gradients(zip(disc_gradients, disc_model.trainable_variables)) | |
| return loss | |
| return disc_train_on_batch | |
| def get_gen_train_fn(): | |
| @tf.function | |
| def gen_train_on_batch(disc_model, gen_model, gen_opt, batch_size): | |
| latent_sample = tf.random.normal(shape=(batch_size, latent_dim)) | |
| fake_batch = gen_model(latent_sample, training=True) | |
| fake_preds = disc_model(fake_batch, training=True) | |
| loss = -tf.reduce_mean(fake_preds) | |
| # Get loss-scaled gradients and apply them | |
| gen_gradients = gen_opt.get_gradients(loss, gen_model.trainable_variables) | |
| gen_opt.apply_gradients(zip(gen_gradients, gen_model.trainable_variables)) | |
| return loss | |
| return gen_train_on_batch | |
| def get_pb_format(): | |
| return [ | |
| '[', pb.Variable('cur_res', format='{formatted_value}', width=1), '/', | |
| pb.Variable('target_res', format='{formatted_value}', width=1), ']', ' ', | |
| pb.Variable('epoch_type', format='{formatted_value}', width=1), ' ', | |
| pb.Variable('g_loss'), ', ', pb.Variable('d_loss'), ' [', pb.SimpleProgress(), '] ', | |
| pb.FileTransferSpeed(unit='img', prefixes=['', 'k', 'm']), | |
| '\t', pb.ETA() | |
| ] | |
| def train(gen_model, disc_model, target_res=16, update_batches=1, fade_epoch=False, model_restore=None): | |
| # train schedule: | |
| # 4x4_straight, 8x8_fade, 8x8_straight, 16x16_fade, 16x16_straight, ..., <res>x<res>_fade, <res>x<res>_straight | |
| print('Starting training...') | |
| # Get the initial training functions | |
| disc_train_on_batch = get_disc_train_fn() | |
| gen_train_on_batch = get_gen_train_fn() | |
| # Define generator and discriminator optimizers | |
| gen_opt = optimizers.Adam(learning_rate=0.00001, beta_1=0, beta_2=0.99, epsilon=1e-8) | |
| disc_opt = optimizers.Adam(learning_rate=0.00001, beta_1=0, beta_2=0.99, epsilon=1e-8) | |
| # Enable FP16 training | |
| gen_opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(gen_opt) | |
| disc_opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(disc_opt) | |
| # get starting resolution from generator model | |
| cur_res = gen_model.output_shape[1] # [N<H>WC], output_shape[2] works equally well | |
| print_stats_freq = 25 # every print_stats_freq batches shown to discriminator | |
| epochs = int(np.log2(target_res / cur_res)) * 2 + 1 | |
| print('Goal resolution %dx%d, starting at %dx%d' % (target_res, target_res, cur_res, cur_res)) | |
| for epoch in range(epochs): | |
| cur_batch_size = batch_sizes[epoch] | |
| cur_dataset = load_dataset('celeba/data%dx%d' % (cur_res, cur_res)).batch(cur_batch_size).as_numpy_iterator() | |
| discriminator_seen = 0 | |
| gen_sum_layer_i = get_weighted_sum_layer(gen_model) | |
| disc_sum_layer_i = get_weighted_sum_layer(disc_model) | |
| # Save model weights at beginning of each epoch | |
| gen_model.save_weights('checkpoints/%dx%d_%s_gen' % (cur_res, cur_res, 'fade' if fade_epoch else 'straight')) | |
| disc_model.save_weights('checkpoints/%dx%d_%s_disc' % (cur_res, cur_res, 'fade' if fade_epoch else 'straight')) | |
| # Main training loop | |
| with pb.ProgressBar(widgets=get_pb_format(), max_value=epoch_images[epoch]) as progress: | |
| while discriminator_seen < epoch_images[epoch]: | |
| disc_losses = [] | |
| for disc_update_batch in range(update_batches): | |
| disc_batch = cur_dataset.next() | |
| # If fade-in epoch, then fade real images in progressively from cur_res / 2 to cur_res | |
| # also update fade-in layer alphas | |
| if fade_epoch: | |
| # 2x downscale -> 2x upscale gives cur_res/2 detail at cur_res resolution | |
| disc_batch_downscaled = tf.image.resize(disc_batch, [cur_res // 2, cur_res // 2]) | |
| disc_batch_downscaled = tf.image.resize(disc_batch_downscaled, [cur_res, cur_res]) | |
| alpha = min(1.0, discriminator_seen / epoch_images[epoch]) | |
| # Update the alphas | |
| k.set_value(gen_model.get_layer(index=gen_sum_layer_i).alpha, alpha) | |
| k.set_value(disc_model.get_layer(index=disc_sum_layer_i).alpha, alpha) | |
| # Linearly interpolate between old and new resolutions | |
| disc_batch = (disc_batch * alpha) + (disc_batch_downscaled * (1 - alpha)) | |
| disc_losses.append(disc_train_on_batch(disc_model, gen_model, disc_opt, disc_batch, cur_batch_size)) | |
| discriminator_seen += cur_batch_size | |
| disc_loss = np.mean(disc_losses) | |
| gen_loss = gen_train_on_batch(disc_model, gen_model, gen_opt, cur_batch_size) | |
| if discriminator_seen % (cur_batch_size * print_stats_freq) == 0: | |
| progress.update(discriminator_seen, cur_res=cur_res, target_res=target_res, g_loss=gen_loss, | |
| d_loss=disc_loss, epoch_type='f' if fade_epoch else 's') | |
| # print a grid of random images from the start seed | |
| if discriminator_seen % (cur_batch_size * print_stats_freq * 12) == 0: | |
| random_image(gen_model, dseen=discriminator_seen, save=True) | |
| # Show random sample | |
| random_image(gen_model) | |
| # Double the resolution and generate new models for the new resolution | |
| if not fade_epoch and cur_res != target_res: | |
| cur_res *= 2 | |
| print('Rebuilding models at res=%d' % cur_res) | |
| new_gen = build_generator(cur_res) | |
| new_disc = build_discriminator(cur_res) | |
| # Copy weights from the old models to new | |
| copy_generator_weights(gen_model, new_gen) | |
| copy_discriminator_weights(disc_model, new_disc) | |
| gen_model = new_gen | |
| disc_model = new_disc | |
| # Reset optimizer internal states | |
| reset_optimizer(gen_opt) | |
| reset_optimizer(disc_opt) | |
| # Generate new training functions (basically rebuilding the graphs to fit the new models/data sizes) | |
| # TODO: Do we really need to do this with eager execution enabled? Isn't the point supposed to be we | |
| # TODO: don't have to do this? | |
| disc_train_on_batch = get_disc_train_fn() | |
| gen_train_on_batch = get_gen_train_fn() | |
| # Clean up for the next epoch | |
| fade_epoch = not fade_epoch | |
| # End main training loop | |
| return gen_model, disc_model | |
| def reset_optimizer(opt): | |
| for var in opt.variables(): | |
| var.assign(tf.zeros_like(var)) | |
| # TODO: Have a function that returns custom layers so we don't have k_i and k_c boilerplate everywhere... | |
| def build_generator(target_res): | |
| input_layer = Input(shape=(latent_dim,)) | |
| g = PixelNorm()(input_layer) | |
| g = Dense(4 * 4 * filters[0], kernel_initializer=init, kernel_constraint=max_norm(3.0))(g) | |
| g = Reshape((4, 4, filters[0]))(g) | |
| g = LeakyReLU(alpha=0.2)(g) | |
| g = PixelNorm()(g) | |
| g = Conv2D(filters[0], (4, 4), padding='same', kernel_initializer=init, kernel_constraint=max_norm(3.0))(g) | |
| g = LeakyReLU(alpha=0.2)(g) | |
| g = PixelNorm()(g) | |
| g = Conv2D(filters[0], (3, 3), padding='same', kernel_initializer=init, kernel_constraint=max_norm(3.0))(g) | |
| g = LeakyReLU(alpha=0.2)(g) | |
| g = PixelNorm()(g) | |
| blocks = int(np.log2(target_res / 4)) # 4x4 is our starting resolution | |
| block_i = 0 | |
| for block_i in range(1, blocks): | |
| g = UpSampling2D()(g) | |
| g = Conv2D(filters[block_i], (3, 3), padding='same', kernel_initializer=init, | |
| kernel_constraint=max_norm(3.0))(g) | |
| g = LeakyReLU(alpha=0.2)(g) | |
| g = PixelNorm()(g) | |
| g = Conv2D(filters[block_i], (3, 3), padding='same', kernel_initializer=init, | |
| kernel_constraint=max_norm(3.0))(g) | |
| g = LeakyReLU(alpha=0.2)(g) | |
| g = PixelNorm()(g) | |
| # Fade-in block | |
| block_i += 1 | |
| if blocks != 0: | |
| new_b = UpSampling2D()(g) | |
| new_b = Conv2D(filters[block_i], (3, 3), padding='same', kernel_initializer=init, | |
| kernel_constraint=max_norm(3.0))(new_b) | |
| new_b = LeakyReLU(alpha=0.2)(new_b) | |
| new_b = PixelNorm()(new_b) | |
| new_b = Conv2D(filters[block_i], (3, 3), padding='same', kernel_initializer=init, | |
| kernel_constraint=max_norm(3.0))(new_b) | |
| new_b = LeakyReLU(alpha=0.2)(new_b) | |
| new_b = PixelNorm()(new_b) | |
| # This is the *new* to-RGB layer so keep its weights for the next model | |
| new_b = Conv2D(3, (1, 1), padding='same', kernel_initializer=init, kernel_constraint=max_norm(3.0), | |
| dtype='float32')(new_b) | |
| g = Conv2D(3, (1, 1), name='nocopy_rgb%d' % block_i, padding='same', kernel_initializer=init, | |
| kernel_constraint=max_norm(3.0), dtype='float32')(g) | |
| g = UpSampling2D(name='nocopy_upsamp%d' % block_i)(g) | |
| g = WeightedSum(name='nocopy_wsum%d' % block_i)([g, new_b]) | |
| else: | |
| g = Conv2D(3, (1, 1), name='nocopy_rgb0', padding='same', kernel_initializer=init, | |
| kernel_constraint=max_norm(3.0), dtype='float32')(g) | |
| return Model(inputs=input_layer, outputs=g) | |
| def build_discriminator(target_res): | |
| input_layer = Input(shape=(target_res, target_res, 3)) | |
| num_blocks = int(np.log2(target_res / 4)) # 4x4 is our starting resolution | |
| d = input_layer | |
| if num_blocks != 0: | |
| new_b = Conv2D(filters[num_blocks], (1, 1), padding='same', kernel_initializer=init, | |
| kernel_constraint=max_norm(3.0))(input_layer) | |
| new_b = LeakyReLU(alpha=0.2)(new_b) | |
| new_b = Conv2D(filters[num_blocks], (3, 3), padding='same', kernel_initializer=init, | |
| kernel_constraint=max_norm(3.0))(new_b) | |
| new_b = LeakyReLU(alpha=0.2)(new_b) | |
| new_b = Conv2D(filters[num_blocks], (3, 3), padding='same', kernel_initializer=init, | |
| kernel_constraint=max_norm(3.0))(new_b) | |
| new_b = LeakyReLU(alpha=0.2)(new_b) | |
| d = AveragePooling2D()(new_b) | |
| old_b = AveragePooling2D(name='nocopy_avgpool2d%d' % num_blocks)(input_layer) | |
| old_b = Conv2D(filters[num_blocks], (1, 1), name='nocopy_rgb%d' % num_blocks, padding='same', | |
| kernel_initializer=init, kernel_constraint=max_norm(3.0))(old_b) | |
| old_b = LeakyReLU(alpha=0.2, name='nocopy_lrelu%d' % num_blocks)(old_b) | |
| d = WeightedSum(name='nocopy_wsum%d' % num_blocks)([old_b, d]) | |
| else: | |
| d = Conv2D(filters[0], (1, 1), name='nocopy_rgb0', padding='same', kernel_initializer=init, | |
| kernel_constraint=max_norm(3.0))(d) | |
| d = LeakyReLU(alpha=0.2)(d) | |
| # Add intermediate discriminator layers | |
| # Do not use blocks+1 here because the output layer differs from the intermediate | |
| for block_i in range(1, num_blocks): | |
| d = Conv2D(filters[num_blocks - block_i], (3, 3), padding='same', kernel_initializer=init, | |
| kernel_constraint=max_norm(3.0))(d) | |
| d = LeakyReLU(alpha=0.2)(d) | |
| d = Conv2D(filters[num_blocks - block_i], (3, 3), padding='same', kernel_initializer=init, | |
| kernel_constraint=max_norm(3.0))(d) | |
| d = LeakyReLU(alpha=0.2)(d) | |
| d = AveragePooling2D()(d) | |
| # Minibatch standard-deviation layer | |
| d = MinibatchStd()(d) | |
| d = Conv2D(filters[0], (3, 3), padding='same', kernel_initializer=init, kernel_constraint=max_norm(3.0))(d) | |
| d = LeakyReLU(alpha=0.2)(d) | |
| d = Conv2D(filters[0], (4, 4), padding='same', kernel_initializer=init, kernel_constraint=max_norm(3.0))(d) | |
| d = LeakyReLU(alpha=0.2)(d) | |
| d = Flatten()(d) | |
| d = Dense(1, dtype='float32')(d) | |
| return Model(inputs=input_layer, outputs=d) | |
| # TODO: Merge these two functions | |
| def copy_generator_weights(old_gen, new_gen): | |
| old_layers = old_gen.layers | |
| new_layers = new_gen.layers | |
| shared_layers = [layer for layer in zip(old_layers, new_layers) | |
| if isinstance(layer[0], type(layer[1])) and layer[0].output_shape == layer[1].output_shape] | |
| for old_layer, new_layer in shared_layers: | |
| new_layer.set_weights(old_layer.get_weights()) | |
| # print('copy g old_rgb to new_rgb') | |
| old_rgb_idx = -1 if old_gen.output_shape[1] == 4 else -2 | |
| # copy rgb layers if they are compatible | |
| if new_gen.layers[-5].output_shape == old_gen.layers[old_rgb_idx].output_shape: | |
| print('copy g old_rgb to new_rgb') | |
| new_gen.layers[-5].set_weights(old_gen.layers[old_rgb_idx].get_weights()) | |
| def copy_discriminator_weights(old_disc, new_disc): | |
| old_layers = reversed(old_disc.layers) | |
| new_layers = reversed(new_disc.layers) | |
| shared_layers = [layer for layer in zip(old_layers, new_layers) | |
| if isinstance(layer[0], type(layer[1])) and layer[0].output_shape == layer[1].output_shape] | |
| for old_layer, new_layer in shared_layers: | |
| new_layer.set_weights(old_layer.get_weights()) | |
| # copy rgb layers if they are compatible | |
| # TODO: Potentially don't need to copy RGB layers. they only have ~2000 parameters | |
| if new_disc.layers[7].output_shape == old_disc.layers[1].output_shape: | |
| print('copy d old_rgb to new_rgb') | |
| new_disc.layers[7].set_weights(old_disc.layers[1].get_weights()) | |
| def pl(la): | |
| [print(_) for _ in la] | |
| g4, g8, g16 = build_generator(4), build_generator(8), build_generator(16) | |
| d4, d8, d16 = build_discriminator(4), build_discriminator(8), build_discriminator(16) | |
| seed = tf.random.normal((3 * 3, latent_dim)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment