Created
February 20, 2016 15:57
-
-
Save szagoruyko/b358209ab31c370fe006 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Example of extracting descriptors on CPU\n", | |
"===========" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"nn = require 'nn'" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Define input patches" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
" 8\n", | |
" 1\n", | |
" 64\n", | |
" 64\n", | |
"[torch.LongStorage of size 4]\n", | |
"\n" | |
] | |
}, | |
"execution_count": 19, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"N = 8 -- the number of patches to match\n", | |
"patches = torch.randn(N,1,64,64):float()\n", | |
"print(patches:size())" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Load network" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"nn.Sequential {\n", | |
" [input -> (1) -> (2) -> (3) -> (4) -> output]\n", | |
" (1): nn.Parallel {\n", | |
" input\n", | |
" |`-> (1): nn.Sequential {\n", | |
" | [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> (7) -> (8) -> (9) -> (10) -> output]\n", | |
" | (1): nn.Reshape(1x64x64)\n", | |
" | (2): nn.SpatialConvolution(1 -> 96, 7x7, 3,3)\n", | |
" | (3): nn.ReLU\n", | |
" | (4): nn.SpatialMaxPooling(2,2,2,2)\n", | |
" | (5): nn.SpatialConvolution(96 -> 192, 5x5, 1,1)\n", | |
" | (6): nn.ReLU\n", | |
" | (7): nn.SpatialMaxPooling(2,2,2,2)\n", | |
" | (8): nn.SpatialConvolution(192 -> 256, 3x3, 1,1)\n", | |
" | (9): nn.ReLU\n", | |
" | (10): nn.View(-1)\n", | |
" | }\n", | |
" |`-> (2): nn.Sequential {\n", | |
" [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> (7) -> (8) -> (9) -> (10) -> output]\n", | |
" (1): nn.Reshape(1x64x64)\n", | |
" (2): nn.SpatialConvolution(1 -> 96, 7x7, 3,3)\n", | |
" (3): nn.ReLU\n", | |
" (4): nn.SpatialMaxPooling(2,2,2,2)\n", | |
" (5): nn.SpatialConvolution(96 -> 192, 5x5, 1,1)\n", | |
" (6): nn.ReLU\n", | |
" (7): nn.SpatialMaxPooling(2,2,2,2)\n", | |
" (8): nn.SpatialConvolution(192 -> 256, 3x3, 1,1)\n", | |
" (9): nn.ReLU\n", | |
" (10): nn.View(-1)\n", | |
" }\n", | |
" ... -> output\n", | |
" }\n", | |
" (2): nn.Linear(512 -> 512)\n", | |
" (3): nn.ReLU\n", | |
" (4): nn.Linear(512 -> 1)\n", | |
"}\t\n" | |
] | |
}, | |
"execution_count": 20, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"net = torch.load'/tmp/networks/siam/siam_liberty.t7'\n", | |
"print(tostring(net))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Compute similarity" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Let's compute similarity between the generated patches and a copy of them" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"patch_pairs size:\n", | |
"\t 8\n", | |
" 2\n", | |
" 64\n", | |
" 64\n", | |
"[torch.LongStorage of size 4]\n", | |
"\n", | |
"similarities:\t\n" | |
] | |
}, | |
"execution_count": 21, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
" 19.0403\n", | |
" 21.5694\n", | |
" 23.9795\n", | |
" 21.1930\n", | |
" 20.9802\n", | |
" 23.6913\n", | |
" 20.4643\n", | |
" 17.5520\n", | |
"[torch.FloatTensor of size 8x1]\n", | |
"\n" | |
] | |
}, | |
"execution_count": 21, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"-- in place mean subtraction\n", | |
"local p = patches:view(N,1,64*64)\n", | |
"p:add(-1, p:mean(3):expandAs(p))\n", | |
"\n", | |
"local patch_pairs = torch.cat(patches, patches, 2)\n", | |
"print('patch_pairs size:\\n',patch_pairs:size())\n", | |
"\n", | |
"-- propagate through the network\n", | |
"print'similarities:'\n", | |
"print(net:forward(patch_pairs))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Similarities are high as expected" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Compute descriptors" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Lets take the first branch of our network:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"nn.Sequential {\n", | |
" [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> (7) -> (8) -> (9) -> (10) -> output]\n", | |
" (1): nn.Reshape(1x64x64)\n", | |
" (2): nn.SpatialConvolution(1 -> 96, 7x7, 3,3)\n", | |
" (3): nn.ReLU\n", | |
" (4): nn.SpatialMaxPooling(2,2,2,2)\n", | |
" (5): nn.SpatialConvolution(96 -> 192, 5x5)\n", | |
" (6): nn.ReLU\n", | |
" (7): nn.SpatialMaxPooling(2,2,2,2)\n", | |
" (8): nn.SpatialConvolution(192 -> 256, 3x3)\n", | |
" (9): nn.ReLU\n", | |
" (10): nn.View(-1)\n", | |
"}\t\n" | |
] | |
}, | |
"execution_count": 25, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"descriptor_net = net:get(1):get(1)\n", | |
"print(tostring(descriptor_net))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
" 124\n", | |
" 1\n", | |
" 64\n", | |
" 64\n", | |
"[torch.LongStorage of size 4]\n", | |
"\n" | |
] | |
}, | |
"execution_count": 29, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
" 124\n", | |
" 256\n", | |
"[torch.LongStorage of size 2]\n", | |
"\n" | |
] | |
}, | |
"execution_count": 29, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"N = 124 -- the number of patches to match\n", | |
"patches = torch.randn(N,1,64,64):float()\n", | |
"print(patches:size())\n", | |
"\n", | |
"-- propagate only patches, not pairs\n", | |
"output = descriptor_net:forward(patches)\n", | |
"\n", | |
"-- print the size of the output tensor\n", | |
"print(#output)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 30, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "", | |
"text/plain": [ | |
"Console does not support images" | |
] | |
}, | |
"metadata": { | |
"image/png": { | |
"height": 124, | |
"width": 256 | |
} | |
}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"\n" | |
] | |
}, | |
"execution_count": 30, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"-- Visualize our descriptors\n", | |
"itorch.image(output)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "iTorch", | |
"language": "lua", | |
"name": "itorch" | |
}, | |
"language_info": { | |
"name": "lua", | |
"version": "5.1" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment