Last active
April 20, 2020 18:57
-
-
Save jens25/31e2faf9d3c16ee17f879872d28a1124 to your computer and use it in GitHub Desktop.
efficientdet_camera_inference
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
import cv2 | |
import time | |
import requests | |
import argparse | |
import tarfile | |
import numpy as np | |
import tensorflow.compat.v1 as tf | |
sys.path.append('.') | |
import inference | |
# min_score_thresh = 0.2 | |
# max_boxes_to_draw = 100 | |
# line_thickness = 4 | |
def maybe_download(model_name): | |
if os.path.exists(model_name): | |
return | |
url = "https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco/{}.tar.gz".format(model_name) | |
r = requests.get(url, allow_redirects=True) | |
with open("{}.tar.gz".format(model_name), 'wb') as f: | |
f.write(r.content) | |
tf = tarfile.open("{}.tar.gz".format(model_name)) | |
tf.extractall() | |
print('Use model {}'.format(model_name)) | |
def inference(model_name): | |
ckpt_path = os.path.join(os.getcwd(), model_name) | |
cap = cv2.VideoCapture(1) | |
ret, frame = cap.read() | |
tf.reset_default_graph() | |
image_size = max(frame.shape) // 128 * 128 | |
driver = inference.ServingDriver(model_name, ckpt_path, image_size=image_size) | |
while True: | |
start = time.time() | |
ret, frame = cap.read() | |
frame = np.rollaxis(frame, 0, 1) | |
pred = driver.serve_images([frame]) | |
end = time.time() | |
print("Fps: %f" % (1.0 / (end - start))) | |
frame = driver.visualize(frame, pred[0]) | |
cv2.imshow("Image", frame) | |
if cv2.waitKey(1) & 0xFF == ord('q'): | |
break | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description='Camera inference for efficient det networks') | |
parser.add_argument('--model_type', help='Model to run: 0-6', type=int) | |
args = parser.parse_args() | |
model_name = 'efficientdet-d{}'.format(args.model_type) | |
maybe_download(model_name) | |
inference(model_name) |
Hi @jens25 , your code use GPU?
Hey @juanmanuelrq,
Unfortunately I don't have a GPU to test it, but it should use the GPU.
On my Laptop (Intel i5) with the efficientdet-d0 Network I get around 2.7 fps.
Hi @jens25, thank you for your answer,
2.7 fps with what size of image?
best regards
Sry, I forgot 640x640
If I set the image size to None in line 41, I get 3-4 fps with the d0 network.
Thanks @jens25
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi, thank you for sharing your code, what FPS do you get with this code?
I got 0.8 FPS
best regards