Created
March 27, 2017 00:56
-
-
Save alexminnaar/a286912efa417dfd0f25bd33992c3d6b to your computer and use it in GitHub Desktop.
Joint image/text classifier in Keras
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from keras.layers import Dropout | |
from keras import applications | |
from keras.layers import Dense, GlobalAveragePooling2D, merge, Input | |
from keras.models import Model | |
max_words = 10000 | |
epochs = 50 | |
batch_size = 32 | |
X_train_image = ... #images training input | |
X_train_text = ... #text training input | |
y_train = ... #training output | |
num_classes = np.max(y_train) + 1 | |
# Text input branch - just a simple MLP | |
text_inputs = Input(shape=(max_words,)) | |
branch_1 = Dense(512, activation='relu')(text_inputs) | |
# Image input branch - a pre-trained Inception module followed by an added fully connected layer | |
base_model = applications.InceptionV3(weights='imagenet', include_top=False) | |
# Freeze Inception's weights - we don't want to train these | |
for layer in base_model.layers: | |
layer.trainable = False | |
# add a fully connected layer after Inception - we do want to train these | |
branch_2 = base_model.output | |
branch_2 = GlobalAveragePooling2D()(branch_2) | |
branch_2 = Dense(1024, activation='relu')(branch_2) | |
# merge the text input branch and the image input branch and add another fully connected layer | |
joint = merge([branch_1, branch_2], mode='concat') | |
joint = Dense(512, activation='relu')(joint) | |
joint = Dropout(0.5)(joint) | |
predictions = Dense(num_classes, activation='sigmoid')(joint) | |
full_model = Model(inputs=[base_model.input, text_inputs], outputs=[predictions]) | |
full_model.compile(loss='categorical_crossentropy', | |
optimizer='rmsprop', | |
metrics=['accuracy']) | |
history = full_model.fit([X_train_image, X_train_text], y_train, | |
epochs=epochs, batch_size=batch_size, | |
verbose=1, validation_split=0.2, shuffle=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment