Last active
March 29, 2017 01:26
-
-
Save t-ae/006f0cf0f2299162b556a731aa43909f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import numpy as np | |
from keras.models import Sequential | |
from keras.layers import InputLayer, Reshape, Flatten, Dense | |
from keras.layers.convolutional import Conv2D, MaxPooling2D, Conv2DTranspose | |
from keras.layers.normalization import BatchNormalization | |
from keras.layers.advanced_activations import ELU | |
from keras.datasets import cifar10 | |
from keras.optimizers import Adam | |
Z_DIM = 32 | |
BATCH_SIZE = 128 | |
def create_z(size): | |
return np.random.uniform(-1, 1, [size, Z_DIM]) | |
def generate_minibatches(x): | |
while True: | |
for i in range(0, len(x), BATCH_SIZE): | |
yield x[i:i+BATCH_SIZE] | |
np.random.shuffle(x) | |
def main(): | |
(x_train, _), _ = cifar10.load_data() | |
x_train = x_train / 255 # type: np.ndarray | |
d = Sequential([ | |
InputLayer([32, 32, 3]), | |
Conv2D(32, (3, 3), padding='same'), | |
ELU(), | |
MaxPooling2D(), | |
Conv2D(64, (3, 3), padding='same'), | |
ELU(), | |
MaxPooling2D(), | |
Conv2D(128, (3, 3), padding='same'), | |
ELU(), | |
MaxPooling2D(), | |
Flatten(), | |
Dense(256), | |
ELU(), | |
Dense(1, activation='sigmoid') | |
]) | |
d_adam = Adam(1e-4, beta_1=0.1) | |
d.compile(d_adam, 'binary_crossentropy') | |
g = Sequential([ | |
InputLayer([Z_DIM]), | |
Dense(4*4*512), | |
ELU(), | |
BatchNormalization(), | |
Reshape([4, 4, 512]), | |
Conv2DTranspose(256, (3, 3), strides=(2, 2), padding='same'), | |
ELU(), | |
BatchNormalization(), | |
Conv2DTranspose(128, (3, 3), strides=(2, 2), padding='same'), | |
ELU(), | |
BatchNormalization(), | |
Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same'), | |
ELU(), | |
BatchNormalization(), | |
Conv2D(3, (3, 3), padding='same', activation='sigmoid'), | |
]) | |
d.trainable = False | |
gan = Sequential([g, d]) | |
g_adam = Adam(2e-4, beta_1=0.5) | |
gan.compile(g_adam, "binary_crossentropy") | |
for i, b in enumerate(generate_minibatches(x_train)): | |
loss_g = gan.train_on_batch(create_z(len(b)), [1]*len(b)) | |
pred = g.predict(create_z(len(b))) | |
x = np.vstack([pred, b]) | |
y = [0]*len(pred) + [1]*len(b) | |
loss_d = d.train_on_batch(x, y) | |
print(f"{i}: g:{loss_g} d:{loss_d}") | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment