Last active
November 18, 2017 16:58
-
-
Save chaonan99/766341e72c63763e028eab9428587f24 to your computer and use it in GitHub Desktop.
That's a torch implementation of LSTM module with attention mechanism base on Karpathy's implementation in NeuralTalk2
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- from https://github.com/karpathy/neuraltalk2/blob/master/misc/gradcheck.lua | |
-- by Andrej Karpathy | |
local gradcheck = {} | |
function gradcheck.relative_error(x, y, h) | |
h = h or 1e-12 | |
if torch.isTensor(x) and torch.isTensor(y) then | |
local top = torch.abs(x - y) | |
local bottom = torch.cmax(torch.abs(x) + torch.abs(y), h) | |
return torch.max(torch.cdiv(top, bottom)) | |
else | |
return math.abs(x - y) / math.max(math.abs(x) + math.abs(y), h) | |
end | |
end | |
function gradcheck.numeric_gradient(f, x, df, eps) | |
df = df or 1.0 | |
eps = eps or 1e-8 | |
local n = x:nElement() | |
local x_flat = x:view(n) | |
local dx_num = x.new(#x):zero() | |
local dx_num_flat = dx_num:view(n) | |
for i = 1, n do | |
local orig = x_flat[i] | |
x_flat[i] = orig + eps | |
local pos = f(x) | |
if torch.isTensor(df) then | |
pos = pos:clone() | |
end | |
x_flat[i] = orig - eps | |
local neg = f(x) | |
if torch.isTensor(df) then | |
neg = neg:clone() | |
end | |
local d = nil | |
if torch.isTensor(df) then | |
d = torch.dot(pos - neg, df) / (2 * eps) | |
else | |
d = df * (pos - neg) / (2 * eps) | |
end | |
dx_num_flat[i] = d | |
x_flat[i] = orig | |
end | |
return dx_num | |
end | |
--[[ | |
Inputs: | |
- f is a function that takes a tensor and returns a scalar | |
- x is the point at which to evalute f | |
- dx is the analytic gradient of f at x | |
--]] | |
function gradcheck.check_random_dims(f, x, dx, eps, num_iterations, verbose) | |
if verbose == nil then verbose = false end | |
eps = eps or 1e-4 | |
local x_flat = x:view(-1) | |
local dx_flat = dx:view(-1) | |
local relative_errors = torch.Tensor(num_iterations) | |
for t = 1, num_iterations do | |
-- Make sure the index is really random. | |
-- We have to call this on the inner loop because some functions | |
-- f may be stochastic, and eliminating their internal randomness for | |
-- gradient checking by setting a manual seed. If this is the case, | |
-- then we will always sample the same index unless we reseed on each | |
-- iteration. | |
torch.seed() | |
local i = torch.random(x:nElement()) | |
local orig = x_flat[i] | |
x_flat[i] = orig + eps | |
local pos = f(x) | |
x_flat[i] = orig - eps | |
local neg = f(x) | |
local d_numeric = (pos - neg) / (2 * eps) | |
local d_analytic = dx_flat[i] | |
x_flat[i] = orig | |
local rel_error = gradcheck.relative_error(d_numeric, d_analytic) | |
relative_errors[t] = rel_error | |
if verbose then | |
print(string.format(' Iteration %d / %d, error = %f', | |
t, num_iterations, rel_error)) | |
print(string.format(' %f %f', d_numeric, d_analytic)) | |
end | |
end | |
return relative_errors | |
end | |
return gradcheck | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'nn' | |
require 'nngraph' | |
local LSTM = {} | |
function LSTM.lstm(input_size, output_size, rnn_size, n, frame_per_video, seq_per_video, dropout) | |
dropout = dropout or 0 | |
-- TODO: No idea how to choose attention size... | |
att_size = att_size or 512 | |
-- there will be 2*n+1 inputs | |
local inputs = {} | |
table.insert(inputs, nn.Identity()()) -- indices giving the sequence of symbols | |
for L = 1,n do | |
table.insert(inputs, nn.Identity()()) -- prev_c[L] | |
table.insert(inputs, nn.Identity()()) -- prev_h[L] | |
end | |
table.insert(inputs, nn.Identity()()) -- img | |
local x, input_size_L, feats | |
local outputs = {} | |
for L = 1,n do | |
-- c,h from previos timesteps | |
local prev_h = inputs[L*2+1] | |
local prev_c = inputs[L*2] | |
-- the input to this layer | |
if L == 1 then | |
x = inputs[1] | |
input_size_L = input_size | |
-- Attention model -- | |
-- TODO: more elegent way to do this? | |
-- TODO: Share parameter or not? | |
feats = inputs[(n+1)*2] | |
local e = {} | |
local featTable = {nn.SplitTable(1)(feats):split(frame_per_video)} | |
local hxW = nn.Linear(rnn_size,att_size)(prev_h) | |
local Ua = nn.Linear(input_size,att_size) | |
local va = nn.Linear(att_size,1) | |
-- assert(#featTable == frame_per_video, 'Currently only support batch_size = 1') | |
for i = 1, frame_per_video do | |
local va_t = va:clone('weight', 'bias', 'gradWeight', 'gradBias') | |
local Ua_t = Ua:clone('weight', 'bias', 'gradWeight', 'gradBias') | |
e[i] = va_t(nn.Tanh()(nn.CAddTable()({hxW,nn.Replicate(seq_per_video)(Ua_t(featTable[i]))}))) | |
end | |
z = nn.MM()({nn.SoftMax()(nn.JoinTable(2)(e)),feats}) | |
-- Attention model end -- | |
else | |
x = outputs[(L-1)*2] | |
if dropout > 0 then x = nn.Dropout(dropout)(x):annotate{name='drop_' .. L} end -- apply dropout, if any | |
input_size_L = rnn_size | |
end | |
-- evaluate the input sums at once for efficiency | |
local i2h = nn.Linear(input_size_L, 4 * rnn_size)(x):annotate{name='i2h_'..L} | |
local h2h = nn.Linear(rnn_size, 4 * rnn_size)(prev_h):annotate{name='h2h_'..L} | |
local all_input_sums, z2h | |
if L == 1 then | |
z2h = nn.Linear(input_size, 4 * rnn_size)(z):annotate{name='z2h_'..L} | |
all_input_sums = nn.CAddTable()({i2h, h2h, z2h}) | |
else | |
all_input_sums = nn.CAddTable()({i2h, h2h}) | |
end | |
local reshaped = nn.Reshape(4, rnn_size)(all_input_sums) | |
local n1, n2, n3, n4 = nn.SplitTable(2)(reshaped):split(4) | |
-- decode the gates | |
local in_gate = nn.Sigmoid()(n1) | |
local forget_gate = nn.Sigmoid()(n2) | |
local out_gate = nn.Sigmoid()(n3) | |
-- decode the write inputs | |
local in_transform = nn.Tanh()(n4) | |
-- perform the LSTM update | |
local next_c = nn.CAddTable()({ | |
nn.CMulTable()({forget_gate, prev_c}), | |
nn.CMulTable()({in_gate, in_transform}) | |
}) | |
-- gated cells form the output | |
local next_h = nn.CMulTable()({out_gate, nn.Tanh()(next_c)}) | |
table.insert(outputs, next_c) | |
table.insert(outputs, next_h) | |
end | |
-- set up the decoder | |
local top_h = outputs[#outputs] | |
if dropout > 0 then top_h = nn.Dropout(dropout)(top_h):annotate{name='drop_final'} end | |
local proj = nn.Linear(rnn_size, output_size)(top_h):annotate{name='decoder'} | |
local logsoft = nn.LogSoftMax()(proj) | |
table.insert(outputs, logsoft) | |
return nn.gModule(inputs, outputs) | |
end | |
return LSTM |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
--[ | |
-- Test the attention calculation module | |
-- [Author] Haonan Chen | |
-- [Date] 08/16/2016 | |
-- [Contact] [email protected] | |
--] | |
require 'nn' | |
require 'nngraph' | |
local gradcheck = require 'misc.gradcheck' | |
local tests = {} | |
local tester = torch.Tester() | |
seq_per_video = 6 | |
att_size = 15 | |
rnn_size = 15 | |
input_size = 15 | |
frame_per_video = 5 | |
local function build_attention() | |
feats = nn.Identity()() | |
prev_h = nn.Identity()() | |
local e = {} | |
local featTable = {nn.SplitTable(1)(feats):split(frame_per_video)} | |
local nh = nn.Linear(rnn_size,att_size)(prev_h) -- Calculate this outside the loop to save time | |
local ua = nn.Linear(input_size,att_size) -- Share parameters | |
local va = nn.Linear(att_size,1) | |
for i = 1, #featTable do | |
-- This will share parameters between different clones | |
local ua_t = ua:clone('weight', 'bias', 'gradWeight', 'gradBias') | |
local va_t = va:clone('weight', 'bias', 'gradWeight', 'gradBias') | |
e[i] = va_t(nn.Tanh()(nn.CAddTable()({nh,nn.Replicate(seq_per_video)(ua_t(featTable[i]))}))) | |
end | |
jo = nn.SoftMax()(nn.JoinTable(2)(e)) | |
ca = nn.MM()({jo,feats}) | |
m = nn.gModule({feats,prev_h},{ca}) | |
return m | |
end | |
local function gradCheck() | |
local in_feat = torch.rand(frame_per_video, input_size) | |
local in_prev_h = torch.rand(seq_per_video, rnn_size) | |
m = build_attention() | |
res = m:forward({in_feat, in_prev_h}) | |
local w = torch.randn(res:size(1), res:size(2)) | |
local loss = torch.sum(torch.cmul(w,res)) | |
local gradOutput = w | |
local gradInput, dummy = unpack(m:backward({in_feat,in_prev_h},gradOutput)) | |
local function f(x) | |
local output = m:forward({x, in_prev_h}) | |
local loss = torch.sum(torch.cmul(w,output)) | |
return loss | |
end | |
local gradInput_num = gradcheck.numeric_gradient(f, in_feat, 1, 1e-6) | |
print(gradInput) | |
print(gradInput_num) | |
local g = gradInput:view(-1) | |
local gn = gradInput_num:view(-1) | |
for i=1,g:nElement() do | |
local r = gradcheck.relative_error(g[i],gn[i]) | |
print(i, r) | |
end | |
tester:assertTensorEq(gradInput, gradInput_num, 1e-6) | |
tester:assertlt(gradcheck.relative_error(gradInput, gradInput_num, 1e-8), 5e-4) | |
end | |
tests.gradCheck = gradCheck | |
tester:add(tests) | |
tester:run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
--[ | |
-- This file shows how to use the LSTMAttention module (run a forward pass) | |
-- [Author] Haonan Chen | |
-- [Date] 08/02/2016 | |
-- [Contact] [email protected] | |
--] | |
require 'nn' | |
-- require 'nngraph' | |
local LSTM = require 'misc.LSTMAttention' | |
seq_per_video = 20 | |
att_size = 512 | |
rnn_size = 512 | |
input_size = 512 | |
frame_per_video = 15 | |
vocab_size = 80 | |
num_layers = 1 | |
dropout = 0 | |
seq_length = 7 | |
num_state = 2*num_layers | |
model = LSTM.lstm(input_size, vocab_size + 1, rnn_size, num_layers, frame_per_video, seq_per_video, dropout) | |
-- Clone LSTM cell | |
clones = {model} | |
for t = 2, seq_length + 2 do | |
clones[t] = model:clone('weight', 'bias', 'gradWeight', 'gradBias') | |
end | |
-- initialize | |
function createInitState(batch_size) | |
assert(batch_size ~= nil, 'batch size must be provided') | |
-- construct the initial state for the LSTM | |
if not init_state then init_state = {} end -- lazy init | |
for h=1,num_layers*2 do | |
-- note, the init state Must be zeros because we are using init_state to init grads in backward call too | |
if init_state[h] then | |
if init_state[h]:size(1) ~= batch_size then | |
init_state[h]:resize(batch_size, rnn_size):zero() -- expand the memory | |
end | |
else | |
init_state[h] = torch.zeros(batch_size, rnn_size) | |
end | |
end | |
return init_state | |
end | |
state = {[0] = createInitState(seq_per_video)} -- For the convenience to apply state[t-1] | |
imgs = torch.rand(frame_per_video, input_size) -- Doesn't change | |
inputs = {} | |
output = {} -- size: seq_length+2, seq_per_image, vocab_size+1 | |
-- Forward the network | |
for t = 1, seq_length + 2 do | |
if t == 1 then | |
inputs[t] = {torch.rand(seq_per_video, input_size), unpack(state[t-1])} | |
table.insert(inputs[t],imgs) | |
print(inputs[t]) | |
else | |
lookup_table_out = torch.rand(seq_per_video, input_size) | |
inputs[t] = {lookup_table_out, unpack(state[t-1])} | |
table.insert(inputs[t],imgs) | |
print(inputs[t]) | |
end | |
local out = clones[t]:forward(inputs[t]) | |
output[t] = out[num_state+1] | |
state[t] = {} | |
for i=1,num_state do table.insert(state[t],out[i]) end | |
end | |
print(output) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment