Skip to content

Instantly share code, notes, and snippets.

@dipanjanS
Created August 20, 2019 17:14
Show Gist options
  • Save dipanjanS/223a8fc0ad58f9bc3dcec15a921b3526 to your computer and use it in GitHub Desktop.
Save dipanjanS/223a8fc0ad58f9bc3dcec15a921b3526 to your computer and use it in GitHub Desktop.
INPUT_SHAPE = (192, 192, 3)
# load pre-trained model
vgg = keras.applications.vgg19.VGG19(include_top=False, weights='imagenet',
input_shape=INPUT_SHAPE)
# fine tune all layers
vgg.trainable = True
for layer in vgg.layers:
layer.trainable = True
# add custom dense layers and output layer
base_vgg = vgg
base_out = base_vgg.output
pool_out = keras.layers.Flatten()(base_out)
hidden1 = keras.layers.Dense(1024, activation='relu')(pool_out)
drop1 = keras.layers.Dropout(rate=0.2)(hidden1)
hidden2 = keras.layers.Dense(512, activation='relu')(drop1)
drop2 = keras.layers.Dropout(rate=0.2)(hidden2)
out = keras.layers.Dense(7, activation='softmax')(drop2)
model = keras.Model(inputs=base_vgg.input, outputs=out)
model.compile(optimizer=keras.optimizers.RMSprop(lr=1e-6),
loss='categorical_crossentropy',
metrics=[categorical_accuracy])
# train model
class EpochModelSaver(keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs={}):
self.model.save('vgg19_finetune_full_seti_epoch_{}.h5'.format(epoch+1))
ms_epoch = EpochModelSaver()
csv_logger = keras.callbacks.CSVLogger('vgg19_finetune_full_seti_log.csv', append=True, separator=',')
history2 = model.fit_generator(
train_generator,
steps_per_epoch=math.ceil(5600 / TRAIN_BATCH_SIZE),
epochs=100,
validation_data=val_generator,
validation_steps=math.ceil(700 / VAL_BATCH_SIZE),
callbacks=[ms_epoch, csv_logger], verbose=1,
)
model.save('vgg19_finetune_full_seti.h5')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment