Skip to content

Instantly share code, notes, and snippets.

@nwatab
Last active January 21, 2019 11:15
Show Gist options
  • Save nwatab/0875a3d8ec08d81128308b2baaef838a to your computer and use it in GitHub Desktop.
Save nwatab/0875a3d8ec08d81128308b2baaef838a to your computer and use it in GitHub Desktop.
simple Generative adversarial networks for MNIST
import matplotlib.pyplot as plt
import numpy as np
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten, BatchNormalization
from keras.layers.advanced_activations import LeakyReLU
from keras.layers import Conv2D, MaxPooling2D, Reshape, UpSampling2D, InputLayer
from keras.optimizers import Adam
import os
class GAN():
def __init__(self):
self.img_shape = (28, 28, 1) # MNIST
self.z_dim = 100
optimizer = Adam(0.0002, 0.5)
# Discriminator
self.discriminator = self.build_discriminator()
self.discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
# Generator + Discriminator
self.generator = self.build_generator()
self.discriminator.trainable = False
self.combined = Sequential([self.generator, self.discriminator])
self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
os.makedirs('images', exist_ok=True)
self.noise = None
def build_generator(self):
model = Sequential()
model.add(Dense(np.product(self.img_shape) * 16// 2**4, input_shape=(self.z_dim,)))
model.add(Reshape((self.img_shape[0] // 2**2, self.img_shape[1] // 2**2, 16)))
model.add(UpSampling2D(size=2))
model.add(Conv2D(32, kernel_size=2, padding='same', activation='relu'))
model.add(BatchNormalization(momentum=0.8))
model.add(UpSampling2D(size=2))
model.add(Conv2D(1, kernel_size=2, padding='same', activation='tanh'))
model.summary()
return model
def build_discriminator(self):
model = Sequential()
model.add(Conv2D(32, kernel_size=2, strides=2, padding='same', input_shape=(self.img_shape)))
model.add(LeakyReLU(alpha=0.2))
model.add(Conv2D(64, kernel_size=2, strides=2, padding='same'))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Flatten())
model.add(Dense(1, activation='sigmoid'))
model.summary()
return model
def train(self, epochs, batch_size=128, sample_interval=50):
(X_train, _), (_, _) = mnist.load_data()
X_train = X_train / 127.5 - 1. # Rescale -1 to 1
X_train = np.expand_dims(X_train, axis=3)
# Adversarial ground truths
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
for epoch in range(epochs):
idx = np.random.randint(0, X_train.shape[0], batch_size)
real_imgs = X_train[idx]
noise = np.random.normal(0, 1, (batch_size, self.z_dim))
fake_imgs = self.generator.predict(noise)
# Train the discriminator and generator
d_score_real = self.discriminator.train_on_batch(real_imgs, valid)
d_score_fake = self.discriminator.train_on_batch(fake_imgs, fake)
g_score = self.combined.train_on_batch(noise, valid)
# Plot the progress
d_score = 0.5 * np.add(d_score_real, d_score_fake)
print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_score[0], 100*d_score[1], g_score))
# Save fake image snapshot
if epoch % sample_interval == 0:
self.sample_images(epoch)
def sample_images(self, epoch, r=10, c=10):
if self.noise is None:
self.noise = np.random.normal(0, 1, (r * c, self.z_dim))
gen_imgs = self.generator.predict(self.noise)
gen_imgs = 0.5 * gen_imgs + 0.5 # Rescale [-1, 1] images into [0, 1]
fig, axs = plt.subplots(r, c)
cnt = 0
for i in range(r):
for j in range(c):
axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray', vmin=0, vmax=1)
axs[i,j].axis('off')
cnt += 1
fig.savefig("images/%d.png" % epoch)
plt.close()
if __name__ == '__main__':
gan = GAN()
gan.train(epochs=100001, batch_size=128, sample_interval=100)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment