Skip to content

Instantly share code, notes, and snippets.

@SuoXC
Created December 19, 2018 06:23
Show Gist options
  • Save SuoXC/82c0d52b1519fb83a79e4fd2a6afb1bd to your computer and use it in GitHub Desktop.
Save SuoXC/82c0d52b1519fb83a79e4fd2a6afb1bd to your computer and use it in GitHub Desktop.
wgan gp loss definition
import tensorflow as tf
def gradient_panalty(real, fake, discriminator, alpha, gp_lambda=10):
# alpha = tf.placeholder(shape=[None, 1, 1, 1], dtype=tf.float32)
# alpha = tf.random_uniform(shape=[batch_size, 1, 1, 1], minval=0., maxval=1.)
interpolated = real + alpha * (fake - real)
logit = discriminator(interpolated)
grad = tf.gradients(logit, interpolated)[0] # gradient of D(interpolated)
grad_norm = tf.norm(tf.layers.flatten(grad), axis=1) # l2 norm
gp = gp_lambda * tf.reduce_mean(tf.square(grad_norm - 1.))
return gp
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment