Skip to content

Instantly share code, notes, and snippets.

@ariG23498
Last active July 8, 2022 14:25
Show Gist options
  • Select an option

  • Save ariG23498/08cdae21637b8b61bdd6d21d11719fb3 to your computer and use it in GitHub Desktop.

Select an option

Save ariG23498/08cdae21637b8b61bdd6d21d11719fb3 to your computer and use it in GitHub Desktop.
GroupViT -- hard-softmax implementation
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/ariG23498/08cdae21637b8b61bdd6d21d11719fb3/scratchpad.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lIYdn1woOS1n"
},
"outputs": [],
"source": [
"# Import the necessary packages\n",
"from typing import Optional, Union, List\n",
"import tensorflow as tf\n",
"import numpy as np\n",
"import torch"
]
},
{
"cell_type": "code",
"source": [
"# Build the numpy logits\n",
"logits = np.random.rand(32, 128, 128)\n",
"dim = 2"
],
"metadata": {
"id": "KcLOElvhI1BU"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# PyTorch"
],
"metadata": {
"id": "MhmkC-t2Ln-Q"
}
},
{
"cell_type": "code",
"source": [
"# PYTORCH\n",
"def hard_softmax_pt(logits: torch.Tensor, dim: int):\n",
" y_soft = logits.softmax(dim)\n",
" # Straight through.\n",
" index = y_soft.max(dim, keepdim=True)[1]\n",
" y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)\n",
" print(y_hard.dtype)\n",
" print(y_soft.dtype)\n",
" ret = y_hard - y_soft.detach() + y_soft\n",
" return ret"
],
"metadata": {
"id": "xvE-iqSJI2lq"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"logits_pt = torch.from_numpy(logits)\n",
"# Check whether the numpy logits and the pytorch logits are the same\n",
"np.testing.assert_allclose(logits, logits_pt, rtol=1e-4, atol=1e-4)"
],
"metadata": {
"id": "84_O_K9JL0Bp"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Execute the PyTorch function\n",
"ret_pt = hard_softmax_pt(logits_pt, dim)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "JeTBHbySL2Dq",
"outputId": "ee5bead9-fc45-475c-eabe-ddeae96e3db4"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"torch.float64\n",
"torch.float64\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"# TensorFlow"
],
"metadata": {
"id": "cC57idGULqqL"
}
},
{
"cell_type": "code",
"source": [
"def stable_softmax(logits: tf.Tensor, axis: Optional[int] = None, name: Optional[str] = None) -> tf.Tensor:\n",
" \"\"\"\n",
" Stable wrapper that returns the same output as `tf.nn.softmax`, but that works reliably with XLA on CPU. It is\n",
" meant as a workaround for the [following issue](https://github.com/tensorflow/tensorflow/issues/55682), and will be\n",
" removed after it gets fixed. The arguments and outputs are the same as `tf.nn.softmax`, and relies on the fact that\n",
" `softmax(x) = softmax(x + c)` (see https://ogunlao.github.io/2020/04/26/you_dont_really_know_softmax.html).\n",
"\n",
" Args:\n",
" logits (`tf.Tensor`):\n",
" Must be one of the following types: half, float32, float64.\n",
" axis (`int`, *optional*):\n",
" The dimension softmax would be performed on. The default is -1 which indicates the last dimension.\n",
" name (`str`, *optional*):\n",
" A name for the operation.\n",
"\n",
" Returns:\n",
" `tf.Tensor`:\n",
" A Tensor. Has the same type and shape as logits.\n",
" \"\"\"\n",
" # TODO: When the issue linked above gets sorted, add a check on TF version here and use the original function if\n",
" # it has the fix. After we drop the support for unfixed versions, remove this function.\n",
" return tf.nn.softmax(logits=logits + 1e-9, axis=axis, name=name)"
],
"metadata": {
"id": "kADPSXqkJhkF"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def shape_list(tensor: Union[tf.Tensor, np.ndarray]) -> List[int]:\n",
" \"\"\"\n",
" Deal with dynamic shape in tensorflow cleanly.\n",
"\n",
" Args:\n",
" tensor (`tf.Tensor` or `np.ndarray`): The tensor we want the shape of.\n",
"\n",
" Returns:\n",
" `List[int]`: The shape of the tensor as a list.\n",
" \"\"\"\n",
" if isinstance(tensor, np.ndarray):\n",
" return list(tensor.shape)\n",
"\n",
" dynamic = tf.shape(tensor)\n",
"\n",
" if tensor.shape == tf.TensorShape(None):\n",
" return dynamic\n",
"\n",
" static = tensor.shape.as_list()\n",
"\n",
" return [dynamic[i] if s is None else s for i, s in enumerate(static)]"
],
"metadata": {
"id": "BDNuA611JrKS"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def hard_softmax_tf(logits: tf.Tensor, dim: int):\n",
" \"\"\"\n",
" Reference: https://gist.github.com/ariG23498/b9eca9a73fc9d93884fb2f59c4a303fb\n",
" \"\"\"\n",
" y_soft = stable_softmax(logits, dim)\n",
" # Straight through.\n",
" index = tf.argmax(y_soft, dim)\n",
" y_hard = tf.one_hot(\n",
" index,\n",
" depth=shape_list(logits)[dim],\n",
" axis=dim,\n",
" dtype=y_soft.dtype\n",
" )\n",
" print(y_hard.dtype)\n",
" print(y_soft.dtype)\n",
" ret = y_hard - tf.stop_gradient(y_soft) + y_soft\n",
" return ret"
],
"metadata": {
"id": "sZCv5xQIJd-3"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"logits_tf = tf.convert_to_tensor(logits)\n",
"# Check whether the TensorFlow and the Numpy logits are the same\n",
"np.testing.assert_allclose(logits, logits_tf, rtol=1e-4, atol=1e-4)"
],
"metadata": {
"id": "7Q2iMD9hL-j-"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Execute the TensorFlow function\n",
"ret_tf = hard_softmax_tf(logits_tf, dim)\n",
"\n",
"# Check whether the TensorFlow and PyTorch outputs are the same\n",
"np.testing.assert_allclose(ret_pt, ret_tf, rtol=1e-4, atol=1e-4)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "_V3dPu-xMFxO",
"outputId": "1107c2c8-3149-42ff-8bfe-eddaf0925475"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"<dtype: 'float64'>\n",
"<dtype: 'float64'>\n"
]
}
]
}
],
"metadata": {
"colab": {
"name": "scratchpad",
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment