Created
October 19, 2021 05:55
-
-
Save tiandiao123/844ce92d56ef1127b23205ea82fba1ff to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
### here is a demo how to convert your tf2 model into tvm relay | |
import tensorflow as tf | |
from tensorflow.python.tools import saved_model_utils | |
from tensorflow import keras | |
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 | |
from tvm import relay | |
from tvm.relay.frontend.tensorflow2 import from_tensorflow | |
dtype_dict = { | |
"tf.float32": "float32", | |
"tf.float16" : "float16", | |
"tf.float64": "float64", | |
"tf.int32": "int32", | |
"tf.int16": "int16", | |
"tf.int8": "int8", | |
"tf.uint8": "uint8", | |
"tf.int64": "int64", | |
} | |
model_path = "/data00/cuiqing.li/models/debug_model/1" | |
### convert to fronzen pb | |
custom_objects = {} | |
for backend_alias in ('K', 'backend'): | |
if backend_alias not in custom_objects: | |
custom_objects[backend_alias] = tf.keras.backend | |
new_model = tf.keras.models.load_model(model_path, custom_objects=custom_objects, compile=False) | |
print(new_model.summary()) | |
print("print input info: ") | |
for input in new_model.inputs: | |
print(input.shape) | |
full_model = tf.function(lambda x: new_model(x)) | |
input_info = [] | |
for input in new_model.inputs: | |
print(input) | |
input_info.append(tf.TensorSpec(input.shape, input.dtype)) | |
full_model = full_model.get_concrete_function(input_info) | |
frozen_func = convert_variables_to_constants_v2(full_model) | |
graph_def = frozen_func.graph.as_graph_def() | |
input_names = [] | |
input_shapes = [] | |
input_types = [] | |
batch_size = 1 | |
for input in frozen_func.inputs: | |
input_names.append(input.name) | |
input_types.append(input.dtype) | |
temp_shape = [] | |
for ele in input.shape: | |
num = batch_size if ele == None else ele | |
temp_shape.append(num) | |
input_shapes.append(temp_shape) | |
for i in range(len(input_types)): | |
if str(input_types[i]) in dtype_dict: | |
input_types[i] = dtype_dict[str(input_types[i])] | |
else: | |
input_types[i] = "float32" | |
print("input names: ") | |
print(input_names) | |
print("input_shapes: ") | |
print(input_shapes) | |
print("input_types") | |
print(input_types) | |
tvm_shape_dict = {k: v for k, v in zip(input_names, input_shapes)} | |
print(tvm_shape_dict) | |
print("converting to relay graph ... ") | |
mod, params = from_tensorflow(graph_def, shape=tvm_shape_dict) | |
print(mod['main']) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment