Created
November 15, 2017 18:07
-
-
Save bicepjai/663c3da21c654618b5a6c8bade265c3f to your computer and use it in GitHub Desktop.
checking on dc gan implementation and training
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train_dcgan(generator_func, discriminator_func, data_in, model_weights_name_prefix, load_model=False, | |
epochs=10, batch_size=32, lr_opt=1e-3, lr_d_opt=1e-4, plot_epoch=None): | |
# get image shape | |
image_shape = data_in[0,:,:,:].shape | |
print("image shape:",image_shape) | |
K.clear_session() | |
# use generator and discriminator functions and make gan model | |
gan_input = Input(shape=image_shape) | |
generator_model = generator_func(image_shape) | |
discriminator_model = discriminator_func(image_shape) | |
discriminator_input = generator_model(gan_input) | |
gan_output = discriminator_model(discriminator_input) | |
gan_model = Model(inputs=gan_input, outputs=gan_output) | |
# used for batching | |
n = data_in.shape[0] | |
# model compilation | |
optimizer = Adam(lr=lr_opt) # loss for generator and total gan | |
d_optimizer = Adam(lr=lr_d_opt) # jsut for discriminator | |
generator_model.compile(loss='binary_crossentropy', optimizer=optimizer) | |
discriminator_model.compile(loss='binary_crossentropy', optimizer=d_optimizer) | |
gan_model.compile(loss='binary_crossentropy', optimizer=optimizer) | |
# no of batched per epoch | |
batch_count = data_in.shape[0] // batch_size | |
noise_batch_shape = tuple([batch_size] + list(data_in.shape[1:])) | |
# as suggested in the paper section 4 on page 3 "No pre-processing was applied to training | |
# images besides scaling to the range of the tanh activation function [-1, 1].", | |
# https://stackoverflow.com/questions/5294955/how-to-scale-down-a-range-of-numbers-with-a-known-min-and-max-value | |
min_intensity, max_intensity = np.min(data_in), np.max(data_in) | |
data_in = -1 + 2.0*(data_in - min_intensity)/(max_intensity - min_intensity) | |
# track losses | |
losses = {'d':[], 'g':[]} | |
with tf.Session() as sess: | |
sess.run(tf.global_variables_initializer()) | |
# option to load existing weights | |
if load_model: | |
try: | |
generator_model.load_weights(model_weights_name_prefix+"_generator.hdf5") | |
discriminator_model.load_weights(model_weights_name_prefix+"_discriminator.hdf5") | |
except IOError as ioe: | |
print("No weights found with name prefix "+model_weights_name_prefix) | |
for epoch in range(epochs): | |
d_losses = [] | |
g_losses = [] | |
progbar = generic_utils.Progbar(n) | |
for batch_index in range(batch_count): | |
# noise input from noise prior for the generator | |
noise_prior_input = np.random.uniform(-1, 1, size=noise_batch_shape) | |
# getting random images from data_in of size=batch_size | |
# these are the real images that will be fed to the discriminator | |
real_image_batch = data_in[np.random.randint(0, data_in.shape[0], size=batch_size)] | |
# predicted fake images from the generator | |
generator_predictions = generator_model.predict(noise_prior_input, batch_size=batch_size) | |
# the discriminator takes in the real images and the generated fake images | |
X = np.concatenate([generator_predictions, real_image_batch]) | |
# labels (in same order as X) for the discriminator | |
# with One-sided label smoothing | |
fake_y = [0]*batch_size | |
real_y = list(np.ones(batch_size) - np.random.random_sample(batch_size)*0.2) | |
y_discriminator = fake_y + real_y | |
# training the discriminator | |
# discriminator trying to distinguish between real and fake images | |
discriminator_model.trainable = True | |
d_loss = discriminator_model.train_on_batch(X, y_discriminator) | |
d_losses += [d_loss] | |
# trianing the generator-discriminator stack | |
# generator trying to fool discriminator by generating real looking images | |
# train on input noise to non-generated output class | |
noise_prior_input = np.random.uniform(-1, 1, size=noise_batch_shape) | |
y_generator = [1]*batch_size | |
discriminator_model.trainable = False | |
g_loss = gan_model.train_on_batch(noise_prior_input, y_generator) | |
g_losses += [g_loss] | |
epoch_header = "Epoch:%d d_loss" % (epoch) | |
progbar.add(batch_size, values=[(epoch_header, np.mean(d_losses)), ("g_loss", np.mean(g_losses))]) | |
# save weights every batch | |
generator_model.save_weights(model_weights_name_prefix+"_generator.hdf5") | |
discriminator_model.save_weights(model_weights_name_prefix+"_discriminator.hdf5") | |
# update losses | |
losses['d'] += d_losses | |
losses['g'] += g_losses | |
if plot_epoch is not None and epoch % plot_epoch == 0: | |
plot_output_nosess(generator_model, model_weights_name_prefix+"_generator.hdf5") | |
return losses |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_generator_model(input_shape): | |
w,h,c = input_shape | |
model = Sequential() | |
model.add(Conv2D(256, (4, 4), input_shape=input_shape, padding='same')) | |
model.add(Activation('relu')) | |
model.add(BatchNormalization()) | |
model.add(UpSampling2D(size=(2, 2))) | |
model.add(Conv2D(128, (8, 8), strides=2, padding='same')) | |
model.add(Activation('relu')) | |
model.add(BatchNormalization()) | |
model.add(UpSampling2D(size=(2, 2))) | |
model.add(Conv2D(64, (8, 8), strides=2, padding='same')) | |
model.add(Activation('relu')) | |
model.add(BatchNormalization()) | |
model.add(UpSampling2D(size=(2, 2))) | |
model.add(Conv2D(c, (4, 4), strides=2, padding='same')) | |
model.add(Activation('tanh')) | |
return model | |
def get_discriminator_model(input_shape): | |
model = Sequential() | |
model.add(Conv2D(128, (4, 4), strides=(2,2), padding='same', input_shape=input_shape)) | |
model.add(LeakyReLU(0.2)) | |
model.add(Conv2D(64, (8, 8), strides=(2,2), padding='same')) | |
model.add(LeakyReLU(0.2)) | |
model.add(Dropout(0.2)) | |
model.add(Conv2D(128, (4, 4), strides=(2,2), padding='same')) | |
model.add(LeakyReLU(0.2)) | |
model.add(Dropout(0.2)) | |
model.add(Flatten()) | |
model.add(Dense(64)) | |
model.add(LeakyReLU(0.2)) | |
model.add(Dense(1)) | |
model.add(Activation('sigmoid')) | |
return model |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment