Last active
August 29, 2015 14:27
-
-
Save szagoruyko/8b13130090b651cc1399 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- a script to simplify trained net by incorporating every SpatialBatchNormalization to SpatialConvolution | |
-- and BatchNormalization to Linear | |
local function BNtoConv(net) | |
for i,v in ipairs(net.modules) do | |
if v.modules then | |
BNtoConv(v) | |
else | |
if torch.typename(v) == 'nn.SpatialBatchNormalization' and | |
(torch.typename(net:get(i-1)):find'SpatialConvolution') then | |
local conv = net:get(i-1) | |
local bn = v | |
net:remove(i) | |
local no = conv.nOutputPlane | |
local conv_w = conv.weight:view(no,-1) | |
conv_w:cmul(bn.running_std:view(no,-1):expandAs(conv_w)) | |
conv.bias:add(-1,bn.running_mean):cmul(bn.running_std) | |
if bn.affine then | |
conv.bias:cmul(bn.weight):add(bn.bias) | |
conv_w:cmul(bn.weight:view(no,-1):expandAs(conv_w)) | |
end | |
end | |
end | |
end | |
end | |
local function BNToLinear(net) | |
for i,v in ipairs(net.modules) do | |
if v.modules then | |
BNToLinear(v) | |
else | |
if torch.typename(v) == 'nn.BatchNormalization' and | |
(torch.typename(net:get(i-1)):find'Linear') then | |
local linear = net:get(i-1) | |
local bn = v | |
net:remove(i) | |
local no = linear.weight:size(1) | |
linear.weight:cmul(bn.running_std:view(no,1):expandAs(linear.weight)) | |
linear.bias:add(-1,bn.running_mean):cmul(bn.running_std) | |
if bn.affine then | |
linear.bias:cmul(bn.weight):add(bn.bias) | |
linear.weight:cmul(bn.weight:view(no,1):expandAs(linear.weight)) | |
end | |
end | |
end | |
end | |
end | |
local function incorporateBNtoConvAndLinear(net) | |
-- works in place! | |
BNtoConv(net) | |
BNtoConv(net) | |
BNToLinear(net) | |
-- check | |
assert(#net:findModules'nn.SpatialBatchNormalization' == 0) | |
assert(#net:findModules'nn.BatchNormalization' == 0) | |
end | |
local function test() | |
require 'cudnn' | |
require 'cunn' | |
local net = torch.load'./Zoo/GoogLeNet/googlenet_imagenet.t7'.model | |
net:evaluate() | |
local input = torch.randn(32,3,227,227):cuda() | |
local output = net:forward(input):clone() | |
incorporateBNtoConvAndLinear(net) | |
assert((output - net:forward(input)):abs():max() < 2e-5) | |
end | |
--test() | |
return incorporateBNtoConvAndLinear |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment