Created
July 20, 2018 11:07
-
-
Save n5ken/7f5e0c9646c471dfacdfcfa95f0323a5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!python3 | |
#pylint: disable=R, W0401, W0614, W0703 | |
from ctypes import * | |
import math | |
import random | |
import os | |
import configparser | |
class BOX(Structure): | |
_fields_ = [("x", c_float), | |
("y", c_float), | |
("w", c_float), | |
("h", c_float)] | |
class DETECTION(Structure): | |
_fields_ = [("bbox", BOX), | |
("classes", c_int), | |
("prob", POINTER(c_float)), | |
("mask", POINTER(c_float)), | |
("objectness", c_float), | |
("sort_class", c_int)] | |
class IMAGE(Structure): | |
_fields_ = [("w", c_int), | |
("h", c_int), | |
("c", c_int), | |
("data", POINTER(c_float))] | |
class METADATA(Structure): | |
_fields_ = [("classes", c_int), | |
("names", POINTER(c_char_p))] | |
class Darknet: | |
def __init__(self, metaPath, configPath, weightPath, hasGPU=True): | |
lib = CDLL("./lib/darknet/libdarknet.so", RTLD_GLOBAL) | |
lib.network_width.argtypes = [c_void_p] | |
lib.network_width.restype = c_int | |
lib.network_height.argtypes = [c_void_p] | |
lib.network_height.restype = c_int | |
predict = lib.network_predict | |
predict.argtypes = [c_void_p, POINTER(c_float)] | |
predict.restype = POINTER(c_float) | |
if hasGPU: | |
set_gpu = lib.cuda_set_device | |
set_gpu.argtypes = [c_int] | |
make_image = lib.make_image | |
make_image.argtypes = [c_int, c_int, c_int] | |
make_image.restype = IMAGE | |
self.get_network_boxes = lib.get_network_boxes | |
self.get_network_boxes.argtypes = [c_void_p, c_int, c_int, c_float, c_float, POINTER(c_int), c_int, POINTER(c_int), c_int] | |
self.get_network_boxes.restype = POINTER(DETECTION) | |
make_network_boxes = lib.make_network_boxes | |
make_network_boxes.argtypes = [c_void_p] | |
make_network_boxes.restype = POINTER(DETECTION) | |
self.free_detections = lib.free_detections | |
self.free_detections.argtypes = [POINTER(DETECTION), c_int] | |
free_ptrs = lib.free_ptrs | |
free_ptrs.argtypes = [POINTER(c_void_p), c_int] | |
network_predict = lib.network_predict | |
network_predict.argtypes = [c_void_p, POINTER(c_float)] | |
reset_rnn = lib.reset_rnn | |
reset_rnn.argtypes = [c_void_p] | |
load_net = lib.load_network | |
load_net.argtypes = [c_char_p, c_char_p, c_int] | |
load_net.restype = c_void_p | |
load_net_custom = lib.load_network_custom | |
load_net_custom.argtypes = [c_char_p, c_char_p, c_int, c_int] | |
load_net_custom.restype = c_void_p | |
do_nms_obj = lib.do_nms_obj | |
do_nms_obj.argtypes = [POINTER(DETECTION), c_int, c_int, c_float] | |
self.do_nms_sort = lib.do_nms_sort | |
self.do_nms_sort.argtypes = [POINTER(DETECTION), c_int, c_int, c_float] | |
self.free_image = lib.free_image | |
self.free_image.argtypes = [IMAGE] | |
letterbox_image = lib.letterbox_image | |
letterbox_image.argtypes = [IMAGE, c_int, c_int] | |
letterbox_image.restype = IMAGE | |
load_meta = lib.get_metadata | |
lib.get_metadata.argtypes = [c_char_p] | |
lib.get_metadata.restype = METADATA | |
self.load_image = lib.load_image_color | |
self.load_image.argtypes = [c_char_p, c_int, c_int] | |
self.load_image.restype = IMAGE | |
rgbgr_image = lib.rgbgr_image | |
rgbgr_image.argtypes = [IMAGE] | |
self.predict_image = lib.network_predict_image | |
self.predict_image.argtypes = [c_void_p, IMAGE] | |
self.predict_image.restype = POINTER(c_float) | |
self.meta = load_meta(metaPath.encode("ascii")) | |
self.net = load_net_custom(configPath.encode("ascii"), weightPath.encode("ascii"), 0, 1) # batch size = 1 | |
config = configparser.RawConfigParser() | |
config.read(metaPath) | |
self.altNames = [] | |
with open(config['name']['names'], 'r') as file: | |
self.altNames = file.read().splitlines() | |
def __sample(self, probs): | |
s = sum(probs) | |
probs = [a/s for a in probs] | |
r = random.uniform(0, 1) | |
for i in range(len(probs)): | |
r = r - probs[i] | |
if r <= 0: | |
return i | |
return len(probs)-1 | |
def __c_array(self, ctype, values): | |
arr = (ctype*len(values))() | |
arr[:] = values | |
return arr | |
def array_to_image(self, arr): | |
import numpy as np | |
# need to return old values to avoid python freeing memory | |
arr = arr.transpose(2,0,1) | |
c = arr.shape[0] | |
h = arr.shape[1] | |
w = arr.shape[2] | |
arr = np.ascontiguousarray(arr.flat, dtype=np.float32) / 255.0 | |
data = arr.ctypes.data_as(POINTER(c_float)) | |
im = IMAGE(w,h,c,data) | |
return im, arr | |
def classify(self, net, meta, im): | |
out = self.predict_image(net, im) | |
res = [] | |
for i in range(meta.classes): | |
res.append((altNames[i], out[i])) | |
res = sorted(res, key=lambda x: -x[1]) | |
return res | |
def detect(self, image, thresh=.5, hier_thresh=.5, nms=.45, debug=False): | |
if type(image).__name__ == 'str': | |
im = self.load_image(image, 0, 0) | |
else: | |
im = image | |
if debug: print("Loaded image") | |
num = c_int(0) | |
if debug: print("Assigned num") | |
pnum = pointer(num) | |
if debug: print("Assigned pnum") | |
self.predict_image(self.net, im) | |
if debug: print("did prediction") | |
dets = self.get_network_boxes(self.net, im.w, im.h, thresh, hier_thresh, None, 0, pnum, 0) | |
if debug: print("Got dets") | |
num = pnum[0] | |
if debug: print("got zeroth index of pnum") | |
if nms: | |
self.do_nms_sort(dets, num, self.meta.classes, nms) | |
if debug: print("did sort") | |
res = [] | |
if debug: print("about to range") | |
for j in range(num): | |
if debug: print("Ranging on "+str(j)+" of "+str(num)) | |
if debug: print("Classes: "+str(self.meta), self.meta.classes, self.meta.names) | |
for i in range(self.meta.classes): | |
if debug: print("Class-ranging on "+str(i)+" of "+str(self.meta.classes)+"= "+str(dets[j].prob[i])) | |
if dets[j].prob[i] > 0: | |
b = dets[j].bbox | |
nameTag = self.altNames[i] | |
if debug: | |
print("Got bbox", b) | |
print(nameTag) | |
print(dets[j].prob[i]) | |
print((b.x, b.y, b.w, b.h)) | |
res.append((nameTag, dets[j].prob[i], (b.x, b.y, b.w, b.h))) | |
if debug: print("did range") | |
res = sorted(res, key=lambda x: -x[1]) | |
if debug: print("did sort") | |
# self.free_image(im) | |
if debug: print("freed image") | |
self.free_detections(dets, num) | |
if debug: print("freed detections") | |
return res |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment