Created
July 11, 2016 14:48
-
-
Save bartvm/ec18e22015db72f8a95692ddecefc0cf to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
local function nll(params, probs, out_arcs, out_mask, lengths) | |
local seq_len = probs:size(1) | |
local max_out_arcs = out_arcs:size(2) | |
local state_nll = {0} | |
for i = 1, seq_len do | |
for j = 1, max_out_arcs do | |
if out_mask[{i, j}] ~= 1 then | |
break | |
end | |
local target = i + lengths[out_arcs[{i, j}]] | |
local contribution = state_nll[i] + probs[{i, j}] | |
if not state_nll[target] then | |
state_nll[target] = contribution | |
else | |
-- these are autograd nodes, so extract the value for comparison | |
local comp = state_nll[target].value > contribution.value | |
local x = comp and state_nll[target] or contribution | |
local y = (not comp) and state_nll[target] or contribution | |
state_nll[target] = x + torch.log1p(torch.exp(y - x)) | |
end | |
end | |
end | |
assert(#state_nll == seq_len + 1) | |
return state_nll[seq_len + 1] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment