Skip to content

Instantly share code, notes, and snippets.

@Breta01
Last active March 14, 2019 12:29
Show Gist options
  • Save Breta01/e434d971b77d42038d5635831394a367 to your computer and use it in GitHub Desktop.
Save Breta01/e434d971b77d42038d5635831394a367 to your computer and use it in GitHub Desktop.
GANEstimator with Combine Adversarial Loss
def combined_loss(gan_model, **kwargs):
"""Wrapper function for combine adversarial loss, use as generator loss"""
# Define non-adversarial loss - for example L1
non_adversarial_loss = tf.losses.absolute_difference(
gan_model.real_data, gan_model.generated_data)
# Define generator loss
generator_loss = tf.contrib.gan.losses.least_squares_generator_loss(
gan_model,
**kwargs)
# The structure of kwargs changes between versions, better to add exception
try:
add_summaries = kwargs['add_summaries']
except:
add_summaries = True
# Combine these losses - you can specify more parameters
# Exactly one of weight_factor and gradient_ratio must be non-None
combined_loss = tf.contrib.gan.losses.wargs.combine_adversarial_loss(
non_adversarial_loss,
generator_loss,
weight_factor=1.0,
gradient_ratio=None,
variables=gan_model.generator_variables,
scalar_summaries=add_summaries,
gradient_summaries=add_summaries)
return combined_loss
gan_estimator = tf.contrib.gan.estimator.GANEstimator(
model_dir,
generator_fn=generator_fn,
discriminator_fn=discriminator_fn,
generator_loss_fn=combined_loss,
discriminator_loss_fn=tf.contrib.gan.losses.least_squares_discriminator_loss,
generator_optimizer=tf.train.AdamOptimizer(1e-4, 0.5),
discriminator_optimizer=tf.train.AdamOptimizer(1e-4, 0.5))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment