Last active
June 14, 2020 22:00
-
-
Save ivanpanshin/e014c47297cb4378bfe78332510929cf to your computer and use it in GitHub Desktop.
keypoints with Caffe2 example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from caffe2.proto import caffe2_pb2 | |
from caffe2.python import core, workspace | |
from detectron2.export.caffe2_inference import ProtobufDetectionModel | |
from detectron2.export.api import Caffe2Model | |
import numpy as np | |
import os | |
import torch | |
import cv2 | |
import time | |
import shutil | |
import math | |
print("Required modules imported.") | |
def create_caffe2_model(): | |
predict_net = caffe2_pb2.NetDef() | |
with open("right_anchors_caffe2_model/model.pb", 'rb') as f: | |
predict_net.ParseFromString(f.read()) | |
init_net = caffe2_pb2.NetDef() | |
with open("right_anchors_caffe2_model/model_init.pb", 'rb') as f: | |
init_net.ParseFromString(f.read()) | |
print('Trying to create model') | |
model = ProtobufDetectionModel(predict_net=predict_net, init_net=init_net) | |
return model | |
def distance(first_point, second_point): | |
x1, y1 = first_point | |
x2, y2 = second_point | |
return math.sqrt( (x2-x1)*(x2-x1) + (y2-y1)*(y2-y1)) | |
def filter_potential_landmarks(potential_landmarks): | |
amount_of_points = len(potential_landmarks) | |
if amount_of_points <= 2: | |
return potential_landmarks | |
if amount_of_points > 4: | |
return [] | |
f_x1, f_y1 = potential_landmarks[0] | |
f_x2, f_y2 = potential_landmarks[1] | |
s_x1, s_y1 = potential_landmarks[2] | |
s_x2, s_y2 = potential_landmarks[3] | |
dx = min(max(f_x1, f_x2), max(s_x1, s_x2)) - max(min(f_x1, f_x2), min(s_x1, s_x2)) | |
dy = min(max(f_y1, f_y2), max(s_y1, s_y2)) - max(min(f_y1, f_y2), min(s_y1, s_y2)) | |
if dx >= 0 and dy >= 0: | |
intersection = dx*dy | |
else: | |
intersection = 0 | |
if intersection != 0: | |
return [[f_x1, f_y1], [f_x2, f_y2]] | |
else: | |
dist_1 = distance([f_x2, f_y2], [s_x1, s_y1]) | |
dist_2 = distance([s_x2, s_y2], [f_x1, f_y1]) | |
if dist_1 < dist_2: | |
return [[f_x2, f_y2], [s_x1, s_y1]] | |
return [[s_x2, s_y2], [f_x1, f_y1]] | |
def show_landmarks(image_path, landmarks): | |
original_img = cv2.imread(image_path) | |
for landmark in landmarks: | |
x, y = landmark | |
cv2.circle(original_img, (int(x),int(y)), 1, (0, 0, 255), -1) | |
cv2.imshow('keypoints', original_img) | |
cv2.waitKey(0) | |
def detect_keypoints(image_path, model, show_result=False, size_of_image=256): | |
potential_landmarks = [] | |
potential_landmarks_indices = [] | |
original_img = cv2.imread(image_path) | |
real_height, real_width = original_img.shape[:2] | |
scale_x = real_width / size_of_image | |
scale_y = real_height / size_of_image | |
original_img_resized = cv2.resize(original_img, (size_of_image, size_of_image)) | |
img = original_img_resized.swapaxes(1, 2).swapaxes(0, 1) | |
results = model([{'image': torch.Tensor(img)}]) | |
if len(results[0]['instances'].pred_classes) > 0: | |
keypoints = (results[0]['instances']).pred_keypoints | |
eyes = (results[0]['instances']).pred_boxes | |
if show_result: | |
for eye in eyes: | |
eye = eye.numpy() | |
x1, y1, x2, y2 = eye | |
x1, x2 = x1 * scale_x, x2 * scale_x | |
y1, y2 = y1 * scale_y, y2 * scale_y | |
cv2.rectangle(original_img, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255), 1) | |
for index_of_patch, patch_of_keypoints in enumerate(keypoints): | |
if index_of_patch == 2: | |
break | |
for index_of_keypoint, keypoint in enumerate(patch_of_keypoints): | |
keypoint = keypoint.numpy() | |
x, y = keypoint[0]*scale_x, keypoint[1]*scale_y | |
if index_of_keypoint in [0, 3]: | |
potential_landmarks.append([x,y]) | |
if show_result: | |
cv2.circle(original_img, (int(x),int(y)), 1, (0, 0, 255), -1) | |
if show_result: | |
cv2.imshow('keypoints', original_img) | |
cv2.waitKey(0) | |
return potential_landmarks | |
if __name__ == "__main__": | |
model = create_caffe2_model() | |
image_path = 'frame.jpg' | |
potential_landmarks = detect_keypoints(image_path, model, False) | |
landmarks = filter_potential_landmarks(potential_landmarks) | |
show_landmarks(image_path, landmarks) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment