Last active
March 7, 2018 09:19
-
-
Save simoninithomas/7ec81dc8108a8d5c760c28666eb3e430 to your computer and use it in GitHub Desktop.
Cat DCGAN
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Size input image for discriminator | |
real_size = (128,128,3) | |
# Size of latent vector to generator | |
z_dim = 100 | |
learning_rate_D = .00005 # Thanks to Alexia Jolicoeur Martineau https://ajolicoeur.wordpress.com/cats/ | |
learning_rate_G = 2e-4 # Thanks to Alexia Jolicoeur Martineau https://ajolicoeur.wordpress.com/cats/ | |
batch_size = 64 | |
epochs = 215 | |
alpha = 0.2 | |
beta1 = 0.5 | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def model_inputs(real_dim, z_dim): | |
""" | |
Create the model inputs | |
:param real_dim: tuple containing width, height and channels | |
:param z_dim: The dimension of Z | |
:return: Tuple of (tensor of real input images, tensor of z data, learning rate G, learning rate D) | |
""" | |
# inputs_real for Discriminator | |
inputs_real = tf.placeholder(tf.float32, (None, *real_dim), name='inputs_real') | |
# inputs_z for Generator | |
inputs_z = tf.placeholder(tf.float32, (None, z_dim), name="input_z") | |
# Two different learning rate : one for the generator, one for the discriminator | |
learning_rate_G = tf.placeholder(tf.float32, name="learning_rate_G") | |
learning_rate_D = tf.placeholder(tf.float32, name="learning_rate_D") | |
return inputs_real, inputs_z, learning_rate_G, learning_rate_D |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def model_loss(input_real, input_z, output_channel_dim, alpha): | |
""" | |
Get the loss for the discriminator and generator | |
:param input_real: Images from the real dataset | |
:param input_z: Z input | |
:param out_channel_dim: The number of channels in the output image | |
:return: A tuple of (discriminator loss, generator loss) | |
""" | |
# Generator network here | |
g_model = generator(input_z, output_channel_dim) | |
# g_model is the generator output | |
# Discriminator network here | |
d_model_real, d_logits_real = discriminator(input_real, alpha=alpha) | |
d_model_fake, d_logits_fake = discriminator(g_model,is_reuse=True, alpha=alpha) | |
# Calculate losses | |
d_loss_real = tf.reduce_mean( | |
tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real, | |
labels=tf.ones_like(d_model_real))) | |
d_loss_fake = tf.reduce_mean( | |
tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, | |
labels=tf.zeros_like(d_model_fake))) | |
d_loss = d_loss_real + d_loss_fake | |
g_loss = tf.reduce_mean( | |
tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, | |
labels=tf.ones_like(d_model_fake))) | |
return d_loss, g_loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def model_optimizers(d_loss, g_loss, lr_D, lr_G, beta1): | |
""" | |
Get optimization operations | |
:param d_loss: Discriminator loss Tensor | |
:param g_loss: Generator loss Tensor | |
:param learning_rate: Learning Rate Placeholder | |
:param beta1: The exponential decay rate for the 1st moment in the optimizer | |
:return: A tuple of (discriminator training operation, generator training operation) | |
""" | |
# Get the trainable_variables, split into G and D parts | |
t_vars = tf.trainable_variables() | |
g_vars = [var for var in t_vars if var.name.startswith("generator")] | |
d_vars = [var for var in t_vars if var.name.startswith("discriminator")] | |
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) | |
# Generator update | |
gen_updates = [op for op in update_ops if op.name.startswith('generator')] | |
# Optimizers | |
with tf.control_dependencies(gen_updates): | |
d_train_opt = tf.train.AdamOptimizer(learning_rate=lr_D, beta1=beta1).minimize(d_loss, var_list=d_vars) | |
g_train_opt = tf.train.AdamOptimizer(learning_rate=lr_G, beta1=beta1).minimize(g_loss, var_list=g_vars) | |
return d_train_opt, g_train_opt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Load the data and train the network here | |
dataset = helper.Dataset(glob(os.path.join(data_resized_dir, '*.jpg'))) | |
with tf.Graph().as_default(): | |
losses, samples = train(epochs, batch_size, z_dim, learning_rate_D, learning_rate_G, beta1, dataset.get_batches, | |
dataset.shape, dataset.image_mode, alpha) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def show_generator_output(sess, n_images, input_z, out_channel_dim, image_mode, image_path, save, show): | |
""" | |
Show example output for the generator | |
:param sess: TensorFlow session | |
:param n_images: Number of Images to display | |
:param input_z: Input Z Tensor | |
:param out_channel_dim: The number of channels in the output image | |
:param image_mode: The mode to use for images ("RGB" or "L") | |
:param image_path: Path to save the image | |
""" | |
cmap = None if image_mode == 'RGB' else 'gray' | |
z_dim = input_z.get_shape().as_list()[-1] | |
example_z = np.random.uniform(-1, 1, size=[n_images, z_dim]) | |
samples = sess.run( | |
generator(input_z, out_channel_dim, False), | |
feed_dict={input_z: example_z}) | |
images_grid = helper.images_square_grid(samples, image_mode) | |
if save == True: | |
# Save image | |
images_grid.save(image_path, 'JPEG') | |
if show == True: | |
plt.imshow(images_grid, cmap=cmap) | |
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train(epoch_count, batch_size, z_dim, learning_rate_D, learning_rate_G, beta1, get_batches, data_shape, data_image_mode, alpha): | |
""" | |
Train the GAN | |
:param epoch_count: Number of epochs | |
:param batch_size: Batch Size | |
:param z_dim: Z dimension | |
:param learning_rate: Learning Rate | |
:param beta1: The exponential decay rate for the 1st moment in the optimizer | |
:param get_batches: Function to get batches | |
:param data_shape: Shape of the data | |
:param data_image_mode: The image mode to use for images ("RGB" or "L") | |
""" | |
# Create our input placeholders | |
input_images, input_z, lr_G, lr_D = model_inputs(data_shape[1:], z_dim) | |
# Losses | |
d_loss, g_loss = model_loss(input_images, input_z, data_shape[3], alpha) | |
# Optimizers | |
d_opt, g_opt = model_optimizers(d_loss, g_loss, lr_D, lr_G, beta1) | |
i = 0 | |
version = "firstTrain" | |
with tf.Session() as sess: | |
sess.run(tf.global_variables_initializer()) | |
# Saver | |
saver = tf.train.Saver() | |
num_epoch = 0 | |
if from_checkpoint == True: | |
saver.restore(sess, "./models/model.ckpt") | |
show_generator_output(sess, 4, input_z, data_shape[3], data_image_mode, image_path, True, False) | |
else: | |
for epoch_i in range(epoch_count): | |
num_epoch += 1 | |
if num_epoch % 5 == 0: | |
# Save model every 5 epochs | |
#if not os.path.exists("models/" + version): | |
# os.makedirs("models/" + version) | |
save_path = saver.save(sess, "./models/model.ckpt") | |
print("Model saved") | |
for batch_images in get_batches(batch_size): | |
# Random noise | |
batch_z = np.random.uniform(-1, 1, size=(batch_size, z_dim)) | |
i += 1 | |
# Run optimizers | |
_ = sess.run(d_opt, feed_dict={input_images: batch_images, input_z: batch_z, lr_D: learning_rate_D}) | |
_ = sess.run(g_opt, feed_dict={input_images: batch_images, input_z: batch_z, lr_G: learning_rate_G}) | |
if i % 10 == 0: | |
train_loss_d = d_loss.eval({input_z: batch_z, input_images: batch_images}) | |
train_loss_g = g_loss.eval({input_z: batch_z}) | |
# Save it | |
image_name = str(i) + ".jpg" | |
image_path = "./images/" + image_name | |
show_generator_output(sess, 4, input_z, data_shape[3], data_image_mode, image_path, True, False) | |
# Print every 5 epochs (for stability overwize the jupyter notebook will bug) | |
if i % 1500 == 0: | |
image_name = str(i) + ".jpg" | |
image_path = "./images/" + image_name | |
print("Epoch {}/{}...".format(epoch_i+1, epochs), | |
"Discriminator Loss: {:.4f}...".format(train_loss_d), | |
"Generator Loss: {:.4f}".format(train_loss_g)) | |
show_generator_output(sess, 4, input_z, data_shape[3], data_image_mode, image_path, False, True) | |
return losses, samples |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment