Skip to content

Instantly share code, notes, and snippets.

@bartvm
Created July 11, 2016 22:22
Show Gist options
  • Save bartvm/3a09d9b880b3f7efe69954d9fada16bc to your computer and use it in GitHub Desktop.
Save bartvm/3a09d9b880b3f7efe69954d9fada16bc to your computer and use it in GitHub Desktop.
local dict_size = probs:size(2)
local max_out_arcs = out_arcs:size(2)
local max_in_arcs = in_arcs:size(2)
local state_probs = torch.narrow(params.state_probs, 1, 1, seq_len)
state_probs.value:fill(0)
for i = 2, seq_len do
local starts = torch.ones(max_in_arcs):long() * i
local origins = starts:add(-lengths:index(1, in_arcs[i]))
local indices = ((origins - 1) * dict_size + 1):add(in_arcs[i])
local in_arc_probs = torch.index(probs:view(seq_len * dict_size), 1, indices)
local in_state_probs = torch.index(state_probs, 1, origins)
local contributions = torch.cmul(in_arc_probs + in_state_probs, in_mask[i])
local max_in_prob = torch.max(contributions.value[{{1, torch.sum(in_mask[i])}}])
state_probs[i] = torch.log1p(torch.sum(torch.cmul(torch.exp(contributions - max_in_prob), in_mask[i]))) + max_in_prob
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment