Last active
July 8, 2022 14:25
-
-
Save ariG23498/08cdae21637b8b61bdd6d21d11719fb3 to your computer and use it in GitHub Desktop.
GroupViT -- hard-softmax implementation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "view-in-github", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "<a href=\"https://colab.research.google.com/gist/ariG23498/08cdae21637b8b61bdd6d21d11719fb3/scratchpad.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "lIYdn1woOS1n" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# Import the necessary packages\n", | |
| "from typing import Optional, Union, List\n", | |
| "import tensorflow as tf\n", | |
| "import numpy as np\n", | |
| "import torch" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Build the numpy logits\n", | |
| "logits = np.random.rand(32, 128, 128)\n", | |
| "dim = 2" | |
| ], | |
| "metadata": { | |
| "id": "KcLOElvhI1BU" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "# PyTorch" | |
| ], | |
| "metadata": { | |
| "id": "MhmkC-t2Ln-Q" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# PYTORCH\n", | |
| "def hard_softmax_pt(logits: torch.Tensor, dim: int):\n", | |
| " y_soft = logits.softmax(dim)\n", | |
| " # Straight through.\n", | |
| " index = y_soft.max(dim, keepdim=True)[1]\n", | |
| " y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)\n", | |
| " print(y_hard.dtype)\n", | |
| " print(y_soft.dtype)\n", | |
| " ret = y_hard - y_soft.detach() + y_soft\n", | |
| " return ret" | |
| ], | |
| "metadata": { | |
| "id": "xvE-iqSJI2lq" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "logits_pt = torch.from_numpy(logits)\n", | |
| "# Check whether the numpy logits and the pytorch logits are the same\n", | |
| "np.testing.assert_allclose(logits, logits_pt, rtol=1e-4, atol=1e-4)" | |
| ], | |
| "metadata": { | |
| "id": "84_O_K9JL0Bp" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Execute the PyTorch function\n", | |
| "ret_pt = hard_softmax_pt(logits_pt, dim)" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "JeTBHbySL2Dq", | |
| "outputId": "ee5bead9-fc45-475c-eabe-ddeae96e3db4" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "torch.float64\n", | |
| "torch.float64\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "# TensorFlow" | |
| ], | |
| "metadata": { | |
| "id": "cC57idGULqqL" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "def stable_softmax(logits: tf.Tensor, axis: Optional[int] = None, name: Optional[str] = None) -> tf.Tensor:\n", | |
| " \"\"\"\n", | |
| " Stable wrapper that returns the same output as `tf.nn.softmax`, but that works reliably with XLA on CPU. It is\n", | |
| " meant as a workaround for the [following issue](https://github.com/tensorflow/tensorflow/issues/55682), and will be\n", | |
| " removed after it gets fixed. The arguments and outputs are the same as `tf.nn.softmax`, and relies on the fact that\n", | |
| " `softmax(x) = softmax(x + c)` (see https://ogunlao.github.io/2020/04/26/you_dont_really_know_softmax.html).\n", | |
| "\n", | |
| " Args:\n", | |
| " logits (`tf.Tensor`):\n", | |
| " Must be one of the following types: half, float32, float64.\n", | |
| " axis (`int`, *optional*):\n", | |
| " The dimension softmax would be performed on. The default is -1 which indicates the last dimension.\n", | |
| " name (`str`, *optional*):\n", | |
| " A name for the operation.\n", | |
| "\n", | |
| " Returns:\n", | |
| " `tf.Tensor`:\n", | |
| " A Tensor. Has the same type and shape as logits.\n", | |
| " \"\"\"\n", | |
| " # TODO: When the issue linked above gets sorted, add a check on TF version here and use the original function if\n", | |
| " # it has the fix. After we drop the support for unfixed versions, remove this function.\n", | |
| " return tf.nn.softmax(logits=logits + 1e-9, axis=axis, name=name)" | |
| ], | |
| "metadata": { | |
| "id": "kADPSXqkJhkF" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "def shape_list(tensor: Union[tf.Tensor, np.ndarray]) -> List[int]:\n", | |
| " \"\"\"\n", | |
| " Deal with dynamic shape in tensorflow cleanly.\n", | |
| "\n", | |
| " Args:\n", | |
| " tensor (`tf.Tensor` or `np.ndarray`): The tensor we want the shape of.\n", | |
| "\n", | |
| " Returns:\n", | |
| " `List[int]`: The shape of the tensor as a list.\n", | |
| " \"\"\"\n", | |
| " if isinstance(tensor, np.ndarray):\n", | |
| " return list(tensor.shape)\n", | |
| "\n", | |
| " dynamic = tf.shape(tensor)\n", | |
| "\n", | |
| " if tensor.shape == tf.TensorShape(None):\n", | |
| " return dynamic\n", | |
| "\n", | |
| " static = tensor.shape.as_list()\n", | |
| "\n", | |
| " return [dynamic[i] if s is None else s for i, s in enumerate(static)]" | |
| ], | |
| "metadata": { | |
| "id": "BDNuA611JrKS" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "def hard_softmax_tf(logits: tf.Tensor, dim: int):\n", | |
| " \"\"\"\n", | |
| " Reference: https://gist.github.com/ariG23498/b9eca9a73fc9d93884fb2f59c4a303fb\n", | |
| " \"\"\"\n", | |
| " y_soft = stable_softmax(logits, dim)\n", | |
| " # Straight through.\n", | |
| " index = tf.argmax(y_soft, dim)\n", | |
| " y_hard = tf.one_hot(\n", | |
| " index,\n", | |
| " depth=shape_list(logits)[dim],\n", | |
| " axis=dim,\n", | |
| " dtype=y_soft.dtype\n", | |
| " )\n", | |
| " print(y_hard.dtype)\n", | |
| " print(y_soft.dtype)\n", | |
| " ret = y_hard - tf.stop_gradient(y_soft) + y_soft\n", | |
| " return ret" | |
| ], | |
| "metadata": { | |
| "id": "sZCv5xQIJd-3" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "logits_tf = tf.convert_to_tensor(logits)\n", | |
| "# Check whether the TensorFlow and the Numpy logits are the same\n", | |
| "np.testing.assert_allclose(logits, logits_tf, rtol=1e-4, atol=1e-4)" | |
| ], | |
| "metadata": { | |
| "id": "7Q2iMD9hL-j-" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Execute the TensorFlow function\n", | |
| "ret_tf = hard_softmax_tf(logits_tf, dim)\n", | |
| "\n", | |
| "# Check whether the TensorFlow and PyTorch outputs are the same\n", | |
| "np.testing.assert_allclose(ret_pt, ret_tf, rtol=1e-4, atol=1e-4)" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "_V3dPu-xMFxO", | |
| "outputId": "1107c2c8-3149-42ff-8bfe-eddaf0925475" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "<dtype: 'float64'>\n", | |
| "<dtype: 'float64'>\n" | |
| ] | |
| } | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "name": "scratchpad", | |
| "provenance": [], | |
| "include_colab_link": true | |
| }, | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "name": "python3" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 0 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment