Skip to content

Instantly share code, notes, and snippets.

@rubenhorn
Created April 12, 2020 22:10
Show Gist options
  • Save rubenhorn/9856f64c9353334782969de94937d963 to your computer and use it in GitHub Desktop.
Save rubenhorn/9856f64c9353334782969de94937d963 to your computer and use it in GitHub Desktop.
Fast (30fps on CPU) object localization using pretrained model from TensorFlow Hub
#!/usr/bin/env python3
import cv2
import tensorflow as tf
import tensorflow_hub as hub
import time
module_handle = 'https://tfhub.dev/google/object_detection/mobile_object_localizer_v1/1'
print('loading object detection model...')
model = hub.load(module_handle).signatures['default']
cap = cv2.VideoCapture(0)
while True:
_, frame = cap.read()
img = cv2.resize(frame, (192, 192))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_tensor = tf.convert_to_tensor(img, dtype=tf.float32)
img_tensor = tf.expand_dims(img_tensor, 0)
start = time.time()
output = model(img_tensor)
stop = time.time()
inference_time = stop - start
print('Inference time: {:.4f} ({} fps)'.format(inference_time, int(1 / inference_time)))
# print(output.keys())
scores = output['detection_scores'][0]
boxes = output['detection_boxes'][0]
for i, score in enumerate(scores):
if score < 0.5:
continue
box = boxes[i]
height = frame.shape[0]
width = frame.shape[1]
cv2.rectangle(frame, (box[1] * width, box[0] * height), ((box[3]) * width, (box[2]) * height), (0, 255, 0), 2)
cv2.imshow('webcam', frame)
if cv2.waitKey(25) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment