Last active
December 16, 2018 15:27
-
-
Save jinyu121/911ea807de9775ebf877b580430dd981 to your computer and use it in GitHub Desktop.
Faster RCNN 结果可视化
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
import pandas as pd | |
import numpy as np | |
import os | |
from sklearn.externals import joblib | |
from skimage import io, draw, img_as_float | |
pkl_pred = "py-faster-rcnn/output/faster_rcnn_alt_opt/voc_2007_test/VGG16_faster_rcnn_final/detections.pkl" | |
pkl_anno = "py-faster-rcnn/data/VOCdevkit2007/annotations_cache/annots.pkl" | |
txt_test = "py-faster-rcnn/data/VOCdevkit2007/VOC2007/ImageSets/Main/test.txt" | |
img_dir = "py-faster-rcnn/data/VOCdevkit2007/VOC2007/JPEGImages" | |
# 重合超过多少就算正确 | |
ovthresh = 0.3 | |
# 总共检测到多少 | |
total_predicted = 0 | |
# 实际有多少 | |
total_groundtruth = 0 | |
# 对的 | |
counter_yes = 0 | |
# 错的 | |
counter_no = 0 | |
predictions = joblib.load(pkl_pred)[1] | |
annotations = joblib.load(pkl_anno) | |
def fnms(txt): | |
with open(txt, 'r') as test_set: | |
for ith, fnm in enumerate(test_set): | |
yield (ith, fnm.strip()) | |
def vis(img, bbgt, bb, yes, no): | |
# 可视化 | |
image_data = img_as_float(io.imread(img)) | |
# 蓝色是GroundTruth | |
for b in bbgt: | |
rr, cc = draw.polygon_perimeter([int(b[1]), int(b[1]), int(b[3]), int(b[3])], | |
[int(b[0]), int(b[2]), int(b[2]), int(b[0])]) | |
draw.set_color(image_data, (rr, cc), (0, 0, 1)) | |
for ith, b in enumerate(bb): | |
rr, cc = draw.polygon_perimeter([int(b[1]), int(b[1]), int(b[3]), int(b[3])], | |
[int(b[0]), int(b[2]), int(b[2]), int(b[0])]) | |
# 绿色是预测正确 | |
# 红色是预测错误 | |
if ith in yes: | |
draw.set_color(image_data, (rr, cc), (0, 1, 0)) | |
else: | |
draw.set_color(image_data, (rr, cc), (1, 0, 0)) | |
io.imshow(image_data) | |
io.show() | |
# 对于每一张图片 | |
for ith, fnm in fnms(txt_test): | |
img = os.path.join(img_dir, fnm + '.png') | |
# 获取到ground truth和预测结果 | |
gt = annotations[fnm] | |
pr = predictions[ith] | |
# 统计个数 | |
total_predicted += len(pr) | |
total_groundtruth += len(gt) | |
yes = list() | |
yes_pair = list() | |
no = list() | |
# 做计算 | |
# 对于每一个检测出来的框子 | |
BBGT = np.array([x['bbox'] for x in gt]) | |
for ith, bbpr in enumerate(pr): | |
bb = np.array(bbpr) | |
ovmax = -np.inf | |
if len(gt) > 0: | |
ixmin = np.maximum(BBGT[:, 0], bb[0]) | |
iymin = np.maximum(BBGT[:, 1], bb[1]) | |
ixmax = np.minimum(BBGT[:, 2], bb[2]) | |
iymax = np.minimum(BBGT[:, 3], bb[3]) | |
iw = np.maximum(ixmax - ixmin + 1., 0.) | |
ih = np.maximum(iymax - iymin + 1., 0.) | |
inters = iw * ih | |
# union | |
uni = ((bb[2] - bb[0] + 1.) * (bb[3] - bb[1] + 1.) + | |
(BBGT[:, 2] - BBGT[:, 0] + 1.) * | |
(BBGT[:, 3] - BBGT[:, 1] + 1.) - inters) | |
overlaps = inters / uni | |
ovmax = np.max(overlaps) | |
jmax = np.argmax(overlaps) | |
yes_pair.append(jmax) | |
if ovmax > ovthresh: | |
counter_yes += 1 | |
yes.append(ith) | |
else: | |
counter_no += 1 | |
no.append(ith) | |
print("Image:", fnm) | |
print("Truth:", len(gt)) | |
print("Yes:", len(yes)) | |
print("No:", len(no)) | |
print("Miss", len(gt) - len(yes)) | |
print() | |
vis(img, BBGT, pr, yes, no) | |
print() | |
print("total_detected", total_predicted) | |
print("total_groundtruth", total_groundtruth) | |
print("Counter_yes", counter_yes) | |
print("Counter_no", counter_no) | |
print("Counter_miss", total_groundtruth - counter_yes) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
import re | |
import numpy as np | |
import matplotlib.pyplot as plt | |
def str2num(str_list, output_type): | |
return [output_type(x) for x in str_list] | |
if "__main__" == __name__: | |
log_file = "experiments/logs/faster_rcnn_end2end_VGG16_.txt.2017-04-11_10-32-10" | |
pattern_itr = re.compile(r"219\]\s+Iteration\s+([\d]+)") | |
pattern_loss = re.compile(r", loss[\s=]{1,3}([\d\.]+)") | |
# pattern_loss = re.compile(r" loss_bbox[\s=]{1,3}([\d\.]+)") | |
# pattern_loss = re.compile(r" loss_cls[\s=]{1,3}([\d\.]+)") | |
# pattern_loss = re.compile(r" rpn_cls_loss[\s=]{1,3}([\d\.]+)") | |
# pattern_loss = re.compile(r" rpn_loss_bbox[\s=]{1,3}([\d\.]+)") | |
with open(log_file, 'r') as f: | |
lines = f.read() | |
itrs = pattern_itr.findall(lines) | |
loss = pattern_loss.findall(lines) | |
itrs = np.array(str2num(itrs, int)) | |
loss = np.array(str2num(loss, float)) | |
plt.figure() | |
plt.plot(itrs, loss) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment