Skip to content

Instantly share code, notes, and snippets.

@MLWhiz
Created June 15, 2019 04:44
Show Gist options
  • Select an option

  • Save MLWhiz/356e039fdd5ebcc684a966560bf4f839 to your computer and use it in GitHub Desktop.

Select an option

Save MLWhiz/356e039fdd5ebcc684a966560bf4f839 to your computer and use it in GitHub Desktop.
# Use a fixed noise vector to see how the GAN Images transition through time on a fixed noise.
fixed_noise = gen_noise(16,noise_shape)
# To keep Track of losses
avg_disc_fake_loss = []
avg_disc_real_loss = []
avg_GAN_loss = []
# We will run for num_steps iterations
for step in range(num_steps):
tot_step = step
print("Begin step: ", tot_step)
# to keep track of time per step
step_begin_time = time.time()
# sample a batch of normalized images from the dataset
real_data_X = sample_from_dataset(batch_size, image_shape, data_dir=data_dir)
# Genearate noise to send as input to the generator
noise = gen_noise(batch_size,noise_shape)
# Use generator to create(predict) images
fake_data_X = generator.predict(noise)
# Save predicted images from the generator every 10th step
if (tot_step % 100) == 0:
step_num = str(tot_step).zfill(4)
save_img_batch(fake_data_X,img_save_dir+step_num+"_image.png")
# Create the labels for real and fake data. We don't give exact ones and zeros but add a small amount of noise. This is an important GAN training trick
real_data_Y = np.ones(batch_size) - np.random.random_sample(batch_size)*0.2
fake_data_Y = np.random.random_sample(batch_size)*0.2
# train the discriminator using data and labels
discriminator.trainable = True
generator.trainable = False
# Training Discriminator seperately on real data
dis_metrics_real = discriminator.train_on_batch(real_data_X,real_data_Y)
# training Discriminator seperately on fake data
dis_metrics_fake = discriminator.train_on_batch(fake_data_X,fake_data_Y)
print("Disc: real loss: %f fake loss: %f" % (dis_metrics_real[0], dis_metrics_fake[0]))
# Save the losses to plot later
avg_disc_fake_loss.append(dis_metrics_fake[0])
avg_disc_real_loss.append(dis_metrics_real[0])
# Train the generator using a random vector of noise and its labels (1's with noise)
generator.trainable = True
discriminator.trainable = False
GAN_X = gen_noise(batch_size,noise_shape)
GAN_Y = real_data_Y
gan_metrics = gan.train_on_batch(GAN_X,GAN_Y)
print("GAN loss: %f" % (gan_metrics[0]))
# Log results by opening a file in append mode
text_file = open(log_dir+"\\training_log.txt", "a")
text_file.write("Step: %d Disc: real loss: %f fake loss: %f GAN loss: %f\n" % (tot_step, dis_metrics_real[0], dis_metrics_fake[0],gan_metrics[0]))
text_file.close()
# save GAN loss to plot later
avg_GAN_loss.append(gan_metrics[0])
end_time = time.time()
diff_time = int(end_time - step_begin_time)
print("Step %d completed. Time took: %s secs." % (tot_step, diff_time))
# save model at every 500 steps
if ((tot_step+1) % 500) == 0:
print("-----------------------------------------------------------------")
print("Average Disc_fake loss: %f" % (np.mean(avg_disc_fake_loss)))
print("Average Disc_real loss: %f" % (np.mean(avg_disc_real_loss)))
print("Average GAN loss: %f" % (np.mean(avg_GAN_loss)))
print("-----------------------------------------------------------------")
discriminator.trainable = False
generator.trainable = False
# predict on fixed_noise
fixed_noise_generate = generator.predict(noise)
step_num = str(tot_step).zfill(4)
save_img_batch(fixed_noise_generate,img_save_dir+step_num+"fixed_image.png")
generator.save(save_model_dir+str(tot_step)+"_GENERATOR_weights_and_arch.hdf5")
discriminator.save(save_model_dir+str(tot_step)+"_DISCRIMINATOR_weights_and_arch.hdf5")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment