Created
June 11, 2015 13:43
-
-
Save szagoruyko/6da251c360340cd3c48a to your computer and use it in GitHub Desktop.
BidirectionalSequencer.lua
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
local BidirectionalSequencer, parent = torch.class('nn.BidirectionalSequencer', 'nn.Container') | |
function BidirectionalSequencer:__init(module_forward, module_backward, nOutputSize) | |
parent.__init(self) | |
self.module_forward = module_forward | |
self.module_backward = module_backward | |
self.modules[1] = nn.Sequencer(module_forward) | |
self.modules[2] = nn.Sequencer(module_backward) | |
self.output = {} | |
self.gradInput = {} | |
end | |
function BidirectionalSequencer:updateOutput(input) | |
local reverse_input = {} | |
for i,v in ipairs(input) do | |
reverse_input[#input - i + 1] = v | |
end | |
local of = self.modules[1]:updateOutput(input) | |
local ob = self.modules[2]:updateOutput(reverse_input) | |
self.of = of | |
self.ob = ob | |
local bs = of[1]:size(1) | |
--if input[1]:nDimension() == 2 then bs = input[1]:size(1) end | |
local of_ndim = of[1]:size(2) | |
local ob_ndim = ob[1]:size(2) | |
for i,v in ipairs(input) do | |
if not self.output[i] then self.output[i] = v.new() end | |
self.output[i]:resize(bs,of_ndim + ob_ndim) | |
self.output[i]:narrow(2,1,of_ndim):copy(of[i]) | |
self.output[i]:narrow(2,of_ndim+1,ob_ndim):copy(ob[i]) | |
end | |
return self.output | |
end | |
function BidirectionalSequencer:updateGradInput(input, gradOutput) | |
local reverse_input = {} | |
for i,v in ipairs(input) do | |
reverse_input[#input - i + 1] = v | |
end | |
local of_ndim = self.modules[1].output[1]:size(2) | |
local ob_ndim = self.modules[2].output[1]:size(2) | |
local forward_gradOutput = {} | |
local backward_gradOutput = {} | |
for i,v in ipairs(gradOutput) do | |
forward_gradOutput[i] = v:narrow(2,1,of_ndim) | |
backward_gradOutput[i] = v:narrow(2,of_ndim+1,of_ndim) | |
end | |
local forward_gradInput = self.modules[1]:updateGradInput(input, forward_gradOutput) | |
local backward_gradInput = self.modules[2]:updateGradInput(reverse_input, backward_gradOutput) | |
for i,v in ipairs(forward_gradInput) do | |
if not self.gradInput[i] then self.gradInput[i] = input[1].new() end | |
self.gradInput[i]:resize(#input[i]) | |
self.gradInput[i]:copy(v) | |
self.gradInput[i]:add(backward_gradInput[#gradOutput - i + 1]) | |
end | |
return self.gradInput | |
end | |
function BidirectionalSequencer:accGradParameters(input, gradOutput, scale) | |
local reverse_input = {} | |
for i,v in ipairs(input) do | |
reverse_input[#input - i + 1] = v | |
end | |
local of_ndim = self.modules[1].output[1]:size(2) | |
local ob_ndim = self.modules[2].output[1]:size(2) | |
local forward_gradOutput = {} | |
local backward_gradOutput = {} | |
for i,v in ipairs(gradOutput) do | |
forward_gradOutput[i] = v:narrow(2,1,of_ndim) | |
backward_gradOutput[i] = v:narrow(2,of_ndim+1,of_ndim) | |
end | |
self.modules[1]:accGradParameters(input, forward_gradOutput, scale) | |
self.modules[2]:accGradParameters(reverse_input, backward_gradOutput, scale) | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment