Last active
May 13, 2020 01:43
-
-
Save Namburger/2f8523d14d738edff75e115835120738 to your computer and use it in GitHub Desktop.
An example to run the edgetpu's face detection model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os, argparse, cv2, sys, time, numpy | |
from tflite_runtime.interpreter import Interpreter | |
from tflite_runtime.interpreter import load_delegate | |
''' | |
Requirements: | |
1) Install the tflite_runtime package from here: | |
https://www.tensorflow.org/lite/guide/python | |
2) Camera to take inputs | |
3) Install libedgetpu | |
https://github.com/google-coral/edgetpu/tree/master/libedgetpu/direct | |
Download models: | |
$ wget https://github.com/google-coral/edgetpu/raw/master/test_data/mobilenet_ssd_v2_face_quant_postprocess_edgetpu.tflite | |
Run: | |
$ python3 edgetpu_face_detector.py --model mobilenet_ssd_v2_face_quant_postprocess_edgetpu.tflite --edgetpu True | |
''' | |
def get_cmd(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--model', help='Path to tflite model.', required=True) | |
parser.add_argument( | |
'--threshold', help='Minimum confidence threshold.', default=0.5) | |
parser.add_argument('--source', help='Video source.', default=0) | |
parser.add_argument('--edgetpu', help='With EdgeTpu', default=False) | |
return parser.parse_args() | |
def main(): | |
args = get_cmd() | |
if args.edgetpu: | |
interpreter = Interpreter(args.model, experimental_delegates=[ | |
load_delegate('libedgetpu.so.1.0')]) | |
else: | |
interpreter = Interpreter(args.model) | |
interpreter.allocate_tensors() | |
input_details = interpreter.get_input_details() | |
output_details = interpreter.get_output_details() | |
width = input_details[0]['shape'][2] | |
height = input_details[0]['shape'][1] | |
cap = cv2.VideoCapture(args.source) | |
image_width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) | |
image_height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) | |
while(True): | |
ret, frame = cap.read() | |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
frame_resized = cv2.resize(frame_rgb, (width, height)) | |
input_data = numpy.expand_dims(frame_resized, axis=0) | |
interpreter.set_tensor(input_details[0]['index'], input_data) | |
interpreter.invoke() | |
boxes = interpreter.get_tensor(output_details[0]['index'])[0] | |
classes = interpreter.get_tensor(output_details[1]['index'])[0] | |
scores = interpreter.get_tensor(output_details[2]['index'])[0] | |
for i in range(len(scores)): | |
if ((scores[i] > args.threshold) and (scores[i] <= 1.0)): | |
ymin = int(max(1, (boxes[i][0] * image_height))) | |
xmin = int(max(1, (boxes[i][1] * image_width))) | |
ymax = int(min(image_height, (boxes[i][2] * image_height))) | |
xmax = int(min(image_width, (boxes[i][3] * image_width))) | |
cv2.rectangle(frame, (xmin, ymin), (xmax, ymax), (10, 255, 0), 4) | |
object_name = 'face' | |
label = '%s: %d%%' % (object_name, int(scores[i]*100)) | |
labelSize, baseLine = cv2.getTextSize( | |
label, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2) | |
label_ymin = max(ymin, labelSize[1] + 10) | |
cv2.putText(frame, label, (xmin, label_ymin-7), | |
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2) | |
cv2.imshow('Object detector', frame) | |
if cv2.waitKey(1) == ord('q'): | |
break | |
cap.release() | |
cv2.destroyAllWindows() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment