Skip to content

Instantly share code, notes, and snippets.

@ground0state
Created August 19, 2019 15:43
Show Gist options
  • Save ground0state/1e0e44cda58eeeddff9287a1fa999dc1 to your computer and use it in GitHub Desktop.
Save ground0state/1e0e44cda58eeeddff9287a1fa999dc1 to your computer and use it in GitHub Desktop.
from tensorflow.python.keras.applications.vgg16 import VGG16
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.layers import Dropout, Flatten, Dense
from tensorflow.python.keras.optimizers import SGD
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
from tensorflow.python.keras.applications.vgg16 import preprocess_input
from tensorflow.python.keras.callbacks import ModelCheckpoint, CSVLogger
import os
from datetime import datetime
import json
import pickle
import math
from utils import *
vgg16 = VGG16(include_top=False, input_shape=(224, 224, 3))
def build_transfer_model(base_model):
model = Sequential(base_model.layers)
for layer in model.layers[:14]:
layer.trainable = False
model.add(Flatten())
model.add(Dense(256, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(1, activation='sigmoid'))
return model
def build_transfer_model_functional(base_model):
x = base_model.output
x = Flatten()(x)
x = Dense(256, activation='relu')(x)
x = Dropout(0.5)(x)
prediction = Dense(1, activation='sigmoid')(x)
model = Model(inputs=base_model.input, outputs=prediction)
for layer in model.layers[:14]:
layer.trainable = False
return model
model = build_transfer_model(vgg16)
model.compile(
optimizer=SGD(lr=1e-4, momentum=0.9),
loss='binary_crossentropy',
metrics=['accuracy']
)
print(model.summary())
idg_train = ImageDataGenerator(
rescale=1/255,
shear_range=0.1,
zoom_range=0.1,
horizontal_flip=True,
preprocessing_function=preprocess_input
)
img_itr_train = idg_train.flow_from_directory(DATA_FOLDER+'img/shrine_temple/train', target_size=(224, 224), batch_size=16, class_mode='binary')
img_itr_validation = idg_train.flow_from_directory(DATA_FOLDER+'img/shrine_temple/validation', target_size=(224, 224), batch_size=16, class_mode='binary')
model_dir = os.path.join('models', datetime.now().strftime('%y%m%d_%H%M'))
os.makedirs(model_dir, exist_ok=True)
print('model_dir:', model_dir)
dir_weights = os.path.join(model_dir, 'weights')
os.makedirs(dir_weights, exist_ok=True)
model_json = os.path.join(model_dir, 'model.json')
with open(model_json, 'w') as f:
json.dump(model.to_json(), f)
model_classes = os.path.join(model_dir, 'classes.pkl')
with open(model_classes, 'wb') as f:
pickle.dump(img_itr_train.class_indices, f)
batch_size = 16
steps_per_epoch = math.ceil(img_itr_train.samples/batch_size)
validation_steps = math.ceil(img_itr_validation.samples/batch_size)
cp_filepath = os.path.join(dir_weights, 'ep_{epoch:02d}_ls_{loss:.1f}.h5')
cp = ModelCheckpoint(cp_filepath, monitor='loss', verbose=0, save_best_only=False,
save_weights_only=True, mode='auto', save_freq=5)
csv_filepath = os.path.join(model_dir, 'loss.csv')
csv = CSVLogger(csv_filepath, append=True)
n_epoch = 10
history = model.fit_generator(img_itr_train,
steps_per_epoch=steps_per_epoch,
epochs=n_epoch,
validation_data=img_itr_validation,
validation_steps=validation_steps,
callbacks=[cp, csv])
test_data_dir = DATA_FOLDER+'img/shrine_temple/test/unknown'
x_test, true_labels = load_random_imgs(test_data_dir, seed=1)
x_test_preproc = preprocess_input(x_test.copy())/255
probs = model.predict(x_test_preproc)
print(probs)
show_test_samples(x_test, probs,
img_itr_train.class_indices,
true_labels)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment