Skip to content

Instantly share code, notes, and snippets.

@Bornlex
Created February 24, 2020 16:35
Show Gist options
  • Select an option

  • Save Bornlex/0357bcec4444151d47c3e0de51fddf6f to your computer and use it in GitHub Desktop.

Select an option

Save Bornlex/0357bcec4444151d47c3e0de51fddf6f to your computer and use it in GitHub Desktop.
EPOCHS = 50
BATCH_SIZE = 128
BATCH_BLOCKS = x_train.shape[0] // BATCH_SIZE
discriminator_losses = []
gan_losses = []
for epoch in range(EPOCHS):
for batch_index in range(BATCH_BLOCKS):
batch = x_train[batch_index * BATCH_SIZE:(batch_index + 1) * BATCH_SIZE].reshape(BATCH_SIZE, 784)
latent_noise = np.random.normal(0, 1, size=[BATCH_SIZE, LATENT_SIZE])
generated = generator.predict(latent_noise)
batch_images = np.concatenate([batch, generated])
labels = np.zeros(2 * BATCH_SIZE)
labels[:BATCH_SIZE] = 0.98
labels[BATCH_SIZE:] = 0.02
discriminator.trainable = True
discriminator_loss = discriminator.train_on_batch(batch_images, labels)
discriminator.trainable = False
latent_noise = np.random.normal(0, 1, size=[BATCH_SIZE, LATENT_SIZE])
labels = np.ones(BATCH_SIZE)
gan_loss = gan.train_on_batch(latent_noise, labels)
discriminator_losses.append(discriminator_loss)
gan_losses.append(gan_loss)
print(f"[{epoch + 1}|{EPOCHS}] Generator loss: {gan_loss:.6f} :: Discriminator loss: {discriminator_loss:.6f}.")
if epoch % 5 == 0:
plot_generated_images(epoch, generator, LATENT_SIZE)
plot_losses(epoch, discriminator_losses, gan_losses)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment