Created
March 7, 2018 09:17
-
-
Save simoninithomas/72b9076f4c62a2521552f16450f72099 to your computer and use it in GitHub Desktop.
Cat DCGAN
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train(epoch_count, batch_size, z_dim, learning_rate_D, learning_rate_G, beta1, get_batches, data_shape, data_image_mode, alpha): | |
""" | |
Train the GAN | |
:param epoch_count: Number of epochs | |
:param batch_size: Batch Size | |
:param z_dim: Z dimension | |
:param learning_rate: Learning Rate | |
:param beta1: The exponential decay rate for the 1st moment in the optimizer | |
:param get_batches: Function to get batches | |
:param data_shape: Shape of the data | |
:param data_image_mode: The image mode to use for images ("RGB" or "L") | |
""" | |
# Create our input placeholders | |
input_images, input_z, lr_G, lr_D = model_inputs(data_shape[1:], z_dim) | |
# Losses | |
d_loss, g_loss = model_loss(input_images, input_z, data_shape[3], alpha) | |
# Optimizers | |
d_opt, g_opt = model_optimizers(d_loss, g_loss, lr_D, lr_G, beta1) | |
i = 0 | |
version = "firstTrain" | |
with tf.Session() as sess: | |
sess.run(tf.global_variables_initializer()) | |
# Saver | |
saver = tf.train.Saver() | |
num_epoch = 0 | |
if from_checkpoint == True: | |
saver.restore(sess, "./models/model.ckpt") | |
show_generator_output(sess, 4, input_z, data_shape[3], data_image_mode, image_path, True, False) | |
else: | |
for epoch_i in range(epoch_count): | |
num_epoch += 1 | |
if num_epoch % 5 == 0: | |
# Save model every 5 epochs | |
#if not os.path.exists("models/" + version): | |
# os.makedirs("models/" + version) | |
save_path = saver.save(sess, "./models/model.ckpt") | |
print("Model saved") | |
for batch_images in get_batches(batch_size): | |
# Random noise | |
batch_z = np.random.uniform(-1, 1, size=(batch_size, z_dim)) | |
i += 1 | |
# Run optimizers | |
_ = sess.run(d_opt, feed_dict={input_images: batch_images, input_z: batch_z, lr_D: learning_rate_D}) | |
_ = sess.run(g_opt, feed_dict={input_images: batch_images, input_z: batch_z, lr_G: learning_rate_G}) | |
if i % 10 == 0: | |
train_loss_d = d_loss.eval({input_z: batch_z, input_images: batch_images}) | |
train_loss_g = g_loss.eval({input_z: batch_z}) | |
# Save it | |
image_name = str(i) + ".jpg" | |
image_path = "./images/" + image_name | |
show_generator_output(sess, 4, input_z, data_shape[3], data_image_mode, image_path, True, False) | |
# Print every 5 epochs (for stability overwize the jupyter notebook will bug) | |
if i % 1500 == 0: | |
image_name = str(i) + ".jpg" | |
image_path = "./images/" + image_name | |
print("Epoch {}/{}...".format(epoch_i+1, epochs), | |
"Discriminator Loss: {:.4f}...".format(train_loss_d), | |
"Generator Loss: {:.4f}".format(train_loss_g)) | |
show_generator_output(sess, 4, input_z, data_shape[3], data_image_mode, image_path, False, True) | |
return losses, samples |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment