Created
June 1, 2020 17:33
-
-
Save sizhky/61a0eb0f8fa99a1b9c53ef1436d9554b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch, math | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class RetinaNet(nn.Module): | |
num_anchors = 9 | |
def __init__(self, num_classes): | |
super(RetinaNet, self).__init__() | |
self.fpn = FPN50() | |
self.num_classes = num_classes | |
self.loc_head = self._make_head(self.num_anchors*4) | |
self.cls_head = self._make_head(self.num_anchors*self.num_classes) | |
def forward(self, x): | |
loc_preds = [] | |
cls_preds = [] | |
fms = self.fpn(x) | |
for fm in fms: | |
loc_pred = self.loc_head(fm) | |
cls_pred = self.cls_head(fm) | |
loc_pred = loc_pred.permute(0,2,3,1).reshape(x.size(0),-1,4) # [N, 9*4,H,W] -> [N,H,W, 9*4] -> [N,H*W*9, 4] | |
cls_pred = cls_pred.permute(0,2,3,1).reshape(x.size(0),-1,self.num_classes) # [N,9*NC,H,W] -> [N,H,W,9*NC] -> [N,H*W*9,NC] | |
loc_preds.append(loc_pred) | |
cls_preds.append(cls_pred) | |
return torch.cat(loc_preds, 1), torch.cat(cls_preds, 1) | |
def _make_head(self, out_planes): | |
layers = [] | |
for _ in range(4): | |
layers.append(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)) | |
layers.append(nn.ReLU(True)) | |
layers.append(nn.Conv2d(256, out_planes, kernel_size=3, stride=1, padding=1)) | |
return nn.Sequential(*layers) | |
def FPN50(): | |
return FPN(Bottleneck, [3,4,6,3]) | |
class FPN(nn.Module): | |
def __init__(self, block, num_blocks): | |
super(FPN, self).__init__() | |
self.in_planes = 64 | |
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) | |
self.bn1 = nn.BatchNorm2d(64) | |
# Bottom-up layers | |
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) | |
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) | |
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) | |
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) | |
self.conv6 = nn.Conv2d(2048, 256, kernel_size=3, stride=2, padding=1) | |
self.conv7 = nn.Conv2d( 256, 256, kernel_size=3, stride=2, padding=1) | |
# Top-down layers | |
self.toplayer = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=0) | |
# Lateral layers | |
self.latlayer1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0) | |
self.latlayer2 = nn.Conv2d( 512, 256, kernel_size=1, stride=1, padding=0) | |
# Smooth layers | |
self.smooth1 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) | |
self.smooth2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) | |
def _make_layer(self, block, planes, num_blocks, stride): | |
strides = [stride] + [1]*(num_blocks-1) | |
layers = [] | |
for stride in strides: | |
layers.append(block(self.in_planes, planes, stride)) | |
self.in_planes = planes * block.expansion | |
return nn.Sequential(*layers) | |
def _upsample_add(self, x, y): | |
'''Upsample and add two feature maps. | |
Args: | |
x: (Variable) top feature map to be upsampled. | |
y: (Variable) lateral feature map. | |
Returns: | |
(Variable) added feature map. | |
Note in PyTorch, when input size is odd, the upsampled feature map | |
with `F.upsample(..., scale_factor=2, mode='nearest')` | |
maybe not equal to the lateral feature map size. | |
e.g. | |
original input size: [N,_,15,15] -> | |
conv2d feature map size: [N,_,8,8] -> | |
upsampled feature map size: [N,_,16,16] | |
So we choose bilinear upsample which supports arbitrary output sizes. | |
''' | |
_,_,H,W = y.size() | |
return F.upsample(x, size=(H,W), mode='bilinear', align_corners=False) + y | |
def forward(self, x): | |
# Bottom-up | |
c1 = F.relu(self.bn1(self.conv1(x))) | |
c1 = F.max_pool2d(c1, kernel_size=3, stride=2, padding=1) | |
c2 = self.layer1(c1) | |
c3 = self.layer2(c2) | |
c4 = self.layer3(c3) | |
c5 = self.layer4(c4) | |
p6 = self.conv6(c5) | |
p7 = self.conv7(F.relu(p6)) | |
# Top-down | |
p5 = self.toplayer(c5) | |
p4 = self._upsample_add(p5, self.latlayer1(c4)) | |
p4 = self.smooth1(p4) | |
p3 = self._upsample_add(p4, self.latlayer2(c3)) | |
p3 = self.smooth2(p3) | |
return p3, p4, p5, p6, p7 | |
class Bottleneck(nn.Module): | |
expansion = 4 | |
def __init__(self, in_planes, planes, stride=1): | |
super(Bottleneck, self).__init__() | |
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) | |
self.bn1 = nn.BatchNorm2d(planes) | |
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) | |
self.bn2 = nn.BatchNorm2d(planes) | |
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) | |
self.bn3 = nn.BatchNorm2d(self.expansion*planes) | |
self.downsample = nn.Sequential() | |
if stride != 1 or in_planes != self.expansion*planes: | |
self.downsample = nn.Sequential( | |
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), | |
nn.BatchNorm2d(self.expansion*planes) | |
) | |
def forward(self, x): | |
out = F.relu(self.bn1(self.conv1(x))) | |
out = F.relu(self.bn2(self.conv2(out))) | |
out = self.bn3(self.conv3(out)) | |
out += self.downsample(x) | |
out = F.relu(out) | |
return out | |
'''Encode object boxes and labels.''' | |
class RetinaBoxCoder: | |
def __init__(self): | |
self.anchor_areas = (32*32., 64*64., 128*128., 256*256., 512*512.) # p3 -> p7 | |
self.aspect_ratios = (1/2., 1/1., 2/1.) | |
self.scale_ratios = (1., pow(2,1/3.), pow(2,2/3.)) | |
self.anchor_boxes = self._get_anchor_boxes(input_size=torch.tensor([640.,640.])) | |
def _get_anchor_wh(self): | |
'''Compute anchor width and height for each feature map. | |
Returns: | |
anchor_wh: (tensor) anchor wh, sized [#fm, #anchors_per_cell, 2]. | |
''' | |
anchor_wh = [] | |
for s in self.anchor_areas: | |
for ar in self.aspect_ratios: # w/h = ar | |
h = math.sqrt(s/ar) | |
w = ar * h | |
for sr in self.scale_ratios: # scale | |
anchor_h = h*sr | |
anchor_w = w*sr | |
anchor_wh.append([anchor_w, anchor_h]) | |
num_fms = len(self.anchor_areas) | |
return torch.Tensor(anchor_wh).view(num_fms, -1, 2) | |
def _get_anchor_boxes(self, input_size): | |
'''Compute anchor boxes for each feature map. | |
Args: | |
input_size: (tensor) model input size of (w,h). | |
Returns: | |
boxes: (list) anchor boxes for each feature map. Each of size [#anchors,4], | |
where #anchors = fmw * fmh * #anchors_per_cell | |
''' | |
num_fms = len(self.anchor_areas) | |
anchor_wh = self._get_anchor_wh() | |
fm_sizes = [(input_size/pow(2.,i+3)).ceil() for i in range(num_fms)] # p3 -> p7 feature map sizes | |
boxes = [] | |
for i in range(num_fms): | |
fm_size = fm_sizes[i] | |
grid_size = input_size / fm_size | |
fm_w, fm_h = int(fm_size[0]), int(fm_size[1]) | |
xy = meshgrid(fm_w,fm_h) + 0.5 # [fm_h*fm_w, 2] | |
xy = (xy*grid_size).view(fm_h,fm_w,1,2).expand(fm_h,fm_w,9,2) | |
wh = anchor_wh[i].view(1,1,9,2).expand(fm_h,fm_w,9,2) | |
box = torch.cat([xy-wh/2.,xy+wh/2.], 3) # [x,y,x,y] | |
boxes.append(box.view(-1,4)) | |
return torch.cat(boxes, 0) | |
def encode(self, boxes, labels): | |
'''Encode target bounding boxes and class labels. | |
We obey the Faster RCNN box coder: | |
tx = (x - anchor_x) / anchor_w | |
ty = (y - anchor_y) / anchor_h | |
tw = log(w / anchor_w) | |
th = log(h / anchor_h) | |
Args: | |
boxes: (tensor) bounding boxes of (xmin,ymin,xmax,ymax), sized [#obj, 4]. | |
labels: (tensor) object class labels, sized [#obj,]. | |
Returns: | |
loc_targets: (tensor) encoded bounding boxes, sized [#anchors,4]. | |
cls_targets: (tensor) encoded class labels, sized [#anchors,]. | |
''' | |
anchor_boxes = self.anchor_boxes | |
ious = box_iou(anchor_boxes, boxes) | |
max_ious, max_ids = ious.max(1) | |
boxes = boxes[max_ids] | |
boxes = change_box_order(boxes, 'xyxy2xywh') | |
anchor_boxes = change_box_order(anchor_boxes, 'xyxy2xywh') | |
loc_xy = (boxes[:,:2]-anchor_boxes[:,:2]) / anchor_boxes[:,2:] | |
loc_wh = torch.log(boxes[:,2:]/anchor_boxes[:,2:]) | |
loc_targets = torch.cat([loc_xy,loc_wh], 1) | |
cls_targets = 1 + labels[max_ids] | |
# cls_targets[max_ious<0.5] = 0 | |
# ignore = (max_ious>0.4) & (max_ious<0.5) # ignore ious between [0.4,0.5] | |
# cls_targets[ignore] = -1 # mark ignored to -1 | |
return loc_targets, cls_targets | |
def decode(self, loc_preds, cls_preds, input_size): | |
'''Decode outputs back to bouding box locations and class labels. | |
Args: | |
loc_preds: (tensor) predicted locations, sized [#anchors, 4]. | |
cls_preds: (tensor) predicted class labels, sized [#anchors, #classes]. | |
input_size: (tuple) model input size of (w,h). | |
Returns: | |
boxes: (tensor) decode box locations, sized [#obj,4]. | |
labels: (tensor) class labels for each box, sized [#obj,]. | |
''' | |
CLS_THRESH = 0.5 | |
NMS_THRESH = 0.5 | |
input_size = torch.Tensor(input_size) | |
anchor_boxes = self._get_anchor_boxes(input_size) # xywh | |
loc_xy = loc_preds[:,:2] | |
loc_wh = loc_preds[:,2:] | |
xy = loc_xy * anchor_boxes[:,2:] + anchor_boxes[:,:2] | |
wh = loc_wh.exp() * anchor_boxes[:,2:] | |
boxes = torch.cat([xy-wh/2, xy+wh/2], 1) # [#anchors,4] | |
score, labels = cls_preds.sigmoid().max(1) # [#anchors,] | |
ids = score > CLS_THRESH | |
ids = ids.nonzero().squeeze() # [#obj,] | |
keep = box_nms(boxes[ids], score[ids], threshold=NMS_THRESH) | |
return boxes[ids][keep], labels[ids][keep] | |
class FocalLoss(nn.Module): | |
def __init__(self, num_classes): | |
super(FocalLoss, self).__init__() | |
self.num_classes = num_classes | |
def _focal_loss(self, x, y): | |
'''Focal loss. | |
This is described in the original paper. | |
With BCELoss, the background should not be counted in num_classes. | |
Args: | |
x: (tensor) predictions, sized [N,D]. | |
y: (tensor) targets, sized [N,]. | |
Return: | |
(tensor) focal loss. | |
''' | |
alpha = 0.25 | |
gamma = 2 | |
t = one_hot_embedding(y-1, self.num_classes) | |
p = x.sigmoid() | |
pt = torch.where(t>0, p, 1-p) # pt = p if t > 0 else 1-p | |
w = (1-pt).pow(gamma) | |
w = torch.where(t>0, alpha*w, (1-alpha)*w).detach() | |
loss = F.binary_cross_entropy_with_logits(x, t, w, size_average=False) | |
return loss | |
def forward(self, loc_preds, cls_preds, loc_targets, cls_targets): | |
'''Compute loss between (loc_preds, loc_targets) and (cls_preds, cls_targets). | |
Args: | |
loc_preds: (tensor) predicted locations, sized [batch_size, #anchors, 4]. | |
loc_targets: (tensor) encoded target locations, sized [batch_size, #anchors, 4]. | |
cls_preds: (tensor) predicted class confidences, sized [batch_size, #anchors, #classes]. | |
cls_targets: (tensor) encoded target labels, sized [batch_size, #anchors]. | |
loss: | |
(tensor) loss = SmoothL1Loss(loc_preds, loc_targets) + FocalLoss(cls_preds, cls_targets). | |
''' | |
batch_size, num_boxes = cls_targets.size() | |
pos = cls_targets > 0 # [N,#anchors] | |
num_pos = pos.sum().item() | |
#=============================================================== | |
# loc_loss = SmoothL1Loss(pos_loc_preds, pos_loc_targets) | |
#=============================================================== | |
mask = pos.unsqueeze(2).expand_as(loc_preds) # [N,#anchors,4] | |
loc_loss = F.smooth_l1_loss(loc_preds[mask], loc_targets[mask], size_average=False) | |
#=============================================================== | |
# cls_loss = FocalLoss(cls_preds, cls_targets) | |
#=============================================================== | |
pos_neg = cls_targets > -1 # exclude ignored anchors | |
mask = pos_neg.unsqueeze(2).expand_as(cls_preds) | |
masked_cls_preds = cls_preds[mask].view(-1,self.num_classes) | |
cls_loss = self._focal_loss(masked_cls_preds, cls_targets[pos_neg]) | |
# print('loc_loss: %.3f | cls_loss: %.3f' % (loc_loss.item()/num_pos, cls_loss.item()/num_pos), end=' | ') | |
loss = (loc_loss+cls_loss)/num_pos | |
return loss | |
def change_box_order(boxes, order): | |
'''Change box order between (xmin,ymin,xmax,ymax) and (xcenter,ycenter,width,height). | |
Args: | |
boxes: (tensor) bounding boxes, sized [N,4]. | |
order: (str) either 'xyxy2xywh' or 'xywh2xyxy'. | |
Returns: | |
(tensor) converted bounding boxes, sized [N,4]. | |
''' | |
assert order in ['xyxy2xywh','xywh2xyxy'] | |
a = boxes[:,:2] | |
b = boxes[:,2:] | |
if order == 'xyxy2xywh': | |
return torch.cat([(a+b)/2,b-a], 1) | |
return torch.cat([a-b/2,a+b/2], 1) | |
def box_clamp(boxes, xmin, ymin, xmax, ymax): | |
'''Clamp boxes. | |
Args: | |
boxes: (tensor) bounding boxes of (xmin,ymin,xmax,ymax), sized [N,4]. | |
xmin: (number) min value of x. | |
ymin: (number) min value of y. | |
xmax: (number) max value of x. | |
ymax: (number) max value of y. | |
Returns: | |
(tensor) clamped boxes. | |
''' | |
boxes[:,0].clamp_(min=xmin, max=xmax) | |
boxes[:,1].clamp_(min=ymin, max=ymax) | |
boxes[:,2].clamp_(min=xmin, max=xmax) | |
boxes[:,3].clamp_(min=ymin, max=ymax) | |
return boxes | |
def box_select(boxes, xmin, ymin, xmax, ymax): | |
'''Select boxes in range (xmin,ymin,xmax,ymax). | |
Args: | |
boxes: (tensor) bounding boxes of (xmin,ymin,xmax,ymax), sized [N,4]. | |
xmin: (number) min value of x. | |
ymin: (number) min value of y. | |
xmax: (number) max value of x. | |
ymax: (number) max value of y. | |
Returns: | |
(tensor) selected boxes, sized [M,4]. | |
(tensor) selected mask, sized [N,]. | |
''' | |
mask = (boxes[:,0]>=xmin) & (boxes[:,1]>=ymin) \ | |
& (boxes[:,2]<=xmax) & (boxes[:,3]<=ymax) | |
boxes = boxes[mask,:] | |
return boxes, mask | |
def box_iou(box1, box2): | |
'''Compute the intersection over union of two set of boxes. | |
The box order must be (xmin, ymin, xmax, ymax). | |
Args: | |
box1: (tensor) bounding boxes, sized [N,4]. | |
box2: (tensor) bounding boxes, sized [M,4]. | |
Return: | |
(tensor) iou, sized [N,M]. | |
Reference: | |
https://github.com/chainer/chainercv/blob/master/chainercv/utils/bbox/bbox_iou.py | |
''' | |
N = box1.size(0) | |
M = box2.size(0) | |
lt = torch.max(box1[:,None,:2], box2[:,:2]) # [N,M,2] | |
rb = torch.min(box1[:,None,2:], box2[:,2:]) # [N,M,2] | |
wh = (rb-lt).clamp(min=0) # [N,M,2] | |
inter = wh[:,:,0] * wh[:,:,1] # [N,M] | |
area1 = (box1[:,2]-box1[:,0]) * (box1[:,3]-box1[:,1]) # [N,] | |
area2 = (box2[:,2]-box2[:,0]) * (box2[:,3]-box2[:,1]) # [M,] | |
iou = inter / (area1[:,None] + area2 - inter) | |
return iou | |
def box_nms(bboxes, scores, threshold=0.5): | |
'''Non maximum suppression. | |
Args: | |
bboxes: (tensor) bounding boxes, sized [N,4]. | |
scores: (tensor) confidence scores, sized [N,]. | |
threshold: (float) overlap threshold. | |
Returns: | |
keep: (tensor) selected indices. | |
Reference: | |
https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/nms/py_cpu_nms.py | |
''' | |
x1 = bboxes[:,0] | |
y1 = bboxes[:,1] | |
x2 = bboxes[:,2] | |
y2 = bboxes[:,3] | |
areas = (x2-x1) * (y2-y1) | |
_, order = scores.sort(0, descending=True) | |
keep = [] | |
while order.numel() > 0: | |
i = order[0] | |
keep.append(i) | |
if order.numel() == 1: | |
break | |
xx1 = x1[order[1:]].clamp(min=x1[i].item()) | |
yy1 = y1[order[1:]].clamp(min=y1[i].item()) | |
xx2 = x2[order[1:]].clamp(max=x2[i].item()) | |
yy2 = y2[order[1:]].clamp(max=y2[i].item()) | |
w = (xx2-xx1).clamp(min=0) | |
h = (yy2-yy1).clamp(min=0) | |
inter = w * h | |
overlap = inter / (areas[i] + areas[order[1:]] - inter) | |
ids = (overlap<=threshold).nonzero().squeeze() | |
if ids.numel() == 0: | |
break | |
order = order[ids+1] | |
return torch.tensor(keep, dtype=torch.long) | |
def meshgrid(x, y, row_major=True): | |
'''Return meshgrid in range x & y. | |
Args: | |
x: (int) first dim range. | |
y: (int) second dim range. | |
row_major: (bool) row major or column major. | |
Returns: | |
(tensor) meshgrid, sized [x*y,2] | |
Example: | |
>> meshgrid(3,2) | |
0 0 | |
1 0 | |
2 0 | |
0 1 | |
1 1 | |
2 1 | |
[torch.FloatTensor of size 6x2] | |
>> meshgrid(3,2,row_major=False) | |
0 0 | |
0 1 | |
0 2 | |
1 0 | |
1 1 | |
1 2 | |
[torch.FloatTensor of size 6x2] | |
''' | |
a = torch.arange(0,x) | |
b = torch.arange(0,y) | |
xx = a.repeat(y).view(-1,1) | |
yy = b.view(-1,1).repeat(1,x).view(-1,1) | |
return torch.cat([xx,yy],1) if row_major else torch.cat([yy,xx],1) | |
def one_hot_embedding(labels, num_classes): | |
'''Embedding labels to one-hot. | |
Args: | |
labels: (LongTensor) class labels, sized [N,]. | |
num_classes: (int) number of classes. | |
Returns: | |
(tensor) encoded labels, sized [N,#classes]. | |
''' | |
y = torch.eye(num_classes, device=labels.device) # [D,D] | |
return y[labels] # [N,D] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment