Skip to content

Instantly share code, notes, and snippets.

@bmabir17
Last active February 6, 2022 20:54
Show Gist options
  • Save bmabir17/754a6e0450ec4fd5e25e462af949cde6 to your computer and use it in GitHub Desktop.
Save bmabir17/754a6e0450ec4fd5e25e462af949cde6 to your computer and use it in GitHub Desktop.
Converts the mask-rcnn keras model https://github.com/matterport/Mask_RCNN/releases/tag/v2.0 to tflite
import tensorflow as tf
import numpy as np
import mrcnn.model as modellib # https://github.com/matterport/Mask_RCNN/
from mrcnn.config import Config
import keras.backend as keras
PATH_TO_SAVE_FROZEN_PB ="./"
FROZEN_NAME ="saved_model.pb"
def load_model(Weights):
global model, graph
class InferenceConfig(Config):
NAME = "coco"
NUM_CLASSES = 1 + 80
IMAGE_META_SIZE = 1 + 3 + 3 + 4 + 1 + NUM_CLASSES
DETECTION_MAX_INSTANCES = 100
DETECTION_MIN_CONFIDENCE = 0.7
DETECTION_NMS_THRESHOLD = 0.3
GPU_COUNT = 1
IMAGES_PER_GPU = 1
config = InferenceConfig()
Weights = Weights
Logs = "./logs"
model = modellib.MaskRCNN(mode="inference", config=config,
model_dir=Logs)
model.load_weights(Weights, by_name=True)
graph = tf.get_default_graph()
# Reference https://github.com/bendangnuksung/mrcnn_serving_ready/blob/master/main.py
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
graph = session.graph
with graph.as_default():
freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
output_names = output_names or []
input_graph_def = graph.as_graph_def()
if clear_devices:
for node in input_graph_def.node:
node.device = ""
frozen_graph = tf.graph_util.convert_variables_to_constants(
session, input_graph_def, output_names, freeze_var_names)
return frozen_graph
def freeze_model(model, name):
frozen_graph = freeze_session(
sess,
output_names=[out.op.name for out in model.outputs][:4])
directory = PATH_TO_SAVE_FROZEN_PB
tf.train.write_graph(frozen_graph, directory, name , as_text=False)
def keras_to_tflite(in_weight_file, out_weight_file):
sess = tf.Session()
keras.set_session(sess)
load_model(in_weight_file)
global model
freeze_model(model.keras_model, FROZEN_NAME)
# https://github.com/matterport/Mask_RCNN/issues/2020#issuecomment-596449757
input_arrays = ["input_image"]
output_arrays = ["mrcnn_class/Softmax","mrcnn_bbox/Reshape"]
converter = tf.contrib.lite.TocoConverter.from_frozen_graph(
PATH_TO_SAVE_FROZEN_PB+"/"+FROZEN_NAME,
input_arrays, output_arrays,
input_shapes={"input_image":[1,256,256,3]}
)
converter.target_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,tf.lite.OpsSet.SELECT_TF_OPS]
converter.post_training_quantize = True
tflite_model = converter.convert()
open(out_weight_file, "wb").write(tflite_model)
print("*"*80)
print("Finished converting keras model to Frozen tflite")
print('PATH: ', out_weight_file)
print("*" * 80)
keras_to_tflite("./mask_rcnn_coco.h5","./mask_rcnn_coco.tflite")
@lbininhbl
Copy link

Hi @bmabir17
Thanks your code that I can convert the tflite successfully. But when I ran the tflite on iOS, I got bellow error
TensorFlow Lite Error: Input tensor 314 lacks data

I look through the code and notice this code in keras_to_tflite function:
input_arrays = ["input_image"] output_arrays = ["mrcnn_class/Softmax","mrcnn_bbox/Reshape"]

But the real keras model needs more inputs when predict function was executed. I don't know if that has anything to do with it. What do think?

@kaamlaS
Copy link

kaamlaS commented Jan 27, 2022

hi @bmabir17
I cant seem to move ahead from this point
ValueError: Invalid tensors 'input_image' were found.
i checked the model using model.keras_model.summary() and it shows that the first layer has the name input_image. What do i do?

@Tubhalooter
Copy link

@kaamlaS i got past this , check this out matterport/Mask_RCNN#2020

let me know if you get it working
also i cant get the output names to work so i left it out but that is stopping the conversion , how did you pass the output names to the parameter

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment