Created
November 10, 2020 14:34
-
-
Save tonyreina/50c5e57053612142395d64ef948b31fa to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# !/usr/bin/env python | |
import tensorflow as tf | |
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2_as_graph | |
from tensorflow.lite.python.util import run_graph_optimizations, get_grappler_config | |
from pathlib import Path | |
import argparse | |
def frozen_keras_graph(model): | |
tf_model = tf.function(model).get_concrete_function(tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype)) | |
frozen_func, graph_def = convert_variables_to_constants_v2_as_graph(tf_model) | |
input_tensors = [ | |
tensor for tensor in frozen_func.inputs | |
if tensor.dtype != tf.resource | |
] | |
output_tensors = frozen_func.outputs | |
graph_def = run_graph_optimizations( | |
graph_def, | |
input_tensors, | |
output_tensors, | |
config=get_grappler_config(["constfold", "function"]), | |
graph=frozen_func.graph) | |
return graph_def | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--input_model', '-m', required=True, type=str, help='Path to Keras model.') | |
return parser | |
def export_keras_to_tf(input_model, output_model): | |
print('Loading Keras model: ', input_model) | |
model = tf.keras.models.load_model(input_model, compile=True) | |
model.summary() | |
graph_def = frozen_keras_graph(model) | |
tf.io.write_graph(graph_def, '.', output_model, as_text=False) | |
def main(): | |
argv = get_args().parse_args() | |
input_model = argv.input_model | |
output_model = str(Path(input_model).name) + '.pb' | |
export_keras_to_tf(input_model, output_model) | |
print('Saved as TF frozen model to: ', output_model) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment