Skip to content

Instantly share code, notes, and snippets.

@lgeiger
Last active March 30, 2018 23:41
Show Gist options
  • Save lgeiger/0270c37b9f411e1031f5a7d98e8a9e3a to your computer and use it in GitHub Desktop.
Save lgeiger/0270c37b9f411e1031f5a7d98e8a9e3a to your computer and use it in GitHub Desktop.
commit 1ebd8b170d31d64ad2523a1db81c5619fed24fc1
Author: Lukas Geiger <[email protected]>
Date: Sat Mar 31 01:40:36 2018 +0200
one_sided penalty
diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl.py b/tensorflow/contrib/gan/python/losses/python/losses_impl.py
index 2a40dbade6..77a86043d8 100644
--- a/tensorflow/contrib/gan/python/losses/python/losses_impl.py
+++ b/tensorflow/contrib/gan/python/losses/python/losses_impl.py
@@ -297,6 +297,7 @@ def wasserstein_gradient_penalty(
discriminator_fn,
discriminator_scope,
epsilon=1e-10,
+ one_sided=False,
weights=1.0,
scope=None,
loss_collection=ops.GraphKeys.LOSSES,
@@ -364,10 +365,13 @@ def wasserstein_gradient_penalty(
# For numerical stability, add epsilon to the sum before taking the square
# root. Note tf.norm does not add epsilon.
slopes = math_ops.sqrt(gradient_squares + epsilon)
- penalties = math_ops.square(slopes - 1.0)
+ penalties = slopes - 1.0
+ if one_sided:
+ penalties = math_ops.maximum(0., penalties)
+ penalties_squared = math_ops.square(penalties)
penalty = losses.compute_weighted_loss(
- penalties, weights, scope=scope, loss_collection=loss_collection,
- reduction=reduction)
+ penalties_squared, weights, scope=scope,
+ loss_collection=loss_collection, reduction=reduction)
if add_summaries:
summary.scalar('gradient_penalty_loss', penalty)
diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py
index 06dd281489..f72e10f060 100644
--- a/tensorflow/contrib/gan/python/train.py
+++ b/tensorflow/contrib/gan/python/train.py
@@ -334,6 +334,7 @@ def gan_loss(
# Auxiliary losses.
gradient_penalty_weight=None,
gradient_penalty_epsilon=1e-10,
+ gradient_penalty_one_sided=False,
mutual_information_penalty_weight=None,
aux_cond_generator_weight=None,
aux_cond_discriminator_weight=None,
@@ -406,7 +407,10 @@ def gan_loss(
# Add optional extra losses.
if _use_aux_loss(gradient_penalty_weight):
gp_loss = tfgan_losses.wasserstein_gradient_penalty(
- model, epsilon=gradient_penalty_epsilon, add_summaries=add_summaries)
+ model,
+ epsilon=gradient_penalty_epsilon,
+ one_sided=gradient_penalty_one_sided,
+ add_summaries=add_summaries)
dis_loss += gradient_penalty_weight * gp_loss
if _use_aux_loss(mutual_information_penalty_weight):
info_loss = tfgan_losses.mutual_information_penalty(
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment