Created
February 11, 2020 01:35
-
-
Save s33a11ev1l/7917feed3dd59bb3bc2efcdaa02117a3 to your computer and use it in GitHub Desktop.
Easy GAN
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.datasets import mnist | |
from keras.layers import Input, Dense, Reshape, Flatten | |
from keras.layers import BatchNormalization | |
from keras.layers.advanced_activations import LeakyReLU | |
from keras.models import Sequential, Model | |
from keras.optimizers import Adam | |
import matplotlib.pyplot as plt | |
import numpy as np | |
class GAN(): | |
def __init__(self): | |
self.img_rows = 28 | |
self.img_cols = 28 | |
self.channels = 1 | |
self.img_shape = (self.img_rows, self.img_cols, self.channels) | |
self.latent_dim = 100 | |
optimizer = Adam(0.0002, 0.5) | |
self.discriminator = self.build_discriminator() | |
self.discriminator.compile(loss='binary_crossentropy', | |
optimizer=optimizer, | |
metrics=['accuracy']) | |
self.generator = self.build_generator() | |
z = Input(shape=(self.latent_dim,)) | |
img = self.generator(z) | |
self.discriminator.trainable = False | |
validity = self.discriminator(img) | |
self.combined = Model(z, validity) | |
self.combined.compile(loss='binary_crossentropy', optimizer=optimizer) | |
def build_generator(self): | |
model = Sequential() | |
model.add(Dense(256, input_dim=self.latent_dim)) | |
model.add(LeakyReLU(alpha=0.2)) | |
model.add(BatchNormalization(momentum=0.8)) | |
model.add(Dense(512)) | |
model.add(LeakyReLU(alpha=0.2)) | |
model.add(BatchNormalization(momentum=0.8)) | |
model.add(Dense(1024)) | |
model.add(LeakyReLU(alpha=0.2)) | |
model.add(BatchNormalization(momentum=0.8)) | |
model.add(Dense(np.prod(self.img_shape), activation='tanh')) | |
model.add(Reshape(self.img_shape)) | |
model.summary() | |
noise = Input(shape=(self.latent_dim,)) | |
img = model(noise) | |
return Model(noise, img) | |
def build_discriminator(self): | |
model = Sequential() | |
model.add(Flatten(input_shape=self.img_shape)) | |
model.add(Dense(512)) | |
model.add(LeakyReLU(alpha=0.2)) | |
model.add(Dense(256)) | |
model.add(LeakyReLU(alpha=0.2)) | |
model.add(Dense(1, activation='sigmoid')) | |
model.summary() | |
img = Input(shape=self.img_shape) | |
validity = model(img) | |
return Model(img, validity) | |
def train(self, epochs, batch_size=128, sample_interval=50): | |
(X_train, _), (_, _) = mnist.load_data() | |
X_train = X_train / 127.5 - 1. | |
X_train = np.expand_dims(X_train, axis=3) | |
valid = np.ones((batch_size, 1)) | |
fake = np.zeros((batch_size, 1)) | |
for epoch in range(epochs): | |
idx = np.random.randint(0, X_train.shape[0], batch_size) | |
imgs = X_train[idx] | |
noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) | |
g_loss = self.combined.train_on_batch(noise, valid) | |
gen_imgs = self.generator.predict(noise) | |
d_loss_real = self.discriminator.train_on_batch(imgs, valid) | |
d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake) | |
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) | |
print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100 * d_loss[1], g_loss)) | |
if epoch % sample_interval == 0: | |
self.sample_images(epoch) | |
def sample_images(self, epoch): | |
r, c = 5, 5 | |
noise = np.random.normal(0, 1, (r * c, self.latent_dim)) | |
gen_imgs = self.generator.predict(noise) | |
gen_imgs = 0.5 * gen_imgs + 0.5 | |
fig, axs = plt.subplots(r, c) | |
cnt = 0 | |
for i in range(r): | |
for j in range(c): | |
axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray') | |
axs[i, j].axis('off') | |
cnt += 1 | |
fig.savefig("images/%d.png" % epoch) | |
plt.close() | |
if __name__ == '__main__': | |
gan = GAN() | |
gan.train(epochs=100000, batch_size=132, sample_interval=10000) |
Hey,
Thanks for providing a neat implementation of DCNN. However, I tried but failed to run the code. It gives a warning
UserWarning: Discrepancy between trainable weights and collected trainable weights, did you set model.trainable without calling model.compile after ? 'Discrepancy between trainable weights and collected trainable'
. I tried to find an answer but failed. Could you please help me in this?Many thanks, and best wishes,
Pawan
Hi, This is due to you made "model.trainable=False" after compiling the model.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi @pawan,
That warning was implemented by the Keras team. There an open bug about this, please refer to: https://github.com/keras-team/keras/issues/8585
The replies there have some examples of how to bypass the warning, while they're still discussing about the issue related to GANs implementation.
I hope it helps.
Cheers,
Gus