Skip to content

Instantly share code, notes, and snippets.

@breeko
Last active November 4, 2018 22:34
Show Gist options
  • Save breeko/bb7a224c501a8bcc45ad6f75ca5d6833 to your computer and use it in GitHub Desktop.
Save breeko/bb7a224c501a8bcc45ad6f75ca5d6833 to your computer and use it in GitHub Desktop.
Training function for GAN
def train(gan,
num_iters,
batch_size=128,
print_every=1,
epochs=1,
min_disc_acc = 0.5,
min_gan_acc = 0.2,
normal_noise=True,
save_example_name=None,
save_model_name=None,
validation_split=0.2,
callbacks=[]
):
"""
gan: model consisting of a
num_iters: number of iterations to run
batch_size: batch size of each training iteration
print_every: either an int that prints every n iterations or a list of iterations to print
min_disc_acc: minimum discriminator accuracy required to go on and train the generator
min_gen_acc: minimum generator accuracy required to go on and train the discriminator
normal_noise: boolean that states whether the random noise input should be normally distributed or uniform
save_example_name: name to save the example output
save_model_name: name to save the model
epochs: maximum number of epochs to run on each iteration
validation_split: percent of batch_size to serve as validation
callbacks: callbacks to run when training model
"""
gen_history = History()
disc_history = History()
gen, disc = gan.layers[1:3]
disc.trainable = True
disc.compile(loss = gan.loss, optimizer = gan.optimizer, metrics= gan.metrics)
if type(print_every) is int:
print_every = np.arange(print_every, num_iters, print_every)
print_every = set(print_every)
actual_batch_size = int(batch_size / (1.0 - validation_split))
for n in range(1, num_iters + 1):
# Training the generator
print("\rIteration: {}/{}".format(n, num_iters), end="")
for epoch in range(1, epochs + 1):
noise = gen_noise(actual_batch_size, normal=normal_noise)
x_batch, y_batch = get_training_set(real=noise)
h = gan.fit(x_batch, y_batch, epochs=1, verbose=0, validation_split=validation_split, callbacks=callbacks)
gen_history.update(h)
if gen_history.get_latest("val_acc") > min_gan_acc:
break
# Training the discriminator
for epoch in range(1, epochs + 1):
fake = gen_fake(gen, actual_batch_size // 2)
real = sample_real(X_train, actual_batch_size // 2)
x_batch, y_batch = get_training_set(real=real, fake=fake)
x_batch = x_batch.reshape([-1, *disc.input_shape[1:]])
h = disc.fit(x_batch, y_batch, epochs=1, verbose=0, validation_split=validation_split, callbacks=callbacks)
disc_history.update(h)
if disc_history.get_latest("val_acc") > min_disc_acc:
break
if n in print_every:
print("""\nGen status: {}. \nDisc status: {}""".format(gen_history, disc_history))
if save_model_name:
gan.save("{:05}_{}.h5".format(n, save_model_name))
display_gen(gan, 10)
if save_example_name:
fname = "{:05}_{}.png".format(n, save_example_name)
display_gen(gan, 15 * 15, real=X_train, fname=fname)
return gen_history, disc_history
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment