Created
October 5, 2021 02:01
-
-
Save PWhiddy/689155a8fb292d62ddc66e8cf53bcf56 to your computer and use it in GitHub Desktop.
Improved implementation of bbox_to_mask
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "49dcc530", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"import random" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"@torch.jit.ignore\n", | |
"def validate_bbox(boxes: torch.Tensor) -> bool:\n", | |
" \"\"\"Validate if a 2D bounding box usable or not. This function checks if the boxes are rectangular or not.\n", | |
" Args:\n", | |
" boxes: a tensor containing the coordinates of the bounding boxes to be extracted. The tensor must have the shape\n", | |
" of Bx4x2, where each box is defined in the following ``clockwise`` order: top-left, top-right, bottom-right,\n", | |
" bottom-left. The coordinates must be in the x, y order.\n", | |
" \"\"\"\n", | |
" if not (len(boxes.shape) == 3 and boxes.shape[1:] == torch.Size([4, 2])):\n", | |
" raise AssertionError(f\"Box shape must be (B, 4, 2). Got {boxes.shape}.\")\n", | |
"\n", | |
" if not torch.allclose((boxes[:, 1, 0] - boxes[:, 0, 0] + 1), (boxes[:, 2, 0] - boxes[:, 3, 0] + 1)):\n", | |
" raise ValueError(\n", | |
" \"Boxes must have be rectangular, while get widths %s and %s\"\n", | |
" % (str(boxes[:, 1, 0] - boxes[:, 0, 0] + 1), str(boxes[:, 2, 0] - boxes[:, 3, 0] + 1))\n", | |
" )\n", | |
"\n", | |
" if not torch.allclose((boxes[:, 2, 1] - boxes[:, 0, 1] + 1), (boxes[:, 3, 1] - boxes[:, 1, 1] + 1)):\n", | |
" raise ValueError(\n", | |
" \"Boxes must have be rectangular, while get heights %s and %s\"\n", | |
" % (str(boxes[:, 2, 1] - boxes[:, 0, 1] + 1), str(boxes[:, 3, 1] - boxes[:, 1, 1] + 1))\n", | |
" )\n", | |
"\n", | |
" return True\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "69c587df", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def bbox_to_mask_old(boxes: torch.Tensor, width: int, height: int) -> torch.Tensor:\n", | |
" \"\"\"Convert 2D bounding boxes to masks. Covered area is 1. and the remaining is 0.\n", | |
" Args:\n", | |
" boxes: a tensor containing the coordinates of the bounding boxes to be extracted. The tensor must have the shape\n", | |
" of Bx4x2, where each box is defined in the following ``clockwise`` order: top-left, top-right, bottom-right\n", | |
" and bottom-left. The coordinates must be in the x, y order.\n", | |
" width: width of the masked image.\n", | |
" height: height of the masked image.\n", | |
" Returns:\n", | |
" the output mask tensor.\n", | |
" Note:\n", | |
" It is currently non-differentiable.\n", | |
" Examples:\n", | |
" >>> boxes = torch.tensor([[\n", | |
" ... [1., 1.],\n", | |
" ... [3., 1.],\n", | |
" ... [3., 2.],\n", | |
" ... [1., 2.],\n", | |
" ... ]]) # 1x4x2\n", | |
" >>> bbox_to_mask(boxes, 5, 5)\n", | |
" tensor([[[0., 0., 0., 0., 0.],\n", | |
" [0., 1., 1., 1., 0.],\n", | |
" [0., 1., 1., 1., 0.],\n", | |
" [0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0.]]])\n", | |
" \"\"\"\n", | |
" validate_bbox(boxes)\n", | |
" # zero padding the surroudings\n", | |
" mask = torch.zeros((len(boxes), height + 2, width + 2))\n", | |
" # push all points one pixel off\n", | |
" # in order to zero-out the fully filled rows or columns\n", | |
" boxes_shifted = boxes + 1\n", | |
"\n", | |
" mask_out = []\n", | |
" # TODO: Looking for a vectorized way\n", | |
" for m, box in zip(mask, boxes_shifted):\n", | |
" m = m.index_fill(1, torch.arange(box[0, 0].item(), box[1, 0].item() + 1, dtype=torch.long), torch.tensor(1))\n", | |
" m = m.index_fill(0, torch.arange(box[1, 1].item(), box[2, 1].item() + 1, dtype=torch.long), torch.tensor(1))\n", | |
" m = m.unsqueeze(dim=0)\n", | |
" m_out = (m == 1).all(dim=1) * (m == 1).all(dim=2).T\n", | |
" m_out = m_out[1:-1, 1:-1]\n", | |
" mask_out.append(m_out)\n", | |
"\n", | |
" return torch.stack(mask_out, dim=0).float()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def bbox_to_mask_new(boxes: torch.Tensor, width: int, height: int) -> torch.Tensor:\n", | |
" \"\"\"Convert 2D bounding boxes to masks. Covered area is 1. and the remaining is 0.\n", | |
" Args:\n", | |
" boxes: a tensor containing the coordinates of the bounding boxes to be extracted. The tensor must have the shape\n", | |
" of Bx4x2, where each box is defined in the following ``clockwise`` order: top-left, top-right, bottom-right\n", | |
" and bottom-left. The coordinates must be in the x, y order.\n", | |
" width: width of the masked image.\n", | |
" height: height of the masked image.\n", | |
" Returns:\n", | |
" the output mask tensor.\n", | |
" Note:\n", | |
" It is currently non-differentiable.\n", | |
" Examples:\n", | |
" >>> boxes = torch.tensor([[\n", | |
" ... [1., 1.],\n", | |
" ... [3., 1.],\n", | |
" ... [3., 2.],\n", | |
" ... [1., 2.],\n", | |
" ... ]]) # 1x4x2\n", | |
" >>> bbox_to_mask(boxes, 5, 5)\n", | |
" tensor([[[0., 0., 0., 0., 0.],\n", | |
" [0., 1., 1., 1., 0.],\n", | |
" [0., 1., 1., 1., 0.],\n", | |
" [0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0.]]])\n", | |
" \"\"\"\n", | |
" validate_bbox(boxes)\n", | |
" # zero padding the surroudings\n", | |
" mask = torch.zeros((len(boxes), height + 2, width + 2), device=boxes.device)\n", | |
" # push all points one pixel off\n", | |
" # in order to zero-out the fully filled rows or columns\n", | |
" box_i = (boxes + 1).long()\n", | |
" # set all pixels within box to 1\n", | |
" mask[:, box_i[:, 0, 1]:box_i[:, 2, 1] + 1, \n", | |
" box_i[:, 0, 0]:box_i[:, 1, 0] + 1] = 1.0 \n", | |
" return mask[:, 1:-1, 1:-1]\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "6b955970", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"test_dev = 'cpu'" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "42f7a15d", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[[0., 0., 0., 0., 0.],\n", | |
" [0., 1., 1., 1., 0.],\n", | |
" [0., 1., 1., 1., 0.],\n", | |
" [0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0.]]])" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"boxes = torch.tensor([[\n", | |
" [1., 1.],\n", | |
" [3., 1.],\n", | |
" [3., 2.],\n", | |
" [1., 2.],\n", | |
" ]], device=test_dev) # 1x4x2\n", | |
"bbox_to_mask_old(boxes, 5, 5)\n", | |
" " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "30f58ded", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor([[[0., 0., 0., 0., 0.],\n", | |
" [0., 1., 1., 1., 0.],\n", | |
" [0., 1., 1., 1., 0.],\n", | |
" [0., 0., 0., 0., 0.],\n", | |
" [0., 0., 0., 0., 0.]]])" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"boxes = torch.tensor([[\n", | |
" [1., 1.],\n", | |
" [3., 1.],\n", | |
" [3., 2.],\n", | |
" [1., 2.],\n", | |
" ]], device=test_dev) # 1x4x2\n", | |
"bbox_to_mask_new(boxes, 5, 5)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "30fccef8", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"True" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.allclose(bbox_to_mask_old(boxes, 5, 5), bbox_to_mask_new(boxes, 5, 5))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def test_random_boxes(func):\n", | |
" random.seed(0)\n", | |
" w = random.randint(130, 180)\n", | |
" h = random.randint(70, 120)\n", | |
" bbox = torch.tensor([[\n", | |
" [10., 10.],\n", | |
" [w, 10.],\n", | |
" [w, h],\n", | |
" [10., h],\n", | |
" ]], device=test_dev)\n", | |
" return func(bbox, random.randint(200, 300), random.randint(200, 300))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"True" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.allclose(test_random_boxes(bbox_to_mask_old), test_random_boxes(bbox_to_mask_new))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "bc60c111", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"624 µs ± 12.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit -n 1000\n", | |
"w = random.randint(130, 180)\n", | |
"h = random.randint(70, 120)\n", | |
"bbox = torch.tensor([[\n", | |
" [10., 10.],\n", | |
" [w, 10.],\n", | |
" [w, h],\n", | |
" [10., h],\n", | |
" ]], device=test_dev)\n", | |
"bbox_to_mask_old(bbox, random.randint(200, 300), random.randint(200, 300))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "c7f704f0", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"257 µs ± 3.26 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit -n 1000\n", | |
"w = random.randint(130, 180)\n", | |
"h = random.randint(70, 120)\n", | |
"bbox = torch.tensor([[\n", | |
" [10., 10.],\n", | |
" [w, 10.],\n", | |
" [w, h],\n", | |
" [10., h],\n", | |
" ]], device=test_dev)\n", | |
"bbox_to_mask_new(bbox, random.randint(200, 300), random.randint(200, 300))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "548a4217", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.0" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment