Skip to content

Instantly share code, notes, and snippets.

@ariG23498
Last active December 1, 2023 17:28
Show Gist options
  • Save ariG23498/777ea321f4c294842f3f7de45dde8258 to your computer and use it in GitHub Desktop.
Save ariG23498/777ea321f4c294842f3f7de45dde8258 to your computer and use it in GitHub Desktop.
anchor-generator
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/ariG23498/777ea321f4c294842f3f7de45dde8258/scratchpad.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"source": [
"!pip install --upgrade -qq tensorflow\n",
"!pip install --upgrade -qq keras"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "kXSYq9j_Oq76",
"outputId": "af0ee3cc-7dca-487a-848f-868479121d89"
},
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m475.2/475.2 MB\u001b[0m \u001b[31m2.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.5/5.5 MB\u001b[0m \u001b[31m58.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m442.0/442.0 kB\u001b[0m \u001b[31m31.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.7/1.7 MB\u001b[0m \u001b[31m58.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m997.1/997.1 kB\u001b[0m \u001b[31m6.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25h\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
"tensorflow 2.15.0 requires keras<2.16,>=2.15.0, but you have keras 3.0.0 which is incompatible.\u001b[0m\u001b[31m\n",
"\u001b[0m"
]
}
]
},
{
"cell_type": "code",
"source": [
"import os\n",
"os.environ[\"KERAS_BACKEND\"] = \"numpy\"\n",
"from keras import ops\n",
"import math"
],
"metadata": {
"id": "VJJSDbHOWvFM"
},
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "lIYdn1woOS1n"
},
"outputs": [],
"source": [
"sizes = [2**x for x in [5, 6, 7]]\n",
"scales = [2**x for x in [0, 1/3, 2/3]]\n",
"aspect_ratios = [1/2, 1/1, 2/1]\n",
"strides = [2**x for x in [3, 4, 5]]"
]
},
{
"cell_type": "code",
"source": [
"def param_to_level_dict(param):\n",
" result = dict()\n",
" for i in range(len(param)):\n",
" result[i] = param[i]\n",
" return result"
],
"metadata": {
"id": "P11eSroq_CMz"
},
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"source": [
"sizes = param_to_level_dict(sizes)\n",
"strides = param_to_level_dict(strides)"
],
"metadata": {
"id": "HOsU26mdBHN9"
},
"execution_count": 5,
"outputs": []
},
{
"cell_type": "code",
"source": [
"sizes"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "oWM0jrSKBPMJ",
"outputId": "cc106551-9802-47c3-f76e-f77c0029bf0e"
},
"execution_count": 6,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{0: 32, 1: 64, 2: 128}"
]
},
"metadata": {},
"execution_count": 6
}
]
},
{
"cell_type": "code",
"source": [
"strides"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "7IRvRRRWBPvq",
"outputId": "6b162d57-c109-48dd-eae6-5bf3f69d1938"
},
"execution_count": 7,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{0: 8, 1: 16, 2: 32}"
]
},
"metadata": {},
"execution_count": 7
}
]
},
{
"cell_type": "code",
"source": [
"def match_param_structure_to_sizes(params, sizes):\n",
" return {key: params for key in sizes.keys()}"
],
"metadata": {
"id": "5-LdF9wcBQ1i"
},
"execution_count": 8,
"outputs": []
},
{
"cell_type": "code",
"source": [
"aspect_ratios = match_param_structure_to_sizes(\n",
" aspect_ratios, sizes\n",
")\n",
"scales = match_param_structure_to_sizes(\n",
" scales, sizes\n",
")"
],
"metadata": {
"id": "pgY6sI1SBZ_t"
},
"execution_count": 9,
"outputs": []
},
{
"cell_type": "code",
"source": [
"aspect_ratios"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "QUZ1rKnZBhbE",
"outputId": "03bda8e6-783e-4c82-d373-4b455988922f"
},
"execution_count": 10,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{0: [0.5, 1.0, 2.0], 1: [0.5, 1.0, 2.0], 2: [0.5, 1.0, 2.0]}"
]
},
"metadata": {},
"execution_count": 10
}
]
},
{
"cell_type": "code",
"source": [
"scales"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "hJfIfUcdBii5",
"outputId": "2af1338c-865a-4c3a-a0e2-4474c4458a5c"
},
"execution_count": 11,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{0: [1, 1.2599210498948732, 1.5874010519681994],\n",
" 1: [1, 1.2599210498948732, 1.5874010519681994],\n",
" 2: [1, 1.2599210498948732, 1.5874010519681994]}"
]
},
"metadata": {},
"execution_count": 11
}
]
},
{
"cell_type": "code",
"source": [
"anchor_generators = {}\n",
"for k in sizes.keys():\n",
" anchor_generators[k] = {\n",
" \"sizes\": sizes[k],\n",
" \"scales\": scales[k],\n",
" \"aspect_ratios\": aspect_ratios[k],\n",
" \"strides\": strides[k],\n",
" }"
],
"metadata": {
"id": "KK48MXtCBz0z"
},
"execution_count": 12,
"outputs": []
},
{
"cell_type": "code",
"source": [
"anchor_generators"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "yHVQt-98CMZ5",
"outputId": "103353cf-9429-46f6-9f67-62885be09688"
},
"execution_count": 13,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{0: {'sizes': 32,\n",
" 'scales': [1, 1.2599210498948732, 1.5874010519681994],\n",
" 'aspect_ratios': [0.5, 1.0, 2.0],\n",
" 'strides': 8},\n",
" 1: {'sizes': 64,\n",
" 'scales': [1, 1.2599210498948732, 1.5874010519681994],\n",
" 'aspect_ratios': [0.5, 1.0, 2.0],\n",
" 'strides': 16},\n",
" 2: {'sizes': 128,\n",
" 'scales': [1, 1.2599210498948732, 1.5874010519681994],\n",
" 'aspect_ratios': [0.5, 1.0, 2.0],\n",
" 'strides': 32}}"
]
},
"metadata": {},
"execution_count": 13
}
]
},
{
"cell_type": "markdown",
"source": [
"```\n",
"SingleAnchorGenerator(\n",
" sizes=32,\n",
" scales=[1, 1.25, 1.58],\n",
" aspect_ratios=[0.5, 1.0, 2.0],\n",
" stride=8,\n",
")\n",
"```"
],
"metadata": {
"id": "8t9o93jMNpq7"
}
},
{
"cell_type": "code",
"source": [
"def single_anchor_gen(config):\n",
" sizes = config['sizes'] # 32\n",
" scales = config['scales'] # [1, 1.25, 1.33]\n",
" aspect_ratios = config['aspect_ratios'] # [0.5, 1.0, 2.0]\n",
" stride = config['strides'] # 8\n",
"\n",
" image_height = 512\n",
" image_width = 512\n",
"\n",
" aspect_ratios = ops.cast(aspect_ratios, \"float32\")\n",
" aspect_ratios_sqrt = ops.cast(ops.sqrt(aspect_ratios), dtype=\"float32\")\n",
" anchor_size = ops.cast(sizes, \"float32\")\n",
"\n",
" # [K]\n",
" anchor_heights = []\n",
" anchor_widths = []\n",
" for scale in scales:\n",
" anchor_size_t = anchor_size * scale\n",
" anchor_height = anchor_size_t / aspect_ratios_sqrt\n",
" anchor_width = anchor_size_t * aspect_ratios_sqrt\n",
" anchor_heights.append(anchor_height)\n",
" anchor_widths.append(anchor_width)\n",
" anchor_heights = ops.concatenate(anchor_heights, axis=0)\n",
" anchor_widths = ops.concatenate(anchor_widths, axis=0)\n",
" half_anchor_heights = ops.reshape(0.5 * anchor_heights, [1, 1, -1])\n",
" half_anchor_widths = ops.reshape(0.5 * anchor_widths, [1, 1, -1])\n",
"\n",
" stride = stride\n",
" # make sure range of `cx` is within limit of `image_width` with\n",
" # `stride`, also for sizes where `image_width % stride != 0`.\n",
" # [W]\n",
" cx = ops.cast(\n",
" ops.arange(\n",
" 0.5 * stride, math.ceil(image_width / stride) * stride, stride\n",
" ),\n",
" \"float32\",\n",
" )\n",
" # make sure range of `cy` is within limit of `image_height` with\n",
" # `stride`, also for sizes where `image_height % stride != 0`.\n",
" # [H]\n",
" cy = ops.cast(\n",
" ops.arange(\n",
" 0.5 * stride, math.ceil(image_height / stride) * stride, stride\n",
" ),\n",
" \"float32\",\n",
" )\n",
" # [H, W]\n",
" cx_grid, cy_grid = ops.meshgrid(cx, cy)\n",
" # [H, W, 1]\n",
" cx_grid = ops.expand_dims(cx_grid, axis=-1)\n",
" cy_grid = ops.expand_dims(cy_grid, axis=-1)\n",
"\n",
" y_min = ops.reshape(cy_grid - half_anchor_heights, (-1,))\n",
" y_max = ops.reshape(cy_grid + half_anchor_heights, (-1,))\n",
" x_min = ops.reshape(cx_grid - half_anchor_widths, (-1,))\n",
" x_max = ops.reshape(cx_grid + half_anchor_widths, (-1,))\n",
"\n",
" # [H * W * K, 1]\n",
" y_min = ops.expand_dims(y_min, axis=-1)\n",
" y_max = ops.expand_dims(y_max, axis=-1)\n",
" x_min = ops.expand_dims(x_min, axis=-1)\n",
" x_max = ops.expand_dims(x_max, axis=-1)\n",
"\n",
" # [H * W * K, 4]\n",
" return ops.cast(\n",
" ops.concatenate([y_min, x_min, y_max, x_max], axis=-1), \"float32\"\n",
" )"
],
"metadata": {
"id": "MvGGb4jQCN-3"
},
"execution_count": 23,
"outputs": []
},
{
"cell_type": "code",
"source": [
"single_anchor_gen(anchor_generators[0])"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4FYWI2dnWdma",
"outputId": "089c223f-8b91-409d-a47a-8e779bac48a3"
},
"execution_count": 24,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([[-18.627417 , -7.3137083, 26.627417 , 15.313708 ],\n",
" [-12. , -12. , 20. , 20. ],\n",
" [ -7.3137083, -18.627417 , 15.313708 , 26.627417 ],\n",
" ...,\n",
" [472.0812 , 490.04062 , 543.91876 , 525.9594 ],\n",
" [482.6016 , 482.6016 , 533.39844 , 533.39844 ],\n",
" [490.04062 , 472.0812 , 525.9594 , 543.91876 ]],\n",
" dtype=float32)"
]
},
"metadata": {},
"execution_count": 24
}
]
}
],
"metadata": {
"colab": {
"name": "scratchpad",
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment