Created
July 25, 2021 06:43
-
-
Save shadowmint/6d5b7988cf51904ee2c265be6ebfc9f9 to your computer and use it in GitHub Desktop.
dephai + flask
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from os import path | |
import cv2 | |
import depthai as dai | |
import numpy as np | |
import mediapipe_utils as mpu | |
from FPS import FPS, now | |
# def to_planar(arr: np.ndarray, shape: tuple) -> list: | |
def to_planar(arr: np.ndarray, shape: tuple) -> np.ndarray: | |
resized = cv2.resize(arr, shape) | |
return resized.transpose(2, 0, 1) | |
class HandTracker: | |
def __init__(self, | |
on_halt=None, | |
on_step=None, | |
input_file=None, | |
pd_path="models/palm_detection.blob", | |
pd_score_thresh=0.5, | |
pd_nms_thresh=0.3, | |
use_lm=True, | |
lm_path="models/hand_landmark.blob", | |
lm_score_threshold=0.5, | |
use_gesture=False): | |
here = path.dirname(__file__) | |
pd_path = path.abspath(path.join(here, pd_path)) | |
lm_path = path.abspath(path.join(here, lm_path)) | |
self.camera = input_file is None | |
self.pd_path = pd_path | |
self.pd_score_thresh = pd_score_thresh | |
self.pd_nms_thresh = pd_nms_thresh | |
self.use_lm = use_lm | |
self.lm_path = lm_path | |
self.lm_score_threshold = lm_score_threshold | |
self.use_gesture = use_gesture | |
# Callbacks | |
self.on_halt = on_halt | |
self.on_step = on_step | |
self.dataset = {"left": None, "right": None} | |
# Create SSD anchors | |
# https://github.com/google/mediapipe/blob/master/mediapipe/modules/palm_detection/palm_detection_cpu.pbtxt | |
anchor_options = mpu.SSDAnchorOptions(num_layers=4, | |
min_scale=0.1484375, | |
max_scale=0.75, | |
input_size_height=128, | |
input_size_width=128, | |
anchor_offset_x=0.5, | |
anchor_offset_y=0.5, | |
strides=[8, 16, 16, 16], | |
aspect_ratios=[1.0], | |
reduce_boxes_in_lowest_layer=False, | |
interpolated_scale_aspect_ratio=1.0, | |
fixed_anchor_size=True) | |
self.anchors = mpu.generate_anchors(anchor_options) | |
self.nb_anchors = self.anchors.shape[0] | |
print(f"{self.nb_anchors} anchors have been created") | |
# Rendering flags | |
if self.use_lm: | |
self.show_pd_box = False | |
self.show_pd_kps = False | |
self.show_rot_rect = False | |
self.show_handedness = False | |
self.show_landmarks = True | |
self.show_scores = False | |
self.show_gesture = self.use_gesture | |
else: | |
self.show_pd_box = True | |
self.show_pd_kps = False | |
self.show_rot_rect = False | |
self.show_scores = False | |
def create_pipeline(self): | |
print("Creating pipeline...") | |
# Start defining a pipeline | |
pipeline = dai.Pipeline() | |
pipeline.setOpenVINOVersion(version=dai.OpenVINO.Version.VERSION_2021_2) | |
self.pd_input_length = 128 | |
# Depth | |
# ============================================================================================================== | |
# Closer-in minimum depth, disparity range is doubled (from 95 to 190): | |
extended_disparity = False | |
# Better accuracy for longer distance, fractional disparity 32-levels: | |
subpixel = False | |
# Better handling for occlusions: | |
lr_check = False | |
# Define a source - two mono (grayscale) cameras | |
left = pipeline.createMonoCamera() | |
left.setResolution(dai.MonoCameraProperties.SensorResolution.THE_400_P) | |
left.setBoardSocket(dai.CameraBoardSocket.LEFT) | |
right = pipeline.createMonoCamera() | |
right.setResolution(dai.MonoCameraProperties.SensorResolution.THE_400_P) | |
right.setBoardSocket(dai.CameraBoardSocket.RIGHT) | |
# Create a node that will produce the depth map (using disparity output as it's easier to visualize depth this way) | |
depth = pipeline.createStereoDepth() | |
depth.setConfidenceThreshold(200) | |
# Options: MEDIAN_OFF, KERNEL_3x3, KERNEL_5x5, KERNEL_7x7 (default) | |
median = dai.StereoDepthProperties.MedianFilter.KERNEL_7x7 # For depth filtering | |
depth.setMedianFilter(median) | |
depth.setLeftRightCheck(lr_check) | |
# Normal disparity values range from 0..95, will be used for normalization | |
max_disparity = 95 | |
if extended_disparity: max_disparity *= 2 # Double the range | |
depth.setExtendedDisparity(extended_disparity) | |
if subpixel: max_disparity *= 32 # 5 fractional bits, x32 | |
depth.setSubpixel(subpixel) | |
# When we get disparity to the host, we will multiply all values with the multiplier | |
# for better visualization | |
self.disparity_multiplier = 255 / max_disparity | |
left.out.link(depth.left) | |
right.out.link(depth.right) | |
# Create output | |
xout = pipeline.createXLinkOut() | |
xout.setStreamName("disparity") | |
depth.disparity.link(xout.input) | |
# Depth END | |
# ============================================================================================================== | |
# ColorCamera | |
print("Creating Color Camera...") | |
cam = pipeline.createColorCamera() | |
cam.setPreviewSize(self.pd_input_length, self.pd_input_length) | |
cam.setResolution(dai.ColorCameraProperties.SensorResolution.THE_1080_P) | |
# Crop video to square shape (palm detection takes square image as input) | |
self.video_size = min(cam.getVideoSize()) | |
cam.setVideoSize(self.video_size, self.video_size) | |
cam.setFps(30) | |
cam.setInterleaved(False) | |
cam.setBoardSocket(dai.CameraBoardSocket.RGB) | |
cam_out = pipeline.createXLinkOut() | |
cam_out.setStreamName("cam_out") | |
# Link video output to host for higher resolution | |
cam.video.link(cam_out.input) | |
# Define palm detection model | |
print("Creating Palm Detection Neural Network...") | |
pd_nn = pipeline.createNeuralNetwork() | |
pd_nn.setBlobPath(self.pd_path) | |
# Increase threads for detection | |
pd_nn.setNumInferenceThreads(2) | |
# Specify that network takes latest arriving frame in non-blocking manner | |
# Palm detection input | |
if self.camera: | |
pd_nn.input.setQueueSize(1) | |
pd_nn.input.setBlocking(False) | |
cam.preview.link(pd_nn.input) | |
else: | |
pd_in = pipeline.createXLinkIn() | |
pd_in.setStreamName("pd_in") | |
pd_in.out.link(pd_nn.input) | |
# Palm detection output | |
pd_out = pipeline.createXLinkOut() | |
pd_out.setStreamName("pd_out") | |
pd_nn.out.link(pd_out.input) | |
# Define hand landmark model | |
if self.use_lm: | |
print("Creating Hand Landmark Neural Network...") | |
lm_nn = pipeline.createNeuralNetwork() | |
lm_nn.setBlobPath(self.lm_path) | |
lm_nn.setNumInferenceThreads(2) | |
# Hand landmark input | |
self.lm_input_length = 224 | |
lm_in = pipeline.createXLinkIn() | |
lm_in.setStreamName("lm_in") | |
lm_in.out.link(lm_nn.input) | |
# Hand landmark output | |
lm_out = pipeline.createXLinkOut() | |
lm_out.setStreamName("lm_out") | |
lm_nn.out.link(lm_out.input) | |
print("Pipeline created.") | |
return pipeline | |
def pd_postprocess(self, inference): | |
scores = np.array(inference.getLayerFp16("classificators"), dtype=np.float16) # 896 | |
bboxes = np.array(inference.getLayerFp16("regressors"), dtype=np.float16).reshape((self.nb_anchors, 18)) # 896x18 | |
# Decode bboxes | |
self.regions = mpu.decode_bboxes(self.pd_score_thresh, scores, bboxes, self.anchors) | |
# Non maximum suppression | |
self.regions = mpu.non_max_suppression(self.regions, self.pd_nms_thresh) | |
if self.use_lm: | |
mpu.detections_to_rect(self.regions) | |
mpu.rect_transformation(self.regions, self.video_size, self.video_size) | |
def pd_render(self, frame): | |
for r in self.regions: | |
if self.show_pd_box: | |
box = (np.array(r.pd_box) * self.video_size).astype(int) | |
cv2.rectangle(frame, (box[0], box[1]), (box[0] + box[2], box[1] + box[3]), (0, 255, 0), 2) | |
if self.show_pd_kps: | |
for i, kp in enumerate(r.pd_kps): | |
x = int(kp[0] * self.video_size) | |
y = int(kp[1] * self.video_size) | |
cv2.circle(frame, (x, y), 6, (0, 0, 255), -1) | |
cv2.putText(frame, str(i), (x, y + 12), cv2.FONT_HERSHEY_PLAIN, 1.5, (0, 255, 0), 2) | |
if self.show_scores: | |
cv2.putText(frame, f"Palm score: {r.pd_score:.2f}", | |
(int(r.pd_box[0] * self.video_size + 10), int((r.pd_box[1] + r.pd_box[3]) * self.video_size + 60)), | |
cv2.FONT_HERSHEY_PLAIN, 2, (255, 255, 0), 2) | |
def lm_postprocess(self, region, inference, dataset): | |
region.lm_score = inference.getLayerFp16("Identity_1")[0] | |
region.handedness = inference.getLayerFp16("Identity_2")[0] | |
lm_raw = np.array(inference.getLayerFp16("Identity_dense/BiasAdd/Add")) | |
lm = [] | |
for i in range(int(len(lm_raw) / 3)): | |
# x,y,z -> x/w,y/h,z/w (here h=w) | |
lm.append(lm_raw[3 * i:3 * (i + 1)] / self.lm_input_length) | |
region.landmarks = lm | |
# convert landmarks into xyz | |
# print(region.landmarks[4]) | |
src = np.array([(0, 0), (1, 0), (1, 1)], dtype=np.float32) | |
dst = np.array([(x, y) for x, y in region.rect_points[1:]], dtype=np.float32) # region.rect_points[0] is left bottom point ! | |
mat = cv2.getAffineTransform(src, dst) | |
raw_xyz = np.expand_dims(np.array([(l[0], l[1], l[2]) for l in region.landmarks]), axis=0) | |
raw_xy = raw_xyz[:, :, 0:2] | |
lm_xy = np.squeeze(cv2.transform(raw_xy, mat)) | |
lm_xyz = np.hstack((lm_xy, raw_xyz[:, :, 2].reshape(lm_xy.shape[0], 1))) | |
lm_xyz = lm_xyz * np.array([1 / self.video_size, 1 / self.video_size, 1]) | |
data = { | |
"points": lm_xyz.reshape(lm_xyz.shape[0] * lm_xyz.shape[1]).tolist(), | |
"score": region.lm_score, | |
} | |
if region.handedness > 0.85: | |
if dataset['right'] is None or region.lm_score > dataset['right']['score']: | |
dataset['right'] = data | |
elif region.handedness < 0.25: | |
if dataset['left'] is None or region.lm_score > dataset['left']['score']: | |
dataset['left'] = data | |
def lm_render(self, frame, region): | |
if region.lm_score > self.lm_score_threshold: | |
# self.show_landmarks: | |
src = np.array([(0, 0), (1, 0), (1, 1)], dtype=np.float32) | |
dst = np.array([(x, y) for x, y in region.rect_points[1:]], dtype=np.float32) # region.rect_points[0] is left bottom point ! | |
mat = cv2.getAffineTransform(src, dst) | |
lm_xy = np.expand_dims(np.array([(l[0], l[1]) for l in region.landmarks]), axis=0) | |
lm_xy = np.squeeze(cv2.transform(lm_xy, mat)).astype(np.int) | |
list_connections = [ | |
[0, 1, 2, 3, 4], | |
[0, 5, 6, 7, 8], | |
[0, 9, 10, 11, 12], | |
[0, 13, 14, 15, 16], | |
[0, 17, 18, 19, 20]] | |
lines = [np.array([lm_xy[point] for point in line]) for line in list_connections] | |
cv2.polylines(frame, lines, False, (255, 0, 0), 2, cv2.LINE_AA) | |
def run(self): | |
device = dai.Device(self.create_pipeline()) | |
# see: https://docs.luxonis.com/projects/api/en/latest/samples/09_mono_mobilenet/ | |
# you can manip the frames like this... | |
# Define data queues | |
q_video = device.getOutputQueue(name="cam_out", maxSize=4, blocking=False) | |
q_pd_out = device.getOutputQueue(name="pd_out", maxSize=1, blocking=False) | |
if self.use_lm: | |
q_lm_out = device.getOutputQueue(name="lm_out", maxSize=2, blocking=False) | |
q_lm_in = device.getInputQueue(name="lm_in") | |
depth_out = device.getOutputQueue(name="disparity", maxSize=4, blocking=False) | |
self.fps = FPS(mean_nb_frames=20) | |
seq_num = 0 | |
nb_pd_inferences = 0 | |
nb_lm_inferences = 0 | |
glob_pd_rtrip_time = 0 | |
glob_lm_rtrip_time = 0 | |
while True: | |
self.fps.update() | |
in_video = q_video.get() | |
video_frame = in_video.getCvFrame() | |
has_depth = False | |
inDepth = depth_out.tryGet() # blocking call, will wait until a new data has arrived | |
if inDepth is not None: | |
has_depth = True | |
if has_depth: | |
depthFrame = inDepth.getFrame() | |
depthFrame = (depthFrame * self.disparity_multiplier).astype(np.uint8) | |
annotated_frame = video_frame.copy() | |
# Get palm detection | |
inference = q_pd_out.get() | |
if not self.camera: glob_pd_rtrip_time += now() - pd_rtrip_time | |
self.pd_postprocess(inference) | |
self.pd_render(annotated_frame) | |
nb_pd_inferences += 1 | |
# Hand landmarks | |
if self.use_lm: | |
for i, r in enumerate(self.regions): | |
img_hand = mpu.warp_rect_img(r.rect_points, video_frame, self.lm_input_length, self.lm_input_length) | |
nn_data = dai.NNData() | |
nn_data.setLayer("input_1", to_planar(img_hand, (self.lm_input_length, self.lm_input_length))) | |
q_lm_in.send(nn_data) | |
if i == 0: lm_rtrip_time = now() # We measure only for the first region | |
# Retrieve hand landmarks | |
self.dataset = {"left": None, "right": None} | |
for i, r in enumerate(self.regions): | |
inference = q_lm_out.get() | |
if i == 0: glob_lm_rtrip_time += now() - lm_rtrip_time | |
self.lm_postprocess(r, inference, self.dataset) | |
self.lm_render(annotated_frame, r) | |
nb_lm_inferences += 1 | |
self.fps.display(annotated_frame, orig=(50, 50), color=(240, 180, 100)) | |
# Make depth frame the right size | |
if has_depth: | |
dh_offset = 40 | |
dw_offset = 40 | |
dh = depthFrame.shape[0] | |
dw = depthFrame.shape[1] | |
effective_dw = int(dh - dh_offset) | |
effective_x = int(dw / 4 - dw_offset) | |
depthFrame = depthFrame[dh_offset:dh_offset + effective_dw, effective_x:effective_x + effective_dw] | |
depthFrame = cv2.resize(depthFrame, (dh, dh)) | |
# Sample depth | |
if self.dataset['left'] is not None: | |
self.dataset['left']['depth'] = 0 | |
for i in range(21): | |
x = int(self.dataset['left']['points'][i*2+0] * dh) | |
y = int(self.dataset['left']['points'][i*2+1] * dh) | |
x = 0 if x < 0 else (dh - 1 if x > dh - 1 else x) | |
y = 0 if y < 0 else (dh - 1 if y > dh - 1 else y) | |
self.dataset['left']['depth'] += float(depthFrame[x, y]) | |
self.dataset['left']['depth'] /= 21.0 | |
if self.dataset['right'] is not None: | |
self.dataset['right']['depth'] = 0 | |
for i in range(21): | |
x = int(self.dataset['right']['points'][i*2+0] * dh) | |
y = int(self.dataset['right']['points'][i*2+1] * dh) | |
x = 0 if x < 0 else (dh - 1 if x > dh - 1 else x) | |
y = 0 if y < 0 else (dh - 1 if y > dh - 1 else y) | |
self.dataset['right']['depth'] += float(depthFrame[x, y]) | |
self.dataset['right']['depth'] /= 21.0 | |
# Update listeners | |
self.on_step(self.dataset) | |
depthFrameRgb = cv2.applyColorMap(depthFrame, cv2.COLORMAP_JET) | |
annotated_frame = cv2.resize(annotated_frame, (dh, dh)) | |
# frame is ready to be shown | |
if has_depth: | |
cv2.imshow("depthFrameRgb", depthFrameRgb) | |
cv2.imshow("video", annotated_frame) | |
key = cv2.waitKey(1) | |
if key == ord('q') or key == 27: | |
if self.on_halt is not None: | |
self.on_halt() | |
break |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
from threading import Thread | |
from flask import Flask | |
from flask_sockets import Sockets | |
from gevent import pywsgi | |
from geventwebsocket.handler import WebSocketHandler | |
from hand_tracking.HandTracker import HandTracker | |
app = Flask(__name__) | |
sockets = Sockets(app) | |
class ApplicationState: | |
connected = [] | |
server = None | |
@classmethod | |
def register_server(cls, server): | |
cls.server = server | |
@classmethod | |
def halt_server(cls): | |
print("Halting server") | |
if cls.server is not None: | |
cls.server.close() | |
@classmethod | |
def publish(cls, data): | |
for target in cls.connected: | |
target.send(data) | |
@classmethod | |
def remove(cls, ws): | |
print("Client disconnected") | |
cls.connected.remove(ws) | |
@classmethod | |
def add(cls, ws): | |
print("Client connected") | |
cls.connected.append(ws) | |
@classmethod | |
def shutdown(cls): | |
for ws in cls.connected: | |
ws.close() | |
cls.connected = [] | |
@sockets.route('/data') | |
def echo_socket(ws): | |
ApplicationState.add(ws) | |
while not ws.closed: | |
ws.receive() | |
ApplicationState.remove(ws) | |
@app.route('/') | |
def hello(): | |
return 'Streaming data...' | |
def on_halt(): | |
ApplicationState.shutdown() | |
def on_step(data): | |
if data['left'] is None and data['right'] is None: | |
return | |
data = json.dumps(data) | |
ApplicationState.publish(data) | |
def service_thread(): | |
port = 5005 | |
print(f"Serving /data on 0.0.0.0:{port}") | |
server = pywsgi.WSGIServer(('0.0.0.0', 5005), app, handler_class=WebSocketHandler) | |
ApplicationState.register_server(server) | |
server.serve_forever() | |
def hand_tracking_service(): | |
thread = Thread(target=service_thread) | |
thread.start() | |
ht = HandTracker(use_gesture=True, on_halt=on_halt, on_step=on_step) | |
ht.run() | |
ApplicationState.halt_server() | |
thread.join() | |
if __name__ == "__main__": | |
hand_tracking_service() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment