Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save tarun-ssharma/8cd4c3cd57e53a3e3be7a52fa3e762fc to your computer and use it in GitHub Desktop.
Save tarun-ssharma/8cd4c3cd57e53a3e3be7a52fa3e762fc to your computer and use it in GitHub Desktop.
from object_detection.utils import config_util
from object_detection.builders import model_builder
#Where to save the SavedModel
output_directory = './try2_trained_model'
#Recover our saved model with the latest checkpoint:
pipeline_config = 'pipeline.config'
#Put the last ckpt from training in here, don't use long pathnames:
model_dir = './trained_checkpoints/ckpt-31'
#Build model from config file
configs = config_util.get_configs_from_pipeline_file(pipeline_config)
model_config = configs['model']
detection_model = model_builder.build(
model_config=model_config, is_training=False)
# Restore latest checkpoint
ckpt = tf.train.Checkpoint(
model=detection_model)
ckpt.restore(model_dir)
#Define function which will perform inferencing
#Function perform detection of the object on image in tensor format:
import tensorflow as tf
def get_model_detection_function(model):
"""Get a tf.function for detection."""
@tf.function
def detect_fn(image):
"""Detect objects in image."""
image, shapes = model.preprocess(image)
prediction_dict = model.predict(image, shapes)
#detections = model.postprocess(prediction_dict, shapes)
return model.postprocess(prediction_dict, shapes)
#detections, prediction_dict, tf.reshape(shapes, [-1])
return detect_fn
detect_fn = get_model_detection_function(detection_model)
tf.saved_model.save(
detection_model , output_directory ,
signatures={
'detect' : detect_fn.get_concrete_function(
tf.TensorSpec([1, 300, 300, 3], dtype=tf.float32, name='detect'))
}
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment