Skip to content

Instantly share code, notes, and snippets.

@ariG23498
Last active July 10, 2022 15:53
Show Gist options
  • Save ariG23498/3777f8d9be25de8ae782256f5aacb2c5 to your computer and use it in GitHub Desktop.
Save ariG23498/3777f8d9be25de8ae782256f5aacb2c5 to your computer and use it in GitHub Desktop.
TensorFlow equivalent of F.interpolate
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/ariG23498/3777f8d9be25de8ae782256f5aacb2c5/scratchpad.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"source": [
"# Import the necessary packages\n",
"from typing import Optional, Union, List\n",
"import tensorflow as tf\n",
"# from tensorflow.python.ops import gen_image_ops\n",
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn"
],
"metadata": {
"id": "ty9809HQNcf2"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Build the numpy logits\n",
"height = 64\n",
"feat_height = int(0.5*height)\n",
"\n",
"width = 64\n",
"feat_width = int(0.5*width)\n",
"attentions = np.random.rand(32, 4, feat_height*feat_width)"
],
"metadata": {
"id": "qCjQIf2qN5dE"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# PyTorch"
],
"metadata": {
"id": "GV0jlJKsN1GG"
}
},
{
"cell_type": "code",
"source": [
"def resize_attention_map_pt(attentions, height, width, align_corners=False):\n",
" \"\"\"\n",
" Args:\n",
" attentions (`torch.Tensor`): attention map of shape [batch_size, groups, feat_height*feat_width]\n",
" height (`int`): height of the output attention map\n",
" width (`int`): width of the output attention map\n",
" align_corners (`bool`, *optional*): the `align_corner` argument for `nn.functional.interpolate`.\n",
"\n",
" Returns:\n",
" `torch.Tensor`: resized attention map of shape [batch_size, groups, height, width]\n",
" \"\"\"\n",
"\n",
" scale = (height * width // attentions.shape[2]) ** 0.5\n",
" if height > width:\n",
" feat_width = int(np.round(width / scale))\n",
" feat_height = attentions.shape[2] // feat_width\n",
" else:\n",
" feat_height = int(np.round(height / scale))\n",
" feat_width = attentions.shape[2] // feat_height\n",
"\n",
" batch_size = attentions.shape[0]\n",
" groups = attentions.shape[1] # number of group token\n",
" # [batch_size, groups, height*width, groups] -> [batch_size, groups, height, width]\n",
" attentions = attentions.reshape(batch_size, groups, feat_height, feat_width)\n",
" attentions = nn.functional.interpolate(\n",
" attentions, size=(height, width), mode=\"bilinear\", align_corners=align_corners\n",
" )\n",
" return attentions"
],
"metadata": {
"id": "BIRj2Gu2N0-i"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"attentions_pt = torch.from_numpy(attentions)\n",
"# Check whether the numpy logits and the pytorch logits are the same\n",
"np.testing.assert_allclose(attentions, attentions_pt, rtol=1e-4, atol=1e-4)"
],
"metadata": {
"id": "VfQVDhtpN_b0"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Execute the PyTorch function\n",
"attentions_pt_false = resize_attention_map_pt(attentions_pt, height=height, width=width, align_corners=False)\n",
"attentions_pt_true = resize_attention_map_pt(attentions_pt, height=height, width=width, align_corners=True)"
],
"metadata": {
"id": "w44wMjLoOhkA"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# TensorFlow"
],
"metadata": {
"id": "JQ8LuvO-Nze6"
}
},
{
"cell_type": "code",
"source": [
"def shape_list(tensor: Union[tf.Tensor, np.ndarray]) -> List[int]:\n",
" \"\"\"\n",
" Deal with dynamic shape in tensorflow cleanly.\n",
"\n",
" Args:\n",
" tensor (`tf.Tensor` or `np.ndarray`): The tensor we want the shape of.\n",
"\n",
" Returns:\n",
" `List[int]`: The shape of the tensor as a list.\n",
" \"\"\"\n",
" if isinstance(tensor, np.ndarray):\n",
" return list(tensor.shape)\n",
"\n",
" dynamic = tf.shape(tensor)\n",
"\n",
" if tensor.shape == tf.TensorShape(None):\n",
" return dynamic\n",
"\n",
" static = tensor.shape.as_list()\n",
"\n",
" return [dynamic[i] if s is None else s for i, s in enumerate(static)]"
],
"metadata": {
"id": "1MO_21gLNycA"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lIYdn1woOS1n"
},
"outputs": [],
"source": [
"def resize_attention_map_tf(attentions: tf.Tensor, height: int, width: int, align_corners: Optional[bool] = False) -> tf.Tensor:\n",
" \"\"\"\n",
" Args:\n",
" attentions (`tf.Tensor`): attention map of shape [batch_size, groups, feat_height*feat_width]\n",
" height (`int`): height of the output attention map\n",
" width (`int`): width of the output attention map\n",
" align_corners (`bool`, *optional*): the `align_corner` argument for `nn.functional.interpolate`.\n",
"\n",
" Returns:\n",
" `tf.Tensor`: resized attention map of shape [batch_size, groups, height, width]\n",
" \"\"\"\n",
"\n",
" scale = (height * width // attentions.shape[2]) ** 0.5\n",
" if height > width:\n",
" feat_width = int(np.round(width / scale))\n",
" feat_height = shape_list(attentions)[2] // feat_width\n",
" else:\n",
" feat_height = int(np.round(height / scale))\n",
" feat_width = shape_list(attentions)[2] // feat_height\n",
"\n",
" batch_size = shape_list(attentions)[0]\n",
" groups = shape_list(attentions)[1] # number of group token\n",
" # [batch_size, groups, height*width, groups] -> [batch_size, groups, height, width]\n",
" attentions = tf.reshape(attentions, (batch_size, groups, feat_height, feat_width))\n",
" attentions = tf.transpose(attentions, perm=(0, 2, 3, 1))\n",
" if align_corners:\n",
" attentions = tf.compat.v1.image.resize(\n",
" attentions, size=(height, width), method=\"bilinear\", align_corners=align_corners,\n",
" )\n",
" else:\n",
" attentions = tf.image.resize(\n",
" attentions, size=(height, width), method=\"bilinear\"\n",
" )\n",
" attentions = tf.transpose(attentions, perm=(0, 3, 1, 2))\n",
" return attentions"
]
},
{
"cell_type": "code",
"source": [
"attentions_tf = tf.convert_to_tensor(attentions)\n",
"# Check whether the TensorFlow and the Numpy logits are the same\n",
"np.testing.assert_allclose(attentions, attentions_tf, rtol=1e-4, atol=1e-4)"
],
"metadata": {
"id": "bUkm8YpRPGpK"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Execute the TensorFlow function\n",
"attentions_tf_false = resize_attention_map_tf(attentions_tf, height=height, width=width, align_corners=False)\n",
"attentions_tf_true = resize_attention_map_tf(attentions_tf, height=height, width=width, align_corners=True)"
],
"metadata": {
"id": "JTM2ebs3PEtq"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"np.testing.assert_allclose(attentions_pt_false, attentions_tf_false, rtol=1e-4, atol=1e-4)"
],
"metadata": {
"id": "ZIlpDuxdP5_K"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"np.testing.assert_allclose(attentions_pt_true, attentions_tf_true, rtol=1e-4, atol=1e-4)"
],
"metadata": {
"id": "y97PgK9kiLFY"
},
"execution_count": null,
"outputs": []
}
],
"metadata": {
"colab": {
"name": "scratchpad",
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment