Created
April 18, 2023 13:04
-
-
Save pythonlessons/d2308545d4af15840a46ff125308085f to your computer and use it in GitHub Desktop.
gan_introduction
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class GAN(tf.keras.models.Model): | |
| """A Generative Adversarial Network (GAN) implementation. | |
| This class inherits from `tf.keras.models.Model` and provides a | |
| straightforward implementation of the GAN algorithm. | |
| """ | |
| def __init__( | |
| self, | |
| discriminator: tf.keras.models.Model, | |
| generator: tf.keras.models.Model, | |
| noise_dim: int | |
| ) -> None: | |
| """Initializes the GAN class. | |
| Args: | |
| discriminator (tf.keras.models.Model): A `tf.keras.model.Model` instance that acts | |
| as the discriminator in the GAN algorithm. | |
| generator (tf.keras.models.Model): A `tf.keras.model.Model` instance that acts as | |
| the generator in the GAN algorithm. | |
| noise_dim (int): The dimensionality of the noise vector that is | |
| inputted to the generator. | |
| """ | |
| super(GAN, self).__init__() | |
| self.discriminator = discriminator | |
| self.generator = generator | |
| self.noise_dim = noise_dim | |
| def compile( | |
| self, | |
| discriminator_opt: tf.keras.optimizers.Optimizer, | |
| generator_opt: tf.keras.optimizers.Optimizer, | |
| discriminator_loss: typing.Callable, | |
| generator_loss: typing.Callable, | |
| **kwargs | |
| ) -> None: | |
| """Configures the model for training. | |
| Args: | |
| discriminator_opt (tf.keras.optimizers.Optimizer): The optimizer for the discriminator. | |
| generator_opt (tf.keras.optimizers.Optimizer): The optimizer for the generator. | |
| discriminator_loss (typing.Callable): The loss function for the discriminator. | |
| generator_loss (typing.Callable): The loss function for the generator. | |
| """ | |
| super(GAN, self).compile(**kwargs) | |
| self.discriminator_opt = discriminator_opt | |
| self.generator_opt = generator_opt | |
| self.discriminator_loss = discriminator_loss | |
| self.generator_loss = generator_loss | |
| def train_step(self, real_images: tf.Tensor) -> typing.Dict[str, tf.Tensor]: | |
| """Executes one training step and returns the loss. | |
| Args: | |
| real_images (tf.Tensor): A batch of real images from the dataset. | |
| Returns: | |
| typing.Dict[str, tf.Tensor]: A dictionary of metric values and losses. | |
| """ | |
| batch_size = tf.shape(real_images)[0] | |
| # Generate random noise for the generator | |
| noise = tf.random.normal([batch_size, self.noise_dim]) | |
| # Train the discriminator with both real images (label as 1) and fake images (label as 0) | |
| with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: | |
| # Generate fake images using the generator | |
| generated_images = self.generator(noise, training=True) | |
| # Get the discriminator's predictions for real and fake images | |
| real_output = self.discriminator(real_images, training=True) | |
| fake_output = self.discriminator(generated_images, training=True) | |
| # Calculate generator and discriminator losses | |
| gen_loss = self.generator_loss(fake_output) | |
| disc_loss = self.discriminator_loss(real_output, fake_output) | |
| # Calculate gradients of generator and discriminator | |
| gradients_of_generator = gen_tape.gradient(gen_loss, self.generator.trainable_variables) | |
| gradients_of_discriminator = disc_tape.gradient(disc_loss, self.discriminator.trainable_variables) | |
| # Apply gradients to generator and discriminator optimizer | |
| self.generator_opt.apply_gradients(zip(gradients_of_generator, self.generator.trainable_variables)) | |
| self.discriminator_opt.apply_gradients(zip(gradients_of_discriminator, self.discriminator.trainable_variables)) | |
| # Update the metrics. | |
| self.compiled_metrics.update_state(real_output, fake_output) | |
| # Construct a dictionary of metric results and losses | |
| results = {m.name: m.result() for m in self.metrics} | |
| results.update({"d_loss": disc_loss, "g_loss": gen_loss}) | |
| return results |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment