Skip to content

Instantly share code, notes, and snippets.

@pythonlessons
Created May 30, 2023 08:45
Show Gist options
  • Select an option

  • Save pythonlessons/763e997a4ff93c308d220f2f8a8490eb to your computer and use it in GitHub Desktop.

Select an option

Save pythonlessons/763e997a4ff93c308d220f2f8a8490eb to your computer and use it in GitHub Desktop.
wgan_gp
def gradient_penalty(
self,
real_samples: tf.Tensor,
fake_samples: tf.Tensor,
discriminator: tf.keras.models.Model
) -> tf.Tensor:
""" Calculates the gradient penalty.
Gradient penalty is calculated on an interpolated data
and added to the discriminator loss.
"""
batch_size = tf.shape(real_samples)[0]
# Generate random values for epsilon
epsilon = tf.random.uniform(shape=[batch_size, 1, 1, 1], minval=0, maxval=1)
# 1. Interpolate between real and fake samples
interpolated_samples = epsilon * real_samples + ((1 - epsilon) * fake_samples)
with tf.GradientTape() as tape:
tape.watch(interpolated_samples)
# 2. Get the Critic's output for the interpolated image
logits = discriminator(interpolated_samples, training=True)
# 3. Calculate the gradients w.r.t to the interpolated image
gradients = tape.gradient(logits, interpolated_samples)
# 4. Calculate the L2 norm of the gradients.
gradients_norm = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1, 2, 3]))
# 5. Calculate gradient penalty
gradient_penalty = tf.reduce_mean((gradients_norm - 1.0) ** 2)
return gradient_penalty
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment