Skip to content

Instantly share code, notes, and snippets.

@t-ae
Last active March 29, 2017 01:26
Show Gist options
  • Save t-ae/006f0cf0f2299162b556a731aa43909f to your computer and use it in GitHub Desktop.
Save t-ae/006f0cf0f2299162b556a731aa43909f to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
import numpy as np
from keras.models import Sequential
from keras.layers import InputLayer, Reshape, Flatten, Dense
from keras.layers.convolutional import Conv2D, MaxPooling2D, Conv2DTranspose
from keras.layers.normalization import BatchNormalization
from keras.layers.advanced_activations import ELU
from keras.datasets import cifar10
from keras.optimizers import Adam
Z_DIM = 32
BATCH_SIZE = 128
def create_z(size):
return np.random.uniform(-1, 1, [size, Z_DIM])
def generate_minibatches(x):
while True:
for i in range(0, len(x), BATCH_SIZE):
yield x[i:i+BATCH_SIZE]
np.random.shuffle(x)
def main():
(x_train, _), _ = cifar10.load_data()
x_train = x_train / 255 # type: np.ndarray
d = Sequential([
InputLayer([32, 32, 3]),
Conv2D(32, (3, 3), padding='same'),
ELU(),
MaxPooling2D(),
Conv2D(64, (3, 3), padding='same'),
ELU(),
MaxPooling2D(),
Conv2D(128, (3, 3), padding='same'),
ELU(),
MaxPooling2D(),
Flatten(),
Dense(256),
ELU(),
Dense(1, activation='sigmoid')
])
d_adam = Adam(1e-4, beta_1=0.1)
d.compile(d_adam, 'binary_crossentropy')
g = Sequential([
InputLayer([Z_DIM]),
Dense(4*4*512),
ELU(),
BatchNormalization(),
Reshape([4, 4, 512]),
Conv2DTranspose(256, (3, 3), strides=(2, 2), padding='same'),
ELU(),
BatchNormalization(),
Conv2DTranspose(128, (3, 3), strides=(2, 2), padding='same'),
ELU(),
BatchNormalization(),
Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same'),
ELU(),
BatchNormalization(),
Conv2D(3, (3, 3), padding='same', activation='sigmoid'),
])
d.trainable = False
gan = Sequential([g, d])
g_adam = Adam(2e-4, beta_1=0.5)
gan.compile(g_adam, "binary_crossentropy")
for i, b in enumerate(generate_minibatches(x_train)):
loss_g = gan.train_on_batch(create_z(len(b)), [1]*len(b))
pred = g.predict(create_z(len(b)))
x = np.vstack([pred, b])
y = [0]*len(pred) + [1]*len(b)
loss_d = d.train_on_batch(x, y)
print(f"{i}: g:{loss_g} d:{loss_d}")
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment