Last active
July 6, 2019 13:34
-
-
Save karpathy/f3ee599538ff78e1bbe9 to your computer and use it in GitHub Desktop.
Batched L2 Normalization Layer for Torch nn package
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
--[[ | |
This layer expects an [n x d] Tensor and normalizes each | |
row to have unit L2 norm. | |
]]-- | |
local L2Normalize, parent = torch.class('nn.L2Normalize', 'nn.Module') | |
function L2Normalize:__init() | |
parent.__init(self) | |
end | |
function L2Normalize:updateOutput(input) | |
assert(input:dim() == 2, 'only mini-batch supported (2D tensor), got ' | |
.. input:dim() .. 'D tensor instead') | |
self.output:resizeAs(input) | |
self.buffer = self.buffer or input.new() | |
self.normSquared = self.normSquared or input.new() | |
self.normSquared:sum(self.buffer:cmul(input, input), 2) | |
self.buffer:sqrt(self.normSquared) | |
self.output:copy(input):cdiv(self.buffer:expandAs(input)) | |
return self.output | |
end | |
function L2Normalize:updateGradInput(input, gradOutput) | |
assert(input:dim() == 2, 'only mini-batch supported') | |
assert(gradOutput:dim() == 2, 'only mini-batch supported') | |
local n = input:size(1) -- batch size | |
local d = input:size(2) -- dimensionality of vectors | |
-- compute diagonal term | |
self.eye = self.eye or torch.eye(d):typeAs(input):repeatTensor(n,1):view(n,d,d) | |
self.diag = self.diag or self.eye.new() | |
self.diag:cmul(self.eye, self.normSquared:view(n,1,1):expand(n,d,d)) | |
-- compute cross term | |
local b1 = input:view(n,d,1) | |
local b2 = input:view(n,1,d) | |
self.diag:add(-torch.bmm(b1,b2)) | |
-- compute the local gradient of the L2 transformation | |
self.diag:cdiv(torch.pow(self.buffer,3):view(n,1,1):expand(n,d,d)) | |
-- chain the gradient | |
self.gradInput:resize(n,d,1):bmm(self.diag, gradOutput:view(n,d,1)):resize(n,d) | |
return self.gradInput | |
end |
Shouldn't line 31
be followed by :squeeze()
for matching the dimensions? I have some funky gradInput
dimensions otherwise..
Do you mean using norm()
in forward pass? That could be done.
Oops, you're right about squeeze
, fixed!
I saw @soumith gave you other pointers as well.
It's nice having you joining the Torch circle 😉
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Any reason for not using
torch.norm()
?You are doing this, right?
\[\frac{\partial\frac{\boldsymbol{x}}{\sqrt{\boldsymbol{x}^\top\boldsymbol{x}}}}{\partial\boldsymbol{x}} = \frac{\mathbb{I}}{\sqrt{\boldsymbol{x}^\top\boldsymbol{x}}} - \frac{\boldsymbol{x}\boldsymbol{x}^\top}{\sqrt{\left(\boldsymbol{x}^\top\boldsymbol{x}\right)^3}}\]