Skip to content

Instantly share code, notes, and snippets.

@bartvm
Created July 11, 2016 14:48
Show Gist options
  • Save bartvm/ec18e22015db72f8a95692ddecefc0cf to your computer and use it in GitHub Desktop.
Save bartvm/ec18e22015db72f8a95692ddecefc0cf to your computer and use it in GitHub Desktop.
local function nll(params, probs, out_arcs, out_mask, lengths)
local seq_len = probs:size(1)
local max_out_arcs = out_arcs:size(2)
local state_nll = {0}
for i = 1, seq_len do
for j = 1, max_out_arcs do
if out_mask[{i, j}] ~= 1 then
break
end
local target = i + lengths[out_arcs[{i, j}]]
local contribution = state_nll[i] + probs[{i, j}]
if not state_nll[target] then
state_nll[target] = contribution
else
-- these are autograd nodes, so extract the value for comparison
local comp = state_nll[target].value > contribution.value
local x = comp and state_nll[target] or contribution
local y = (not comp) and state_nll[target] or contribution
state_nll[target] = x + torch.log1p(torch.exp(y - x))
end
end
end
assert(#state_nll == seq_len + 1)
return state_nll[seq_len + 1]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment