Created
September 13, 2016 10:02
-
-
Save JoostvDoorn/bd43ce85a101c0aeee03f734ead688fb to your computer and use it in GitHub Desktop.
Conversion between SeqLSTM and cudnn LSTM
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- | |
-- Author: Joost van Doorn <[email protected]> | |
-- | |
require 'rnn' | |
require 'cudnn' | |
function toCudnnLSTM(seqLSTM) | |
local rnn = cudnn.LSTM(seqLSTM.inputsize, seqLSTM.outputsize, 1) | |
local H, R, D = seqLSTM.outputsize, seqLSTM.outputsize, seqLSTM.inputsize | |
local biases = rnn:biases() | |
local weights = rnn:weights() | |
assert(#biases == 1, "Conversion only supported for 1 layer LSTM") | |
-- Note cudnn uses twice as many bias parameters so we set one of them to zero | |
biases[1][1]:copy(seqLSTM.bias[{{1, H}}]) -- Input gate | |
biases[1][2]:copy(seqLSTM.bias[{{H + 1, 2 * H}}]) -- Forget gate | |
biases[1][4]:copy(seqLSTM.bias[{{2 * H + 1, 3 * H}}]) -- Output gate | |
biases[1][3]:copy(seqLSTM.bias[{{3 * H + 1, 4 * H}}]) -- Memory gate | |
biases[1][5]:fill(0) | |
biases[1][6]:fill(0) | |
biases[1][7]:fill(0) | |
biases[1][8]:fill(0) | |
local Wx = seqLSTM.weight:narrow(1,1,D) | |
local Wh = seqLSTM.weight:narrow(1,D+1,R) | |
-- 1, 5 input gates | |
-- 2, 6 forget gates | |
-- 3, 7 memory gates | |
-- 4, 8 output gates | |
weights[1][1]:copy(Wx[{{}, {1, H}}]:t()) -- Input gate | |
weights[1][2]:copy(Wx[{{}, {H + 1, 2 * H}}]:t()) -- Forget gate | |
weights[1][4]:copy(Wx[{{}, {2 * H + 1, 3 * H}}]:t()) -- Output gate | |
weights[1][3]:copy(Wx[{{}, {3 * H + 1, 4 * H}}]:t()) -- Memory gate | |
weights[1][5]:copy(Wh[{{}, {1, H}}]:t()) -- Input gate | |
weights[1][6]:copy(Wh[{{}, {H + 1, 2 * H}}]:t()) -- Forget gate | |
weights[1][8]:copy(Wh[{{}, {2 * H + 1, 3 * H}}]:t()) -- Output gate | |
weights[1][7]:copy(Wh[{{}, {3 * H + 1, 4 * H}}]:t()) -- Memory gate | |
return rnn | |
end | |
function toSeqLSTM(rnn) | |
local seqLSTM = nn.SeqLSTM(rnn.inputSize, rnn.hiddenSize) | |
seqLSTM:cuda() | |
local H, R, D = seqLSTM.outputsize, seqLSTM.outputsize, seqLSTM.inputsize | |
local biases = rnn:biases() | |
local weights = rnn:weights() | |
assert(#biases == 1, "Conversion only supported for 1 layer LSTM") | |
-- Note cudnn uses twice as many bias parameters so we have to merge them into one parameter | |
seqLSTM.bias[{{1, H}}]:copy(biases[1][1]+biases[1][5]) -- Input gate | |
seqLSTM.bias[{{H + 1, 2 * H}}]:copy(biases[1][2]+biases[1][6]) -- Forget gate | |
seqLSTM.bias[{{2 * H + 1, 3 * H}}]:copy(biases[1][4]+biases[1][8]) -- Output gate | |
seqLSTM.bias[{{3 * H + 1, 4 * H}}]:copy(biases[1][3]+biases[1][7]) -- Memory gate | |
local Wx = seqLSTM.weight:narrow(1,1,D) | |
local Wh = seqLSTM.weight:narrow(1,D+1,R) | |
-- 1, 5 input gates | |
-- 2, 6 forget gates | |
-- 3, 7 memory gates | |
-- 4, 8 output gates | |
Wx[{{}, {1, H}}]:t():copy(weights[1][1]) -- Input gate | |
Wx[{{}, {H + 1, 2 * H}}]:t():copy(weights[1][2]) -- Forget gate | |
Wx[{{}, {2 * H + 1, 3 * H}}]:t():copy(weights[1][4]) -- Output gate | |
Wx[{{}, {3 * H + 1, 4 * H}}]:t():copy(weights[1][3]) -- Memory gate | |
Wh[{{}, {1, H}}]:t():copy(weights[1][5]) -- Input gate | |
Wh[{{}, {H + 1, 2 * H}}]:t():copy(weights[1][6]) -- Forget gate | |
Wh[{{}, {2 * H + 1, 3 * H}}]:t():copy(weights[1][8]) -- Output gate | |
Wh[{{}, {3 * H + 1, 4 * H}}]:t():copy(weights[1][7]) -- Memory gate | |
return seqLSTM | |
end | |
local rnn = nn.SeqLSTM(256, 256) | |
rnn:cuda() | |
local rnn2 = toCudnnLSTM(rnn) | |
local input = torch.randn(2, 1, 256):cuda() | |
print(torch.mean(rnn:forward(input))) | |
print(torch.mean(rnn2:forward(input))) | |
print(torch.mean(rnn:forward(input)-rnn2:forward(input))) | |
local rnn2 = cudnn.LSTM(256, 256, 1) | |
local rnn = toSeqLSTM(rnn2) | |
print(torch.mean(rnn:forward(input))) | |
print(torch.mean(rnn2:forward(input))) | |
print(torch.mean(rnn:forward(input)-rnn2:forward(input))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment