Created
February 22, 2016 11:02
-
-
Save ck196/198885b45118e2963cbf to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -------------------------------------------------------- | |
# Faster R-CNN | |
# Copyright (c) 2015 Microsoft | |
# Licensed under The MIT License [see LICENSE for details] | |
# Written by Ross Girshick | |
# -------------------------------------------------------- | |
""" | |
Demo script showing detections in sample images. | |
See README.md for installation instructions before running. | |
""" | |
import _init_paths | |
from fast_rcnn.config import cfg | |
from fast_rcnn.test import im_detect | |
from fast_rcnn.nms_wrapper import nms | |
from utils.timer import Timer | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import scipy.io as sio | |
import caffe, os, sys, cv2 | |
import argparse | |
CLASSES = ('__background__', 'face') | |
def vis_detections(im, class_name, dets, thresh=0.5): | |
"""Draw detected bounding boxes.""" | |
inds = np.where(dets[:, -1] >= thresh)[0] | |
if len(inds) == 0: | |
return | |
#im = im[:, :, (2, 1, 0)] | |
#fig, ax = plt.subplots(figsize=(12, 12)) | |
#ax.imshow(im, aspect='equal') | |
#plt.pause(5) | |
for i in inds: | |
bbox = dets[i, :4] | |
score = dets[i, -1] | |
cv2.rectangle(im, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 0, 255), 2) | |
cv2.putText(im, class_name, (int(bbox[0]), int(bbox[1]) + 15), cv2.FONT_HERSHEY_DUPLEX, 0.6, (255, 255, 0), 1) | |
cv2.imshow('image',im) | |
def demo(net, image_name): | |
"""Detect object classes in an image using pre-computed object proposals.""" | |
# Load the demo image | |
im_file = os.path.join(cfg.ROOT_DIR, 'data', 'demo', image_name) | |
im = cv2.imread(im_file) | |
# Detect all object classes and regress object bounds | |
timer = Timer() | |
timer.tic() | |
scores, boxes = im_detect(net, im) | |
#print scores | |
#print boxes | |
timer.toc() | |
print ('Detection took {:.3f}s for ' | |
'{:d} object proposals').format(timer.total_time, boxes.shape[0]) | |
# Visualize detections for each class | |
CONF_THRESH = 0.8 | |
NMS_THRESH = 0.3 | |
for cls_ind, cls in enumerate(CLASSES[1:]): | |
cls_ind += 1 # because we skipped background | |
cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)] | |
cls_scores = scores[:, cls_ind] | |
dets = np.hstack((cls_boxes, | |
cls_scores[:, np.newaxis])).astype(np.float32) | |
keep = nms(dets, NMS_THRESH) | |
dets = dets[keep, :] | |
vis_detections(im, cls, dets, thresh=CONF_THRESH) | |
if __name__ == '__main__': | |
cfg.TEST.HAS_RPN = True # Use RPN for proposals | |
#args = parse_args() | |
prototxt = "models/VGG16/face/test.prototxt" | |
caffemodel = "/data/kju/facedetect/face.caffemodel" | |
caffe.set_mode_gpu() | |
caffe.set_device(0) | |
cfg.GPU_ID = 0 | |
net = caffe.Net(prototxt, caffemodel, caffe.TEST) | |
# Warmup on a dummy image | |
im = 128 * np.ones((300, 500, 3), dtype=np.uint8) | |
for i in xrange(2): | |
_, _= im_detect(net, im) | |
im_names = ['000456.jpg', '000542.jpg', '001150.jpg', | |
'001763.jpg', '004545.jpg', 'test.jpg'] | |
for im_name in im_names: | |
print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~' | |
print 'Demo for data/demo/{}'.format(im_name) | |
demo(net, im_name) | |
#plt.show() | |
cv2.waitKey(0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment