Skip to content

Instantly share code, notes, and snippets.

@KostaMalsev
Last active January 11, 2021 13:43
Show Gist options
  • Save KostaMalsev/90fd36dede052f128c8046ce7f83b864 to your computer and use it in GitHub Desktop.
Save KostaMalsev/90fd36dede052f128c8046ce7f83b864 to your computer and use it in GitHub Desktop.
#Recover our saved model with the latest checkpoint:
pipeline_config = pipeline_file
#Put the last ckpt from training in here, don't use long pathnames:
model_dir = '/content/training/ckpt-2'
configs = config_util.get_configs_from_pipeline_file(pipeline_config)
model_config = configs['model']
detection_model = model_builder.build(
model_config=model_config, is_training=False)
# Restore last checkpoint
ckpt = tf.compat.v2.train.Checkpoint(
model=detection_model)
#ckpt.restore(os.path.join(model_dir))
ckpt.restore(model_dir)
#Function perform detection of the object on image in tensor format:
def get_model_detection_function(model):
"""Get a tf.function for detection."""
@tf.function
def detect_fn(image):
"""Detect objects in image."""
image, shapes = model.preprocess(image)
prediction_dict = model.predict(image, shapes)
detections = model.postprocess(prediction_dict, shapes)
return detections, prediction_dict, tf.reshape(shapes, [-1])
return detect_fn
#Define function which performs detection:
detect_fn = get_model_detection_function(detection_model)
#map labels for inference decoding
label_map_path = configs['eval_input_config'].label_map_path
label_map = label_map_util.load_labelmap(label_map_path)
categories = label_map_util.convert_label_map_to_categories(
label_map,
max_num_classes=label_map_util.get_max_label_map_index(label_map),
use_display_name=True)
category_index = label_map_util.create_category_index(categories)
label_map_dict = label_map_util.get_label_map_dict(label_map, use_display_name=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment