Last active
July 10, 2022 15:53
-
-
Save ariG23498/3777f8d9be25de8ae782256f5aacb2c5 to your computer and use it in GitHub Desktop.
TensorFlow equivalent of F.interpolate
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/ariG23498/3777f8d9be25de8ae782256f5aacb2c5/scratchpad.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Import the necessary packages\n", | |
"from typing import Optional, Union, List\n", | |
"import tensorflow as tf\n", | |
"# from tensorflow.python.ops import gen_image_ops\n", | |
"import numpy as np\n", | |
"import torch\n", | |
"import torch.nn as nn" | |
], | |
"metadata": { | |
"id": "ty9809HQNcf2" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Build the numpy logits\n", | |
"height = 64\n", | |
"feat_height = int(0.5*height)\n", | |
"\n", | |
"width = 64\n", | |
"feat_width = int(0.5*width)\n", | |
"attentions = np.random.rand(32, 4, feat_height*feat_width)" | |
], | |
"metadata": { | |
"id": "qCjQIf2qN5dE" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# PyTorch" | |
], | |
"metadata": { | |
"id": "GV0jlJKsN1GG" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def resize_attention_map_pt(attentions, height, width, align_corners=False):\n", | |
" \"\"\"\n", | |
" Args:\n", | |
" attentions (`torch.Tensor`): attention map of shape [batch_size, groups, feat_height*feat_width]\n", | |
" height (`int`): height of the output attention map\n", | |
" width (`int`): width of the output attention map\n", | |
" align_corners (`bool`, *optional*): the `align_corner` argument for `nn.functional.interpolate`.\n", | |
"\n", | |
" Returns:\n", | |
" `torch.Tensor`: resized attention map of shape [batch_size, groups, height, width]\n", | |
" \"\"\"\n", | |
"\n", | |
" scale = (height * width // attentions.shape[2]) ** 0.5\n", | |
" if height > width:\n", | |
" feat_width = int(np.round(width / scale))\n", | |
" feat_height = attentions.shape[2] // feat_width\n", | |
" else:\n", | |
" feat_height = int(np.round(height / scale))\n", | |
" feat_width = attentions.shape[2] // feat_height\n", | |
"\n", | |
" batch_size = attentions.shape[0]\n", | |
" groups = attentions.shape[1] # number of group token\n", | |
" # [batch_size, groups, height*width, groups] -> [batch_size, groups, height, width]\n", | |
" attentions = attentions.reshape(batch_size, groups, feat_height, feat_width)\n", | |
" attentions = nn.functional.interpolate(\n", | |
" attentions, size=(height, width), mode=\"bilinear\", align_corners=align_corners\n", | |
" )\n", | |
" return attentions" | |
], | |
"metadata": { | |
"id": "BIRj2Gu2N0-i" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"attentions_pt = torch.from_numpy(attentions)\n", | |
"# Check whether the numpy logits and the pytorch logits are the same\n", | |
"np.testing.assert_allclose(attentions, attentions_pt, rtol=1e-4, atol=1e-4)" | |
], | |
"metadata": { | |
"id": "VfQVDhtpN_b0" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Execute the PyTorch function\n", | |
"attentions_pt_false = resize_attention_map_pt(attentions_pt, height=height, width=width, align_corners=False)\n", | |
"attentions_pt_true = resize_attention_map_pt(attentions_pt, height=height, width=width, align_corners=True)" | |
], | |
"metadata": { | |
"id": "w44wMjLoOhkA" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# TensorFlow" | |
], | |
"metadata": { | |
"id": "JQ8LuvO-Nze6" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def shape_list(tensor: Union[tf.Tensor, np.ndarray]) -> List[int]:\n", | |
" \"\"\"\n", | |
" Deal with dynamic shape in tensorflow cleanly.\n", | |
"\n", | |
" Args:\n", | |
" tensor (`tf.Tensor` or `np.ndarray`): The tensor we want the shape of.\n", | |
"\n", | |
" Returns:\n", | |
" `List[int]`: The shape of the tensor as a list.\n", | |
" \"\"\"\n", | |
" if isinstance(tensor, np.ndarray):\n", | |
" return list(tensor.shape)\n", | |
"\n", | |
" dynamic = tf.shape(tensor)\n", | |
"\n", | |
" if tensor.shape == tf.TensorShape(None):\n", | |
" return dynamic\n", | |
"\n", | |
" static = tensor.shape.as_list()\n", | |
"\n", | |
" return [dynamic[i] if s is None else s for i, s in enumerate(static)]" | |
], | |
"metadata": { | |
"id": "1MO_21gLNycA" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "lIYdn1woOS1n" | |
}, | |
"outputs": [], | |
"source": [ | |
"def resize_attention_map_tf(attentions: tf.Tensor, height: int, width: int, align_corners: Optional[bool] = False) -> tf.Tensor:\n", | |
" \"\"\"\n", | |
" Args:\n", | |
" attentions (`tf.Tensor`): attention map of shape [batch_size, groups, feat_height*feat_width]\n", | |
" height (`int`): height of the output attention map\n", | |
" width (`int`): width of the output attention map\n", | |
" align_corners (`bool`, *optional*): the `align_corner` argument for `nn.functional.interpolate`.\n", | |
"\n", | |
" Returns:\n", | |
" `tf.Tensor`: resized attention map of shape [batch_size, groups, height, width]\n", | |
" \"\"\"\n", | |
"\n", | |
" scale = (height * width // attentions.shape[2]) ** 0.5\n", | |
" if height > width:\n", | |
" feat_width = int(np.round(width / scale))\n", | |
" feat_height = shape_list(attentions)[2] // feat_width\n", | |
" else:\n", | |
" feat_height = int(np.round(height / scale))\n", | |
" feat_width = shape_list(attentions)[2] // feat_height\n", | |
"\n", | |
" batch_size = shape_list(attentions)[0]\n", | |
" groups = shape_list(attentions)[1] # number of group token\n", | |
" # [batch_size, groups, height*width, groups] -> [batch_size, groups, height, width]\n", | |
" attentions = tf.reshape(attentions, (batch_size, groups, feat_height, feat_width))\n", | |
" attentions = tf.transpose(attentions, perm=(0, 2, 3, 1))\n", | |
" if align_corners:\n", | |
" attentions = tf.compat.v1.image.resize(\n", | |
" attentions, size=(height, width), method=\"bilinear\", align_corners=align_corners,\n", | |
" )\n", | |
" else:\n", | |
" attentions = tf.image.resize(\n", | |
" attentions, size=(height, width), method=\"bilinear\"\n", | |
" )\n", | |
" attentions = tf.transpose(attentions, perm=(0, 3, 1, 2))\n", | |
" return attentions" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"attentions_tf = tf.convert_to_tensor(attentions)\n", | |
"# Check whether the TensorFlow and the Numpy logits are the same\n", | |
"np.testing.assert_allclose(attentions, attentions_tf, rtol=1e-4, atol=1e-4)" | |
], | |
"metadata": { | |
"id": "bUkm8YpRPGpK" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Execute the TensorFlow function\n", | |
"attentions_tf_false = resize_attention_map_tf(attentions_tf, height=height, width=width, align_corners=False)\n", | |
"attentions_tf_true = resize_attention_map_tf(attentions_tf, height=height, width=width, align_corners=True)" | |
], | |
"metadata": { | |
"id": "JTM2ebs3PEtq" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"np.testing.assert_allclose(attentions_pt_false, attentions_tf_false, rtol=1e-4, atol=1e-4)" | |
], | |
"metadata": { | |
"id": "ZIlpDuxdP5_K" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"np.testing.assert_allclose(attentions_pt_true, attentions_tf_true, rtol=1e-4, atol=1e-4)" | |
], | |
"metadata": { | |
"id": "y97PgK9kiLFY" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
], | |
"metadata": { | |
"colab": { | |
"name": "scratchpad", | |
"provenance": [], | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"display_name": "Python 3", | |
"name": "python3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment