Skip to content

Instantly share code, notes, and snippets.

@pythonlessons
Created April 18, 2023 13:04
Show Gist options
  • Select an option

  • Save pythonlessons/cfab68af1e2bfb3f8a28a6332f296e5b to your computer and use it in GitHub Desktop.

Select an option

Save pythonlessons/cfab68af1e2bfb3f8a28a6332f296e5b to your computer and use it in GitHub Desktop.
gan_introduction
# Load MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# Normalize pixel values to [-1, 1]
x_train = (x_train.astype('float32') - 127.5) / 127.5
# Set the input shape and size for the generator and discriminator
img_shape = (28, 28, 1) # The shape of the input image, input to the discriminator
noise_dim = 100 # The dimension of the noise vector, input to the generator
model_path = 'Models/01_GANs_introduction'
os.makedirs(model_path, exist_ok=True)
generator = build_generator(noise_dim)
discriminator = build_discriminator(img_shape)
generator_optimizer = tf.keras.optimizers.Adam(0.0001, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(0.0001, beta_1=0.5)
callback = ResultsCallback(noise_dim=noise_dim, results_path=model_path)
tb_callback = TensorBoard(model_path + '/logs', update_freq=1)
gan = GAN(discriminator, generator, noise_dim)
gan.compile(discriminator_optimizer, generator_optimizer, discriminator_loss, generator_loss, run_eagerly=False)
gan.fit(x_train, epochs=100, batch_size=128, callbacks=[callback, tb_callback])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment