Created
March 6, 2017 21:31
-
-
Save lichengunc/224516d194e1a69da62cd96e2279a509 to your computer and use it in GitHub Desktop.
Beam Search (PPL ranked)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
--[[ | |
Implement beam search | |
]] | |
function layer:sample_beam(imgs, opt) | |
local beam_size = utils.getopt(opt, 'beam_size', 10) | |
local batch_size, feat_dim = imgs:size(1), imgs:size(2) | |
local function compare(a,b) return a.p > b.p end -- used downstream | |
local function compare_ppl(a, b) return a.ppl < b.ppl end -- used upstream | |
assert(beam_size <= self.vocab_size+1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed') | |
local seq = torch.LongTensor(self.seq_length, batch_size):zero() | |
local seqLogprobs = torch.FloatTensor(self.seq_length, batch_size) | |
local Done_beams = {} -- will contain k done_beams | |
-- lets process every image independently for now, for simplicity | |
for k=1,batch_size do | |
-- create initial states for all beams | |
self:_createInitState(beam_size) | |
local state = self.init_state | |
-- we will write output predictions into tensor seq | |
local imgk = imgs[{ {k,k} }]:expand(beam_size, feat_dim) -- k'th image feature expanded out | |
local beam_seq = torch.LongTensor(self.seq_length, beam_size):zero() | |
local beam_seq_logprobs = torch.FloatTensor(self.seq_length, beam_size):zero() | |
local beam_logprobs_sum = torch.zeros(beam_size) -- running sum of logprobs for each beam | |
local logprobs -- logprobs predicted in last time step, shape (beam_size, vocab_size+1) | |
local done_beams = {} -- done_beams for the k-th img | |
for t=1,self.seq_length+1 do | |
local xt, it, sampleLogprobs | |
local new_state | |
if t == 1 then | |
-- feed in the start tokens | |
it = torch.LongTensor(beam_size):fill(self.vocab_size+1) | |
xt = self.lookup_table:forward(it) | |
else | |
--[[ | |
perform a beam merge. that is, | |
for every previous beam we now many new possibilities to branch out | |
we need to resort our beams to maintain the loop invariant of keeping | |
the top beam_size most likely sequences. | |
]]-- | |
local logprobsf = logprobs:float() -- lets go to CPU for more efficiency in indexing operations | |
ys,ix = torch.sort(logprobsf,2,true) -- sorted array of logprobs along each previous beam (last true = descending) | |
local candidates = {} | |
local cols = math.min(beam_size,ys:size(2)) | |
local rows = beam_size | |
if t == 2 then rows = 1 end -- at first time step only the first beam is active | |
for c=1,cols do -- for each column (word, essentially) | |
for q=1,rows do -- for each beam expansion | |
-- compute logprob of expanding beam q with word in (sorted) position c | |
local local_logprob = ys[{ q,c }] | |
local candidate_logprob = beam_logprobs_sum[q] + local_logprob | |
table.insert(candidates, {c=ix[{ q,c }], q=q, p=candidate_logprob, r=local_logprob }) | |
end | |
end | |
table.sort(candidates, compare) -- find the best c,q pairs | |
-- construct new beams | |
new_state = net_utils.clone_list(state) | |
local beam_seq_prev, beam_seq_logprobs_prev | |
if t > 2 then | |
-- well need these as reference when we fork beams around | |
beam_seq_prev = beam_seq[{ {1,t-2}, {} }]:clone() | |
beam_seq_logprobs_prev = beam_seq_logprobs[{ {1,t-2}, {} }]:clone() | |
end | |
for vix=1,beam_size do | |
local v = candidates[vix] | |
-- fork beam index q into index vix | |
if t > 2 then | |
beam_seq[{ {1,t-2}, vix }] = beam_seq_prev[{ {}, v.q }] | |
beam_seq_logprobs[{ {1,t-2}, vix }] = beam_seq_logprobs_prev[{ {}, v.q }] | |
end | |
-- rearrange recurrent states | |
for state_ix = 1,#new_state do | |
-- copy over state in previous beam q to new beam at vix | |
new_state[state_ix][vix] = state[state_ix][v.q] | |
end | |
-- append new end terminal at the end of this beam | |
beam_seq[{ t-1, vix }] = v.c -- c'th word is the continuation | |
beam_seq_logprobs[{ t-1, vix }] = v.r -- the raw logprob here | |
beam_logprobs_sum[vix] = v.p -- the new (sum) logprob along this beam | |
if v.c == self.vocab_size+1 or t == self.seq_length+1 then | |
-- END token special case here, or we reached the end. | |
-- add the beam to a set of done beams | |
table.insert(done_beams, {seq = beam_seq[{ {}, vix }]:clone(), | |
logps = beam_seq_logprobs[{ {}, vix }]:clone(), | |
p = beam_logprobs_sum[vix], | |
-- ppl = -beam_logprobs_sum[vix]/(t-1) | |
ppl = torch.exp(-beam_logprobs_sum[vix]/(t-1)) -- ppl = (p1p2..pn)^(-1/n) | |
}) | |
-- we won't consider this beam any more | |
-- otherwise, some beams would remain the top even they arrived <end> | |
beam_logprobs_sum[vix] = -1000 | |
end | |
end | |
-- encode as vectors | |
it = beam_seq[t-1] | |
xt = self.lookup_table:forward(it) | |
end | |
if new_state then state = new_state end -- swap rnn state, if we reassinged beams | |
local inputs = {torch.cat(imgk, xt), unpack(state)} | |
local out = self.core:forward(inputs) | |
logprobs = out[self.num_state+1] -- last element is the output vector | |
logprobs[{ {}, self.vocab_size }] = -1e5 -- make UNK very low | |
state = {} | |
for i=1,self.num_state do table.insert(state, out[i]) end | |
end | |
-- table.sort(done_beams, compare) | |
table.sort(done_beams, compare_ppl) | |
seq[{ {}, k }] = done_beams[1].seq -- the first beam has highest cumulative score | |
seqLogprobs[{ {}, k }] = done_beams[1].logps | |
-- chunk done_beams to beam_size | |
for j=1+beam_size,#done_beams do table.remove(done_beams, 1+beam_size) end | |
table.insert(Done_beams, done_beams) | |
end | |
-- return the samples and their log likelihoods | |
return seq, seqLogprobs, Done_beams | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment