Created
March 7, 2018 09:21
-
-
Save simoninithomas/02c6a65938c3a7fabf80fee114baf03e to your computer and use it in GitHub Desktop.
Cat DCGAN
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def model_optimizers(d_loss, g_loss, lr_D, lr_G, beta1): | |
""" | |
Get optimization operations | |
:param d_loss: Discriminator loss Tensor | |
:param g_loss: Generator loss Tensor | |
:param learning_rate: Learning Rate Placeholder | |
:param beta1: The exponential decay rate for the 1st moment in the optimizer | |
:return: A tuple of (discriminator training operation, generator training operation) | |
""" | |
# Get the trainable_variables, split into G and D parts | |
t_vars = tf.trainable_variables() | |
g_vars = [var for var in t_vars if var.name.startswith("generator")] | |
d_vars = [var for var in t_vars if var.name.startswith("discriminator")] | |
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) | |
# Generator update | |
gen_updates = [op for op in update_ops if op.name.startswith('generator')] | |
# Optimizers | |
with tf.control_dependencies(gen_updates): | |
d_train_opt = tf.train.AdamOptimizer(learning_rate=lr_D, beta1=beta1).minimize(d_loss, var_list=d_vars) | |
g_train_opt = tf.train.AdamOptimizer(learning_rate=lr_G, beta1=beta1).minimize(g_loss, var_list=g_vars) | |
return d_train_opt, g_train_opt |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment