Last active
December 1, 2023 17:28
-
-
Save ariG23498/777ea321f4c294842f3f7de45dde8258 to your computer and use it in GitHub Desktop.
anchor-generator
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/ariG23498/777ea321f4c294842f3f7de45dde8258/scratchpad.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install --upgrade -qq tensorflow\n", | |
"!pip install --upgrade -qq keras" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "kXSYq9j_Oq76", | |
"outputId": "af0ee3cc-7dca-487a-848f-868479121d89" | |
}, | |
"execution_count": 1, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m475.2/475.2 MB\u001b[0m \u001b[31m2.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.5/5.5 MB\u001b[0m \u001b[31m58.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m442.0/442.0 kB\u001b[0m \u001b[31m31.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.7/1.7 MB\u001b[0m \u001b[31m58.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m997.1/997.1 kB\u001b[0m \u001b[31m6.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25h\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", | |
"tensorflow 2.15.0 requires keras<2.16,>=2.15.0, but you have keras 3.0.0 which is incompatible.\u001b[0m\u001b[31m\n", | |
"\u001b[0m" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import os\n", | |
"os.environ[\"KERAS_BACKEND\"] = \"numpy\"\n", | |
"from keras import ops\n", | |
"import math" | |
], | |
"metadata": { | |
"id": "VJJSDbHOWvFM" | |
}, | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"id": "lIYdn1woOS1n" | |
}, | |
"outputs": [], | |
"source": [ | |
"sizes = [2**x for x in [5, 6, 7]]\n", | |
"scales = [2**x for x in [0, 1/3, 2/3]]\n", | |
"aspect_ratios = [1/2, 1/1, 2/1]\n", | |
"strides = [2**x for x in [3, 4, 5]]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def param_to_level_dict(param):\n", | |
" result = dict()\n", | |
" for i in range(len(param)):\n", | |
" result[i] = param[i]\n", | |
" return result" | |
], | |
"metadata": { | |
"id": "P11eSroq_CMz" | |
}, | |
"execution_count": 4, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"sizes = param_to_level_dict(sizes)\n", | |
"strides = param_to_level_dict(strides)" | |
], | |
"metadata": { | |
"id": "HOsU26mdBHN9" | |
}, | |
"execution_count": 5, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"sizes" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "oWM0jrSKBPMJ", | |
"outputId": "cc106551-9802-47c3-f76e-f77c0029bf0e" | |
}, | |
"execution_count": 6, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"{0: 32, 1: 64, 2: 128}" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 6 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"strides" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "7IRvRRRWBPvq", | |
"outputId": "6b162d57-c109-48dd-eae6-5bf3f69d1938" | |
}, | |
"execution_count": 7, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"{0: 8, 1: 16, 2: 32}" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 7 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def match_param_structure_to_sizes(params, sizes):\n", | |
" return {key: params for key in sizes.keys()}" | |
], | |
"metadata": { | |
"id": "5-LdF9wcBQ1i" | |
}, | |
"execution_count": 8, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"aspect_ratios = match_param_structure_to_sizes(\n", | |
" aspect_ratios, sizes\n", | |
")\n", | |
"scales = match_param_structure_to_sizes(\n", | |
" scales, sizes\n", | |
")" | |
], | |
"metadata": { | |
"id": "pgY6sI1SBZ_t" | |
}, | |
"execution_count": 9, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"aspect_ratios" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "QUZ1rKnZBhbE", | |
"outputId": "03bda8e6-783e-4c82-d373-4b455988922f" | |
}, | |
"execution_count": 10, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"{0: [0.5, 1.0, 2.0], 1: [0.5, 1.0, 2.0], 2: [0.5, 1.0, 2.0]}" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 10 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"scales" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "hJfIfUcdBii5", | |
"outputId": "2af1338c-865a-4c3a-a0e2-4474c4458a5c" | |
}, | |
"execution_count": 11, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"{0: [1, 1.2599210498948732, 1.5874010519681994],\n", | |
" 1: [1, 1.2599210498948732, 1.5874010519681994],\n", | |
" 2: [1, 1.2599210498948732, 1.5874010519681994]}" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 11 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"anchor_generators = {}\n", | |
"for k in sizes.keys():\n", | |
" anchor_generators[k] = {\n", | |
" \"sizes\": sizes[k],\n", | |
" \"scales\": scales[k],\n", | |
" \"aspect_ratios\": aspect_ratios[k],\n", | |
" \"strides\": strides[k],\n", | |
" }" | |
], | |
"metadata": { | |
"id": "KK48MXtCBz0z" | |
}, | |
"execution_count": 12, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"anchor_generators" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "yHVQt-98CMZ5", | |
"outputId": "103353cf-9429-46f6-9f67-62885be09688" | |
}, | |
"execution_count": 13, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"{0: {'sizes': 32,\n", | |
" 'scales': [1, 1.2599210498948732, 1.5874010519681994],\n", | |
" 'aspect_ratios': [0.5, 1.0, 2.0],\n", | |
" 'strides': 8},\n", | |
" 1: {'sizes': 64,\n", | |
" 'scales': [1, 1.2599210498948732, 1.5874010519681994],\n", | |
" 'aspect_ratios': [0.5, 1.0, 2.0],\n", | |
" 'strides': 16},\n", | |
" 2: {'sizes': 128,\n", | |
" 'scales': [1, 1.2599210498948732, 1.5874010519681994],\n", | |
" 'aspect_ratios': [0.5, 1.0, 2.0],\n", | |
" 'strides': 32}}" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 13 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"```\n", | |
"SingleAnchorGenerator(\n", | |
" sizes=32,\n", | |
" scales=[1, 1.25, 1.58],\n", | |
" aspect_ratios=[0.5, 1.0, 2.0],\n", | |
" stride=8,\n", | |
")\n", | |
"```" | |
], | |
"metadata": { | |
"id": "8t9o93jMNpq7" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def single_anchor_gen(config):\n", | |
" sizes = config['sizes'] # 32\n", | |
" scales = config['scales'] # [1, 1.25, 1.33]\n", | |
" aspect_ratios = config['aspect_ratios'] # [0.5, 1.0, 2.0]\n", | |
" stride = config['strides'] # 8\n", | |
"\n", | |
" image_height = 512\n", | |
" image_width = 512\n", | |
"\n", | |
" aspect_ratios = ops.cast(aspect_ratios, \"float32\")\n", | |
" aspect_ratios_sqrt = ops.cast(ops.sqrt(aspect_ratios), dtype=\"float32\")\n", | |
" anchor_size = ops.cast(sizes, \"float32\")\n", | |
"\n", | |
" # [K]\n", | |
" anchor_heights = []\n", | |
" anchor_widths = []\n", | |
" for scale in scales:\n", | |
" anchor_size_t = anchor_size * scale\n", | |
" anchor_height = anchor_size_t / aspect_ratios_sqrt\n", | |
" anchor_width = anchor_size_t * aspect_ratios_sqrt\n", | |
" anchor_heights.append(anchor_height)\n", | |
" anchor_widths.append(anchor_width)\n", | |
" anchor_heights = ops.concatenate(anchor_heights, axis=0)\n", | |
" anchor_widths = ops.concatenate(anchor_widths, axis=0)\n", | |
" half_anchor_heights = ops.reshape(0.5 * anchor_heights, [1, 1, -1])\n", | |
" half_anchor_widths = ops.reshape(0.5 * anchor_widths, [1, 1, -1])\n", | |
"\n", | |
" stride = stride\n", | |
" # make sure range of `cx` is within limit of `image_width` with\n", | |
" # `stride`, also for sizes where `image_width % stride != 0`.\n", | |
" # [W]\n", | |
" cx = ops.cast(\n", | |
" ops.arange(\n", | |
" 0.5 * stride, math.ceil(image_width / stride) * stride, stride\n", | |
" ),\n", | |
" \"float32\",\n", | |
" )\n", | |
" # make sure range of `cy` is within limit of `image_height` with\n", | |
" # `stride`, also for sizes where `image_height % stride != 0`.\n", | |
" # [H]\n", | |
" cy = ops.cast(\n", | |
" ops.arange(\n", | |
" 0.5 * stride, math.ceil(image_height / stride) * stride, stride\n", | |
" ),\n", | |
" \"float32\",\n", | |
" )\n", | |
" # [H, W]\n", | |
" cx_grid, cy_grid = ops.meshgrid(cx, cy)\n", | |
" # [H, W, 1]\n", | |
" cx_grid = ops.expand_dims(cx_grid, axis=-1)\n", | |
" cy_grid = ops.expand_dims(cy_grid, axis=-1)\n", | |
"\n", | |
" y_min = ops.reshape(cy_grid - half_anchor_heights, (-1,))\n", | |
" y_max = ops.reshape(cy_grid + half_anchor_heights, (-1,))\n", | |
" x_min = ops.reshape(cx_grid - half_anchor_widths, (-1,))\n", | |
" x_max = ops.reshape(cx_grid + half_anchor_widths, (-1,))\n", | |
"\n", | |
" # [H * W * K, 1]\n", | |
" y_min = ops.expand_dims(y_min, axis=-1)\n", | |
" y_max = ops.expand_dims(y_max, axis=-1)\n", | |
" x_min = ops.expand_dims(x_min, axis=-1)\n", | |
" x_max = ops.expand_dims(x_max, axis=-1)\n", | |
"\n", | |
" # [H * W * K, 4]\n", | |
" return ops.cast(\n", | |
" ops.concatenate([y_min, x_min, y_max, x_max], axis=-1), \"float32\"\n", | |
" )" | |
], | |
"metadata": { | |
"id": "MvGGb4jQCN-3" | |
}, | |
"execution_count": 23, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"single_anchor_gen(anchor_generators[0])" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "4FYWI2dnWdma", | |
"outputId": "089c223f-8b91-409d-a47a-8e779bac48a3" | |
}, | |
"execution_count": 24, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"array([[-18.627417 , -7.3137083, 26.627417 , 15.313708 ],\n", | |
" [-12. , -12. , 20. , 20. ],\n", | |
" [ -7.3137083, -18.627417 , 15.313708 , 26.627417 ],\n", | |
" ...,\n", | |
" [472.0812 , 490.04062 , 543.91876 , 525.9594 ],\n", | |
" [482.6016 , 482.6016 , 533.39844 , 533.39844 ],\n", | |
" [490.04062 , 472.0812 , 525.9594 , 543.91876 ]],\n", | |
" dtype=float32)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 24 | |
} | |
] | |
} | |
], | |
"metadata": { | |
"colab": { | |
"name": "scratchpad", | |
"provenance": [], | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"display_name": "Python 3", | |
"name": "python3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment