-
-
Save CasiaFan/5eebd085fff4aa0267e0132046b80437 to your computer and use it in GitHub Desktop.
import tensorflow as tf | |
from tensorflow.core.framework import types_pb2, graph_pb2, attr_value_pb2 | |
from tensorflow.tools.graph_transforms import TransformGraph | |
from google.protobuf import text_format | |
import numpy as np | |
# object detection api input and output nodes | |
input_name = "image_tensor" | |
output_names = ["detection_boxes", "detection_classes", "detection_scores", "num_detections"] | |
# Const should be float32 in object detection api during nms (see here: https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/non-max-suppression-v4.html) | |
keep_fp32_node_name = ["Postprocessor/BatchMultiClassNonMaxSuppression/MultiClassNonMaxSuppression/non_max_suppression/iou_threshold", | |
"Postprocessor/BatchMultiClassNonMaxSuppression/MultiClassNonMaxSuppression/non_max_suppression/score_threshold"] | |
def load_graph(model_path): | |
graph = tf.Graph() | |
with graph.as_default(): | |
graph_def = tf.GraphDef() | |
if model_path.endswith("pb"): | |
with open(model_path, "rb") as f: | |
graph_def.ParseFromString(f.read()) | |
else: | |
with open(model_path, "r") as pf: | |
text_format.Parse(pf.read(), graph_def) | |
tf.import_graph_def(graph_def, name="") | |
sess = tf.Session(graph=graph) | |
return sess | |
def rewrite_batch_norm_node_v2(node, graph_def, target_type='fp16'): | |
""" | |
Rewrite FusedBatchNorm with FusedBatchNormV2 for reserve_space_1 and reserve_space_2 in FusedBatchNorm require float32 for | |
gradient calculation (See here: https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/fused-batch-norm) | |
""" | |
if target_type == 'fp16': | |
dtype = types_pb2.DT_HALF | |
elif target_type == 'fp64': | |
dtype = types_pb2.DT_DOUBLE | |
else: | |
dtype = types_pb2.DT_FLOAT | |
new_node = graph_def.node.add() | |
new_node.op = "FusedBatchNormV2" | |
new_node.name = node.name | |
new_node.input.extend(node.input) | |
new_node.attr["U"].CopyFrom(attr_value_pb2.AttrValue(type=types_pb2.DT_FLOAT)) | |
for attr in list(node.attr.keys()): | |
if attr == "T": | |
node.attr[attr].type = dtype | |
new_node.attr[attr].CopyFrom(node.attr[attr]) | |
print("rewrite fused_batch_norm done!") | |
def convert_graph_to_fp16(model_path, save_path, name, as_text=False, target_type='fp16', input_name=None, output_names=None): | |
if target_type == 'fp16': | |
dtype = types_pb2.DT_HALF | |
elif target_type == 'fp64': | |
dtype = types_pb2.DT_DOUBLE | |
else: | |
dtype = types_pb2.DT_FLOAT | |
source_sess = load_graph(model_path) | |
source_graph_def = source_sess.graph.as_graph_def() | |
target_graph_def = graph_pb2.GraphDef() | |
target_graph_def.versions.CopyFrom(source_graph_def.versions) | |
for node in source_graph_def.node: | |
# fused batch norm node | |
if node.op == "FusedBatchNorm": | |
rewrite_batch_norm_node_v2(node, target_graph_def, target_type=target_type) | |
continue | |
# replicate node | |
new_node = target_graph_def.node.add() | |
new_node.op = node.op | |
new_node.name = node.name | |
new_node.input.extend(node.input) | |
attrs = list(node.attr.keys()) | |
# keep batch norm params node | |
if ("BatchNorm" in node.name) or ('batch_normalization' in node.name): | |
for attr in attrs: | |
new_node.attr[attr].CopyFrom(node.attr[attr]) | |
continue | |
# replace dtype in node attr with target dtype | |
for attr in attrs: | |
# keep special node in fp32 | |
if node.name in keep_fp32_node_name: | |
new_node.attr[attr].CopyFrom(node.attr[attr]) | |
continue | |
if node.attr[attr].type == types_pb2.DT_FLOAT: | |
# modify node dtype | |
new_node.attr[attr].type = dtype | |
if attr == "value": | |
tensor = node.attr[attr].tensor | |
if tensor.dtype == types_pb2.DT_FLOAT: | |
# if float_val exists | |
if tensor.float_val: | |
float_val = tf.make_ndarray(node.attr[attr].tensor) | |
new_node.attr[attr].tensor.CopyFrom(tf.make_tensor_proto(float_val, dtype=dtype)) | |
continue | |
# if tensor content exists | |
if tensor.tensor_content: | |
tensor_shape = [x.size for x in tensor.tensor_shape.dim] | |
tensor_weights = tf.make_ndarray(tensor) | |
# reshape tensor | |
tensor_weights = np.reshape(tensor_weights, tensor_shape) | |
tensor_proto = tf.make_tensor_proto(tensor_weights, dtype=dtype) | |
new_node.attr[attr].tensor.CopyFrom(tensor_proto) | |
continue | |
new_node.attr[attr].CopyFrom(node.attr[attr]) | |
# transform graph | |
if output_names: | |
if not input_name: | |
input_name = [] | |
transforms = ["strip_unused_nodes"] | |
target_graph_def = TransformGraph(target_graph_def, input_name, output_names, transforms) | |
# write graph_def to model | |
tf.io.write_graph(target_graph_def, logdir=save_path, name=name, as_text=as_text) | |
print("Converting done ...") | |
save_path = "test" | |
name = "test.pb" | |
as_text = False | |
target_type = 'fp16' | |
convert_graph_to_fp16(model_path, save_path, name, as_text=as_text, target_type=target_type, input_name=input_name, output_names=output_names) | |
# test loading | |
# ISSUE: loading detection model is extremely slow while loading classification model is normal | |
sess = load_graph(save_path+"/"+name) |
@CasiaFan I am using tensorflow version 1.13
@CasiaFan Hi, I met the same issue
"File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
return func(*args, **kwargs)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/importer.py", line 430, in import_graph_def
raise ValueError(str(e))
ValueError: Input 0 of node Preprocessor/map/while/Enter_2 was passed float from Preprocessor/map/TensorArray_2:1 incompatible with expected half. "
Do you have fixed it?
In line 85 change
node.attr[attr].type = dtype
to:
new_node.attr[attr].type = dtype
My fault. Thanks for fixing this issue! @sayradley
Hi,fan,do you have any idea about add a new node into pb?
thanks, fan. Actually, I didnot supposed to add node while frozen graph.
I find it is work to use your method to add a new node in pb, LOL.
like this:
def add_cast_node(node, src_graph, graph_def, target_type='fp16', node_name='import/cast'):
new_node = graph_def.node.add()
new_node.op = "Cast"
new_node.name = node_name
new_node.input.append(node.input[0])
new_node.attr["SrcT"].type = types_pb2.DT_INT32
new_node.attr["DstT"].type = types_pb2.DT_INT64
new_node.attr["Truncate"].b = False
return new_node
ValueError: Input 0 of node InceptionResnetV1/Conv2d_1a_3x3/BatchNorm/cond/FusedBatchNorm was passed float from InceptionResnetV1/Conv2d_1a_3x3/BatchNorm/cond/FusedBatchNorm/Switch:1 incompatible with expected half.
im trying to convert pretrained model in facenet with input_names='input and output_names='embeddings' ,using tensorflow 1.14.0
@glennford49 As it said, try to keep the attributes of the Switch
input node in the FusedBatchNorm
, rather than being converted to ft16.
i have copied entirely your code but stil gives me same error, i will try to downgrade tf to 1.13.1
@glennford49 I'm not sure about this error. Could you provide your model?
I just downloaded the model 20180402-114759.pb trained in vggface2 in david sanberg github, i have used this model without converting to fp16 it works fine.. when i try to convert this using your code replacing the input name with input and ouput name with embeddings, but i get errors
ValueError: Input 0 of node InceptionResnetV1/Conv2d_1a_3x3/BatchNorm/cond/FusedBatchNorm was passed float from InceptionResnetV1/Conv2d_1a_3x3/BatchNorm/cond/FusedBatchNorm/Switch:1 incompatible with expected half.
im trying to convert pretrained model in facenet with input_names='input and output_names='embeddings' ,using tensorflow 1.14.0
same issue here. any idea to solve it ?
some bugs:
the fp16 is not more than tf.float16.max and makes some value be inf.
for the attr "SrcT","T","Tparams","DstT", this code doesn't make any change.
@durgabhavaniv It seems loading frozen model fails in the beginning. It may be due to incompatibility for TF version used for model training and loading. So check your TF version and could you provide your model?