Skip to content

Instantly share code, notes, and snippets.

@AlessandroMondin
Last active December 18, 2022 18:37
Show Gist options
  • Save AlessandroMondin/c46213b825d931a925bed2239e29686f to your computer and use it in GitHub Desktop.
Save AlessandroMondin/c46213b825d931a925bed2239e29686f to your computer and use it in GitHub Desktop.
class HEADS(nn.Module):
def __init__(self, nc=80, anchors=(), ch=()): # detection layer
super(HEADS, self).__init__()
self.nc = nc # number of classes
self.nl = len(anchors) # number of detection layers
self.naxs = len(anchors[0]) # number of anchors per scale
self.stride = [8, 16, 32]
# anchors are divided by the stride (anchors_for_head_1/8, anchors_for_head_1/16 etc.)
anchors_ = torch.tensor(anchors).float().view(self.nl, -1, 2) / torch.tensor(self.stride).repeat(6, 1).T.reshape(3, 3, 2)
self.register_buffer('anchors', anchors_)
self.out_convs = nn.ModuleList()
for in_channels in ch:
self.out_convs += [
nn.Conv2d(in_channels=in_channels, out_channels=(5+self.nc) * self.naxs, kernel_size=1)
]
def forward(self, x):
for i in range(self.nl):
x[i] = self.out_convs[i](x[i])
bs, _, grid_y, grid_x = x[i].shape
x[i] = x[i].view(bs, self.naxs, (5+self.nc), grid_y, grid_x).permute(0, 1, 3, 4, 2).contiguous()
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment