Skip to content

Instantly share code, notes, and snippets.

@ksindi
Created April 2, 2017 15:39
Show Gist options
  • Save ksindi/42f1eafb1944e6a9c1b532bd8547dcb5 to your computer and use it in GitHub Desktop.
Save ksindi/42f1eafb1944e6a9c1b532bd8547dcb5 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Train model using transfer learning on InceptionV3."""
from keras.applications.inception_v3 import InceptionV3
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Model
from keras.layers import Dense, GlobalAveragePooling2D
CLASSES = "daisy dandelion roses sunflowers tulips".split()
# create the base pre-trained model
base_model = InceptionV3(weights='imagenet', include_top=False)
# add a global spatial average pooling layer
x = base_model.output
x = GlobalAveragePooling2D()(x)
# let's add a fully-connected layer
x = Dense(1024, activation='relu')(x)
# and a logistic layer -- let's say we have 2 classes
predictions = Dense(len(CLASSES), activation='softmax')(x)
# this is the model we will train
model = Model(input=base_model.input, output=predictions)
# train only the top layer
for layer in model.layers[:-1]:
layer.trainable = False
# compile the model (should be done *after* setting layers to non-trainable)
model.compile(optimizer='rmsprop', loss='categorical_crossentropy',
metrics=['accuracy'])
train_datagen = ImageDataGenerator(
rescale=1. / 255,
shear_range=0.2,
horizontal_flip=True,
rotation_range=10.,
width_shift_range=0.2,
height_shift_range=0.2)
test_datagen = ImageDataGenerator(rescale=1. / 255)
train_generator = train_datagen.flow_from_directory(
'./flower_photos/train/',
target_size=(299, 299),
batch_size=32,
classes=CLASSES,
class_mode='categorical')
validation_generator = test_datagen.flow_from_directory(
'./flower_photos/test/',
target_size=(299, 299),
batch_size=32,
classes=CLASSES,
class_mode='categorical')
model.fit_generator(train_generator,
samples_per_epoch=625,
validation_data=validation_generator,
nb_val_samples=100,
epochs=10)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment