Skip to content

Instantly share code, notes, and snippets.

@heronyang
Last active December 1, 2018 18:23
Show Gist options
  • Save heronyang/f008ca43b38b33209cceb8cd999fe0da to your computer and use it in GitHub Desktop.
Save heronyang/f008ca43b38b33209cceb8cd999fe0da to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
"""This script loads a pretrained InceptionV3 model, then retrains the
fully-connected layer using a given training data. Per epoch, one model with
weight values and one prediction on a given test data are saved to disk.
Input:
- X_TRAIN_FILENAME: the location of training dataset file (.npy file)
- Y_TRAIN_FILENAME: the location of training label file (.npy file)
- X_TEST_FILENAME: the location of the dataset to predict (.npy file)
- N_CLASSES: number of classes to classify
Output:
- A model file with its weight is generated per epoch.
- A prediction file is generated per epoch.
"""
from datetime import datetime
import csv
import numpy as np
from keras.applications import InceptionV3
from keras.callbacks import Callback
from keras.layers import Dense, Dropout, GlobalAveragePooling2D
from keras.models import Sequential
from keras.optimizers import SGD
X_TRAIN_FILENAME = 'X_train.npy' # (14424, 224, 224, 3)
Y_TRAIN_FILENAME = 'y_train.npy' # (14424,)
X_TEST_FILENAME = 'X_test.npy' # (3000, 224, 224, 3)
N_CLASSES = 3 # Number of classes to classify
class LogAndPredictCallback(Callback):
"""Creates predict generator callback (makes a prediction per epoch).
"""
def __init__(self, test_dataset, task_id):
self.test_dataset = test_dataset
self.task_id = task_id
def on_epoch_end(self, epoch, logs=None):
# Makes prediction and saves to disk
predict = np.argmax(
self.model.predict(self.test_dataset, batch_size=16, verbose=1), 1)
self.__save_predict(predict, epoch)
# Saves model weights
self.__save_weights(epoch)
# Prints logs information on screen
print('epoch', epoch, 'logs', logs, 'model')
self.model.summary()
def __save_predict(self, predict, epoch):
output_file = 'predict.%s.epoch-%02d.csv' % (self.task_id, epoch)
with open(output_file, 'w', newline='') as f_out:
writer = csv.writer(f_out)
for i in range(0, len(predict)):
writer.writerow([str(i), str(predict[i])])
print(output_file + ' was saved.')
def __save_weights(self, epoch):
output_file = 'weight.%s.epoch-%02d.h5' % (self.task_id, epoch)
self.model.save_weights(output_file)
print(output_file + ' was saved.')
def main():
"""Script starts here.
"""
# Loads data
train_dataset, train_label, test_dataset = load_data()
# Builds model
model = build_model()
# Trains the model
train_history = model.fit(
train_dataset, train_label,
batch_size=32, epochs=60, validation_split=0.1,
callbacks=[LogAndPredictCallback(test_dataset, get_current_time())]
)
print('Done', train_history)
def load_data():
"""Loads data from disk into numpy arrays.
"""
train_dataset = np.load(X_TRAIN_FILENAME)
train_label = (np.arange(3) == np.load(Y_TRAIN_FILENAME)[:, None])
test_dataset = np.load(X_TEST_FILENAME)
return train_dataset, train_label, test_dataset
def build_model():
"""Builds the model.
"""
# Initializes the model with weights trained from ImageNet.
base_model = InceptionV3(
weights='imagenet',
include_top=False,
input_shape=(224, 224, 3))
# Freeze weights in the pre-trained model from being trained.
base_model.trainable = False
# Builds the model we will train on.
model = Sequential()
model.add(base_model)
model.add(GlobalAveragePooling2D())
model.add(Dense(1024, activation='relu'))
model.add(Dense(N_CLASSES, activation='softmax'))
# NOTE: Specify the previous model filename to resume from it.
# model.load_weights('weight.epoch-xxxx.h5')
sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(
loss='categorical_crossentropy',
optimizer=sgd,
metrics=['categorical_accuracy']
)
return model
def get_current_time():
"""Returns a nicely printed timestamp.
"""
return datetime.now().strftime("%Y-%m-%d.%H-%M")
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment