Skip to content

Instantly share code, notes, and snippets.

@szagoruyko
Last active August 29, 2015 14:27
Show Gist options
  • Save szagoruyko/8b13130090b651cc1399 to your computer and use it in GitHub Desktop.
Save szagoruyko/8b13130090b651cc1399 to your computer and use it in GitHub Desktop.
-- a script to simplify trained net by incorporating every SpatialBatchNormalization to SpatialConvolution
-- and BatchNormalization to Linear
local function BNtoConv(net)
for i,v in ipairs(net.modules) do
if v.modules then
BNtoConv(v)
else
if torch.typename(v) == 'nn.SpatialBatchNormalization' and
(torch.typename(net:get(i-1)):find'SpatialConvolution') then
local conv = net:get(i-1)
local bn = v
net:remove(i)
local no = conv.nOutputPlane
local conv_w = conv.weight:view(no,-1)
conv_w:cmul(bn.running_std:view(no,-1):expandAs(conv_w))
conv.bias:add(-1,bn.running_mean):cmul(bn.running_std)
if bn.affine then
conv.bias:cmul(bn.weight):add(bn.bias)
conv_w:cmul(bn.weight:view(no,-1):expandAs(conv_w))
end
end
end
end
end
local function BNToLinear(net)
for i,v in ipairs(net.modules) do
if v.modules then
BNToLinear(v)
else
if torch.typename(v) == 'nn.BatchNormalization' and
(torch.typename(net:get(i-1)):find'Linear') then
local linear = net:get(i-1)
local bn = v
net:remove(i)
local no = linear.weight:size(1)
linear.weight:cmul(bn.running_std:view(no,1):expandAs(linear.weight))
linear.bias:add(-1,bn.running_mean):cmul(bn.running_std)
if bn.affine then
linear.bias:cmul(bn.weight):add(bn.bias)
linear.weight:cmul(bn.weight:view(no,1):expandAs(linear.weight))
end
end
end
end
end
local function incorporateBNtoConvAndLinear(net)
-- works in place!
BNtoConv(net)
BNtoConv(net)
BNToLinear(net)
-- check
assert(#net:findModules'nn.SpatialBatchNormalization' == 0)
assert(#net:findModules'nn.BatchNormalization' == 0)
end
local function test()
require 'cudnn'
require 'cunn'
local net = torch.load'./Zoo/GoogLeNet/googlenet_imagenet.t7'.model
net:evaluate()
local input = torch.randn(32,3,227,227):cuda()
local output = net:forward(input):clone()
incorporateBNtoConvAndLinear(net)
assert((output - net:forward(input)):abs():max() < 2e-5)
end
--test()
return incorporateBNtoConvAndLinear
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment