Skip to content

Instantly share code, notes, and snippets.

@szagoruyko
Forked from fmassa/convertLinear2Conv1x1.lua
Created September 28, 2015 18:16
Show Gist options
  • Save szagoruyko/4e0d2b7f5fdaf877a6a9 to your computer and use it in GitHub Desktop.
Save szagoruyko/4e0d2b7f5fdaf877a6a9 to your computer and use it in GitHub Desktop.
Simple example on how to convert a Linear model to a 1x1 convolution
require 'nn'
-- you just need to provide the linear module you want to convert,
-- and the dimensions of the field of view of the linear layer
function convertLinear2Conv1x1(linmodule,in_size)
local s_in = linmodule.weight:size(2)/(in_size[1]*in_size[2])
local s_out = linmodule.weight:size(1)
local convmodule = nn.SpatialConvolutionMM(s_in,s_out,in_size[1],in_size[2],1,1)
convmodule.weight:copy(linmodule.weight)
convmodule.bias:copy(linmodule.bias)
return convmodule
end
input = torch.rand(3,6,6)
m = nn.Linear(3*6*6,10)
mm = convertLinear2Conv1x1(m,{6,6})
output_lin = m:forward(input:view(3*6*6))
output_conv = mm:forward(input)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment