Last active
February 6, 2022 20:54
-
-
Save bmabir17/754a6e0450ec4fd5e25e462af949cde6 to your computer and use it in GitHub Desktop.
Converts the mask-rcnn keras model https://github.com/matterport/Mask_RCNN/releases/tag/v2.0 to tflite
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import numpy as np | |
import mrcnn.model as modellib # https://github.com/matterport/Mask_RCNN/ | |
from mrcnn.config import Config | |
import keras.backend as keras | |
PATH_TO_SAVE_FROZEN_PB ="./" | |
FROZEN_NAME ="saved_model.pb" | |
def load_model(Weights): | |
global model, graph | |
class InferenceConfig(Config): | |
NAME = "coco" | |
NUM_CLASSES = 1 + 80 | |
IMAGE_META_SIZE = 1 + 3 + 3 + 4 + 1 + NUM_CLASSES | |
DETECTION_MAX_INSTANCES = 100 | |
DETECTION_MIN_CONFIDENCE = 0.7 | |
DETECTION_NMS_THRESHOLD = 0.3 | |
GPU_COUNT = 1 | |
IMAGES_PER_GPU = 1 | |
config = InferenceConfig() | |
Weights = Weights | |
Logs = "./logs" | |
model = modellib.MaskRCNN(mode="inference", config=config, | |
model_dir=Logs) | |
model.load_weights(Weights, by_name=True) | |
graph = tf.get_default_graph() | |
# Reference https://github.com/bendangnuksung/mrcnn_serving_ready/blob/master/main.py | |
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True): | |
graph = session.graph | |
with graph.as_default(): | |
freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or [])) | |
output_names = output_names or [] | |
input_graph_def = graph.as_graph_def() | |
if clear_devices: | |
for node in input_graph_def.node: | |
node.device = "" | |
frozen_graph = tf.graph_util.convert_variables_to_constants( | |
session, input_graph_def, output_names, freeze_var_names) | |
return frozen_graph | |
def freeze_model(model, name): | |
frozen_graph = freeze_session( | |
sess, | |
output_names=[out.op.name for out in model.outputs][:4]) | |
directory = PATH_TO_SAVE_FROZEN_PB | |
tf.train.write_graph(frozen_graph, directory, name , as_text=False) | |
def keras_to_tflite(in_weight_file, out_weight_file): | |
sess = tf.Session() | |
keras.set_session(sess) | |
load_model(in_weight_file) | |
global model | |
freeze_model(model.keras_model, FROZEN_NAME) | |
# https://github.com/matterport/Mask_RCNN/issues/2020#issuecomment-596449757 | |
input_arrays = ["input_image"] | |
output_arrays = ["mrcnn_class/Softmax","mrcnn_bbox/Reshape"] | |
converter = tf.contrib.lite.TocoConverter.from_frozen_graph( | |
PATH_TO_SAVE_FROZEN_PB+"/"+FROZEN_NAME, | |
input_arrays, output_arrays, | |
input_shapes={"input_image":[1,256,256,3]} | |
) | |
converter.target_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,tf.lite.OpsSet.SELECT_TF_OPS] | |
converter.post_training_quantize = True | |
tflite_model = converter.convert() | |
open(out_weight_file, "wb").write(tflite_model) | |
print("*"*80) | |
print("Finished converting keras model to Frozen tflite") | |
print('PATH: ', out_weight_file) | |
print("*" * 80) | |
keras_to_tflite("./mask_rcnn_coco.h5","./mask_rcnn_coco.tflite") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@kaamlaS i got past this , check this out matterport/Mask_RCNN#2020
let me know if you get it working
also i cant get the output names to work so i left it out but that is stopping the conversion , how did you pass the output names to the parameter