Skip to content

Instantly share code, notes, and snippets.

@rohit-gupta
Created September 13, 2017 18:15
Show Gist options
  • Save rohit-gupta/7668b79389e29598ace41813fc1a50d2 to your computer and use it in GitHub Desktop.
Save rohit-gupta/7668b79389e29598ace41813fc1a50d2 to your computer and use it in GitHub Desktop.
Fitting a generator to video data to train a action classification model
# Create generator
train_generator = MSRVTTSequence(train_captions, video_folder=videos_folder, fps_dict=video_fps, tag_dict=tags, batch_size=16)
validation_generator = MSRVTTSequence(validation_captions, video_folder=videos_folder, fps_dict=video_fps, tag_dict=tags, batch_size=16)
from keras.applications.resnet50 import ResNet50
from keras.layers import TimeDistributed, Bidirectional
from keras.layers import Input, LSTM, Dense
from keras.models import Model
from keras.callbacks import CSVLogger, ModelCheckpoint, ReduceLROnPlateau
from keras import backend as K
K.set_learning_phase(1)
# Define Model
video_input = Input(shape=(NUM_FRAMES, 224, 224, 3))
convnet_model = ResNet50(weights='imagenet', include_top=False, pooling='avg')
for layer in convnet_model.layers:
layer.trainable = False
encoded_frame_sequence = TimeDistributed(convnet_model)(video_input)
#encoded_video = Bidirectional(LSTM(results.lstm_size,implementation=1,dropout=0.5))(encoded_frame_sequence)
encoded_video = LSTM(results.lstm_size,implementation=2,dropout=0.2)(encoded_frame_sequence)
output = Dense(NUM_TAGS, activation='sigmoid')(encoded_video)
tag_model = Model(inputs=video_input, outputs=output)
tag_model.compile(loss='binary_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
tag_model.summary()
# tag_model.load_weights('../models/lstm_vgg_'+results.tag_type+'_tag_model_augmented.h5')
# Train Model
csv_logger = CSVLogger('logs/tag_model_'+results.tag_type+'_tag_model.log')
checkpointer = ModelCheckpoint(filepath='models/tag_model_'+results.tag_type+'_tag_model.h5', verbose=1, save_best_only=True)
#reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, min_lr=0.0001)
#tag_model.fit(augmented_train_frames, augmented_train_tags, epochs=10, batch_size=16, validation_split=0.2, callbacks=[csv_logger, checkpointer, reduce_lr])
tag_model.fit_generator(train_generator, steps_per_epoch=407, epochs=10, verbose=1, callbacks=[csv_logger,checkpointer],validation_data=validation_generator,validation_steps=31, max_queue_size=5, workers=1, use_multiprocessing=True, shuffle=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment