Last active
December 1, 2023 17:28
-
-
Save ariG23498/777ea321f4c294842f3f7de45dde8258 to your computer and use it in GitHub Desktop.
anchor-generator
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "view-in-github", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "<a href=\"https://colab.research.google.com/gist/ariG23498/777ea321f4c294842f3f7de45dde8258/scratchpad.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "!pip install --upgrade -qq tensorflow\n", | |
| "!pip install --upgrade -qq keras" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "kXSYq9j_Oq76", | |
| "outputId": "af0ee3cc-7dca-487a-848f-868479121d89" | |
| }, | |
| "execution_count": 1, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m475.2/475.2 MB\u001b[0m \u001b[31m2.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.5/5.5 MB\u001b[0m \u001b[31m58.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m442.0/442.0 kB\u001b[0m \u001b[31m31.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.7/1.7 MB\u001b[0m \u001b[31m58.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
| "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m997.1/997.1 kB\u001b[0m \u001b[31m6.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
| "\u001b[?25h\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", | |
| "tensorflow 2.15.0 requires keras<2.16,>=2.15.0, but you have keras 3.0.0 which is incompatible.\u001b[0m\u001b[31m\n", | |
| "\u001b[0m" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "import os\n", | |
| "os.environ[\"KERAS_BACKEND\"] = \"numpy\"\n", | |
| "from keras import ops\n", | |
| "import math" | |
| ], | |
| "metadata": { | |
| "id": "VJJSDbHOWvFM" | |
| }, | |
| "execution_count": 2, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": { | |
| "id": "lIYdn1woOS1n" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "sizes = [2**x for x in [5, 6, 7]]\n", | |
| "scales = [2**x for x in [0, 1/3, 2/3]]\n", | |
| "aspect_ratios = [1/2, 1/1, 2/1]\n", | |
| "strides = [2**x for x in [3, 4, 5]]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "def param_to_level_dict(param):\n", | |
| " result = dict()\n", | |
| " for i in range(len(param)):\n", | |
| " result[i] = param[i]\n", | |
| " return result" | |
| ], | |
| "metadata": { | |
| "id": "P11eSroq_CMz" | |
| }, | |
| "execution_count": 4, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "sizes = param_to_level_dict(sizes)\n", | |
| "strides = param_to_level_dict(strides)" | |
| ], | |
| "metadata": { | |
| "id": "HOsU26mdBHN9" | |
| }, | |
| "execution_count": 5, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "sizes" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "oWM0jrSKBPMJ", | |
| "outputId": "cc106551-9802-47c3-f76e-f77c0029bf0e" | |
| }, | |
| "execution_count": 6, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "{0: 32, 1: 64, 2: 128}" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 6 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "strides" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "7IRvRRRWBPvq", | |
| "outputId": "6b162d57-c109-48dd-eae6-5bf3f69d1938" | |
| }, | |
| "execution_count": 7, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "{0: 8, 1: 16, 2: 32}" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 7 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "def match_param_structure_to_sizes(params, sizes):\n", | |
| " return {key: params for key in sizes.keys()}" | |
| ], | |
| "metadata": { | |
| "id": "5-LdF9wcBQ1i" | |
| }, | |
| "execution_count": 8, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "aspect_ratios = match_param_structure_to_sizes(\n", | |
| " aspect_ratios, sizes\n", | |
| ")\n", | |
| "scales = match_param_structure_to_sizes(\n", | |
| " scales, sizes\n", | |
| ")" | |
| ], | |
| "metadata": { | |
| "id": "pgY6sI1SBZ_t" | |
| }, | |
| "execution_count": 9, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "aspect_ratios" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "QUZ1rKnZBhbE", | |
| "outputId": "03bda8e6-783e-4c82-d373-4b455988922f" | |
| }, | |
| "execution_count": 10, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "{0: [0.5, 1.0, 2.0], 1: [0.5, 1.0, 2.0], 2: [0.5, 1.0, 2.0]}" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 10 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "scales" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "hJfIfUcdBii5", | |
| "outputId": "2af1338c-865a-4c3a-a0e2-4474c4458a5c" | |
| }, | |
| "execution_count": 11, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "{0: [1, 1.2599210498948732, 1.5874010519681994],\n", | |
| " 1: [1, 1.2599210498948732, 1.5874010519681994],\n", | |
| " 2: [1, 1.2599210498948732, 1.5874010519681994]}" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 11 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "anchor_generators = {}\n", | |
| "for k in sizes.keys():\n", | |
| " anchor_generators[k] = {\n", | |
| " \"sizes\": sizes[k],\n", | |
| " \"scales\": scales[k],\n", | |
| " \"aspect_ratios\": aspect_ratios[k],\n", | |
| " \"strides\": strides[k],\n", | |
| " }" | |
| ], | |
| "metadata": { | |
| "id": "KK48MXtCBz0z" | |
| }, | |
| "execution_count": 12, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "anchor_generators" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "yHVQt-98CMZ5", | |
| "outputId": "103353cf-9429-46f6-9f67-62885be09688" | |
| }, | |
| "execution_count": 13, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "{0: {'sizes': 32,\n", | |
| " 'scales': [1, 1.2599210498948732, 1.5874010519681994],\n", | |
| " 'aspect_ratios': [0.5, 1.0, 2.0],\n", | |
| " 'strides': 8},\n", | |
| " 1: {'sizes': 64,\n", | |
| " 'scales': [1, 1.2599210498948732, 1.5874010519681994],\n", | |
| " 'aspect_ratios': [0.5, 1.0, 2.0],\n", | |
| " 'strides': 16},\n", | |
| " 2: {'sizes': 128,\n", | |
| " 'scales': [1, 1.2599210498948732, 1.5874010519681994],\n", | |
| " 'aspect_ratios': [0.5, 1.0, 2.0],\n", | |
| " 'strides': 32}}" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 13 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "```\n", | |
| "SingleAnchorGenerator(\n", | |
| " sizes=32,\n", | |
| " scales=[1, 1.25, 1.58],\n", | |
| " aspect_ratios=[0.5, 1.0, 2.0],\n", | |
| " stride=8,\n", | |
| ")\n", | |
| "```" | |
| ], | |
| "metadata": { | |
| "id": "8t9o93jMNpq7" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "def single_anchor_gen(config):\n", | |
| " sizes = config['sizes'] # 32\n", | |
| " scales = config['scales'] # [1, 1.25, 1.33]\n", | |
| " aspect_ratios = config['aspect_ratios'] # [0.5, 1.0, 2.0]\n", | |
| " stride = config['strides'] # 8\n", | |
| "\n", | |
| " image_height = 512\n", | |
| " image_width = 512\n", | |
| "\n", | |
| " aspect_ratios = ops.cast(aspect_ratios, \"float32\")\n", | |
| " aspect_ratios_sqrt = ops.cast(ops.sqrt(aspect_ratios), dtype=\"float32\")\n", | |
| " anchor_size = ops.cast(sizes, \"float32\")\n", | |
| "\n", | |
| " # [K]\n", | |
| " anchor_heights = []\n", | |
| " anchor_widths = []\n", | |
| " for scale in scales:\n", | |
| " anchor_size_t = anchor_size * scale\n", | |
| " anchor_height = anchor_size_t / aspect_ratios_sqrt\n", | |
| " anchor_width = anchor_size_t * aspect_ratios_sqrt\n", | |
| " anchor_heights.append(anchor_height)\n", | |
| " anchor_widths.append(anchor_width)\n", | |
| " anchor_heights = ops.concatenate(anchor_heights, axis=0)\n", | |
| " anchor_widths = ops.concatenate(anchor_widths, axis=0)\n", | |
| " half_anchor_heights = ops.reshape(0.5 * anchor_heights, [1, 1, -1])\n", | |
| " half_anchor_widths = ops.reshape(0.5 * anchor_widths, [1, 1, -1])\n", | |
| "\n", | |
| " stride = stride\n", | |
| " # make sure range of `cx` is within limit of `image_width` with\n", | |
| " # `stride`, also for sizes where `image_width % stride != 0`.\n", | |
| " # [W]\n", | |
| " cx = ops.cast(\n", | |
| " ops.arange(\n", | |
| " 0.5 * stride, math.ceil(image_width / stride) * stride, stride\n", | |
| " ),\n", | |
| " \"float32\",\n", | |
| " )\n", | |
| " # make sure range of `cy` is within limit of `image_height` with\n", | |
| " # `stride`, also for sizes where `image_height % stride != 0`.\n", | |
| " # [H]\n", | |
| " cy = ops.cast(\n", | |
| " ops.arange(\n", | |
| " 0.5 * stride, math.ceil(image_height / stride) * stride, stride\n", | |
| " ),\n", | |
| " \"float32\",\n", | |
| " )\n", | |
| " # [H, W]\n", | |
| " cx_grid, cy_grid = ops.meshgrid(cx, cy)\n", | |
| " # [H, W, 1]\n", | |
| " cx_grid = ops.expand_dims(cx_grid, axis=-1)\n", | |
| " cy_grid = ops.expand_dims(cy_grid, axis=-1)\n", | |
| "\n", | |
| " y_min = ops.reshape(cy_grid - half_anchor_heights, (-1,))\n", | |
| " y_max = ops.reshape(cy_grid + half_anchor_heights, (-1,))\n", | |
| " x_min = ops.reshape(cx_grid - half_anchor_widths, (-1,))\n", | |
| " x_max = ops.reshape(cx_grid + half_anchor_widths, (-1,))\n", | |
| "\n", | |
| " # [H * W * K, 1]\n", | |
| " y_min = ops.expand_dims(y_min, axis=-1)\n", | |
| " y_max = ops.expand_dims(y_max, axis=-1)\n", | |
| " x_min = ops.expand_dims(x_min, axis=-1)\n", | |
| " x_max = ops.expand_dims(x_max, axis=-1)\n", | |
| "\n", | |
| " # [H * W * K, 4]\n", | |
| " return ops.cast(\n", | |
| " ops.concatenate([y_min, x_min, y_max, x_max], axis=-1), \"float32\"\n", | |
| " )" | |
| ], | |
| "metadata": { | |
| "id": "MvGGb4jQCN-3" | |
| }, | |
| "execution_count": 23, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "single_anchor_gen(anchor_generators[0])" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "4FYWI2dnWdma", | |
| "outputId": "089c223f-8b91-409d-a47a-8e779bac48a3" | |
| }, | |
| "execution_count": 24, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "array([[-18.627417 , -7.3137083, 26.627417 , 15.313708 ],\n", | |
| " [-12. , -12. , 20. , 20. ],\n", | |
| " [ -7.3137083, -18.627417 , 15.313708 , 26.627417 ],\n", | |
| " ...,\n", | |
| " [472.0812 , 490.04062 , 543.91876 , 525.9594 ],\n", | |
| " [482.6016 , 482.6016 , 533.39844 , 533.39844 ],\n", | |
| " [490.04062 , 472.0812 , 525.9594 , 543.91876 ]],\n", | |
| " dtype=float32)" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 24 | |
| } | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "name": "scratchpad", | |
| "provenance": [], | |
| "include_colab_link": true | |
| }, | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "name": "python3" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 0 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment