Created
May 8, 2019 23:05
-
-
Save zhreshold/5f44dcb1a00f84bc1d981465f1c1f0e2 to your computer and use it in GitHub Desktop.
GluonCV cam demo pose
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import division | |
import argparse, time, logging, os, math, tqdm, cv2 | |
import numpy as np | |
import mxnet as mx | |
from mxnet import gluon, nd, image | |
from mxnet.gluon.data.vision import transforms | |
import matplotlib.pyplot as plt | |
import matplotlib | |
matplotlib.use("TkAgg") | |
import gluoncv as gcv | |
from gluoncv import data | |
from gluoncv.data import mscoco | |
from gluoncv.model_zoo import get_model | |
from gluoncv.data.transforms.pose import detector_to_simple_pose, heatmap_to_coord | |
from gluoncv.utils.viz import plot_image, plot_keypoints | |
parser = argparse.ArgumentParser(description='Predict ImageNet classes from a given image') | |
parser.add_argument('--detector', type=str, default='yolo3_mobilenet1.0_coco', | |
help='name of the detection model to use') | |
parser.add_argument('--pose-model', type=str, default='simple_pose_resnet50_v1b', | |
help='name of the pose estimation model to use') | |
parser.add_argument('--num-frames', type=int, default=100, | |
help='Number of frames to capture') | |
opt = parser.parse_args() | |
def cv_plot_image(img, **kwargs): | |
if not img: | |
return | |
canvas = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) | |
cv2.imshow('demo', canvas) | |
cv2.waitKey(1) | |
def cv_plot_bbox(img, bboxes, scores=None, labels=None, thresh=0.5, | |
class_names=None, colors=None, ax=None, | |
reverse_rgb=False, absolute_coordinates=True): | |
from matplotlib import pyplot as plt | |
import random | |
if labels is not None and not len(bboxes) == len(labels): | |
raise ValueError('The length of labels and bboxes mismatch, {} vs {}' | |
.format(len(labels), len(bboxes))) | |
if scores is not None and not len(bboxes) == len(scores): | |
raise ValueError('The length of scores and bboxes mismatch, {} vs {}' | |
.format(len(scores), len(bboxes))) | |
if len(bboxes) < 1: | |
return img | |
if isinstance(bboxes, mx.nd.NDArray): | |
bboxes = bboxes.asnumpy() | |
if isinstance(labels, mx.nd.NDArray): | |
labels = labels.asnumpy() | |
if isinstance(scores, mx.nd.NDArray): | |
scores = scores.asnumpy() | |
if not absolute_coordinates: | |
# convert to absolute coordinates using image shape | |
height = img.shape[0] | |
width = img.shape[1] | |
bboxes[:, (0, 2)] *= width | |
bboxes[:, (1, 3)] *= height | |
# use random colors if None is provided | |
if colors is None: | |
colors = dict() | |
for i, bbox in enumerate(bboxes): | |
if scores is not None and scores.flat[i] < thresh: | |
continue | |
if labels is not None and labels.flat[i] < 0: | |
continue | |
cls_id = int(labels.flat[i]) if labels is not None else -1 | |
if cls_id not in colors: | |
if class_names is not None: | |
colors[cls_id] = plt.get_cmap('hsv')(cls_id / len(class_names)) | |
else: | |
colors[cls_id] = (random.random(), random.random(), random.random()) | |
xmin, ymin, xmax, ymax = [int(x) for x in bbox] | |
bcolor = [x * 255 for x in colors[cls_id]] | |
cv2.rectangle(img, (xmin, ymin), (xmax, ymax), bcolor, 3) | |
if class_names is not None and cls_id < len(class_names): | |
class_name = class_names[cls_id] | |
else: | |
class_name = str(cls_id) if cls_id >= 0 else '' | |
score = '{:.3f}'.format(scores.flat[i]) if scores is not None else '' | |
if class_name or score: | |
pass | |
# ax.text(xmin, ymin - 2, | |
# '{:s} {:s}'.format(class_name, score), | |
# bbox=dict(facecolor=colors[cls_id], alpha=0.5), | |
# fontsize=12, color='white') | |
# cv2.putText(img, '{:s} {:s}'.format(class_name, score), | |
# (xmin, ymin-2), cv2.FONT_HERSHEY_TRIPLEX, 12, (255, 255, 255)) | |
return img | |
def cv_plot_keypoints(img, coords, confidence, class_ids, bboxes, scores, | |
box_thresh=0.5, keypoint_thresh=0.2, **kwargs): | |
def to_int(float_arr): | |
return tuple([int(x) for x in float_arr]) | |
if isinstance(coords, mx.nd.NDArray): | |
coords = coords.asnumpy() | |
if isinstance(class_ids, mx.nd.NDArray): | |
class_ids = class_ids.asnumpy() | |
if isinstance(bboxes, mx.nd.NDArray): | |
bboxes = bboxes.asnumpy() | |
if isinstance(scores, mx.nd.NDArray): | |
scores = scores.asnumpy() | |
if isinstance(confidence, mx.nd.NDArray): | |
confidence = confidence.asnumpy() | |
joint_visible = confidence[:, :, 0] > keypoint_thresh | |
joint_pairs = [[0, 1], [1, 3], [0, 2], [2, 4], | |
[5, 6], [5, 7], [7, 9], [6, 8], [8, 10], | |
[5, 11], [6, 12], [11, 12], | |
[11, 13], [12, 14], [13, 15], [14, 16]] | |
person_ind = class_ids[0] == 0 | |
img = cv_plot_bbox(img, bboxes[0][person_ind[:, 0]], | |
scores[0][person_ind[:, 0]], thresh=box_thresh, **kwargs) | |
colormap_index = np.linspace(0, 1, len(joint_pairs)) | |
for i in range(coords.shape[0]): | |
pts = coords[i] | |
for cm_ind, jp in zip(colormap_index, joint_pairs): | |
if joint_visible[i, jp[0]] and joint_visible[i, jp[1]]: | |
cm_color = tuple([int(x * 255) for x in plt.cm.cool(cm_ind)[1:]]) | |
pt1 = (int(pts[jp, 0][0]), int(pts[jp, 1][0])) | |
pt2 = (int(pts[jp, 0][1]), int(pts[jp, 1][1])) | |
cv2.line(img, pt1, pt2, cm_color, 3) | |
# cv2.circle(img, pt1, 1, cm_color) | |
# cv2.circle(img, pt2, 1, cm_color) | |
canvas = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) | |
cv2.imshow('demo', canvas) | |
cv2.waitKey(1) | |
def keypoint_detection(img, detector, pose_net, ctx=mx.cpu(), axes=None): | |
x, img = gcv.data.transforms.presets.yolo.transform_test(img, short=512, max_size=350) | |
x = x.as_in_context(ctx) | |
class_IDs, scores, bounding_boxs = detector(x) | |
# plt.cla() | |
pose_input, upscale_bbox = detector_to_simple_pose(img, class_IDs, scores, bounding_boxs, | |
output_shape=(256, 192), ctx=ctx) | |
if len(upscale_bbox) > 0: | |
predicted_heatmap = pose_net(pose_input) | |
pred_coords, confidence = heatmap_to_coord(predicted_heatmap, upscale_bbox) | |
axes = cv_plot_keypoints(img, pred_coords, confidence, class_IDs, bounding_boxs, scores, | |
box_thresh=0.5, keypoint_thresh=0.2, ax=axes) | |
# plt.draw() | |
# plt.pause(0.001) | |
else: | |
axes = cv_plot_image(frame, ax=axes) | |
# plt.draw() | |
# plt.pause(0.001) | |
return axes | |
if __name__ == '__main__': | |
ctx = mx.cpu() | |
detector_name = "ssd_512_mobilenet1.0_coco" | |
detector = get_model(detector_name, pretrained=True, ctx=ctx) | |
detector.reset_class(classes=['person'], reuse_weights={'person':'person'}) | |
net = get_model('simple_pose_resnet50_v1b', pretrained=True, ctx=ctx) | |
cap = cv2.VideoCapture(0) | |
time.sleep(1) ### letting the camera autofocus | |
axes = None | |
for i in range(opt.num_frames): | |
ret, frame = cap.read() | |
frame = mx.nd.array(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)).astype('uint8') | |
axes = keypoint_detection(frame, detector, net, ctx, axes=axes) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment