Last active
March 14, 2019 12:29
-
-
Save Breta01/e434d971b77d42038d5635831394a367 to your computer and use it in GitHub Desktop.
GANEstimator with Combine Adversarial Loss
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def combined_loss(gan_model, **kwargs): | |
"""Wrapper function for combine adversarial loss, use as generator loss""" | |
# Define non-adversarial loss - for example L1 | |
non_adversarial_loss = tf.losses.absolute_difference( | |
gan_model.real_data, gan_model.generated_data) | |
# Define generator loss | |
generator_loss = tf.contrib.gan.losses.least_squares_generator_loss( | |
gan_model, | |
**kwargs) | |
# The structure of kwargs changes between versions, better to add exception | |
try: | |
add_summaries = kwargs['add_summaries'] | |
except: | |
add_summaries = True | |
# Combine these losses - you can specify more parameters | |
# Exactly one of weight_factor and gradient_ratio must be non-None | |
combined_loss = tf.contrib.gan.losses.wargs.combine_adversarial_loss( | |
non_adversarial_loss, | |
generator_loss, | |
weight_factor=1.0, | |
gradient_ratio=None, | |
variables=gan_model.generator_variables, | |
scalar_summaries=add_summaries, | |
gradient_summaries=add_summaries) | |
return combined_loss | |
gan_estimator = tf.contrib.gan.estimator.GANEstimator( | |
model_dir, | |
generator_fn=generator_fn, | |
discriminator_fn=discriminator_fn, | |
generator_loss_fn=combined_loss, | |
discriminator_loss_fn=tf.contrib.gan.losses.least_squares_discriminator_loss, | |
generator_optimizer=tf.train.AdamOptimizer(1e-4, 0.5), | |
discriminator_optimizer=tf.train.AdamOptimizer(1e-4, 0.5)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment