Last active
November 4, 2018 22:34
-
-
Save breeko/bb7a224c501a8bcc45ad6f75ca5d6833 to your computer and use it in GitHub Desktop.
Training function for GAN
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train(gan, | |
num_iters, | |
batch_size=128, | |
print_every=1, | |
epochs=1, | |
min_disc_acc = 0.5, | |
min_gan_acc = 0.2, | |
normal_noise=True, | |
save_example_name=None, | |
save_model_name=None, | |
validation_split=0.2, | |
callbacks=[] | |
): | |
""" | |
gan: model consisting of a | |
num_iters: number of iterations to run | |
batch_size: batch size of each training iteration | |
print_every: either an int that prints every n iterations or a list of iterations to print | |
min_disc_acc: minimum discriminator accuracy required to go on and train the generator | |
min_gen_acc: minimum generator accuracy required to go on and train the discriminator | |
normal_noise: boolean that states whether the random noise input should be normally distributed or uniform | |
save_example_name: name to save the example output | |
save_model_name: name to save the model | |
epochs: maximum number of epochs to run on each iteration | |
validation_split: percent of batch_size to serve as validation | |
callbacks: callbacks to run when training model | |
""" | |
gen_history = History() | |
disc_history = History() | |
gen, disc = gan.layers[1:3] | |
disc.trainable = True | |
disc.compile(loss = gan.loss, optimizer = gan.optimizer, metrics= gan.metrics) | |
if type(print_every) is int: | |
print_every = np.arange(print_every, num_iters, print_every) | |
print_every = set(print_every) | |
actual_batch_size = int(batch_size / (1.0 - validation_split)) | |
for n in range(1, num_iters + 1): | |
# Training the generator | |
print("\rIteration: {}/{}".format(n, num_iters), end="") | |
for epoch in range(1, epochs + 1): | |
noise = gen_noise(actual_batch_size, normal=normal_noise) | |
x_batch, y_batch = get_training_set(real=noise) | |
h = gan.fit(x_batch, y_batch, epochs=1, verbose=0, validation_split=validation_split, callbacks=callbacks) | |
gen_history.update(h) | |
if gen_history.get_latest("val_acc") > min_gan_acc: | |
break | |
# Training the discriminator | |
for epoch in range(1, epochs + 1): | |
fake = gen_fake(gen, actual_batch_size // 2) | |
real = sample_real(X_train, actual_batch_size // 2) | |
x_batch, y_batch = get_training_set(real=real, fake=fake) | |
x_batch = x_batch.reshape([-1, *disc.input_shape[1:]]) | |
h = disc.fit(x_batch, y_batch, epochs=1, verbose=0, validation_split=validation_split, callbacks=callbacks) | |
disc_history.update(h) | |
if disc_history.get_latest("val_acc") > min_disc_acc: | |
break | |
if n in print_every: | |
print("""\nGen status: {}. \nDisc status: {}""".format(gen_history, disc_history)) | |
if save_model_name: | |
gan.save("{:05}_{}.h5".format(n, save_model_name)) | |
display_gen(gan, 10) | |
if save_example_name: | |
fname = "{:05}_{}.png".format(n, save_example_name) | |
display_gen(gan, 15 * 15, real=X_train, fname=fname) | |
return gen_history, disc_history |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment