Created
July 24, 2020 08:00
-
-
Save akash-ch2812/eaf58537ffc46075e59370327ee3b0cb to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import keras | |
def create_model(max_caption_length, vocab_length): | |
# sub network for handling the image feature part | |
input_layer1 = keras.Input(shape=(18432)) | |
feature1 = keras.layers.Dropout(0.2)(input_layer1) | |
feature2 = keras.layers.Dense(max_caption_length*4, activation='relu')(feature1) | |
feature3 = keras.layers.Dense(max_caption_length*4, activation='relu')(feature2) | |
feature4 = keras.layers.Dense(max_caption_length*4, activation='relu')(feature3) | |
feature5 = keras.layers.Dense(max_caption_length*4, activation='relu')(feature4) | |
# sub network for handling the text generation part | |
input_layer2 = keras.Input(shape=(max_caption_length,)) | |
cap_layer1 = keras.layers.Embedding(vocab_length, 300, input_length=max_caption_length)(input_layer2) | |
cap_layer2 = keras.layers.Dropout(0.2)(cap_layer1) | |
cap_layer3 = keras.layers.LSTM(max_caption_length*4, activation='relu', return_sequences=True)(cap_layer2) | |
cap_layer4 = keras.layers.LSTM(max_caption_length*4, activation='relu', return_sequences=True)(cap_layer3) | |
cap_layer5 = keras.layers.LSTM(max_caption_length*4, activation='relu', return_sequences=True)(cap_layer4) | |
cap_layer6 = keras.layers.LSTM(max_caption_length*4, activation='relu')(cap_layer5) | |
# merging the two sub network | |
decoder1 = keras.layers.merge.add([feature5, cap_layer6]) | |
decoder2 = keras.layers.Dense(256, activation='relu')(decoder1) | |
decoder3 = keras.layers.Dense(256, activation='relu')(decoder2) | |
# output is the next word in sequence | |
output_layer = keras.layers.Dense(vocab_length, activation='softmax')(decoder3) | |
model = keras.models.Model(inputs=[input_layer1, input_layer2], outputs=output_layer) | |
model.summary() | |
return model |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment