Skip to content

Instantly share code, notes, and snippets.

@sizhky
Created June 1, 2020 17:33
Show Gist options
  • Save sizhky/61a0eb0f8fa99a1b9c53ef1436d9554b to your computer and use it in GitHub Desktop.
Save sizhky/61a0eb0f8fa99a1b9c53ef1436d9554b to your computer and use it in GitHub Desktop.
import torch, math
import torch.nn as nn
import torch.nn.functional as F
class RetinaNet(nn.Module):
num_anchors = 9
def __init__(self, num_classes):
super(RetinaNet, self).__init__()
self.fpn = FPN50()
self.num_classes = num_classes
self.loc_head = self._make_head(self.num_anchors*4)
self.cls_head = self._make_head(self.num_anchors*self.num_classes)
def forward(self, x):
loc_preds = []
cls_preds = []
fms = self.fpn(x)
for fm in fms:
loc_pred = self.loc_head(fm)
cls_pred = self.cls_head(fm)
loc_pred = loc_pred.permute(0,2,3,1).reshape(x.size(0),-1,4) # [N, 9*4,H,W] -> [N,H,W, 9*4] -> [N,H*W*9, 4]
cls_pred = cls_pred.permute(0,2,3,1).reshape(x.size(0),-1,self.num_classes) # [N,9*NC,H,W] -> [N,H,W,9*NC] -> [N,H*W*9,NC]
loc_preds.append(loc_pred)
cls_preds.append(cls_pred)
return torch.cat(loc_preds, 1), torch.cat(cls_preds, 1)
def _make_head(self, out_planes):
layers = []
for _ in range(4):
layers.append(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1))
layers.append(nn.ReLU(True))
layers.append(nn.Conv2d(256, out_planes, kernel_size=3, stride=1, padding=1))
return nn.Sequential(*layers)
def FPN50():
return FPN(Bottleneck, [3,4,6,3])
class FPN(nn.Module):
def __init__(self, block, num_blocks):
super(FPN, self).__init__()
self.in_planes = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
# Bottom-up layers
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.conv6 = nn.Conv2d(2048, 256, kernel_size=3, stride=2, padding=1)
self.conv7 = nn.Conv2d( 256, 256, kernel_size=3, stride=2, padding=1)
# Top-down layers
self.toplayer = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=0)
# Lateral layers
self.latlayer1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)
self.latlayer2 = nn.Conv2d( 512, 256, kernel_size=1, stride=1, padding=0)
# Smooth layers
self.smooth1 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.smooth2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1]*(num_blocks-1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def _upsample_add(self, x, y):
'''Upsample and add two feature maps.
Args:
x: (Variable) top feature map to be upsampled.
y: (Variable) lateral feature map.
Returns:
(Variable) added feature map.
Note in PyTorch, when input size is odd, the upsampled feature map
with `F.upsample(..., scale_factor=2, mode='nearest')`
maybe not equal to the lateral feature map size.
e.g.
original input size: [N,_,15,15] ->
conv2d feature map size: [N,_,8,8] ->
upsampled feature map size: [N,_,16,16]
So we choose bilinear upsample which supports arbitrary output sizes.
'''
_,_,H,W = y.size()
return F.upsample(x, size=(H,W), mode='bilinear', align_corners=False) + y
def forward(self, x):
# Bottom-up
c1 = F.relu(self.bn1(self.conv1(x)))
c1 = F.max_pool2d(c1, kernel_size=3, stride=2, padding=1)
c2 = self.layer1(c1)
c3 = self.layer2(c2)
c4 = self.layer3(c3)
c5 = self.layer4(c4)
p6 = self.conv6(c5)
p7 = self.conv7(F.relu(p6))
# Top-down
p5 = self.toplayer(c5)
p4 = self._upsample_add(p5, self.latlayer1(c4))
p4 = self.smooth1(p4)
p3 = self._upsample_add(p4, self.latlayer2(c3))
p3 = self.smooth2(p3)
return p3, p4, p5, p6, p7
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, in_planes, planes, stride=1):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(self.expansion*planes)
self.downsample = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out += self.downsample(x)
out = F.relu(out)
return out
'''Encode object boxes and labels.'''
class RetinaBoxCoder:
def __init__(self):
self.anchor_areas = (32*32., 64*64., 128*128., 256*256., 512*512.) # p3 -> p7
self.aspect_ratios = (1/2., 1/1., 2/1.)
self.scale_ratios = (1., pow(2,1/3.), pow(2,2/3.))
self.anchor_boxes = self._get_anchor_boxes(input_size=torch.tensor([640.,640.]))
def _get_anchor_wh(self):
'''Compute anchor width and height for each feature map.
Returns:
anchor_wh: (tensor) anchor wh, sized [#fm, #anchors_per_cell, 2].
'''
anchor_wh = []
for s in self.anchor_areas:
for ar in self.aspect_ratios: # w/h = ar
h = math.sqrt(s/ar)
w = ar * h
for sr in self.scale_ratios: # scale
anchor_h = h*sr
anchor_w = w*sr
anchor_wh.append([anchor_w, anchor_h])
num_fms = len(self.anchor_areas)
return torch.Tensor(anchor_wh).view(num_fms, -1, 2)
def _get_anchor_boxes(self, input_size):
'''Compute anchor boxes for each feature map.
Args:
input_size: (tensor) model input size of (w,h).
Returns:
boxes: (list) anchor boxes for each feature map. Each of size [#anchors,4],
where #anchors = fmw * fmh * #anchors_per_cell
'''
num_fms = len(self.anchor_areas)
anchor_wh = self._get_anchor_wh()
fm_sizes = [(input_size/pow(2.,i+3)).ceil() for i in range(num_fms)] # p3 -> p7 feature map sizes
boxes = []
for i in range(num_fms):
fm_size = fm_sizes[i]
grid_size = input_size / fm_size
fm_w, fm_h = int(fm_size[0]), int(fm_size[1])
xy = meshgrid(fm_w,fm_h) + 0.5 # [fm_h*fm_w, 2]
xy = (xy*grid_size).view(fm_h,fm_w,1,2).expand(fm_h,fm_w,9,2)
wh = anchor_wh[i].view(1,1,9,2).expand(fm_h,fm_w,9,2)
box = torch.cat([xy-wh/2.,xy+wh/2.], 3) # [x,y,x,y]
boxes.append(box.view(-1,4))
return torch.cat(boxes, 0)
def encode(self, boxes, labels):
'''Encode target bounding boxes and class labels.
We obey the Faster RCNN box coder:
tx = (x - anchor_x) / anchor_w
ty = (y - anchor_y) / anchor_h
tw = log(w / anchor_w)
th = log(h / anchor_h)
Args:
boxes: (tensor) bounding boxes of (xmin,ymin,xmax,ymax), sized [#obj, 4].
labels: (tensor) object class labels, sized [#obj,].
Returns:
loc_targets: (tensor) encoded bounding boxes, sized [#anchors,4].
cls_targets: (tensor) encoded class labels, sized [#anchors,].
'''
anchor_boxes = self.anchor_boxes
ious = box_iou(anchor_boxes, boxes)
max_ious, max_ids = ious.max(1)
boxes = boxes[max_ids]
boxes = change_box_order(boxes, 'xyxy2xywh')
anchor_boxes = change_box_order(anchor_boxes, 'xyxy2xywh')
loc_xy = (boxes[:,:2]-anchor_boxes[:,:2]) / anchor_boxes[:,2:]
loc_wh = torch.log(boxes[:,2:]/anchor_boxes[:,2:])
loc_targets = torch.cat([loc_xy,loc_wh], 1)
cls_targets = 1 + labels[max_ids]
# cls_targets[max_ious<0.5] = 0
# ignore = (max_ious>0.4) & (max_ious<0.5) # ignore ious between [0.4,0.5]
# cls_targets[ignore] = -1 # mark ignored to -1
return loc_targets, cls_targets
def decode(self, loc_preds, cls_preds, input_size):
'''Decode outputs back to bouding box locations and class labels.
Args:
loc_preds: (tensor) predicted locations, sized [#anchors, 4].
cls_preds: (tensor) predicted class labels, sized [#anchors, #classes].
input_size: (tuple) model input size of (w,h).
Returns:
boxes: (tensor) decode box locations, sized [#obj,4].
labels: (tensor) class labels for each box, sized [#obj,].
'''
CLS_THRESH = 0.5
NMS_THRESH = 0.5
input_size = torch.Tensor(input_size)
anchor_boxes = self._get_anchor_boxes(input_size) # xywh
loc_xy = loc_preds[:,:2]
loc_wh = loc_preds[:,2:]
xy = loc_xy * anchor_boxes[:,2:] + anchor_boxes[:,:2]
wh = loc_wh.exp() * anchor_boxes[:,2:]
boxes = torch.cat([xy-wh/2, xy+wh/2], 1) # [#anchors,4]
score, labels = cls_preds.sigmoid().max(1) # [#anchors,]
ids = score > CLS_THRESH
ids = ids.nonzero().squeeze() # [#obj,]
keep = box_nms(boxes[ids], score[ids], threshold=NMS_THRESH)
return boxes[ids][keep], labels[ids][keep]
class FocalLoss(nn.Module):
def __init__(self, num_classes):
super(FocalLoss, self).__init__()
self.num_classes = num_classes
def _focal_loss(self, x, y):
'''Focal loss.
This is described in the original paper.
With BCELoss, the background should not be counted in num_classes.
Args:
x: (tensor) predictions, sized [N,D].
y: (tensor) targets, sized [N,].
Return:
(tensor) focal loss.
'''
alpha = 0.25
gamma = 2
t = one_hot_embedding(y-1, self.num_classes)
p = x.sigmoid()
pt = torch.where(t>0, p, 1-p) # pt = p if t > 0 else 1-p
w = (1-pt).pow(gamma)
w = torch.where(t>0, alpha*w, (1-alpha)*w).detach()
loss = F.binary_cross_entropy_with_logits(x, t, w, size_average=False)
return loss
def forward(self, loc_preds, cls_preds, loc_targets, cls_targets):
'''Compute loss between (loc_preds, loc_targets) and (cls_preds, cls_targets).
Args:
loc_preds: (tensor) predicted locations, sized [batch_size, #anchors, 4].
loc_targets: (tensor) encoded target locations, sized [batch_size, #anchors, 4].
cls_preds: (tensor) predicted class confidences, sized [batch_size, #anchors, #classes].
cls_targets: (tensor) encoded target labels, sized [batch_size, #anchors].
loss:
(tensor) loss = SmoothL1Loss(loc_preds, loc_targets) + FocalLoss(cls_preds, cls_targets).
'''
batch_size, num_boxes = cls_targets.size()
pos = cls_targets > 0 # [N,#anchors]
num_pos = pos.sum().item()
#===============================================================
# loc_loss = SmoothL1Loss(pos_loc_preds, pos_loc_targets)
#===============================================================
mask = pos.unsqueeze(2).expand_as(loc_preds) # [N,#anchors,4]
loc_loss = F.smooth_l1_loss(loc_preds[mask], loc_targets[mask], size_average=False)
#===============================================================
# cls_loss = FocalLoss(cls_preds, cls_targets)
#===============================================================
pos_neg = cls_targets > -1 # exclude ignored anchors
mask = pos_neg.unsqueeze(2).expand_as(cls_preds)
masked_cls_preds = cls_preds[mask].view(-1,self.num_classes)
cls_loss = self._focal_loss(masked_cls_preds, cls_targets[pos_neg])
# print('loc_loss: %.3f | cls_loss: %.3f' % (loc_loss.item()/num_pos, cls_loss.item()/num_pos), end=' | ')
loss = (loc_loss+cls_loss)/num_pos
return loss
def change_box_order(boxes, order):
'''Change box order between (xmin,ymin,xmax,ymax) and (xcenter,ycenter,width,height).
Args:
boxes: (tensor) bounding boxes, sized [N,4].
order: (str) either 'xyxy2xywh' or 'xywh2xyxy'.
Returns:
(tensor) converted bounding boxes, sized [N,4].
'''
assert order in ['xyxy2xywh','xywh2xyxy']
a = boxes[:,:2]
b = boxes[:,2:]
if order == 'xyxy2xywh':
return torch.cat([(a+b)/2,b-a], 1)
return torch.cat([a-b/2,a+b/2], 1)
def box_clamp(boxes, xmin, ymin, xmax, ymax):
'''Clamp boxes.
Args:
boxes: (tensor) bounding boxes of (xmin,ymin,xmax,ymax), sized [N,4].
xmin: (number) min value of x.
ymin: (number) min value of y.
xmax: (number) max value of x.
ymax: (number) max value of y.
Returns:
(tensor) clamped boxes.
'''
boxes[:,0].clamp_(min=xmin, max=xmax)
boxes[:,1].clamp_(min=ymin, max=ymax)
boxes[:,2].clamp_(min=xmin, max=xmax)
boxes[:,3].clamp_(min=ymin, max=ymax)
return boxes
def box_select(boxes, xmin, ymin, xmax, ymax):
'''Select boxes in range (xmin,ymin,xmax,ymax).
Args:
boxes: (tensor) bounding boxes of (xmin,ymin,xmax,ymax), sized [N,4].
xmin: (number) min value of x.
ymin: (number) min value of y.
xmax: (number) max value of x.
ymax: (number) max value of y.
Returns:
(tensor) selected boxes, sized [M,4].
(tensor) selected mask, sized [N,].
'''
mask = (boxes[:,0]>=xmin) & (boxes[:,1]>=ymin) \
& (boxes[:,2]<=xmax) & (boxes[:,3]<=ymax)
boxes = boxes[mask,:]
return boxes, mask
def box_iou(box1, box2):
'''Compute the intersection over union of two set of boxes.
The box order must be (xmin, ymin, xmax, ymax).
Args:
box1: (tensor) bounding boxes, sized [N,4].
box2: (tensor) bounding boxes, sized [M,4].
Return:
(tensor) iou, sized [N,M].
Reference:
https://github.com/chainer/chainercv/blob/master/chainercv/utils/bbox/bbox_iou.py
'''
N = box1.size(0)
M = box2.size(0)
lt = torch.max(box1[:,None,:2], box2[:,:2]) # [N,M,2]
rb = torch.min(box1[:,None,2:], box2[:,2:]) # [N,M,2]
wh = (rb-lt).clamp(min=0) # [N,M,2]
inter = wh[:,:,0] * wh[:,:,1] # [N,M]
area1 = (box1[:,2]-box1[:,0]) * (box1[:,3]-box1[:,1]) # [N,]
area2 = (box2[:,2]-box2[:,0]) * (box2[:,3]-box2[:,1]) # [M,]
iou = inter / (area1[:,None] + area2 - inter)
return iou
def box_nms(bboxes, scores, threshold=0.5):
'''Non maximum suppression.
Args:
bboxes: (tensor) bounding boxes, sized [N,4].
scores: (tensor) confidence scores, sized [N,].
threshold: (float) overlap threshold.
Returns:
keep: (tensor) selected indices.
Reference:
https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/nms/py_cpu_nms.py
'''
x1 = bboxes[:,0]
y1 = bboxes[:,1]
x2 = bboxes[:,2]
y2 = bboxes[:,3]
areas = (x2-x1) * (y2-y1)
_, order = scores.sort(0, descending=True)
keep = []
while order.numel() > 0:
i = order[0]
keep.append(i)
if order.numel() == 1:
break
xx1 = x1[order[1:]].clamp(min=x1[i].item())
yy1 = y1[order[1:]].clamp(min=y1[i].item())
xx2 = x2[order[1:]].clamp(max=x2[i].item())
yy2 = y2[order[1:]].clamp(max=y2[i].item())
w = (xx2-xx1).clamp(min=0)
h = (yy2-yy1).clamp(min=0)
inter = w * h
overlap = inter / (areas[i] + areas[order[1:]] - inter)
ids = (overlap<=threshold).nonzero().squeeze()
if ids.numel() == 0:
break
order = order[ids+1]
return torch.tensor(keep, dtype=torch.long)
def meshgrid(x, y, row_major=True):
'''Return meshgrid in range x & y.
Args:
x: (int) first dim range.
y: (int) second dim range.
row_major: (bool) row major or column major.
Returns:
(tensor) meshgrid, sized [x*y,2]
Example:
>> meshgrid(3,2)
0 0
1 0
2 0
0 1
1 1
2 1
[torch.FloatTensor of size 6x2]
>> meshgrid(3,2,row_major=False)
0 0
0 1
0 2
1 0
1 1
1 2
[torch.FloatTensor of size 6x2]
'''
a = torch.arange(0,x)
b = torch.arange(0,y)
xx = a.repeat(y).view(-1,1)
yy = b.view(-1,1).repeat(1,x).view(-1,1)
return torch.cat([xx,yy],1) if row_major else torch.cat([yy,xx],1)
def one_hot_embedding(labels, num_classes):
'''Embedding labels to one-hot.
Args:
labels: (LongTensor) class labels, sized [N,].
num_classes: (int) number of classes.
Returns:
(tensor) encoded labels, sized [N,#classes].
'''
y = torch.eye(num_classes, device=labels.device) # [D,D]
return y[labels] # [N,D]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment