Created
December 19, 2018 06:23
-
-
Save SuoXC/82c0d52b1519fb83a79e4fd2a6afb1bd to your computer and use it in GitHub Desktop.
wgan gp loss definition
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
def gradient_panalty(real, fake, discriminator, alpha, gp_lambda=10): | |
# alpha = tf.placeholder(shape=[None, 1, 1, 1], dtype=tf.float32) | |
# alpha = tf.random_uniform(shape=[batch_size, 1, 1, 1], minval=0., maxval=1.) | |
interpolated = real + alpha * (fake - real) | |
logit = discriminator(interpolated) | |
grad = tf.gradients(logit, interpolated)[0] # gradient of D(interpolated) | |
grad_norm = tf.norm(tf.layers.flatten(grad), axis=1) # l2 norm | |
gp = gp_lambda * tf.reduce_mean(tf.square(grad_norm - 1.)) | |
return gp |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment