Skip to content

Instantly share code, notes, and snippets.

@mlzxy
Created January 12, 2017 20:14
Show Gist options
  • Save mlzxy/a4370046470da082006165de42da778d to your computer and use it in GitHub Desktop.
Save mlzxy/a4370046470da082006165de42da778d to your computer and use it in GitHub Desktop.
"""
make trainable
"""
def make_trainable(net, val):
pass
net.trainable = val
for l in net.layers:
l.trainable = val
"""
make the Generator
"""
initial_channel = 100
initial_2d_shape = [14, 14]
g_input = Input(shape=[latent_dim])
H = Dense(initial_channel*initial_2d_shape[0]*initial_2d_shape[1],
init='glorot_normal')(g_input)
# H = BatchNormalization(mode=2)(H)
H = Activation('relu')(H)
H = Reshape( [initial_channel]+initial_2d_shape )(H)
H = UpSampling2D(size=(2, 2))(H)
H = Convolution2D(int(initial_channel/2), 3, 3, border_mode='same', init='glorot_uniform')(H)
H = BatchNormalization(mode=2, axis=1)(H)
H = Activation('relu')(H)
H = Convolution2D(int(initial_channel/4), 3, 3, border_mode='same', init='glorot_uniform')(H)
# H = BatchNormalization(mode=0, axis=1)(H)
H = Activation('relu')(H)
H = Convolution2D(1, 3, 3, border_mode='same', init='glorot_uniform')(H)
g_output = Activation('sigmoid')(H)
generator = Model(g_input, g_output)
generator.compile(loss='binary_crossentropy', optimizer=decoding_optimizer_g)
generator.summary()
"""
Make the discriminator
"""
d_input = Input(shape=image_shape)
H = Convolution2D(256, 5, 5, subsample=(2, 2), border_mode = 'same', activation='relu')(d_input)
H = LeakyReLU(0.2)(H)
H = Dropout(env.dropout_rate.discriminator)(H)
H = Convolution2D(512, 5, 5, subsample=(2, 2), border_mode = 'same', activation='relu')(H)
H = LeakyReLU(0.2)(H)
H = Dropout(env.dropout_rate.discriminator)(H)
H = Flatten()(H)
H = Dense(256)(H)
H = LeakyReLU(0.2)(H)
H = Dropout(env.dropout_rate.discriminator)(H)
d_output = Dense(2,activation='softmax')(H)
discriminator = Model(d_input, d_output)
discriminator.compile(loss='categorical_crossentropy', optimizer=discriminative_optimizer)
discriminator.summary()
"""
concat to make GAN, and make Discriminator not trainable in GAN
"""
gan_input = Input(shape=[latent_dim])
H = generator(gan_input)
make_trainable(discriminator, False)
gan_output = discriminator(H)
GAN = Model(gan_input, gan_output)
GAN.compile(loss='categorical_crossentropy', optimizer=gan_optimizer)
print("GAN: ")
GAN.summary()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment