Skip to content

Instantly share code, notes, and snippets.

@pythonlessons
Created April 18, 2023 13:04
Show Gist options
  • Select an option

  • Save pythonlessons/d2308545d4af15840a46ff125308085f to your computer and use it in GitHub Desktop.

Select an option

Save pythonlessons/d2308545d4af15840a46ff125308085f to your computer and use it in GitHub Desktop.
gan_introduction
class GAN(tf.keras.models.Model):
"""A Generative Adversarial Network (GAN) implementation.
This class inherits from `tf.keras.models.Model` and provides a
straightforward implementation of the GAN algorithm.
"""
def __init__(
self,
discriminator: tf.keras.models.Model,
generator: tf.keras.models.Model,
noise_dim: int
) -> None:
"""Initializes the GAN class.
Args:
discriminator (tf.keras.models.Model): A `tf.keras.model.Model` instance that acts
as the discriminator in the GAN algorithm.
generator (tf.keras.models.Model): A `tf.keras.model.Model` instance that acts as
the generator in the GAN algorithm.
noise_dim (int): The dimensionality of the noise vector that is
inputted to the generator.
"""
super(GAN, self).__init__()
self.discriminator = discriminator
self.generator = generator
self.noise_dim = noise_dim
def compile(
self,
discriminator_opt: tf.keras.optimizers.Optimizer,
generator_opt: tf.keras.optimizers.Optimizer,
discriminator_loss: typing.Callable,
generator_loss: typing.Callable,
**kwargs
) -> None:
"""Configures the model for training.
Args:
discriminator_opt (tf.keras.optimizers.Optimizer): The optimizer for the discriminator.
generator_opt (tf.keras.optimizers.Optimizer): The optimizer for the generator.
discriminator_loss (typing.Callable): The loss function for the discriminator.
generator_loss (typing.Callable): The loss function for the generator.
"""
super(GAN, self).compile(**kwargs)
self.discriminator_opt = discriminator_opt
self.generator_opt = generator_opt
self.discriminator_loss = discriminator_loss
self.generator_loss = generator_loss
def train_step(self, real_images: tf.Tensor) -> typing.Dict[str, tf.Tensor]:
"""Executes one training step and returns the loss.
Args:
real_images (tf.Tensor): A batch of real images from the dataset.
Returns:
typing.Dict[str, tf.Tensor]: A dictionary of metric values and losses.
"""
batch_size = tf.shape(real_images)[0]
# Generate random noise for the generator
noise = tf.random.normal([batch_size, self.noise_dim])
# Train the discriminator with both real images (label as 1) and fake images (label as 0)
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
# Generate fake images using the generator
generated_images = self.generator(noise, training=True)
# Get the discriminator's predictions for real and fake images
real_output = self.discriminator(real_images, training=True)
fake_output = self.discriminator(generated_images, training=True)
# Calculate generator and discriminator losses
gen_loss = self.generator_loss(fake_output)
disc_loss = self.discriminator_loss(real_output, fake_output)
# Calculate gradients of generator and discriminator
gradients_of_generator = gen_tape.gradient(gen_loss, self.generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, self.discriminator.trainable_variables)
# Apply gradients to generator and discriminator optimizer
self.generator_opt.apply_gradients(zip(gradients_of_generator, self.generator.trainable_variables))
self.discriminator_opt.apply_gradients(zip(gradients_of_discriminator, self.discriminator.trainable_variables))
# Update the metrics.
self.compiled_metrics.update_state(real_output, fake_output)
# Construct a dictionary of metric results and losses
results = {m.name: m.result() for m in self.metrics}
results.update({"d_loss": disc_loss, "g_loss": gen_loss})
return results
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment