Last active
October 4, 2021 11:50
-
-
Save udibr/67be473cf053d8c38730 to your computer and use it in GitHub Desktop.
beam search for Keras RNN
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# variation to https://github.com/ryankiros/skip-thoughts/blob/master/decoding/search.py | |
def keras_rnn_predict(samples, empty=empty, rnn_model=model, maxlen=maxlen): | |
"""for every sample, calculate probability for every possible label | |
you need to supply your RNN model and maxlen - the length of sequences it can handle | |
""" | |
data = sequence.pad_sequences(samples, maxlen=maxlen, value=empty) | |
return rnn_model.predict(data, verbose=0) | |
def beamsearch(predict=keras_rnn_predict, | |
k=1, maxsample=400, use_unk=False, oov=oov, empty=empty, eos=eos): | |
"""return k samples (beams) and their NLL scores, each sample is a sequence of labels, | |
all samples starts with an `empty` label and end with `eos` or truncated to length of `maxsample`. | |
You need to supply `predict` which returns the label probability of each sample. | |
`use_unk` allow usage of `oov` (out-of-vocabulary) label in samples | |
""" | |
dead_k = 0 # samples that reached eos | |
dead_samples = [] | |
dead_scores = [] | |
live_k = 1 # samples that did not yet reached eos | |
live_samples = [[empty]] | |
live_scores = [0] | |
while live_k and dead_k < k: | |
# for every possible live sample calc prob for every possible label | |
probs = predict(live_samples, empty=empty) | |
# total score for every sample is sum of -log of word prb | |
cand_scores = np.array(live_scores)[:,None] - np.log(probs) | |
if not use_unk and oov is not None: | |
cand_scores[:,oov] = 1e20 | |
cand_flat = cand_scores.flatten() | |
# find the best (lowest) scores we have from all possible samples and new words | |
ranks_flat = cand_flat.argsort()[:(k-dead_k)] | |
live_scores = cand_flat[ranks_flat] | |
# append the new words to their appropriate live sample | |
voc_size = probs.shape[1] | |
live_samples = [live_samples[r//voc_size]+[r%voc_size] for r in ranks_flat] | |
# live samples that should be dead are... | |
zombie = [s[-1] == eos or len(s) >= maxsample for s in live_samples] | |
# add zombies to the dead | |
dead_samples += [s for s,z in zip(live_samples,zombie) if z] # remove first label == empty | |
dead_scores += [s for s,z in zip(live_scores,zombie) if z] | |
dead_k = len(dead_samples) | |
# remove zombies from the living | |
live_samples = [s for s,z in zip(live_samples,zombie) if not z] | |
live_scores = [s for s,z in zip(live_scores,zombie) if not z] | |
live_k = len(live_samples) | |
return dead_samples + live_samples, dead_scores + live_scores |
@jeetp465 is your issue resolved ?
@jeetp465, @andersonzhu, According to my understanding, Beam search is not the part of model definition. It is The way how we decode the output of LSTM(RNN). So, BeamSearch is used where you are generating words from the predicted output received from LSTM. So The output of full_model.predict() will be passed to BeamSearch to get output captions.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Interesting naming conventions