Last active
May 26, 2020 14:12
-
-
Save Namburger/f44a938886ad4a0325ca2f30263fcee0 to your computer and use it in GitHub Desktop.
Example code for post training quantization with tensorflow from_frozen_graph API (deprecated in tensorflow2.0).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# More infro here on Post Training Quantization here: | |
# https://www.tensorflow.org/lite/performance/post_training_quantization | |
# from_frozen_graph api is not n tf2.0 but can still be use with tf.compat.v1.lite, more on this api: | |
# https://www.tensorflow.org/api_docs/python/tf/compat/v1/lite/TFLiteConverter#from_frozen_graph | |
# This is an example for converting a frozen graph model to a fully quantized tflite model | |
# The model used here is http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_192.tgz | |
# Note that with post training quantization, sometimes it is not guarantee that the model will be fully quantized. | |
import sys, os, glob | |
import tensorflow as tf | |
import pathlib | |
import numpy as np | |
if len(sys.argv) != 2: | |
print('Usage: <' + sys.argv[0] + '> <frozen_graph_file>') | |
exit() | |
tf.compat.v1.enable_eager_execution() | |
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.DEBUG) | |
def fake_representative_data_gen(): | |
for _ in range(100): | |
fake_image = np.random.random((1,224,224,3)).astype(np.float32) | |
yield [fake_image] | |
frozen_graph = sys.argv[1] | |
input_array = ['input'] | |
output_array = ['MobilenetV1/Predictions/Reshape_1'] | |
converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(frozen_graph, input_array, output_array) | |
converter.optimizations = [tf.lite.Optimize.DEFAULT] | |
converter.representative_dataset = fake_representative_data_gen | |
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] | |
converter.inference_input_type = tf.uint8 | |
converter.inference_output_type = tf.uint8 | |
tflite_model = converter.convert() | |
quant_dir = pathlib.Path(os.getcwd(), 'output') | |
quant_dir.mkdir(exist_ok=True, parents=True) | |
tflite_model_file = quant_dir/'mobilenet_v1_0.25_224_quant.tflite' | |
tflite_model_file.write_bytes(tflite_model) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment