Last active
December 1, 2018 18:23
-
-
Save heronyang/f008ca43b38b33209cceb8cd999fe0da to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
"""This script loads a pretrained InceptionV3 model, then retrains the | |
fully-connected layer using a given training data. Per epoch, one model with | |
weight values and one prediction on a given test data are saved to disk. | |
Input: | |
- X_TRAIN_FILENAME: the location of training dataset file (.npy file) | |
- Y_TRAIN_FILENAME: the location of training label file (.npy file) | |
- X_TEST_FILENAME: the location of the dataset to predict (.npy file) | |
- N_CLASSES: number of classes to classify | |
Output: | |
- A model file with its weight is generated per epoch. | |
- A prediction file is generated per epoch. | |
""" | |
from datetime import datetime | |
import csv | |
import numpy as np | |
from keras.applications import InceptionV3 | |
from keras.callbacks import Callback | |
from keras.layers import Dense, Dropout, GlobalAveragePooling2D | |
from keras.models import Sequential | |
from keras.optimizers import SGD | |
X_TRAIN_FILENAME = 'X_train.npy' # (14424, 224, 224, 3) | |
Y_TRAIN_FILENAME = 'y_train.npy' # (14424,) | |
X_TEST_FILENAME = 'X_test.npy' # (3000, 224, 224, 3) | |
N_CLASSES = 3 # Number of classes to classify | |
class LogAndPredictCallback(Callback): | |
"""Creates predict generator callback (makes a prediction per epoch). | |
""" | |
def __init__(self, test_dataset, task_id): | |
self.test_dataset = test_dataset | |
self.task_id = task_id | |
def on_epoch_end(self, epoch, logs=None): | |
# Makes prediction and saves to disk | |
predict = np.argmax( | |
self.model.predict(self.test_dataset, batch_size=16, verbose=1), 1) | |
self.__save_predict(predict, epoch) | |
# Saves model weights | |
self.__save_weights(epoch) | |
# Prints logs information on screen | |
print('epoch', epoch, 'logs', logs, 'model') | |
self.model.summary() | |
def __save_predict(self, predict, epoch): | |
output_file = 'predict.%s.epoch-%02d.csv' % (self.task_id, epoch) | |
with open(output_file, 'w', newline='') as f_out: | |
writer = csv.writer(f_out) | |
for i in range(0, len(predict)): | |
writer.writerow([str(i), str(predict[i])]) | |
print(output_file + ' was saved.') | |
def __save_weights(self, epoch): | |
output_file = 'weight.%s.epoch-%02d.h5' % (self.task_id, epoch) | |
self.model.save_weights(output_file) | |
print(output_file + ' was saved.') | |
def main(): | |
"""Script starts here. | |
""" | |
# Loads data | |
train_dataset, train_label, test_dataset = load_data() | |
# Builds model | |
model = build_model() | |
# Trains the model | |
train_history = model.fit( | |
train_dataset, train_label, | |
batch_size=32, epochs=60, validation_split=0.1, | |
callbacks=[LogAndPredictCallback(test_dataset, get_current_time())] | |
) | |
print('Done', train_history) | |
def load_data(): | |
"""Loads data from disk into numpy arrays. | |
""" | |
train_dataset = np.load(X_TRAIN_FILENAME) | |
train_label = (np.arange(3) == np.load(Y_TRAIN_FILENAME)[:, None]) | |
test_dataset = np.load(X_TEST_FILENAME) | |
return train_dataset, train_label, test_dataset | |
def build_model(): | |
"""Builds the model. | |
""" | |
# Initializes the model with weights trained from ImageNet. | |
base_model = InceptionV3( | |
weights='imagenet', | |
include_top=False, | |
input_shape=(224, 224, 3)) | |
# Freeze weights in the pre-trained model from being trained. | |
base_model.trainable = False | |
# Builds the model we will train on. | |
model = Sequential() | |
model.add(base_model) | |
model.add(GlobalAveragePooling2D()) | |
model.add(Dense(1024, activation='relu')) | |
model.add(Dense(N_CLASSES, activation='softmax')) | |
# NOTE: Specify the previous model filename to resume from it. | |
# model.load_weights('weight.epoch-xxxx.h5') | |
sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True) | |
model.compile( | |
loss='categorical_crossentropy', | |
optimizer=sgd, | |
metrics=['categorical_accuracy'] | |
) | |
return model | |
def get_current_time(): | |
"""Returns a nicely printed timestamp. | |
""" | |
return datetime.now().strftime("%Y-%m-%d.%H-%M") | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment