Skip to content

Instantly share code, notes, and snippets.

@MLWhiz
Created June 15, 2019 04:14
Show Gist options
  • Select an option

  • Save MLWhiz/fc08cf3e3312e2f38581321c88c68492 to your computer and use it in GitHub Desktop.

Select an option

Save MLWhiz/fc08cf3e3312e2f38581321c88c68492 to your computer and use it in GitHub Desktop.
def get_disc_normal(image_shape=(64,64,3)):
dropout_prob = 0.4
kernel_init = 'glorot_uniform'
dis_input = Input(shape = image_shape)
# Conv layer 1:
discriminator = Conv2D(filters = 64, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(dis_input)
discriminator = LeakyReLU(0.2)(discriminator)
# Conv layer 2:
discriminator = Conv2D(filters = 128, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(discriminator)
discriminator = BatchNormalization(momentum = 0.5)(discriminator)
discriminator = LeakyReLU(0.2)(discriminator)
# Conv layer 3:
discriminator = Conv2D(filters = 256, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(discriminator)
discriminator = BatchNormalization(momentum = 0.5)(discriminator)
discriminator = LeakyReLU(0.2)(discriminator)
# Conv layer 4:
discriminator = Conv2D(filters = 512, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(discriminator)
discriminator = BatchNormalization(momentum = 0.5)(discriminator)
discriminator = LeakyReLU(0.2)(discriminator)#discriminator = MaxPooling2D(pool_size=(2, 2))(discriminator)
# Flatten
discriminator = Flatten()(discriminator)
# Dense Layer
discriminator = Dense(1)(discriminator)
# Sigmoid Activation
discriminator = Activation('sigmoid')(discriminator)
# Optimizer and Compiling model
dis_opt = Adam(lr=0.0002, beta_1=0.5)
discriminator_model = Model(input = dis_input, output = discriminator)
discriminator_model.compile(loss='binary_crossentropy', optimizer=dis_opt, metrics=['accuracy'])
discriminator_model.summary()
return discriminator_model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment