Last active
April 7, 2018 15:53
-
-
Save SohanChy/fb1a2bb485387addf8de694bc5f11c1b to your computer and use it in GitHub Desktop.
Transfer Learning
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
""" | |
Created on Sat Aug 26 20:40:38 2017 | |
@author: dhaval | |
Updated to KERAS 2 API on Sat Apr 7 9:53:00 2018 | |
@contrib: sohanchy | |
""" | |
import os | |
import sys | |
import glob | |
import argparse | |
import matplotlib | |
#matplotlib.use('agg') | |
import matplotlib.pyplot as plt | |
from keras import backend as K | |
from keras import __version__ | |
from keras.applications.inception_v3 import InceptionV3, preprocess_input | |
from keras.models import Model | |
from keras.layers import Dense, AveragePooling2D, GlobalAveragePooling2D, Input, Flatten, Dropout | |
from keras.preprocessing.image import ImageDataGenerator | |
from keras.optimizers import SGD | |
IM_WIDTH, IM_HEIGHT = 299, 299 #fixed size for InceptionV3 | |
NB_EPOCHS = 3 | |
BAT_SIZE = 32 | |
FC_SIZE = 1024 | |
#NB_IV3_LAYERS_TO_FREEZE = 172 | |
def get_nb_files(directory): | |
"""Get number of files by searching directory recursively""" | |
if not os.path.exists(directory): | |
return 0 | |
cnt = 0 | |
for r, dirs, files in os.walk(directory): | |
for dr in dirs: | |
cnt += len(glob.glob(os.path.join(r, dr + "/*"))) | |
return cnt | |
def setup_to_transfer_learn(model, base_model): | |
"""Freeze all layers and compile the model""" | |
for layer in base_model.layers: | |
layer.trainable = False | |
model.compile(optimizer='rmsprop', | |
loss='categorical_crossentropy', | |
metrics=['accuracy']) | |
def add_new_last_layer(base_model, nb_classes): | |
"""Add last layer to the convnet | |
Args: | |
base_model: keras model excluding top | |
nb_classes: # of classes | |
Returns: | |
new keras model with last layer | |
""" | |
x = base_model.output | |
x = AveragePooling2D(pool_size=(8, 8),padding='valid',name='avg_pool')(x) | |
x = Dropout(0.4)(x) | |
x = Flatten()(x) | |
predictions = Dense(nb_classes, activation='softmax')(x) | |
model = Model(inputs=base_model.input, outputs=predictions) | |
return model | |
def plot_training(history): | |
acc = history.history['acc'] | |
val_acc = history.history['val_acc'] | |
loss = history.history['loss'] | |
val_loss = history.history['val_loss'] | |
epochs = range(len(acc)) | |
plt.plot(epochs, acc, 'r.') | |
plt.plot(epochs, val_acc, 'r') | |
plt.title('Training and validation accuracy') | |
plt.savefig('accuracy.png') | |
plt.figure() | |
plt.plot(epochs, loss, 'r.') | |
plt.plot(epochs, val_loss, 'r-') | |
plt.title('Training and validation loss') | |
plt.savefig('loss.png') | |
""" | |
def setup_to_finetune(model): | |
Freeze the bottom NB_IV3_LAYERS and retrain the remaining top layers. | |
note: NB_IV3_LAYERS corresponds to the top 2 inception blocks in the inceptionv3 arch | |
Args: | |
model: keras model | |
for layer in model.layers[:NB_IV3_LAYERS_TO_FREEZE]: | |
layer.trainable = False | |
for layer in model.layers[NB_IV3_LAYERS_TO_FREEZE:]: | |
layer.trainable = True | |
model.compile(optimizer=SGD(lr=0.0001, momentum=0.9), loss='categorical_crossentropy', metrics=['accuracy']) | |
""" | |
def train(output_model_file,nb_epoch=5,doPlot=False,batch_size=32,train_dir="train",test_dir="test"): | |
"""Use transfer learning and fine-tuning to train a network on a new dataset""" | |
train_img = train_dir | |
validation_img = test_dir | |
nb_train_samples = get_nb_files(train_img) | |
nb_classes = len(glob.glob(train_img + "/*")) | |
# data prep | |
train_datagen = ImageDataGenerator( | |
rotation_range=40, | |
width_shift_range=0.2, | |
height_shift_range=0.2, | |
rescale=1./255, | |
shear_range=0.2, | |
zoom_range=0.2, | |
horizontal_flip=True, | |
fill_mode='nearest') | |
validation_datagen = ImageDataGenerator( | |
rotation_range=40, | |
width_shift_range=0.2, | |
height_shift_range=0.2, | |
rescale=1./255, | |
shear_range=0.2, | |
zoom_range=0.2, | |
horizontal_flip=True, | |
fill_mode='nearest') | |
train_generator = train_datagen.flow_from_directory( | |
train_img, | |
batch_size=batch_size, | |
target_size=(299, 299), | |
class_mode='categorical' | |
) | |
validation_generator = validation_datagen.flow_from_directory( | |
validation_img, | |
batch_size=batch_size, | |
target_size=(299, 299), | |
class_mode='categorical' | |
) | |
if(K.image_dim_ordering() == 'th'): | |
input_tensor = Input(shape=(3, 299, 299)) | |
else: | |
input_tensor = Input(shape=(299, 299, 3)) | |
# setup model | |
base_model = InceptionV3(input_tensor = input_tensor,weights='imagenet', include_top=False) #include_top=False excludes final FC layer | |
model = add_new_last_layer(base_model, nb_classes) | |
# transfer learning | |
setup_to_transfer_learn(model, base_model) | |
history_tl = model.fit_generator(train_generator, | |
steps_per_epoch=8, | |
epochs=nb_epoch, | |
validation_data=validation_generator) | |
model.save(output_model_file) | |
if doPlot: | |
plot_training(history_tl) | |
#EXAMPLE USE | |
# nb_epoch = 5 | |
# output_model_file = "dhakaia_pola.model" | |
# batch_size = 32 | |
# doPlot = True | |
# train(args) | |
# train(output_model_file,nb_epoch,doPlot,batch_size) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment