Skip to content

Instantly share code, notes, and snippets.

@AlessandroMondin
Last active December 19, 2022 10:22
Show Gist options
  • Save AlessandroMondin/46cc8a920807895411ba7cbf6cf7b6c6 to your computer and use it in GitHub Desktop.
Save AlessandroMondin/46cc8a920807895411ba7cbf6cf7b6c6 to your computer and use it in GitHub Desktop.
def cells_to_bboxes(predictions, anchors, strides):
num_out_layers = len(predictions)
grid = [torch.empty(0) for _ in range(num_out_layers)] # initialize
anchor_grid = [torch.empty(0) for _ in range(num_out_layers)] # initialize
all_bboxes = []
for i in range(num_out_layers):
bs, naxs, ny, nx, _ = predictions[i].shape
stride = strides[i]
grid[i], anchor_grid[i] = make_grids(anchors, naxs, ny=ny, nx=nx, stride=stride, i=i)
layer_prediction = predictions[i].sigmoid()
obj = layer_prediction[..., 4:5]
xy = (2 * (layer_prediction[..., 0:2]) + grid[i] - 0.5) * stride
wh = ((2*layer_prediction[..., 2:4])**2) * anchor_grid[i]
best_class = torch.argmax(layer_prediction[..., 5:], dim=-1).unsqueeze(-1)
scale_bboxes = torch.cat((best_class, obj, xy, wh), dim=-1).reshape(bs, -1, 6)
all_bboxes.append(scale_bboxes)
return torch.cat(all_bboxes, dim=1)
def make_grids(anchors, naxs, stride, nx=20, ny=20, i=0):
x_grid = torch.arange(nx)
x_grid = x_grid.repeat(ny).reshape(ny, nx)
y_grid = torch.arange(ny).unsqueeze(0)
y_grid = y_grid.T.repeat(1, nx).reshape(ny, nx)
xy_grid = torch.stack([x_grid, y_grid], dim=-1)
xy_grid = xy_grid.expand(1, naxs, ny, nx, 2)
anchor_grid = (anchors[i]*stride).reshape((1, naxs, 1, 1, 2)).expand(1, naxs, ny, nx, 2)
return xy_grid, anchor_grid
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment