This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from keras import mixed_precision | |
| policy = mixed_precision.Policy('mixed_float16') | |
| mixed_precision.set_global_policy(policy) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import cv2 | |
| import typing | |
| import imageio | |
| import numpy as np | |
| import tensorflow as tf | |
| tf.config.experimental.set_memory_growth(tf.config.experimental.list_physical_devices('GPU')[0], True) | |
| from keras.callbacks import TensorBoard | |
| from keras.preprocessing.image import ImageDataGenerator |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Wasserstein loss for the discriminator | |
| def discriminator_w_loss(pred_real, pred_fake): | |
| real_loss = tf.reduce_mean(pred_real) | |
| fake_loss = tf.reduce_mean(pred_fake) | |
| return fake_loss - real_loss | |
| # Wasserstein loss for the generator | |
| def generator_w_loss(pred_fake): | |
| return -tf.reduce_mean(pred_fake) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class LRSheduler(tf.keras.callbacks.Callback): | |
| """Learning rate scheduler for WGAN-GP""" | |
| def __init__(self, decay_epochs: int, tb_callback=None, min_lr: float=0.00001): | |
| super(LRSheduler, self).__init__() | |
| self.decay_epochs = decay_epochs | |
| self.min_lr = min_lr | |
| self.tb_callback = tb_callback | |
| self.compiled = False | |
| def on_epoch_end(self, epoch, logs=None): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class ResultsCallback(tf.keras.callbacks.Callback): | |
| """ Callback for generating and saving images during training.""" | |
| def __init__( | |
| self, | |
| noise_dim: int, | |
| output_path: str, | |
| examples_to_generate: int=16, | |
| grid_size: tuple=(4, 4), | |
| spacing: int=5, | |
| gif_size: tuple=(416, 416), |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Add instance noise to real and fake samples | |
| real_samples = self.add_instance_noise(real_samples) | |
| fake_samples = self.add_instance_noise(fake_samples) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class WGAN_GP(tf.keras.models.Model): | |
| def __init__( | |
| self, | |
| discriminator: tf.keras.models.Model, | |
| generator: tf.keras.models.Model, | |
| noise_dim: int, | |
| discriminator_extra_steps: int=5, | |
| gp_weight: typing.Union[float, int]=10.0 | |
| ) -> None: | |
| super(WGAN_GP, self).__init__() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def gradient_penalty( | |
| self, | |
| real_samples: tf.Tensor, | |
| fake_samples: tf.Tensor, | |
| discriminator: tf.keras.models.Model | |
| ) -> tf.Tensor: | |
| """ Calculates the gradient penalty. | |
| Gradient penalty is calculated on an interpolated data | |
| and added to the discriminator loss. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import tensorflow as tf | |
| from keras import layers | |
| # Define the discriminator model | |
| def build_discriminator(img_shape, activation='linear', alpha=0.2): | |
| inputs = layers.Input(shape=img_shape, name="input") | |
| x = layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)(inputs) | |
| x = layers.LeakyReLU(alpha)(x) | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import tensorflow as tf | |
| from keras import layers | |
| # Define the generator model | |
| def build_generator(noise_dim, output_channels=3, activation="tanh", alpha=0.2): | |
| inputs = layers.Input(shape=noise_dim, name="input") | |
| x = layers.Dense(4*4*512, use_bias=False)(inputs) | |
| x = layers.Reshape((4, 4, 512))(x) |