Last active
July 14, 2021 21:24
-
-
Save surya00060/a0bcf49a53353011c36bf5290fa81355 to your computer and use it in GitHub Desktop.
A code snippet to realize inconsistent operator scheduling. Some code written by Jake Stevens.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
def device_mapped_call_factory(layer, mapping, model): | |
def device_mapped_call(inp, *args, **kwargs): | |
with tf.device(layer.mapping): | |
ret = layer.orig_call(inp, *args, **kwargs) | |
return ret | |
return device_mapped_call | |
def device_map_model(model, mappings): | |
for l, layer in enumerate(model.layers): | |
layer.orig_call = layer.call | |
mapping = mappings[l] | |
layer.mapping = mapping | |
layer.call = device_mapped_call_factory(layer, mapping, model) | |
def device_remap_model(model, mappings): | |
for l, layer in enumerate(model.layers): | |
mapping = mappings[l] | |
layer.mapping = mapping | |
def evaluate(model, x): | |
return model(x, training=False) | |
### Need to change thenumber filters and batch size corresponding, so that CPU exec. time is significant as compared to GPU exec. time. | |
input_shape = (8, 56, 56, 384) | |
x = tf.random.normal(input_shape) | |
conv1 = tf.keras.layers.Conv2D(384, 3, padding='same', use_bias=False) | |
conv2 = tf.keras.layers.Conv2D(768, 3, padding='same', use_bias=False) | |
conv3 = tf.keras.layers.Conv2D(384, 3, padding='same', use_bias=False) | |
conv4 = tf.keras.layers.Conv2D(768, 3, padding='same', use_bias=False) | |
inputs = tf.keras.layers.Input(shape=(input_shape[1], input_shape[2], input_shape[3])) | |
maps = ['/GPU:0', '/GPU:0', '/GPU:0', '/CPU:0', '/GPU:0'] | |
device_map_model(model, maps) | |
modelCall = tf.function(evaluate).get_concrete_function(model, x) | |
# measureTime(modelCall, x) | |
with tf.profiler.experimental.Profile("logs"): | |
with tf.profiler.experimental.Trace("trace"): | |
modelCall(x) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment