Skip to content

Instantly share code, notes, and snippets.

@lichengunc
Created March 6, 2017 21:31
Show Gist options
  • Save lichengunc/224516d194e1a69da62cd96e2279a509 to your computer and use it in GitHub Desktop.
Save lichengunc/224516d194e1a69da62cd96e2279a509 to your computer and use it in GitHub Desktop.
Beam Search (PPL ranked)
--[[
Implement beam search
]]
function layer:sample_beam(imgs, opt)
local beam_size = utils.getopt(opt, 'beam_size', 10)
local batch_size, feat_dim = imgs:size(1), imgs:size(2)
local function compare(a,b) return a.p > b.p end -- used downstream
local function compare_ppl(a, b) return a.ppl < b.ppl end -- used upstream
assert(beam_size <= self.vocab_size+1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed')
local seq = torch.LongTensor(self.seq_length, batch_size):zero()
local seqLogprobs = torch.FloatTensor(self.seq_length, batch_size)
local Done_beams = {} -- will contain k done_beams
-- lets process every image independently for now, for simplicity
for k=1,batch_size do
-- create initial states for all beams
self:_createInitState(beam_size)
local state = self.init_state
-- we will write output predictions into tensor seq
local imgk = imgs[{ {k,k} }]:expand(beam_size, feat_dim) -- k'th image feature expanded out
local beam_seq = torch.LongTensor(self.seq_length, beam_size):zero()
local beam_seq_logprobs = torch.FloatTensor(self.seq_length, beam_size):zero()
local beam_logprobs_sum = torch.zeros(beam_size) -- running sum of logprobs for each beam
local logprobs -- logprobs predicted in last time step, shape (beam_size, vocab_size+1)
local done_beams = {} -- done_beams for the k-th img
for t=1,self.seq_length+1 do
local xt, it, sampleLogprobs
local new_state
if t == 1 then
-- feed in the start tokens
it = torch.LongTensor(beam_size):fill(self.vocab_size+1)
xt = self.lookup_table:forward(it)
else
--[[
perform a beam merge. that is,
for every previous beam we now many new possibilities to branch out
we need to resort our beams to maintain the loop invariant of keeping
the top beam_size most likely sequences.
]]--
local logprobsf = logprobs:float() -- lets go to CPU for more efficiency in indexing operations
ys,ix = torch.sort(logprobsf,2,true) -- sorted array of logprobs along each previous beam (last true = descending)
local candidates = {}
local cols = math.min(beam_size,ys:size(2))
local rows = beam_size
if t == 2 then rows = 1 end -- at first time step only the first beam is active
for c=1,cols do -- for each column (word, essentially)
for q=1,rows do -- for each beam expansion
-- compute logprob of expanding beam q with word in (sorted) position c
local local_logprob = ys[{ q,c }]
local candidate_logprob = beam_logprobs_sum[q] + local_logprob
table.insert(candidates, {c=ix[{ q,c }], q=q, p=candidate_logprob, r=local_logprob })
end
end
table.sort(candidates, compare) -- find the best c,q pairs
-- construct new beams
new_state = net_utils.clone_list(state)
local beam_seq_prev, beam_seq_logprobs_prev
if t > 2 then
-- well need these as reference when we fork beams around
beam_seq_prev = beam_seq[{ {1,t-2}, {} }]:clone()
beam_seq_logprobs_prev = beam_seq_logprobs[{ {1,t-2}, {} }]:clone()
end
for vix=1,beam_size do
local v = candidates[vix]
-- fork beam index q into index vix
if t > 2 then
beam_seq[{ {1,t-2}, vix }] = beam_seq_prev[{ {}, v.q }]
beam_seq_logprobs[{ {1,t-2}, vix }] = beam_seq_logprobs_prev[{ {}, v.q }]
end
-- rearrange recurrent states
for state_ix = 1,#new_state do
-- copy over state in previous beam q to new beam at vix
new_state[state_ix][vix] = state[state_ix][v.q]
end
-- append new end terminal at the end of this beam
beam_seq[{ t-1, vix }] = v.c -- c'th word is the continuation
beam_seq_logprobs[{ t-1, vix }] = v.r -- the raw logprob here
beam_logprobs_sum[vix] = v.p -- the new (sum) logprob along this beam
if v.c == self.vocab_size+1 or t == self.seq_length+1 then
-- END token special case here, or we reached the end.
-- add the beam to a set of done beams
table.insert(done_beams, {seq = beam_seq[{ {}, vix }]:clone(),
logps = beam_seq_logprobs[{ {}, vix }]:clone(),
p = beam_logprobs_sum[vix],
-- ppl = -beam_logprobs_sum[vix]/(t-1)
ppl = torch.exp(-beam_logprobs_sum[vix]/(t-1)) -- ppl = (p1p2..pn)^(-1/n)
})
-- we won't consider this beam any more
-- otherwise, some beams would remain the top even they arrived <end>
beam_logprobs_sum[vix] = -1000
end
end
-- encode as vectors
it = beam_seq[t-1]
xt = self.lookup_table:forward(it)
end
if new_state then state = new_state end -- swap rnn state, if we reassinged beams
local inputs = {torch.cat(imgk, xt), unpack(state)}
local out = self.core:forward(inputs)
logprobs = out[self.num_state+1] -- last element is the output vector
logprobs[{ {}, self.vocab_size }] = -1e5 -- make UNK very low
state = {}
for i=1,self.num_state do table.insert(state, out[i]) end
end
-- table.sort(done_beams, compare)
table.sort(done_beams, compare_ppl)
seq[{ {}, k }] = done_beams[1].seq -- the first beam has highest cumulative score
seqLogprobs[{ {}, k }] = done_beams[1].logps
-- chunk done_beams to beam_size
for j=1+beam_size,#done_beams do table.remove(done_beams, 1+beam_size) end
table.insert(Done_beams, done_beams)
end
-- return the samples and their log likelihoods
return seq, seqLogprobs, Done_beams
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment