Created
March 7, 2018 09:20
-
-
Save simoninithomas/ad6c006d32760ef5d5bb60b21208fade to your computer and use it in GitHub Desktop.
Cat DCGAN
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def model_loss(input_real, input_z, output_channel_dim, alpha): | |
""" | |
Get the loss for the discriminator and generator | |
:param input_real: Images from the real dataset | |
:param input_z: Z input | |
:param out_channel_dim: The number of channels in the output image | |
:return: A tuple of (discriminator loss, generator loss) | |
""" | |
# Generator network here | |
g_model = generator(input_z, output_channel_dim) | |
# g_model is the generator output | |
# Discriminator network here | |
d_model_real, d_logits_real = discriminator(input_real, alpha=alpha) | |
d_model_fake, d_logits_fake = discriminator(g_model,is_reuse=True, alpha=alpha) | |
# Calculate losses | |
d_loss_real = tf.reduce_mean( | |
tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real, | |
labels=tf.ones_like(d_model_real))) | |
d_loss_fake = tf.reduce_mean( | |
tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, | |
labels=tf.zeros_like(d_model_fake))) | |
d_loss = d_loss_real + d_loss_fake | |
g_loss = tf.reduce_mean( | |
tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, | |
labels=tf.ones_like(d_model_fake))) | |
return d_loss, g_loss |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment