Last active
August 29, 2021 12:37
-
-
Save tempdeltavalue/1a47b5c32fc94e16ef73b2eb6822b457 to your computer and use it in GitHub Desktop.
MN vs MN+LSTM
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import keras | |
import time | |
from keras.layers import TimeDistributed, GlobalAveragePooling2D, Dense, LSTM, Input | |
from keras.models import Model | |
def inference_model(): | |
model = keras.applications.MobileNetV2(include_top=True, weights=None) | |
x = model.output | |
x = keras.layers.Dense(130, activation='relu')(x) | |
x = keras.layers.Dense(1, activation='sigmoid')(x) | |
model = keras.models.Model(model.input, x) | |
return model | |
def mobilenetv2lstm(): | |
backbone_model = keras.applications.MobileNetV2(include_top=False, | |
weights='imagenet', | |
alpha=0.5) | |
image_input = Input(shape=[None, 224, 224, 3]) | |
backbone_feature_layer = TimeDistributed(backbone_model)(image_input) | |
x = TimeDistributed(GlobalAveragePooling2D())(backbone_feature_layer) | |
x = TimeDistributed(Dense(130, activation='elu'))(x) | |
x = LSTM(16, activation='elu', return_sequences=True)(x) | |
output = TimeDistributed(Dense(1, activation='sigmoid', name='predictions'))(x) | |
model = Model(image_input, output) | |
return model | |
if __name__ == '__main__': | |
batch_size = 1 | |
seq_len = 8 | |
model = mobilenetv2lstm() | |
input = np.random.rand(batch_size, seq_len, 224, 224, 3) | |
# model = inference_model() | |
# input = np.random.rand(batch_size * seq_len, 224, 224, 3) | |
print('\n input shape', input.shape) | |
for i in range(0, 20): | |
start_time = time.time() | |
model.predict(input) | |
print(time.time() - start_time) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment