Last active
May 17, 2021 09:29
-
-
Save berak/c5d58315a332ba2bf3246b3ad2686c4c to your computer and use it in GitHub Desktop.
weighted box fusion
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
template<class _Tp> | |
struct WeightedBoxesFusion { | |
typedef Rect_<_Tp> RECT; | |
struct Box { | |
int label; | |
float x1,y1,x2,y2, score, weight; | |
Box() {} | |
Box(int label, float x1,float y1,float x2,float y2, float score, float weight) | |
: label(label), x1(x1), y1(y1), x2(x2), y2(y2), score(score), weight(weight) | |
{ | |
} | |
double distance(const Box &b) const { | |
RECT r1(x1,y1,x2-x1,y2-y1); | |
RECT r2(b.x1,b.y1,b.x2-b.x1,b.y2-b.y1); | |
return jaccardDistance(r1,r2); | |
} | |
}; | |
WeightedBoxesFusion(double iou_thresh=0.45) : THR_IOU(iou_thresh) {} | |
vector<Box> B; | |
vector<pair<vector<Box>,Box>> FL; // combined F and L lists | |
int num_models = 0; | |
double weights = 0; | |
double THR_IOU = (1.0 - 0.55); // jaccardDistance returns (1-iou) | |
bool addModel(const vector<RECT> &rects, const vector<float> &scores, const vector<int> &labels, float weight, float conf_thresh) { | |
CV_Assert(rects.size()==scores.size()); | |
CV_Assert(scores.size()==labels.size()); | |
for (size_t i=0; i<rects.size(); i++) { | |
if (scores[i] < conf_thresh) | |
continue; | |
B.push_back(Box(labels[i], | |
rects[i].x, rects[i].y, rects[i].x+rects[i].width, rects[i].y+rects[i].height, | |
scores[i] * weight, weight)); | |
} | |
num_models ++; | |
weights += weight; | |
return true; | |
} | |
bool fuse(vector<RECT> &rects, vector<float> &scores, vector<int> &labels) { | |
sort(B.begin(), B.end(), [](const Box &a, const Box &b) { | |
return a.score > b.score; | |
}); | |
// 3. | |
for (const auto b : B) { | |
double min_d = THR_IOU; | |
int best = -1; | |
for (int i=0; i<FL.size(); i++) { | |
auto &f = FL[i]; | |
if (b.label != f.second.label) | |
continue; | |
double d = b.distance(f.second); | |
if (d < min_d) { | |
best = i; | |
min_d = d; | |
} | |
} | |
if (best != -1) { | |
FL[best].first.push_back(b); | |
} else { | |
FL.push_back(make_pair(vector<Box>{b},b)); | |
} | |
} | |
// 6. | |
for (auto f : FL) { | |
float x1=0, x2=0, y1=0, y2=0, sum_score=0; | |
for (auto q : f.first) { | |
x1 += q.score * q.x1; y1 += q.score * q.y1; | |
x2 += q.score * q.x2; y2 += q.score * q.y2; | |
sum_score += q.score; | |
} | |
x1 /= sum_score; | |
y1 /= sum_score; | |
x2 /= sum_score; | |
y2 /= sum_score; | |
rects.push_back(RECT(x1,y1,x2-x1,y2-y1)); | |
scores.push_back(sum_score / weights); | |
labels.push_back(f.second.label); | |
} | |
return true; | |
} | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment