Skip to content

Instantly share code, notes, and snippets.

@simoninithomas
Created March 7, 2018 09:21
Show Gist options
  • Save simoninithomas/02c6a65938c3a7fabf80fee114baf03e to your computer and use it in GitHub Desktop.
Save simoninithomas/02c6a65938c3a7fabf80fee114baf03e to your computer and use it in GitHub Desktop.
Cat DCGAN
def model_optimizers(d_loss, g_loss, lr_D, lr_G, beta1):
"""
Get optimization operations
:param d_loss: Discriminator loss Tensor
:param g_loss: Generator loss Tensor
:param learning_rate: Learning Rate Placeholder
:param beta1: The exponential decay rate for the 1st moment in the optimizer
:return: A tuple of (discriminator training operation, generator training operation)
"""
# Get the trainable_variables, split into G and D parts
t_vars = tf.trainable_variables()
g_vars = [var for var in t_vars if var.name.startswith("generator")]
d_vars = [var for var in t_vars if var.name.startswith("discriminator")]
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
# Generator update
gen_updates = [op for op in update_ops if op.name.startswith('generator')]
# Optimizers
with tf.control_dependencies(gen_updates):
d_train_opt = tf.train.AdamOptimizer(learning_rate=lr_D, beta1=beta1).minimize(d_loss, var_list=d_vars)
g_train_opt = tf.train.AdamOptimizer(learning_rate=lr_G, beta1=beta1).minimize(g_loss, var_list=g_vars)
return d_train_opt, g_train_opt
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment