Created
June 15, 2019 04:44
-
-
Save MLWhiz/356e039fdd5ebcc684a966560bf4f839 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Use a fixed noise vector to see how the GAN Images transition through time on a fixed noise. | |
| fixed_noise = gen_noise(16,noise_shape) | |
| # To keep Track of losses | |
| avg_disc_fake_loss = [] | |
| avg_disc_real_loss = [] | |
| avg_GAN_loss = [] | |
| # We will run for num_steps iterations | |
| for step in range(num_steps): | |
| tot_step = step | |
| print("Begin step: ", tot_step) | |
| # to keep track of time per step | |
| step_begin_time = time.time() | |
| # sample a batch of normalized images from the dataset | |
| real_data_X = sample_from_dataset(batch_size, image_shape, data_dir=data_dir) | |
| # Genearate noise to send as input to the generator | |
| noise = gen_noise(batch_size,noise_shape) | |
| # Use generator to create(predict) images | |
| fake_data_X = generator.predict(noise) | |
| # Save predicted images from the generator every 10th step | |
| if (tot_step % 100) == 0: | |
| step_num = str(tot_step).zfill(4) | |
| save_img_batch(fake_data_X,img_save_dir+step_num+"_image.png") | |
| # Create the labels for real and fake data. We don't give exact ones and zeros but add a small amount of noise. This is an important GAN training trick | |
| real_data_Y = np.ones(batch_size) - np.random.random_sample(batch_size)*0.2 | |
| fake_data_Y = np.random.random_sample(batch_size)*0.2 | |
| # train the discriminator using data and labels | |
| discriminator.trainable = True | |
| generator.trainable = False | |
| # Training Discriminator seperately on real data | |
| dis_metrics_real = discriminator.train_on_batch(real_data_X,real_data_Y) | |
| # training Discriminator seperately on fake data | |
| dis_metrics_fake = discriminator.train_on_batch(fake_data_X,fake_data_Y) | |
| print("Disc: real loss: %f fake loss: %f" % (dis_metrics_real[0], dis_metrics_fake[0])) | |
| # Save the losses to plot later | |
| avg_disc_fake_loss.append(dis_metrics_fake[0]) | |
| avg_disc_real_loss.append(dis_metrics_real[0]) | |
| # Train the generator using a random vector of noise and its labels (1's with noise) | |
| generator.trainable = True | |
| discriminator.trainable = False | |
| GAN_X = gen_noise(batch_size,noise_shape) | |
| GAN_Y = real_data_Y | |
| gan_metrics = gan.train_on_batch(GAN_X,GAN_Y) | |
| print("GAN loss: %f" % (gan_metrics[0])) | |
| # Log results by opening a file in append mode | |
| text_file = open(log_dir+"\\training_log.txt", "a") | |
| text_file.write("Step: %d Disc: real loss: %f fake loss: %f GAN loss: %f\n" % (tot_step, dis_metrics_real[0], dis_metrics_fake[0],gan_metrics[0])) | |
| text_file.close() | |
| # save GAN loss to plot later | |
| avg_GAN_loss.append(gan_metrics[0]) | |
| end_time = time.time() | |
| diff_time = int(end_time - step_begin_time) | |
| print("Step %d completed. Time took: %s secs." % (tot_step, diff_time)) | |
| # save model at every 500 steps | |
| if ((tot_step+1) % 500) == 0: | |
| print("-----------------------------------------------------------------") | |
| print("Average Disc_fake loss: %f" % (np.mean(avg_disc_fake_loss))) | |
| print("Average Disc_real loss: %f" % (np.mean(avg_disc_real_loss))) | |
| print("Average GAN loss: %f" % (np.mean(avg_GAN_loss))) | |
| print("-----------------------------------------------------------------") | |
| discriminator.trainable = False | |
| generator.trainable = False | |
| # predict on fixed_noise | |
| fixed_noise_generate = generator.predict(noise) | |
| step_num = str(tot_step).zfill(4) | |
| save_img_batch(fixed_noise_generate,img_save_dir+step_num+"fixed_image.png") | |
| generator.save(save_model_dir+str(tot_step)+"_GENERATOR_weights_and_arch.hdf5") | |
| discriminator.save(save_model_dir+str(tot_step)+"_DISCRIMINATOR_weights_and_arch.hdf5") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment