Last active
June 7, 2019 17:34
-
-
Save tonyreina/77ab836bf33db815d4d398bb1284a500 to your computer and use it in GitHub Desktop.
Calculate time it takes TensorFlow Keras to handle dynamic reshaping of model inputs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tensorflow as tf | |
num_channels = 3 | |
def create_model(): | |
""" | |
Generic fully-convolutional (FCN) model | |
Model inputs can be defined with size "None" so that they can | |
use different shapes each time. This works for FCN models | |
because there are no operations with pre-defined shapes (e.g. Dense or Flatten) | |
""" | |
inputs_to_model = tf.keras.layers.Input([None, None, num_channels], name="myInput") | |
conv1 = tf.keras.layers.Conv2D(32, (3,3), activation="relu", padding="same")(inputs_to_model) | |
pool1 = tf.keras.layers.MaxPooling2D((2,2))(conv1) | |
conv2 = tf.keras.layers.Conv2D(64, (3,3), activation="relu", padding="same")(pool1) | |
up1 = tf.keras.layers.Conv2DTranspose(name="transconv1", filters=32, | |
kernel_size=(2, 2), strides=(2, 2), padding="same")(conv2) | |
outputs_to_model = tf.keras.layers.Conv2D(num_channels, (3,3), activation="sigmoid", padding="same")(up1) | |
model = tf.keras.models.Model(inputs=[inputs_to_model], outputs=[outputs_to_model]) | |
model.summary() # Ask Keras to print the model summary | |
model.compile(loss="binary_crossentropy", optimizer="Adam", metrics=["accuracy"]) | |
return model | |
if __name__ == "__main__": | |
""" | |
Create fully convolutional model trained on [224, 224, 3] images | |
(Model just acts as autoencoder) | |
""" | |
model = create_model() | |
dataset_size = 1024*4 | |
input_shape = [dataset_size, 224, 224, num_channels] | |
random_input = np.random.random(input_shape) | |
# Train model on [224, 224, 3] input images | |
model.fit(random_input, random_input, epochs=2, batch_size=64) | |
""" | |
Do inference on model with different input sizes and dataset sizes | |
""" | |
# Create random size input shape at runtime to measure time between dynamic reshape | |
dataset_size1 = np.random.randint(32, 1024) | |
input_shape1 = [dataset_size1, np.random.randint(32, 1024), np.random.randint(32, 1024), num_channels] | |
random_input1 = np.random.random(input_shape1) | |
dataset_size2 = np.random.randint(32, 1024) | |
input_shape2 = [dataset_size2, np.random.randint(128, 1024), np.random.randint(128, 1024), num_channels] | |
random_input2 = np.random.random(input_shape2) | |
dataset_size3 = np.random.randint(512, 1024) | |
input_shape3 = [dataset_size3, np.random.randint(32, 1024), np.random.randint(32, 1024), num_channels] | |
random_input3 = np.random.random(input_shape3) | |
dataset_size4 = np.random.randint(128, 1024) | |
input_shape4 = [dataset_size4, np.random.randint(640, 1024), np.random.randint(640, 1024), num_channels] | |
random_input4 = np.random.random(input_shape4) | |
# Looking for the delay between one predict and the next. | |
import datetime | |
print("\n\nTest time between predictions on dynamic reshape") | |
print("Add up the time for each prediction and subtract from total time.") | |
print("That's the time to reshape model to new shape.") | |
print("Random input size 1 = {}".format(random_input1.shape)) | |
print("Random input size 2 = {}".format(random_input2.shape)) | |
print("Random input size 3 = {}".format(random_input3.shape)) | |
print("Random input size 4 = {}".format(random_input4.shape)) | |
print("Testing random sizes with back to back model.predict()") | |
for iter in range(3): | |
print("Run #{}".format(iter+1)) | |
start = datetime.datetime.now() | |
for i in range(3): # Repeat to give us better estimate of time | |
model.predict(random_input1, verbose=1) | |
model.predict(random_input2, verbose=1) | |
model.predict(random_input3, verbose=1) | |
model.predict(random_input4, verbose=1) | |
stop = datetime.datetime.now() | |
print("Elapsed time = {}".format(stop-start)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment