Skip to content

Instantly share code, notes, and snippets.

@pythonlessons
Created May 30, 2023 08:45
Show Gist options
  • Select an option

  • Save pythonlessons/13e4aeaf349fecdd9d7c01e656169927 to your computer and use it in GitHub Desktop.

Select an option

Save pythonlessons/13e4aeaf349fecdd9d7c01e656169927 to your computer and use it in GitHub Desktop.
wgan_gp
class WGAN_GP(tf.keras.models.Model):
def __init__(
self,
discriminator: tf.keras.models.Model,
generator: tf.keras.models.Model,
noise_dim: int,
discriminator_extra_steps: int=5,
gp_weight: typing.Union[float, int]=10.0
) -> None:
super(WGAN_GP, self).__init__()
self.discriminator = discriminator
self.generator = generator
self.noise_dim = noise_dim
self.discriminator_extra_steps = discriminator_extra_steps
self.gp_weight = gp_weight
def compile(
self,
discriminator_opt: tf.keras.optimizers.Optimizer,
generator_opt: tf.keras.optimizers.Optimizer,
discriminator_loss: typing.Callable,
generator_loss: typing.Callable,
**kwargs
) -> None:
super(WGAN_GP, self).compile(**kwargs)
self.discriminator_opt = discriminator_opt
self.generator_opt = generator_opt
self.discriminator_loss = discriminator_loss
self.generator_loss = generator_loss
def add_instance_noise(self, x: tf.Tensor, stddev: float=0.1) -> tf.Tensor:
""" Adds instance noise to the input tensor."""
noise = tf.random.normal(tf.shape(x), mean=0.0, stddev=stddev, dtype=x.dtype)
return x + noise
def gradient_penalty(
self,
real_samples: tf.Tensor,
fake_samples: tf.Tensor,
discriminator: tf.keras.models.Model
) -> tf.Tensor:
""" Calculates the gradient penalty.
Gradient penalty is calculated on an interpolated data
and added to the discriminator loss.
"""
batch_size = tf.shape(real_samples)[0]
# Generate random values for epsilon
epsilon = tf.random.uniform(shape=[batch_size, 1, 1, 1], minval=0, maxval=1)
# 1. Interpolate between real and fake samples
interpolated_samples = epsilon * real_samples + ((1 - epsilon) * fake_samples)
with tf.GradientTape() as tape:
tape.watch(interpolated_samples)
# 2. Get the Critic's output for the interpolated image
logits = discriminator(interpolated_samples, training=True)
# 3. Calculate the gradients w.r.t to the interpolated image
gradients = tape.gradient(logits, interpolated_samples)
# 4. Calculate the L2 norm of the gradients.
gradients_norm = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1, 2, 3]))
# 5. Calculate gradient penalty
gradient_penalty = tf.reduce_mean((gradients_norm - 1.0) ** 2)
return gradient_penalty
def train_step(self, real_samples: tf.Tensor) -> typing.Dict[str, float]:
batch_size = tf.shape(real_samples)[0]
noise = tf.random.normal([batch_size, self.noise_dim])
gps = []
# Step 1. Train the discriminator with both real and fake samples
# Train the discriminator more often than the generator
for _ in range(self.discriminator_extra_steps):
# Step 1. Train the discriminator with both real images and fake images
with tf.GradientTape() as tape:
fake_samples = self.generator(noise, training=True)
pred_real = self.discriminator(real_samples, training=True)
pred_fake = self.discriminator(fake_samples, training=True)
# Add instance noise to real and fake samples
real_samples = self.add_instance_noise(real_samples)
fake_samples = self.add_instance_noise(fake_samples)
# Calculate the WGAN-GP gradient penalty
gp = self.gradient_penalty(real_samples, fake_samples, self.discriminator)
gps.append(gp)
# Add gradient penalty to the original discriminator loss
disc_loss = self.discriminator_loss(pred_real, pred_fake) + gp * self.gp_weight
# Compute discriminator gradients
grads = tape.gradient(disc_loss, self.discriminator.trainable_variables)
# Update discriminator weights
self.discriminator_opt.apply_gradients(zip(grads, self.discriminator.trainable_variables))
# Step 2. Train the generator
with tf.GradientTape() as tape:
fake_samples = self.generator(noise, training=True)
pred_fake = self.discriminator(fake_samples, training=True)
gen_loss = self.generator_loss(pred_fake)
# Compute generator gradients
grads = tape.gradient(gen_loss, self.generator.trainable_variables)
# Update generator wieghts
self.generator_opt.apply_gradients(zip(grads, self.generator.trainable_variables))
# Update the metrics.
# Metrics are configured in `compile()`.
self.compiled_metrics.update_state(real_samples, fake_samples)
results = {m.name: m.result() for m in self.metrics}
results.update({"d_loss": disc_loss, "g_loss": gen_loss, "gp": tf.reduce_mean(gps)})
return results
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment