Skip to content

Instantly share code, notes, and snippets.

@JoostvDoorn
Created September 13, 2016 10:02
Show Gist options
  • Save JoostvDoorn/bd43ce85a101c0aeee03f734ead688fb to your computer and use it in GitHub Desktop.
Save JoostvDoorn/bd43ce85a101c0aeee03f734ead688fb to your computer and use it in GitHub Desktop.
Conversion between SeqLSTM and cudnn LSTM
--
-- Author: Joost van Doorn <[email protected]>
--
require 'rnn'
require 'cudnn'
function toCudnnLSTM(seqLSTM)
local rnn = cudnn.LSTM(seqLSTM.inputsize, seqLSTM.outputsize, 1)
local H, R, D = seqLSTM.outputsize, seqLSTM.outputsize, seqLSTM.inputsize
local biases = rnn:biases()
local weights = rnn:weights()
assert(#biases == 1, "Conversion only supported for 1 layer LSTM")
-- Note cudnn uses twice as many bias parameters so we set one of them to zero
biases[1][1]:copy(seqLSTM.bias[{{1, H}}]) -- Input gate
biases[1][2]:copy(seqLSTM.bias[{{H + 1, 2 * H}}]) -- Forget gate
biases[1][4]:copy(seqLSTM.bias[{{2 * H + 1, 3 * H}}]) -- Output gate
biases[1][3]:copy(seqLSTM.bias[{{3 * H + 1, 4 * H}}]) -- Memory gate
biases[1][5]:fill(0)
biases[1][6]:fill(0)
biases[1][7]:fill(0)
biases[1][8]:fill(0)
local Wx = seqLSTM.weight:narrow(1,1,D)
local Wh = seqLSTM.weight:narrow(1,D+1,R)
-- 1, 5 input gates
-- 2, 6 forget gates
-- 3, 7 memory gates
-- 4, 8 output gates
weights[1][1]:copy(Wx[{{}, {1, H}}]:t()) -- Input gate
weights[1][2]:copy(Wx[{{}, {H + 1, 2 * H}}]:t()) -- Forget gate
weights[1][4]:copy(Wx[{{}, {2 * H + 1, 3 * H}}]:t()) -- Output gate
weights[1][3]:copy(Wx[{{}, {3 * H + 1, 4 * H}}]:t()) -- Memory gate
weights[1][5]:copy(Wh[{{}, {1, H}}]:t()) -- Input gate
weights[1][6]:copy(Wh[{{}, {H + 1, 2 * H}}]:t()) -- Forget gate
weights[1][8]:copy(Wh[{{}, {2 * H + 1, 3 * H}}]:t()) -- Output gate
weights[1][7]:copy(Wh[{{}, {3 * H + 1, 4 * H}}]:t()) -- Memory gate
return rnn
end
function toSeqLSTM(rnn)
local seqLSTM = nn.SeqLSTM(rnn.inputSize, rnn.hiddenSize)
seqLSTM:cuda()
local H, R, D = seqLSTM.outputsize, seqLSTM.outputsize, seqLSTM.inputsize
local biases = rnn:biases()
local weights = rnn:weights()
assert(#biases == 1, "Conversion only supported for 1 layer LSTM")
-- Note cudnn uses twice as many bias parameters so we have to merge them into one parameter
seqLSTM.bias[{{1, H}}]:copy(biases[1][1]+biases[1][5]) -- Input gate
seqLSTM.bias[{{H + 1, 2 * H}}]:copy(biases[1][2]+biases[1][6]) -- Forget gate
seqLSTM.bias[{{2 * H + 1, 3 * H}}]:copy(biases[1][4]+biases[1][8]) -- Output gate
seqLSTM.bias[{{3 * H + 1, 4 * H}}]:copy(biases[1][3]+biases[1][7]) -- Memory gate
local Wx = seqLSTM.weight:narrow(1,1,D)
local Wh = seqLSTM.weight:narrow(1,D+1,R)
-- 1, 5 input gates
-- 2, 6 forget gates
-- 3, 7 memory gates
-- 4, 8 output gates
Wx[{{}, {1, H}}]:t():copy(weights[1][1]) -- Input gate
Wx[{{}, {H + 1, 2 * H}}]:t():copy(weights[1][2]) -- Forget gate
Wx[{{}, {2 * H + 1, 3 * H}}]:t():copy(weights[1][4]) -- Output gate
Wx[{{}, {3 * H + 1, 4 * H}}]:t():copy(weights[1][3]) -- Memory gate
Wh[{{}, {1, H}}]:t():copy(weights[1][5]) -- Input gate
Wh[{{}, {H + 1, 2 * H}}]:t():copy(weights[1][6]) -- Forget gate
Wh[{{}, {2 * H + 1, 3 * H}}]:t():copy(weights[1][8]) -- Output gate
Wh[{{}, {3 * H + 1, 4 * H}}]:t():copy(weights[1][7]) -- Memory gate
return seqLSTM
end
local rnn = nn.SeqLSTM(256, 256)
rnn:cuda()
local rnn2 = toCudnnLSTM(rnn)
local input = torch.randn(2, 1, 256):cuda()
print(torch.mean(rnn:forward(input)))
print(torch.mean(rnn2:forward(input)))
print(torch.mean(rnn:forward(input)-rnn2:forward(input)))
local rnn2 = cudnn.LSTM(256, 256, 1)
local rnn = toSeqLSTM(rnn2)
print(torch.mean(rnn:forward(input)))
print(torch.mean(rnn2:forward(input)))
print(torch.mean(rnn:forward(input)-rnn2:forward(input)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment