Skip to content

Instantly share code, notes, and snippets.

@albanD
Last active September 30, 2015 14:46
Show Gist options
  • Save albanD/8568df9ed3981418a133 to your computer and use it in GitHub Desktop.
Save albanD/8568df9ed3981418a133 to your computer and use it in GitHub Desktop.
require 'nn'
require 'stn' -- github.com/qassemoquab/stnbhwd by Maxime Oquab
local localization_network = torch.load('your_locnet.t7')
local ct = nn.ConcatTable()
local branch1 = nn.Transpose({3,4},{2,4})
local branch2 = nn.Sequential()
branch2:add(localization_network)
branch2:add(nn.AffineTransformMatrixGenerator(use_rot, use_sca, use_tra))
branch2:add(nn.AffineGridGeneratorBHWD(input_size, input_size))
ct:add(branch1)
ct:add(branch2)
local st_module = nn.Sequential()
st_module:add(ct)
st_module:add(nn.BilinearSamplerBHWD())
st_module:add(nn.Transpose({2,4},{3,4}))
return st_module
require 'nn'
-- Create a network
local network = nn.Sequential()
network:add(nn.SpatialConvolution(3, 32, 5, 5, 1, 1, 2, 2))
network:add(nn.ReLU())
network:add(nn.SpatialMaxPooling(2, 2, 2, 2))
network:add(nn.Linear(32*24*24, 10))
-- Create a criterion
local criterion = nn.CrossEntropyCriterion()
-- Create dummy input
local input = torch.rand(3,48,48)
local target = 1
-- learn
network:zeroGradParameters()
local output = network:forward(input)
local error = criterion:forward(output, target)
local grad = criterion:backward(output, target)
network:backward(input, grad)
network:updateParameters(0.01)
43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment