Created
June 7, 2020 20:02
-
-
Save robisen1/616a2a1be498741d5d3ca65967de06e7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
#added threading to read | |
import os | |
import colorsys | |
import numpy as np | |
from keras import backend as K | |
from keras.layers import Input | |
import cv2 | |
from imutils.video import FPS | |
from imutils.video import FileVideoStream | |
from yolo4.model import yolo_eval, yolo4_body | |
from yolo4.utils import letterbox_image | |
from PIL import Image, ImageFont, ImageDraw | |
from timeit import default_timer as timer | |
class Yolo4(object): | |
def get_class(self): | |
classes_path = os.path.expanduser(self.classes_path) | |
with open(classes_path) as f: | |
class_names = f.readlines() | |
class_names = [c.strip() for c in class_names] | |
return class_names | |
def get_anchors(self): | |
anchors_path = os.path.expanduser(self.anchors_path) | |
with open(anchors_path) as f: | |
anchors = f.readline() | |
anchors = [float(x) for x in anchors.split(',')] | |
return np.array(anchors).reshape(-1, 2) | |
def load_yolo(self): | |
model_path = os.path.expanduser(self.model_path) | |
assert model_path.endswith('.h5'), 'Keras model or weights must be a .h5 file.' | |
self.class_names = self.get_class() | |
self.anchors = self.get_anchors() | |
num_anchors = len(self.anchors) | |
num_classes = len(self.class_names) | |
# Generate colors for drawing bounding boxes. | |
hsv_tuples = [(x / len(self.class_names), 1., 1.) | |
for x in range(len(self.class_names))] | |
self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples)) | |
self.colors = list( | |
map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), | |
self.colors)) | |
self.sess = K.get_session() | |
# Load model, or construct model and load weights. | |
self.yolo4_model = yolo4_body(Input(shape=(416, 416, 3)), num_anchors//3, num_classes) | |
self.yolo4_model.load_weights(model_path) | |
print('{} model, anchors, and classes loaded.'.format(model_path)) | |
if self.gpu_num>=2: | |
self.yolo4_model = multi_gpu_model(self.yolo4_model, gpus=self.gpu_num) | |
self.input_image_shape = K.placeholder(shape=(2, )) | |
self.boxes, self.scores, self.classes = yolo_eval(self.yolo4_model.output, self.anchors, | |
len(self.class_names), self.input_image_shape, | |
score_threshold=self.score) | |
def __init__(self, score, iou, anchors_path, classes_path, model_path, gpu_num=1): | |
self.score = score | |
self.iou = iou | |
self.anchors_path = anchors_path | |
self.classes_path = classes_path | |
self.model_path = model_path | |
self.gpu_num = gpu_num | |
self.load_yolo() | |
def close_session(self): | |
self.sess.close() | |
def detect_image(self, image, model_image_size=(416, 416)): | |
start = timer() | |
boxed_image = letterbox_image(image, tuple(reversed(model_image_size))) | |
image_data = np.array(boxed_image, dtype='float32') | |
#print(image_data.shape) | |
image_data /= 255. | |
# print("the image data: ", image_data) | |
image_data = np.expand_dims(image_data, 0) # Add batch dimension. | |
out_boxes, out_scores, out_classes = self.sess.run( | |
[self.boxes, self.scores, self.classes], | |
feed_dict={ | |
self.yolo4_model.input: image_data, | |
self.input_image_shape: [image.size[1], image.size[0]], | |
K.learning_phase(): 0 | |
}) | |
print('Found {} boxes for {}'.format(len(out_boxes), 'img')) | |
font = ImageFont.truetype(font='font/FiraMono-Medium.otf', | |
size=np.floor(3e-2 * image.size[1] + 0.5).astype('int32')) | |
thickness = (image.size[0] + image.size[1]) // 300 | |
for i, c in reversed(list(enumerate(out_classes))): | |
predicted_class = self.class_names[c] | |
box = out_boxes[i] | |
score = out_scores[i] | |
label = '{} {:.2f}'.format(predicted_class, score) | |
draw = ImageDraw.Draw(image) | |
label_size = draw.textsize(label, font) | |
top, left, bottom, right = box | |
top = max(0, np.floor(top + 0.5).astype('int32')) | |
left = max(0, np.floor(left + 0.5).astype('int32')) | |
bottom = min(image.size[1], np.floor(bottom + 0.5).astype('int32')) | |
right = min(image.size[0], np.floor(right + 0.5).astype('int32')) | |
print(label, (left, top), (right, bottom)) | |
if top - label_size[1] >= 0: | |
text_origin = np.array([left, top - label_size[1]]) | |
else: | |
text_origin = np.array([left, top + 1]) | |
# My kingdom for a good redistributable image drawing library. | |
for i in range(thickness): | |
draw.rectangle( | |
[left + i, top + i, right - i, bottom - i], | |
outline=self.colors[c]) | |
draw.rectangle( | |
[tuple(text_origin), tuple(text_origin + label_size)], | |
fill=self.colors[c]) | |
draw.text(text_origin, label, fill=(0, 0, 0), font=font) | |
del draw | |
end = timer() | |
print(end - start) | |
return image | |
if __name__ == '__main__': | |
model_path = 'model_data/yolo4_weight.h5' | |
anchors_path = 'model_data/yolo4_anchors.txt' | |
classes_path = 'model_data/data_classes.txt' | |
score = 0.35 | |
iou = 0.35 | |
model_image_size = (416, 416) | |
yolo4_model = Yolo4(score, iou, anchors_path, classes_path, model_path) | |
fps = 0 | |
vid_in = 'BerghouseLeopardJog.mp4' | |
#increase performance by threading | |
fvs = FileVideoStream(vid_in).start() | |
video_FourCC = cv2.VideoWriter_fourcc(*'DIVX') | |
# create the video capture object cap | |
# pay attention to Width and Height. OpenCV can crash if they are not | |
# right. much of the time it deals with it but dont count on that | |
# you can also dynamically get width and height from attributes of vid | |
# also OpenCV only supports video out as .avi. you can try other formats | |
# but its the only offically supported one | |
out = cv2.VideoWriter('processed_vid.avi', video_FourCC, 24, (1280,720) ) | |
fps = FPS().start() | |
while True: | |
#read frames from the stream. | |
frame = fvs.read() | |
#in case of bad frames | |
if frame is None: | |
break | |
image = Image.fromarray(frame) | |
image = yolo4_model.detect_image(image) | |
result = np.asarray(image) | |
#this is just to show whats going on. If you comment all of the writing of | |
# frames to imshow then you will get a few more FPS | |
cv2.putText(result, text="Add Label", org=(3, 15), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.50, color=(255, 0, 0), thickness=2) | |
cv2.namedWindow("Frame", cv2.WINDOW_NORMAL) | |
cv2.imshow("Frame", result) | |
out.write(result) | |
# Press Q to stop! | |
if cv2.waitKey(1) & 0xFF == ord('q'): | |
break | |
fps.update() | |
#for imutils use stop | |
fps.stop() | |
print("[INFO] elasped time: {:.2f}".format(fps.elapsed())) | |
print("[INFO] approx. FPS: {:.2f}".format(fps.fps())) | |
#clean up and release output and | |
cv2.destroyAllWindows() | |
fvs.stop() | |
out.release() | |
yolo4_model.close_session() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment