Skip to content

Instantly share code, notes, and snippets.

@szagoruyko
Created February 28, 2016 22:45
Show Gist options
  • Save szagoruyko/315a94ed7622f2fe7fe5 to your computer and use it in GitHub Desktop.
Save szagoruyko/315a94ed7622f2fe7fe5 to your computer and use it in GitHub Desktop.
local SpatialAffine, parent = torch.class('nn.SpatialAffine', 'nn.Module')
function SpatialAffine:__init(nOutput)
parent.__init(self)
self.weight = torch.Tensor(nOutput)
self.bias = torch.Tensor(nOutput)
end
function SpatialAffine:updateOutput(input)
local nFeature = self.weight:numel()
assert(input:nDimension() == 4)
assert(input:size(2) == nFeature)
local bs = input:size(1)
local nH = input:size(3)
local nW = input:size(4)
self.buffer = self.buffer or input.new()
self.buffer:resizeAs(input)
local weight = self.weight:view(1,self.weight:numel(),1,1)
local bias = self.bias:view(1,self.bias:numel(),1,1)
self.buffer:repeatTensor(weight, bs, 1, nH, nW)
self.output:resizeAs(input):copy(input)
self.output:cmul(self.buffer)
self.buffer:repeatTensor(bias, bs, 1, nH, nW)
self.output:add(self.buffer)
return self.output
end
function SpatialAffine:updateGradInput(input, gradOutput)
local nFeature = self.weight:numel()
assert(input:nDimension() == 4)
assert(input:size(2) == nFeature)
local bs = input:size(1)
local nH = input:size(3)
local nW = input:size(4)
local weight = self.weight:view(1,nFeature,1,1)
self.gradInput:resizeAs(input):repeatTensor(weight, bs, 1, nH, nW)
self.gradInput:cmul(gradOutput)
return self.gradInput
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment