Skip to content

Instantly share code, notes, and snippets.

@melgor
Last active April 6, 2017 12:06
Show Gist options
  • Save melgor/23679d140cde6fc372bb6ee0ad45df5b to your computer and use it in GitHub Desktop.
Save melgor/23679d140cde6fc372bb6ee0ad45df5b to your computer and use it in GitHub Desktop.
Strange behavior of STN
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import sys,os\n",
"import torch\n",
"import torch.optim as optim\n",
"import torch.nn.functional as F\n",
"import numpy as np\n",
"import torch.nn as nn\n",
"from torch.autograd import Variable\n",
"from PIL import Image\n",
"from matplotlib import mlab\n",
"import matplotlib.pyplot as plt\n",
"sys.path.append(\"stn.pytorch/script/\")\n",
"from modules.stn import STN\n",
"from modules.gridgen import CylinderGridGen, AffineGridGen, AffineGridGenV2, DenseAffineGridGen"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"''' Same implementation like in https://github.com/qassemoquab/stnbhwd/blob/master/demo/spatial_transformer.lua\n",
" \n",
" get NCHW format\n",
" |\n",
" |\n",
" Sampler\n",
" | \\\n",
" | \\ \n",
" | \\\n",
" | /Network with NCHW format. Output Grid \n",
" to NHWC/ \n",
" ^ / \n",
" |/\n",
" input NCHW\n",
" \n",
" \n",
"Issue: Output image where first channel is used in every other channel\n",
"'''\n",
"\n",
"class ConvSpatialTransformer(nn.Module):\n",
" def __init__(self, height, width):\n",
" super(ConvSpatialTransformer, self).__init__()\n",
" self.s = STN()\n",
" self.height = height\n",
" self.width = width\n",
" self.conv = nn.Sequential( \n",
" torch.nn.MaxPool2d(2,2),\n",
" torch.nn.Conv2d(3, 20, 5, stride=2, padding=2),\n",
" torch.nn.ReLU(),\n",
" torch.nn.MaxPool2d(2,2),\n",
" torch.nn.Conv2d(20, 20, 5, stride=2, padding=2),\n",
" torch.nn.ReLU(),\n",
" )\n",
" self.grid_param = nn.Linear(80, 6)\n",
" self.g = AffineGridGenV2(height, width)\n",
" \n",
" self.grid_param.weight.data.zero_()\n",
" self.grid_param.bias.data.zero_()\n",
" self.grid_param.bias.data[0] = 1\n",
" self.grid_param.bias.data[4] = 1\n",
" \n",
" def forward(self, input):\n",
" conv = self.conv(input)\n",
" x = conv.view(conv.size(0),-1)\n",
" x = self.grid_param(x)\n",
" grid_param = x.view(conv.size(0),2,3)\n",
" out = self.g(grid_param )\n",
" input_NHWC = input.permute(0,2,3,1)\n",
" out2 = self.s(input_NHWC, out) \n",
" out2_NCHW = out2.permute(0,3,1,2)\n",
" return out2_NCHW, out\n",
" \n",
"\n",
"c = ConvSpatialTransformer(28, 28)\n",
"# NCHW, In red channel will be a line\n",
"input_lua = torch.Tensor(1,3,28,28)\n",
"input_lua[:] = 0.0\n",
"input_lua[0,0,5:13] = 1.0\n",
"input_lua[0,1,15:23] = 1.0\n",
"input_lua[0,2,23:28] = 1.0\n",
"\n",
"res_lua, grid_lua = c(Variable(input_lua))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7fda880db5d0>"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAADeFJREFUeJzt3U2sHWd9x/Hv33GMX2OMwbZih1BEhbpBFpWzSRcXIUHU\nTSIWadRNwqJi0RSWSdk43bUsImXDosEgF2EhiAQJldqYKrUgCxoL6iYBx0ZqYwjgS0T9kpv63f8u\nZq7v3JNz7hnf8zb28/1Io5nz3Jkzj4/8O88zb+eJzERSWdbMugKSps/gSwUy+FKBDL5UIIMvFcjg\nSwUaKfgR8UBEvBERJyPiiXFVStJkxWqv40fEGuAk8Gngt8BR4JHMfKNnPW8UkGYkM6Nf+Sgt/n3A\nLzPzVGZeAb4NPDjC+0maklGCvxv4deP1W3WZpI7z5J5UoFGC/xvgw43Xe+oySR03SvCPAh+LiHsj\nYh3wCPDCeKolaZLWrnbDzLwWEY8Dh6m+QA5k5vGx1UzSxKz6cl7rHXg5T5qZSVzOk3SLMvhSgQy+\nVCCDLxXI4EsFMvhSgQy+VCCDLxXI4EsFMvhSgQy+VCCDLxXI4EsFMvhSgQy+VCCDLxXI4EsFMvhS\ngQy+VKBV/9impFkIqvZ60LxpYeC7GHyps6LP8h1Usb2znurlWFy+Y2mT5cNYLmPwpU4KlsLenN8B\nvA9YX02xful1rGdZpA2+dCvpDX1zWksV9I0Qm4BNS3M21S3/cAZf6qze0Dda/NgIbIG4C7irnm8F\n1rV6Z4MvdVIz7GtY3uKvo2rxtwDvh/gAsA1iG7Ch1bsbfKlzBnXz1wBrIequPlvqsG+H+BDEB+vy\n4Qy+1En9Qr+Gpa7+BmCxxf8gxA6IXVTH+sNNJfibN2+exm6k20Bv0BenxbItVRc/7oLYWk1shXh/\nNbGUtYX/HbyXqQT/7rvvnsZupNtAb/B7b9DZXHfrt9eh31Kd1Y/1EOtoRvqkwZduFf1O6DWDv7Hu\n2n+AqptfB58NdfCXLuedfH3wXgy+1Cn9ju2brzfW3fptVMG/qy5b/57gr2Sk4EfEm8A54DpwJTPv\n67eewZfa6nftvjHFepZdt48tVDfurKe6zNcu0qO2+NeBucw8s9JKBl9qqxn03tdRt+qbITZTncHf\n3OjqL96/P9yowe/3SNB7GHyprUH36NfLsY7qJp0N9SW9DfVdfNNt8RP4YURcA/4xM5/tt9Lu3btH\n3I1Umn7hh6V79dfVN/I0l9ex7Om8FYwa/Psz83cR8SGqL4Djmfly70qHDh26sbxv3z727ds34m6l\n21m/x3Ebj+Xe6NI3priTo6/8mKNHf9xuD5k5nqpG7Afeycyne8rzxIkTY9mHVI5+4Ydlz+PH2qVl\n6uVYavE//sdBZjY3vmHVLX5EbATWZOZCRGwCPgP8Xb91161r98SQpKbezC6eUrtj+RS9JwSHG6Wr\nvxP4XkRk/T7fyszD/Va88852Zxol9Qtvoyx6g9+4w6997lcf/Mz8H2Bvm3Vt8aWbNSjFjYd1ovde\n/vamcueewZdG0Xu83/sQz+I1/ul09Vsz+NK4rHRnX3tTCb7H+NLNGhbk5jH9zYUephT8tWv9vQ+p\nSxxJRyqQwZcKZPClAhl8qUAGXyqQwZcKZPClAhl8qUAGXyqQwZcKZPClAhl8qUAGXyqQwZcKZPCl\nAhl8qUAGXyqQwZcKZPClAhl8qUAGXyqQwZcKZPClAhl8qUAGXyqQwZcKZPClAg0NfkQciIj5iHi1\nUbYtIg5HxImIeDEitk62mpLGqU2L/w3gsz1lTwL/lpkfB14C/nbcFZM0OUODn5kvA2d6ih8EDtbL\nB4GHxlwvSRO02mP8HZk5D5CZp4Ed46uSpEkb18D1udIfn3rqqRvLc3NzzM3NjWm3khYdOXKEI0eO\ntFo3MlfMbLVSxL3ADzLzE/Xr48BcZs5HxC7g3zPzTwZsm232IWm8IoLMjH5/a9vVj3pa9ALwWL38\nKPD8qmsnaeqGtvgRcQiYA7YD88B+4PvAd4F7gFPAw5l5dsD2tvjSDKzU4rfq6o+4c4MvzcA4uvqS\nbiMGXyqQwZcKNK7r+JLGYsj5sMx6nUFTOwZf6rTeMF8HrjWmq0vLea3P+v0ZfKmTss9yUoX8ytKU\nV4DLjbJrrd7d4EudlX3m16la+UuQl6o5FxvLV1u9s8GXOqe3tW8ev9ctfl4CLlRT/l9jbvClW9ig\nk3fXqbr0l6nC/i6w0JhfafXuBl/qtN7gL57QuwR5AXgX8h3gfD2/1OpdDb7USc3ufU+Ln1dYOrZ/\nF3gH8lw1cbHVu08l+BcuXJjGbqTbxHWWd+2by4vH8u9CLlStfJ6navHP0qngnznT+8tdkgYbFvyz\nwLlqnufq4C/ODb50ixoW/HPA2Ub3vg59nqdTx/gGX2qr31n85usLSy39YvgNvnSraxP8Zkt/jur4\nvoNn9Q2+1Fab4J9f3tLb4ku3uiHBz4tULXyjpV/22uBLt6C2wT9Pdf3+PMu/AC632stUgn/2bN/f\n4ZT0HkOet8+LVIFfYOlW3cV79S9R3co73FSCf+7cuWnsRroNDLpjb3G6vBT2vFi95irk9cZ2wxl8\nqXN6n8hrvr5CFfoLVCfyLkPWP8ZxE79mbfClzukNfLPsCtXdeZe63+KfP39+GruRbgP9fnyjMc+r\nVGFvTLn4yzsdC74tvnQz+v3s1uJy/Tx+XqV6PPdKPb9+U3sw+FJn9bbgzWfyry+f38QPbYLBl24x\nK534ax/8qYydN9EdSBrIsfMk3WDwpQIZfKlAQ4MfEQciYj4iXm2U7Y+ItyLiZ/X0wGSrKWmc2rT4\n3wA+26f86cz8ZD3965jrJWmChgY/M18G+j1X2/dsoaTuG+UY//GIOBYRX4uIrWOrkaSJW23wvwp8\nNDP3AqeBp8dXJUmTtqrgZ+bbuXTnz7PAvvFVSdKktQ1+0Dimj4hdjb99Dnh9nJWSNFlD79WPiEPA\nHLA9In4F7Ac+FRF7qZ4QeBP4wgTrKGnMvFdfuo15r76kGwy+VCCDLxXI4EsFMvhSgQy+VCCDLxXI\n4EsFMvhSgQy+VCCDLxXI4EsFMvhSgQy+VCCDLxXI4EsFMvhSgQy+VCCDLxXI4EsFMvhSgQy+VCCD\nLxXI4EsFMvhSgQy+VCCDLxXI4EsFMvhSgQy+VCCDLxVoaPAjYk9EvBQRP4+I1yLii3X5tog4HBEn\nIuLFiNg6+epKGofIzJVXiNgF7MrMYxGxGfgp8CDweeAPmfmViHgC2JaZT/bZfuUdSJqYzIx+5UNb\n/Mw8nZnH6uUF4Diwhyr8B+vVDgIPjaeqkibtpo7xI+IjwF7gJ8DOzJyH6ssB2DHuykmajNbBr7v5\nzwFfqlv+3i68XXrpFtEq+BGxlir038zM5+vi+YjYWf99F/D7yVRR0ri1bfG/DvwiM59plL0APFYv\nPwo837uRpG5qc1b/fuBHwGtU3fkEvgy8AnwHuAc4BTycmWf7bO8hgDQjg87qDw3+qAy+NDurvpwn\n6fZj8KUCGXypQAZfKpDBlwpk8KUCGXypQAZfKpDBlwpk8KUCGXypQAZfKpDBlwpk8KUCGXypQAZf\nKpDBlwpk8KUCGXypQAZfKpDBlwpk8KUCGXypQAZfKpDBlwpk8KUCGXypQAZfKpDBlwpk8KUCDQ1+\nROyJiJci4ucR8VpE/E1dvj8i3oqIn9XTA5OvrqRxiMyVh6+PiF3Arsw8FhGbgZ8CDwJ/AbyTmU8P\n2X7lHUiamMyMfuVrW2x4GjhdLy9ExHFgd/3nvm8qqdtu6hg/Ij4C7AX+oy56PCKORcTXImLrmOsm\naUJaB7/u5j8HfCkzF4CvAh/NzL1UPYIVu/ySumPoMT5ARKwF/hn4l8x8ps/f7wV+kJmf6PM3j/Gl\nGRl0jN+2xf868Itm6OuTfos+B7y++upJmqY2Z/XvB34EvAZkPX0Z+Euq4/3rwJvAFzJzvs/2tvjS\njAxq8Vt19Udh8KXZGbWrL+k2YvClAhl8qUAGXyqQwZcKZPClAhl8qUAGXyqQwZcKZPClAhl8qUAG\nXyqQwZcKZPClAhl8qUAGXyqQwZcKNPFf4JHUPbb4UoEMvlSgqQU/Ih6IiDci4mREPDGt/bYVEW9G\nxH9FxH9GxCsdqM+BiJiPiFcbZdsi4nBEnIiIF2c5etGA+nVmINU+g71+sS7vxGc468Fop3KMHxFr\ngJPAp4HfAkeBRzLzjYnvvKWI+G/gTzPzzKzrAhARfwYsAP+0OFBJRPwD8IfM/Er95bktM5/sUP32\n02Ig1WlYYbDXz9OBz3DUwWhHNa0W/z7gl5l5KjOvAN+m+kd2SdChQ5/MfBno/RJ6EDhYLx8EHppq\npRoG1A86MpBqZp7OzGP18gJwHNhDRz7DAfWb2mC00/qPvhv4deP1Wyz9I7sigR9GxNGI+KtZV2aA\nHYuDltSjGO+YcX366dxAqo3BXn8C7OzaZziLwWg708J1wP2Z+Ungz4G/rruyXde1a7GdG0i1z2Cv\nvZ/ZTD/DWQ1GO63g/wb4cOP1nrqsMzLzd/X8beB7VIcnXTMfETvhxjHi72dcn2Uy8+1cOmn0LLBv\nlvWpB3t9DvhmZj5fF3fmM+xXv2l9htMK/lHgYxFxb0SsAx4BXpjSvoeKiI31Ny8RsQn4DN0YBDRY\nfrz3AvBYvfwo8HzvBlO2rH4dHEj1PYO90q3PcGaD0U7tzr36ssQzVF82BzLz76ey4xYi4o+oWvkE\n1gLfmnX9IuIQMAdsB+aB/cD3ge8C9wCngIcz82yH6vcpWgykOqX6DRrs9RXgO8z4Mxx1MNqR9+8t\nu1J5PLknFcjgSwUy+FKBDL5UIIMvFcjgSwUy+FKBDL5UoP8HtQPR1vA76fwAAAAASUVORK5CYII=\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7fda89724f90>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACwRJREFUeJzt3U/IHPUdx/HPR4IHFSWVJg8YzdMiFC8SWvSSHh4RrJRC\nxIMNQlEp4qFWsT1ovURKD62HQCh4qI0SiyIqaGKhNRZdJAdrapsm6hMjSFL/5dFao+Ym5tvDTnTz\nuH/m2ZmdGfN9v2Bxn9k/v18W38/M7O4z44gQgFzOaHsCAJpH+EBChA8kRPhAQoQPJET4QEKVwrd9\nte2Dtg/ZvrOuSQGYLU/7Ob7tMyQdknSlpHcl7ZW0OSIOLrsfXxQAWhIRHra8yhr/cklvRMSRiPhM\n0qOSNlV4PgANqRL+BZLeGvj57WIZgI7jzT0goSrhvyPpooGf1xXLAHRclfD3SrrY9nrbZ0raLGlX\nPdMCMEurpn1gRHxu+1ZJu9X/BbI9IhZrmxmAmZn647zSA/BxHtCaWXycB+BrivCBhAgfSIjwgYQI\nH0iI8IGECB9IiPCBhAgfSIjwgYQIH0iI8IGECB9IiPCBhAgfSIjwgYQIH0iI8IGECB9IiPCBhAgf\nSIjwgYQIH0iI8IGECB9IiPCBhAgfSGjqk2auxLlNDALgFJ+Mua2R8OebGATAKfaPuY3wgdMU4QM4\nRaXwbR+W9LGkE5I+i4jLh91vvsogAGpXdY1/QtJCRHw07k7zFQcBUK+q4VslPhKcrzgIgHo5IqZ/\nsP2mpGOSPpf0h4i4f8h94n/Tzw/AlL4hKSI87Laqa/yNEfGe7W9Ketb2YkTsWX6nbQPXF4oLgHr1\niksZldb4pzyRvUXSpxGxddnymkYAsBLW6DX+1F/ZtX2W7XOK62dLukrSK9M+H4DmVNnUXyvpSdtR\nPM/DEbG7nmkBmKXaNvVHDsCmPtCKmWzqA/j6InwgIcIHEiJ8ICHCBxIifCAhwgcSInwgIcIHEiJ8\nICHCBxIifCAhwgcSInwgIcIHEiJ8ICHCBxIifCAhwgcSInwgIcIHEiJ8ICHCBxIifCAhwgcSInwg\nIcIHEiJ8ICHCBxIifCAhwgcSmhi+7e22l2zvH1i22vZu26/bfsb2ebOdJoA6lVnjPyjpB8uW3SXp\nbxHxHUnPSfpV3RMDMDsTw4+IPZI+WrZ4k6QdxfUdkq6peV4AZmjaffw1EbEkSRFxVNKa+qYEYNZW\n1fQ8Me7GewauLxQXAPXqFZcyHDG22f6d7PWSno6IS4ufFyUtRMSS7TlJz0fEJSMeW2IEAHWzpIjw\nsNvKbuq7uJy0S9KNxfUbJO2cdnIAmjdxjW/7EfW3zs+XtCRpi6SnJD0u6UJJRyRdFxHHRjyeNT7Q\ngnFr/FKb+pUGJ3ygFXVs6gM4jRA+kBDhAwkRPpAQ4QMJET6QEOEDCRE+kBDhAwkRPpAQ4QMJ1fX3\n+GN93MQgAEprJPzDTQwCoDTCBxIifCAhwgcSInwgIcIHEiJ8IKFGjrk30wEAjMQx9wB8gfCBhAgf\nSIjwgYQIH0iI8IGECB9IqJEv8OjcRkYBMOiT0Tc1E/58I6MAGLR/9E2ED5yuCB/AoInh294u6UeS\nliLi0mLZFkk3S3q/uNvdEfHXkU8yX3meAGpUZo3/oKTfS3po2fKtEbG11CjzK5sUgNmaGH5E7LG9\nfshNQ//qZ6j5FcwIwMxV2ce/1fZPJP1D0i8jYvRRtOcrjAKgdtOGf5+kX0dE2P6NpK2Sfjry3tsG\nri8UFwD16hWXEkodiKPY1H/65Jt7ZW8rbg9xKA6gea5+IA5rYJ/e9tzAbddKemX62QFoWpmP8x5R\nf+P8fNv/kbRF0hW2N0g6of4h9W6Z4RwB1KyZY+6xqQ80r4ZNfQCnEcIHEiJ8ICHCBxIifCAhwgcS\nInwgIcIHEiJ8ICHCBxIifCAhwgcSInwgIcIHEiJ8ICHCBxIifCAhwgcSInwgIcIHEiJ8ICHCBxIi\nfCAhwgcSInwgIcIHEiJ8ICHCBxIifCAhwgcSInwgoYnh215n+znbr9o+YPu2Yvlq27ttv277Gdvn\nzX66AOrgiBh/B3tO0lxE7LN9jqSXJW2SdJOkDyPiXtt3SlodEXcNeXxo/BAAZsFSRHjYTRPX+BFx\nNCL2FdePS1qUtE79+HcUd9sh6Zp6Zgtg1la0j297XtIGSS9KWhsRS1L/l4OkNXVPDsBsrCp7x2Iz\n/wlJt0fEcdvLN+BHb9DfM3B9obgAqFevuJQwcR9fkmyvkvRnSX+JiG3FskVJCxGxVLwP8HxEXDLk\nsezjA22oso9feEDSayejL+ySdGNx/QZJO6eeIIBGlXlXf6OkFyQdUH9zPiTdLeklSY9JulDSEUnX\nRcSxIY9njQ+0Ycwav9SmfqWxCR9oRw2b+gBOI4QPJET4QEKEDyRE+EBChA8kRPhAQoQPJET4QEKE\nDyRE+EBCpf8ev5KPz21kGACDPhl5SzPhH55vZBgAg/aPvIXwgdMW4QMYQPhAQoQPJET4QEKEDyTU\nzDH3OOge0AJzzD0AXyJ8ICHCBxIifCAhwgcSInwgIcIHEiJ8ICHCBxIifCChieHbXmf7Oduv2j5g\n++fF8i2237b9z+Jy9eynC6AOE7+rb3tO0lxE7LN9jqSXJW2S9GNJn0bE1gmP57v6QCtGf1d/4l/n\nRcRRSUeL68dtL0q64ItnBvC1s6J9fNvzkjZI+nux6Fbb+2z/0fZ5Nc8NwIyUDr/YzH9C0u0RcVzS\nfZK+HREb1N8iGLvJD6A7Sh2Iw/Yq9aP/U0TslKSI+GDgLvdLenr0M9wzcH2huACoV6+4TFbqQBy2\nH5L034j4xcCyuWL/X7bvkHRZRFw/5LG8uQe0YvSbe2Xe1d8o6QVJB9QvOCTdLel69ff3T0g6LOmW\niFga8njCB1pRIfzKQxM+0BIOvQVgAOEDCRE+kBDhAwkRPpAQ4QMJET6QEOEDCRE+kBDhAwkRPpAQ\n4QMJtRB+r/khV6TX9gQm6LU9gQl6bU9gjF7bE5ig19hIhP8VvbYnMEGv7QlM0Gt7AmP02p7ABL3G\nRmJTH0iI8IGEGjoQB4A2tHYEHgDdw6Y+kBDhAwk1Fr7tq20ftH3I9p1NjVuW7cO2/237X7Zf6sB8\ntttesr1/YNlq27ttv277mTbPXjRifp05keqQk73eVizvxGvY9sloG9nHt32GpEOSrpT0rqS9kjZH\nxMGZD16S7TclfS8iPmp7LpJk+/uSjkt6KCIuLZb9TtKHEXFv8ctzdUTc1aH5bVGJE6k2YczJXm9S\nB17DqiejraqpNf7lkt6IiCMR8ZmkR9X/R3aJ1aFdn4jYI2n5L6FNknYU13dIuqbRSQ0YMT+pIydS\njYijEbGvuH5c0qKkderIazhifo2djLap/9EvkPTWwM9v68t/ZFeEpGdt77V9c9uTGWHNyZOWFGcx\nWtPyfIbp3IlUB072+qKktV17Dds4GW1n1nAdsDEivivph5J+VmzKdl3XPovt3IlUh5zsdflr1upr\n2NbJaJsK/x1JFw38vK5Y1hkR8V7x3w8kPan+7knXLNleK32xj/h+y/M5RUR8EF++aXS/pMvanM+w\nk72qQ6/hqJPRNvEaNhX+XkkX215v+0xJmyXtamjsiWyfVfzmle2zJV0l6ZV2ZyWpv683uL+3S9KN\nxfUbJO1c/oCGnTK/IqSTrlX7r+EDkl6LiG0Dy7r0Gn5lfk29ho19c6/4WGKb+r9stkfEbxsZuATb\n31J/LR/qnzr84bbnZ/sR9c8nfr6kJUlbJD0l6XFJF0o6Ium6iDjWofldoRInUm1ofqNO9vqSpMfU\n8mtY9WS0lcfnK7tAPry5ByRE+EBChA8kRPhAQoQPJET4QEKEDyRE+EBC/wfQaV85t2eqmQAAAABJ\nRU5ErkJggg==\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7fda881d4710>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure()\n",
"plt.imshow(res_lua.data.numpy()[0].transpose(1,2,0))\n",
"plt.figure()\n",
"plt.imshow(input_lua.numpy()[0].transpose(1,2,0))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"''' Same implementation like in https://github.com/fxia22/stn.pytorch/blob/master/script/test_conv_stn.ipynb\n",
" \n",
" get NHWC format\n",
" |\n",
" |\n",
" Sampler\n",
" | \\\n",
" | \\ \n",
" | \\\n",
" | /Network with NCHW format. So first input to NCHW. Output grid\n",
" |\n",
" | / \n",
" |/\n",
" input NHWC\n",
" \n",
" \n",
"Issue: Output image where first channel is used in every other channel\n",
"'''\n",
"\n",
"class ConvSpatialTransformer(nn.Module):\n",
" def __init__(self, height, width):\n",
" super(ConvSpatialTransformer, self).__init__()\n",
" self.s = STN()\n",
" self.height = height\n",
" self.width = width\n",
" self.conv = nn.Sequential( \n",
" torch.nn.MaxPool2d(2,2),\n",
" torch.nn.Conv2d(3, 20, 5, stride=2, padding=2),\n",
" torch.nn.ReLU(),\n",
" torch.nn.MaxPool2d(2,2),\n",
" torch.nn.Conv2d(20, 20, 5, stride=2, padding=2),\n",
" torch.nn.ReLU(),\n",
" )\n",
" self.grid_param = nn.Linear(80, 6)\n",
" self.g = AffineGridGenV2(height, width)\n",
" \n",
" self.grid_param.weight.data.zero_()\n",
" self.grid_param.bias.data.zero_()\n",
" self.grid_param.bias.data[0] = 1\n",
" self.grid_param.bias.data[4] = 1\n",
" \n",
" def forward(self, input):\n",
" conv = self.conv(input.permute(0,3,1,2))\n",
" x = conv.view(conv.size(0),-1)\n",
" x = self.grid_param(x)\n",
" grid_param = x.view(conv.size(0),2,3)\n",
" out = self.g(grid_param )\n",
" out2 = self.s(input, out) \n",
" return out2, out\n",
" \n",
"\n",
"c = ConvSpatialTransformer(28, 28)\n",
"# NHWC, In red channel will be a line\n",
"input_py = torch.Tensor(1,28,28,3)\n",
"input_py[:] = 0.0\n",
"input_py[0,5:13,:,0] = 1.0\n",
"input_py[0,15:23,:,1] = 1.0\n",
"input_py[0,23:28,:,2] = 1.0\n",
"\n",
"res_py, grid_py = c(Variable(input_py))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7fda782dd990>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAC41JREFUeJzt3U2oXdUZxvHnCZqBSiWISWqisUUQJ5K26CRFrghWSiHW\ngQ1CUSnioFZpO9A6iYUOWgcBKXTQNEosBlFBkxRaY9FLyMCa2kaj5kMoSf3KNYqxXmhrMG8HZ6ee\nXM/Hvvecs/ZO3v8PNu6zP85ad+Nz1lp7n5zliBCAXBY1XQEA5RF8ICGCDyRE8IGECD6QEMEHEhop\n+LZvsL3f9kHb946rUgAmywt9jm97kaSDkq6T9K6k3ZLWRcT+OcfxRQGgIRHhXttHafGvlvRmRByO\niOOSHpe0doT3A1DIKMFfIemtrtdvV9sAtBw394CERgn+O5Iu6Xq9stoGoOVGCf5uSZfZXmV7saR1\nkraNp1oAJumshZ4YEZ/ZvkvSDnU+QDZFxL6x1QzAxCz4cV7tAnicBzRmEo/zAJymCD6QEMEHEiL4\nQEIEH0iI4AMJEXwgIYIPJETwgYQIPpAQwQcSIvhAQgQfSIjgAwkRfCAhgg8kRPCBhAg+kBDBBxIi\n+EBCBB9IiOADCRF8ICGCDyRE8IGECD6QEMEHElrwpJnzcV6JQgCcYnbAviLBv6hEIQBOcXDAPoIP\nnKEIPoBTjBR824ckfSzphKTjEXF1r+MIPtAuo7b4JyRNRcRHgw4i+EC7jBp8q8YjQYIPtMuowQ9J\nz9n+TNJvI2Jjr4NWjFgIgPEaNfhrIuI92xeq8wGwLyJ2zT1oS9f6VdUCYLx2V0sdjoixFGp7vaRP\nImLDnO1xYCwlAJiPyyVFhHvtW3CLb/scSYsiYtb2uZKul/TzXscuXmghACZilK7+MklP247qfR6L\niB29Djx7hEIAjN/Yuvp9C7Dj/YmWAKCXpZpAV38+6OoD7ULwgYSKBJ8xPtAuRYJfpBAAtfELPEBC\nBB9IiOADCRF8ICGCDyRE8IGECD6QEMEHEiL4QEIEH0iI4AMJEXwgIYIPJETwgYQIPpAQwQcSIvhA\nQgQfSIjgAwkRfCAhgg8kRPCBhAg+kBDBBxIi+EBCBB9IiOADCQ0Nvu1Ntmdsv9q1bYntHbYP2H7W\n9vmTrSaAcarT4j8i6Vtztt0n6c8Rcbmk5yX9bNwVAzA5Q4MfEbskfTRn81pJm6v1zZJuHHO9AEzQ\nQsf4SyNiRpIi4oikpeOrEoBJG9fU9TFo5wNd61PVAmC8pqulDkcMzGznIHuVpO0RcWX1ep+kqYiY\nsb1c0gsRcUWfc2uUAGDcLCki3Gtf3a6+q+WkbZJuq9ZvlbR1oZUDUN7QFt/2FnV65xdImpG0XtIz\nkp6UdLGkw5Jujohjfc6nxQcaMKjFr9XVH6lwgg80YhxdfQBnEIIPJETwgYQIPpAQwQcSIvhAQgQf\nSIjgAwkRfCAhgg8kRPCBhMb17/EH+neJQgDUViT4c3+3C0CzCD6QEMEHEiL4QEIEH0iI4AMJEXwg\noSLB7/krnAAaUyT4H5coBEBtBB9IiOADCRUJ/r9KFAKgNlp8ICGCDyRE8IGEigT/v4tLlALgFJ/2\n31Uk+Dq/SCkAuh3tv4vgA2cqgg+g29Dg294k6TuSZiLiymrbekl3SHq/Ouz+iPhT3zf50ugVBTA+\ndVr8RyT9WtKjc7ZviIgNtUqhxQdaZWjwI2KX7VU9drl2KQQfaJVRxvh32f6+pL9K+mlE9H9cT/CB\nVnFEDD+o0+Jv7xrjXyjpg4gI27+Q9OWI+EGfc0PXdG24tFoAjNehajlppxQRPXvmC2rxI6L7QcFG\nSdsHnvDdhZQCYF6WSPpa1+ud/Q+tG3yra0xve3lEHKle3iTptaEVAtAadR7nbZE0JekC2/+UtF7S\ntbZXSzqhTufizoFvQvCBVqk1xh+pADsGdTkATMg1Yx7jzxstPtAqBB9IiOADCZUJ/jlFSkEGo96S\nqv9902YU+vsWjVgMgNMQwQcSIvhAQgQfSIjgAwkRfCAhgg8kVOY5PnCmOEO+R0CLDyRE8IGECD6Q\nEMEHEiL4QEIEH0iI4AMJ8Rwfp5eWPAdfsGHfAyj099HiAwkRfCAhgg8kRPCBhAg+kBDBBxIi+EBC\nPMcH5mPYc/bJTkU5NrT4QEIEH0iI4AMJDQ2+7ZW2n7f9uu29tu+uti+xvcP2AdvP2j5/8tUFMA6O\nGHw3wvZyScsjYo/t8yS9LGmtpNslfRgRD9q+V9KSiLivx/lxutzwAEbWph/jtBQRPd9xaIsfEUci\nYk+1Pitpn6SV6oR/c3XYZkk3jqe2ACZtXmN825dKWi3pRUnLImJG6nw4SFo67soBmIzaz/Grbv5T\nku6JiFnbczs1/Ts5D3StT1ULcCZq8vcCpqulhqFjfEmyfZakP0j6Y0Q8VG3bJ2kqImaq+wAvRMQV\nPc5ljA80YZQxfuVhSW+cDH1lm6TbqvVbJW1dcAUBFFXnrv4aSTsl7VWnOx+S7pf0kqQnJF0s6bCk\nmyPiWI/zafGBJgxo8Wt19Ucqm+ADzRhDVx/AGYTgAwkRfCAhgg8kRPCBhAg+kBDBBxIi+EBCBB9I\niOADCRF8IKEyv6v/GZ8vQHkn+u4pE/xPFxcpBkC3//TdUyj4ZxcpBkC3poN/nBYfaBO6+kBCdPWB\nhGjxgYQY4wMJlQn+MabVA9qkTPDfvahIMQDqIfhAQoWCv6JIMQDqocUHEiL4QEIEH0ioTPDfYYwP\ntEmZufOYPA9ogJk7D8DnCD6Q0NDg215p+3nbr9vea/tH1fb1tt+2/bdquWHy1QUwDkPH+LaXS1oe\nEXtsnyfpZUlrJX1P0icRsWHI+YzxgUb0H+MPvasfEUckHanWZ23vk3TyNn3PNwXQbvMa49u+VNJq\nSX+pNt1le4/t39nmn+ABp4nawa+6+U9JuiciZiX9RtJXI2K1Oj2CgV1+AO1R6ws8ts9SJ/S/j4it\nkhQRR7sO2Shpe/93eKBrfapaAIzXdLUMV+sLPLYflfRBRPyka9vyavwv2z+WdFVE3NLjXG7uAY3o\nf3Ovzl39NZJ2StqrToJD0v2SblFnvH9C0iFJd0bETI/zCT7QiBGCP3LRBB9oCF/ZBdCF4AMJEXwg\nIYIPJETwgYQIPpAQwQcSIvhAQgQfSIjgAwkRfCAhgg8k1EDwp8sXOS/TTVdgiOmmKzDEdNMVGGC6\n6QoMMV2sJIL/BdNNV2CI6aYrMMR00xUYYLrpCgwxXawkuvpAQgQfSKjQD3EAaEJjv8ADoH3o6gMJ\nEXwgoWLBt32D7f22D9q+t1S5ddk+ZPsV23+3/VIL6rPJ9oztV7u2LbG9w/YB2882OXtRn/q1ZiLV\nHpO93l1tb8U1bHoy2iJjfNuLJB2UdJ2kdyXtlrQuIvZPvPCabP9D0jci4qOm6yJJtr8paVbSoxFx\nZbXtV5I+jIgHqw/PJRFxX4vqt141JlItYcBkr7erBddw1MloR1Wqxb9a0psRcTgijkt6XJ0/sk2s\nFg19ImKXpLkfQmslba7WN0u6sWiluvSpn9SSiVQj4khE7KnWZyXtk7RSLbmGfepXbDLaUv+jr5D0\nVtfrt/X5H9kWIek527tt39F0ZfpYenLSkmoWo6UN16eX1k2k2jXZ64uSlrXtGjYxGW1rWrgWWBMR\nX5f0bUk/rLqybde2Z7Gtm0i1x2Svc69Zo9ewqcloSwX/HUmXdL1eWW1rjYh4r/rvUUlPqzM8aZsZ\n28uk/48R32+4PqeIiKPx+U2jjZKuarI+vSZ7VYuuYb/JaEtcw1LB3y3pMturbC+WtE7StkJlD2X7\nnOqTV7bPlXS9pNearZWkzlive7y3TdJt1fqtkrbOPaGwU+pXBemkm9T8NXxY0hsR8VDXtjZdwy/U\nr9Q1LPbNveqxxEPqfNhsiohfFim4BttfUaeVD3WmDn+s6frZ3qLOfOIXSJqRtF7SM5KelHSxpMOS\nbo6IYy2q37WqMZFqofr1m+z1JUlPqOFrOOpktCOXz1d2gXy4uQckRPCBhAg+kBDBBxIi+EBCBB9I\niOADCRF8IKH/AWZmc+xcK4EHAAAAAElFTkSuQmCC\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7fda8809b890>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD8CAYAAABXXhlaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACwRJREFUeJzt3U/IHPUdx/HPR4IHFSWVJg8YzdMiFC8SWvSSHh4RrJRC\nxIMNQlEp4qFWsT1ovURKD62HQCh4qI0SiyIqaGKhNRZdJAdrapsm6hMjSFL/5dFao+Ym5tvDTnTz\nuH/m2ZmdGfN9v2Bxn9k/v18W38/M7O4z44gQgFzOaHsCAJpH+EBChA8kRPhAQoQPJET4QEKVwrd9\nte2Dtg/ZvrOuSQGYLU/7Ob7tMyQdknSlpHcl7ZW0OSIOLrsfXxQAWhIRHra8yhr/cklvRMSRiPhM\n0qOSNlV4PgANqRL+BZLeGvj57WIZgI7jzT0goSrhvyPpooGf1xXLAHRclfD3SrrY9nrbZ0raLGlX\nPdMCMEurpn1gRHxu+1ZJu9X/BbI9IhZrmxmAmZn647zSA/BxHtCaWXycB+BrivCBhAgfSIjwgYQI\nH0iI8IGECB9IiPCBhAgfSIjwgYQIH0iI8IGECB9IiPCBhAgfSIjwgYQIH0iI8IGECB9IiPCBhAgf\nSIjwgYQIH0iI8IGECB9IiPCBhAgfSGjqk2auxLlNDALgFJ+Mua2R8OebGATAKfaPuY3wgdMU4QM4\nRaXwbR+W9LGkE5I+i4jLh91vvsogAGpXdY1/QtJCRHw07k7zFQcBUK+q4VslPhKcrzgIgHo5IqZ/\nsP2mpGOSPpf0h4i4f8h94n/Tzw/AlL4hKSI87Laqa/yNEfGe7W9Ketb2YkTsWX6nbQPXF4oLgHr1\niksZldb4pzyRvUXSpxGxddnymkYAsBLW6DX+1F/ZtX2W7XOK62dLukrSK9M+H4DmVNnUXyvpSdtR\nPM/DEbG7nmkBmKXaNvVHDsCmPtCKmWzqA/j6InwgIcIHEiJ8ICHCBxIifCAhwgcSInwgIcIHEiJ8\nICHCBxIifCAhwgcSInwgIcIHEiJ8ICHCBxIifCAhwgcSInwgIcIHEiJ8ICHCBxIifCAhwgcSInwg\nIcIHEiJ8ICHCBxIifCAhwgcSmhi+7e22l2zvH1i22vZu26/bfsb2ebOdJoA6lVnjPyjpB8uW3SXp\nbxHxHUnPSfpV3RMDMDsTw4+IPZI+WrZ4k6QdxfUdkq6peV4AZmjaffw1EbEkSRFxVNKa+qYEYNZW\n1fQ8Me7GewauLxQXAPXqFZcyHDG22f6d7PWSno6IS4ufFyUtRMSS7TlJz0fEJSMeW2IEAHWzpIjw\nsNvKbuq7uJy0S9KNxfUbJO2cdnIAmjdxjW/7EfW3zs+XtCRpi6SnJD0u6UJJRyRdFxHHRjyeNT7Q\ngnFr/FKb+pUGJ3ygFXVs6gM4jRA+kBDhAwkRPpAQ4QMJET6QEOEDCRE+kBDhAwkRPpAQ4QMJ1fX3\n+GN93MQgAEprJPzDTQwCoDTCBxIifCAhwgcSInwgIcIHEiJ8IKFGjrk30wEAjMQx9wB8gfCBhAgf\nSIjwgYQIH0iI8IGECB9IqJEv8OjcRkYBMOiT0Tc1E/58I6MAGLR/9E2ED5yuCB/AoInh294u6UeS\nliLi0mLZFkk3S3q/uNvdEfHXkU8yX3meAGpUZo3/oKTfS3po2fKtEbG11CjzK5sUgNmaGH5E7LG9\nfshNQ//qZ6j5FcwIwMxV2ce/1fZPJP1D0i8jYvRRtOcrjAKgdtOGf5+kX0dE2P6NpK2Sfjry3tsG\nri8UFwD16hWXEkodiKPY1H/65Jt7ZW8rbg9xKA6gea5+IA5rYJ/e9tzAbddKemX62QFoWpmP8x5R\nf+P8fNv/kbRF0hW2N0g6of4h9W6Z4RwB1KyZY+6xqQ80r4ZNfQCnEcIHEiJ8ICHCBxIifCAhwgcS\nInwgIcIHEiJ8ICHCBxIifCAhwgcSInwgIcIHEiJ8ICHCBxIifCAhwgcSInwgIcIHEiJ8ICHCBxIi\nfCAhwgcSInwgIcIHEiJ8ICHCBxIifCAhwgcSInwgoYnh215n+znbr9o+YPu2Yvlq27ttv277Gdvn\nzX66AOrgiBh/B3tO0lxE7LN9jqSXJW2SdJOkDyPiXtt3SlodEXcNeXxo/BAAZsFSRHjYTRPX+BFx\nNCL2FdePS1qUtE79+HcUd9sh6Zp6Zgtg1la0j297XtIGSS9KWhsRS1L/l4OkNXVPDsBsrCp7x2Iz\n/wlJt0fEcdvLN+BHb9DfM3B9obgAqFevuJQwcR9fkmyvkvRnSX+JiG3FskVJCxGxVLwP8HxEXDLk\nsezjA22oso9feEDSayejL+ySdGNx/QZJO6eeIIBGlXlXf6OkFyQdUH9zPiTdLeklSY9JulDSEUnX\nRcSxIY9njQ+0Ycwav9SmfqWxCR9oRw2b+gBOI4QPJET4QEKEDyRE+EBChA8kRPhAQoQPJET4QEKE\nDyRE+EBCpf8ev5KPz21kGACDPhl5SzPhH55vZBgAg/aPvIXwgdMW4QMYQPhAQoQPJET4QEKEDyTU\nzDH3OOge0AJzzD0AXyJ8ICHCBxIifCAhwgcSInwgIcIHEiJ8ICHCBxIifCChieHbXmf7Oduv2j5g\n++fF8i2237b9z+Jy9eynC6AOE7+rb3tO0lxE7LN9jqSXJW2S9GNJn0bE1gmP57v6QCtGf1d/4l/n\nRcRRSUeL68dtL0q64ItnBvC1s6J9fNvzkjZI+nux6Fbb+2z/0fZ5Nc8NwIyUDr/YzH9C0u0RcVzS\nfZK+HREb1N8iGLvJD6A7Sh2Iw/Yq9aP/U0TslKSI+GDgLvdLenr0M9wzcH2huACoV6+4TFbqQBy2\nH5L034j4xcCyuWL/X7bvkHRZRFw/5LG8uQe0YvSbe2Xe1d8o6QVJB9QvOCTdLel69ff3T0g6LOmW\niFga8njCB1pRIfzKQxM+0BIOvQVgAOEDCRE+kBDhAwkRPpAQ4QMJET6QEOEDCRE+kBDhAwkRPpAQ\n4QMJtRB+r/khV6TX9gQm6LU9gQl6bU9gjF7bE5ig19hIhP8VvbYnMEGv7QlM0Gt7AmP02p7ABL3G\nRmJTH0iI8IGEGjoQB4A2tHYEHgDdw6Y+kBDhAwk1Fr7tq20ftH3I9p1NjVuW7cO2/237X7Zf6sB8\ntttesr1/YNlq27ttv277mTbPXjRifp05keqQk73eVizvxGvY9sloG9nHt32GpEOSrpT0rqS9kjZH\nxMGZD16S7TclfS8iPmp7LpJk+/uSjkt6KCIuLZb9TtKHEXFv8ctzdUTc1aH5bVGJE6k2YczJXm9S\nB17DqiejraqpNf7lkt6IiCMR8ZmkR9X/R3aJ1aFdn4jYI2n5L6FNknYU13dIuqbRSQ0YMT+pIydS\njYijEbGvuH5c0qKkderIazhifo2djLap/9EvkPTWwM9v68t/ZFeEpGdt77V9c9uTGWHNyZOWFGcx\nWtPyfIbp3IlUB072+qKktV17Dds4GW1n1nAdsDEivivph5J+VmzKdl3XPovt3IlUh5zsdflr1upr\n2NbJaJsK/x1JFw38vK5Y1hkR8V7x3w8kPan+7knXLNleK32xj/h+y/M5RUR8EF++aXS/pMvanM+w\nk72qQ6/hqJPRNvEaNhX+XkkX215v+0xJmyXtamjsiWyfVfzmle2zJV0l6ZV2ZyWpv683uL+3S9KN\nxfUbJO1c/oCGnTK/IqSTrlX7r+EDkl6LiG0Dy7r0Gn5lfk29ho19c6/4WGKb+r9stkfEbxsZuATb\n31J/LR/qnzr84bbnZ/sR9c8nfr6kJUlbJD0l6XFJF0o6Ium6iDjWofldoRInUm1ofqNO9vqSpMfU\n8mtY9WS0lcfnK7tAPry5ByRE+EBChA8kRPhAQoQPJET4QEKEDyRE+EBC/wfQaV85t2eqmQAAAABJ\nRU5ErkJggg==\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7fda78316210>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure()\n",
"plt.imshow(res_py.data.numpy()[0])\n",
"plt.figure()\n",
"plt.imshow(input_py.numpy()[0])"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Variable containing:\n",
" 0\n",
"[torch.FloatTensor of size 1]\n",
"\n",
"0.0\n",
"Variable containing:\n",
" 115.0000\n",
"[torch.FloatTensor of size 1]\n",
"\n"
]
}
],
"source": [
"print (grid_lua - grid_py).sum()\n",
"print (input_lua - input_py).sum()\n",
"print (res_lua - res_py ).sum()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment