Skip to content

Instantly share code, notes, and snippets.

@CasiaFan
Last active June 14, 2022 08:15
Show Gist options
  • Save CasiaFan/5eebd085fff4aa0267e0132046b80437 to your computer and use it in GitHub Desktop.
Save CasiaFan/5eebd085fff4aa0267e0132046b80437 to your computer and use it in GitHub Desktop.
post-training quantization tensorflow model to float16
import tensorflow as tf
from tensorflow.core.framework import types_pb2, graph_pb2, attr_value_pb2
from tensorflow.tools.graph_transforms import TransformGraph
from google.protobuf import text_format
import numpy as np
# object detection api input and output nodes
input_name = "image_tensor"
output_names = ["detection_boxes", "detection_classes", "detection_scores", "num_detections"]
# Const should be float32 in object detection api during nms (see here: https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/non-max-suppression-v4.html)
keep_fp32_node_name = ["Postprocessor/BatchMultiClassNonMaxSuppression/MultiClassNonMaxSuppression/non_max_suppression/iou_threshold",
"Postprocessor/BatchMultiClassNonMaxSuppression/MultiClassNonMaxSuppression/non_max_suppression/score_threshold"]
def load_graph(model_path):
graph = tf.Graph()
with graph.as_default():
graph_def = tf.GraphDef()
if model_path.endswith("pb"):
with open(model_path, "rb") as f:
graph_def.ParseFromString(f.read())
else:
with open(model_path, "r") as pf:
text_format.Parse(pf.read(), graph_def)
tf.import_graph_def(graph_def, name="")
sess = tf.Session(graph=graph)
return sess
def rewrite_batch_norm_node_v2(node, graph_def, target_type='fp16'):
"""
Rewrite FusedBatchNorm with FusedBatchNormV2 for reserve_space_1 and reserve_space_2 in FusedBatchNorm require float32 for
gradient calculation (See here: https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/fused-batch-norm)
"""
if target_type == 'fp16':
dtype = types_pb2.DT_HALF
elif target_type == 'fp64':
dtype = types_pb2.DT_DOUBLE
else:
dtype = types_pb2.DT_FLOAT
new_node = graph_def.node.add()
new_node.op = "FusedBatchNormV2"
new_node.name = node.name
new_node.input.extend(node.input)
new_node.attr["U"].CopyFrom(attr_value_pb2.AttrValue(type=types_pb2.DT_FLOAT))
for attr in list(node.attr.keys()):
if attr == "T":
node.attr[attr].type = dtype
new_node.attr[attr].CopyFrom(node.attr[attr])
print("rewrite fused_batch_norm done!")
def convert_graph_to_fp16(model_path, save_path, name, as_text=False, target_type='fp16', input_name=None, output_names=None):
if target_type == 'fp16':
dtype = types_pb2.DT_HALF
elif target_type == 'fp64':
dtype = types_pb2.DT_DOUBLE
else:
dtype = types_pb2.DT_FLOAT
source_sess = load_graph(model_path)
source_graph_def = source_sess.graph.as_graph_def()
target_graph_def = graph_pb2.GraphDef()
target_graph_def.versions.CopyFrom(source_graph_def.versions)
for node in source_graph_def.node:
# fused batch norm node
if node.op == "FusedBatchNorm":
rewrite_batch_norm_node_v2(node, target_graph_def, target_type=target_type)
continue
# replicate node
new_node = target_graph_def.node.add()
new_node.op = node.op
new_node.name = node.name
new_node.input.extend(node.input)
attrs = list(node.attr.keys())
# keep batch norm params node
if ("BatchNorm" in node.name) or ('batch_normalization' in node.name):
for attr in attrs:
new_node.attr[attr].CopyFrom(node.attr[attr])
continue
# replace dtype in node attr with target dtype
for attr in attrs:
# keep special node in fp32
if node.name in keep_fp32_node_name:
new_node.attr[attr].CopyFrom(node.attr[attr])
continue
if node.attr[attr].type == types_pb2.DT_FLOAT:
# modify node dtype
new_node.attr[attr].type = dtype
if attr == "value":
tensor = node.attr[attr].tensor
if tensor.dtype == types_pb2.DT_FLOAT:
# if float_val exists
if tensor.float_val:
float_val = tf.make_ndarray(node.attr[attr].tensor)
new_node.attr[attr].tensor.CopyFrom(tf.make_tensor_proto(float_val, dtype=dtype))
continue
# if tensor content exists
if tensor.tensor_content:
tensor_shape = [x.size for x in tensor.tensor_shape.dim]
tensor_weights = tf.make_ndarray(tensor)
# reshape tensor
tensor_weights = np.reshape(tensor_weights, tensor_shape)
tensor_proto = tf.make_tensor_proto(tensor_weights, dtype=dtype)
new_node.attr[attr].tensor.CopyFrom(tensor_proto)
continue
new_node.attr[attr].CopyFrom(node.attr[attr])
# transform graph
if output_names:
if not input_name:
input_name = []
transforms = ["strip_unused_nodes"]
target_graph_def = TransformGraph(target_graph_def, input_name, output_names, transforms)
# write graph_def to model
tf.io.write_graph(target_graph_def, logdir=save_path, name=name, as_text=as_text)
print("Converting done ...")
save_path = "test"
name = "test.pb"
as_text = False
target_type = 'fp16'
convert_graph_to_fp16(model_path, save_path, name, as_text=as_text, target_type=target_type, input_name=input_name, output_names=output_names)
# test loading
# ISSUE: loading detection model is extremely slow while loading classification model is normal
sess = load_graph(save_path+"/"+name)
@CasiaFan
Copy link
Author

CasiaFan commented Mar 5, 2020

My fault. Thanks for fixing this issue! @sayradley

@longside
Copy link

Hi,fan,do you have any idea about add a new node into pb?

@CasiaFan
Copy link
Author

@longside I think you may need to the network structure file. This repo may help you

@longside
Copy link

longside commented Jun 1, 2020

thanks, fan. Actually, I didnot supposed to add node while frozen graph.

@longside
Copy link

longside commented Jun 1, 2020

I find it is work to use your method to add a new node in pb, LOL.
like this:

def add_cast_node(node, src_graph, graph_def, target_type='fp16', node_name='import/cast'):
    new_node = graph_def.node.add()
    new_node.op = "Cast"
    new_node.name = node_name
    new_node.input.append(node.input[0])
    new_node.attr["SrcT"].type = types_pb2.DT_INT32
    new_node.attr["DstT"].type = types_pb2.DT_INT64
    new_node.attr["Truncate"].b = False

    return new_node

@glennford49
Copy link

glennford49 commented Jun 7, 2020

ValueError: Input 0 of node InceptionResnetV1/Conv2d_1a_3x3/BatchNorm/cond/FusedBatchNorm was passed float from InceptionResnetV1/Conv2d_1a_3x3/BatchNorm/cond/FusedBatchNorm/Switch:1 incompatible with expected half.

im trying to convert pretrained model in facenet with input_names='input and output_names='embeddings' ,using tensorflow 1.14.0

@CasiaFan
Copy link
Author

CasiaFan commented Jun 9, 2020

@glennford49 As it said, try to keep the attributes of the Switch input node in the FusedBatchNorm, rather than being converted to ft16.

@glennford49
Copy link

i have copied entirely your code but stil gives me same error, i will try to downgrade tf to 1.13.1

@CasiaFan
Copy link
Author

@glennford49 I'm not sure about this error. Could you provide your model?

@glennford49
Copy link

I just downloaded the model 20180402-114759.pb trained in vggface2 in david sanberg github, i have used this model without converting to fp16 it works fine.. when i try to convert this using your code replacing the input name with input and ouput name with embeddings, but i get errors

@xiexie123
Copy link

ValueError: Input 0 of node InceptionResnetV1/Conv2d_1a_3x3/BatchNorm/cond/FusedBatchNorm was passed float from InceptionResnetV1/Conv2d_1a_3x3/BatchNorm/cond/FusedBatchNorm/Switch:1 incompatible with expected half.

im trying to convert pretrained model in facenet with input_names='input and output_names='embeddings' ,using tensorflow 1.14.0

same issue here. any idea to solve it ?

@154912369
Copy link

some bugs:
the fp16 is not more than tf.float16.max and makes some value be inf.
for the attr "SrcT","T","Tparams","DstT", this code doesn't make any change.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment