Skip to content

Instantly share code, notes, and snippets.

@pythonlessons
Created May 30, 2023 08:45
Show Gist options
  • Select an option

  • Save pythonlessons/4431825ad789be36b239a6406489daf2 to your computer and use it in GitHub Desktop.

Select an option

Save pythonlessons/4431825ad789be36b239a6406489daf2 to your computer and use it in GitHub Desktop.
wgan_gp
import os
import cv2
import typing
import imageio
import numpy as np
import tensorflow as tf
tf.config.experimental.set_memory_growth(tf.config.experimental.list_physical_devices('GPU')[0], True)
from keras.callbacks import TensorBoard
from keras.preprocessing.image import ImageDataGenerator
from model import build_generator, build_discriminator
from keras import mixed_precision
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy)
# celebA dataset path
dataset_path = "Dataset/img_align_celeba"
# Set the input shape and size for the generator and discriminator
batch_size = 128
img_shape = (64, 64, 3) # The shape of the input image, input to the discriminator
noise_dim = 128 # The dimension of the noise vector, input to the generator
model_path = 'Models/02_WGANGP_faces'
os.makedirs(model_path, exist_ok=True)
# Define your data generator
datagen = ImageDataGenerator(
preprocessing_function=lambda x: (x / 127.5) - 1.0, # Normalize image pixel values to [-1, 1]
horizontal_flip=True # Data augmentation
)
# Create a generator that yields batches of images
train_generator = datagen.flow_from_directory(
directory=dataset_path, # Path to directory containing images
target_size=img_shape[:2], # Size of images (height, width)
batch_size=batch_size,
class_mode=None, # Do not use labels
shuffle=True, # Shuffle the data
)
generator = build_generator(noise_dim)
generator.summary()
discriminator = build_discriminator(img_shape)
discriminator.summary()
generator_optimizer = tf.keras.optimizers.Adam(0.0002, beta_1=0.5, beta_2=0.9)
discriminator_optimizer = tf.keras.optimizers.Adam(0.0002, beta_1=0.5, beta_2=0.9)
# Wasserstein loss for the discriminator
def discriminator_w_loss(pred_real, pred_fake):
real_loss = tf.reduce_mean(pred_real)
fake_loss = tf.reduce_mean(pred_fake)
return fake_loss - real_loss
# Wasserstein loss for the generator
def generator_w_loss(pred_fake):
return -tf.reduce_mean(pred_fake)
callback = ResultsCallback(noise_dim=noise_dim, output_path=model_path, duration=0.04)
tb_callback = TensorBoard(model_path + '/logs')
lr_scheduler = LRSheduler(decay_epochs=500, tb_callback=tb_callback)
gan = WGAN_GP(discriminator, generator, noise_dim, discriminator_extra_steps=5)
gan.compile(discriminator_optimizer, generator_optimizer, discriminator_w_loss, generator_w_loss, run_eagerly=False)
gan.fit(train_generator, epochs=500, callbacks=[callback, tb_callback, lr_scheduler])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment