Created
August 21, 2015 13:13
-
-
Save albanD/954021a4be9e1ccab753 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'nn' | |
require 'stn' | |
------ | |
-- Prepare your localization network | |
local localization_network = torch.load('your_locnet.t7') | |
------ | |
-- prepare both branches of the st | |
local ct = nn.ConcatTable() | |
-- This branch does not modify the input, just change the data layout to bhwd | |
local branch1 = nn.Transpose({3,4},{2,4}) | |
-- This branch will compute the parameters and generate the grid | |
local branch2 = nn.Sequential() | |
branch2:add(localization_network) | |
-- Here you can restrict the possible transformation with the "use_*" boolean variables | |
branch2:add(nn.AffineTransformMatrixGenerator(use_rot, use_sca, use_tra)) | |
branch2:add(nn.AffineGridGeneratorBHWD(input_size, input_size)) | |
ct:add(branch1) | |
ct:add(branch2) | |
------ | |
-- Wrap the st in one module | |
local st_module = nn.Sequential() | |
st_module:add(ct) | |
st_module:add(nn.BilinearSamplerBHWD()) | |
-- go back to the bdhw layout (used by all default torch modules) | |
st_module:add(nn.Transpose({2,4},{3,4})) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment