Last active
March 30, 2018 23:41
-
-
Save lgeiger/0270c37b9f411e1031f5a7d98e8a9e3a to your computer and use it in GitHub Desktop.
https://github.com/tensorflow/tensorflow/pull/18098 patch for tf1.4
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
commit 1ebd8b170d31d64ad2523a1db81c5619fed24fc1 | |
Author: Lukas Geiger <[email protected]> | |
Date: Sat Mar 31 01:40:36 2018 +0200 | |
one_sided penalty | |
diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl.py b/tensorflow/contrib/gan/python/losses/python/losses_impl.py | |
index 2a40dbade6..77a86043d8 100644 | |
--- a/tensorflow/contrib/gan/python/losses/python/losses_impl.py | |
+++ b/tensorflow/contrib/gan/python/losses/python/losses_impl.py | |
@@ -297,6 +297,7 @@ def wasserstein_gradient_penalty( | |
discriminator_fn, | |
discriminator_scope, | |
epsilon=1e-10, | |
+ one_sided=False, | |
weights=1.0, | |
scope=None, | |
loss_collection=ops.GraphKeys.LOSSES, | |
@@ -364,10 +365,13 @@ def wasserstein_gradient_penalty( | |
# For numerical stability, add epsilon to the sum before taking the square | |
# root. Note tf.norm does not add epsilon. | |
slopes = math_ops.sqrt(gradient_squares + epsilon) | |
- penalties = math_ops.square(slopes - 1.0) | |
+ penalties = slopes - 1.0 | |
+ if one_sided: | |
+ penalties = math_ops.maximum(0., penalties) | |
+ penalties_squared = math_ops.square(penalties) | |
penalty = losses.compute_weighted_loss( | |
- penalties, weights, scope=scope, loss_collection=loss_collection, | |
- reduction=reduction) | |
+ penalties_squared, weights, scope=scope, | |
+ loss_collection=loss_collection, reduction=reduction) | |
if add_summaries: | |
summary.scalar('gradient_penalty_loss', penalty) | |
diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py | |
index 06dd281489..f72e10f060 100644 | |
--- a/tensorflow/contrib/gan/python/train.py | |
+++ b/tensorflow/contrib/gan/python/train.py | |
@@ -334,6 +334,7 @@ def gan_loss( | |
# Auxiliary losses. | |
gradient_penalty_weight=None, | |
gradient_penalty_epsilon=1e-10, | |
+ gradient_penalty_one_sided=False, | |
mutual_information_penalty_weight=None, | |
aux_cond_generator_weight=None, | |
aux_cond_discriminator_weight=None, | |
@@ -406,7 +407,10 @@ def gan_loss( | |
# Add optional extra losses. | |
if _use_aux_loss(gradient_penalty_weight): | |
gp_loss = tfgan_losses.wasserstein_gradient_penalty( | |
- model, epsilon=gradient_penalty_epsilon, add_summaries=add_summaries) | |
+ model, | |
+ epsilon=gradient_penalty_epsilon, | |
+ one_sided=gradient_penalty_one_sided, | |
+ add_summaries=add_summaries) | |
dis_loss += gradient_penalty_weight * gp_loss | |
if _use_aux_loss(mutual_information_penalty_weight): | |
info_loss = tfgan_losses.mutual_information_penalty( |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment