-
-
Save kastnerkyle/0edc9d569009b84f19265878344aa7f9 to your computer and use it in GitHub Desktop.
beam search for Keras RNN
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# variation to https://github.com/ryankiros/skip-thoughts/blob/master/decoding/search.py | |
def keras_rnn_predict(samples, empty=empty, rnn_model=model, maxlen=maxlen): | |
"""for every sample, calculate probability for every possible label | |
you need to supply your RNN model and maxlen - the length of sequences it can handle | |
""" | |
data = sequence.pad_sequences(samples, maxlen=maxlen, value=empty) | |
return rnn_model.predict(data, verbose=0) | |
def beamsearch(predict=keras_rnn_predict, | |
k=1, maxsample=400, use_unk=False, oov=oov, empty=empty, eos=eos): | |
"""return k samples (beams) and their NLL scores, each sample is a sequence of labels, | |
all samples starts with an `empty` label and end with `eos` or truncated to length of `maxsample`. | |
You need to supply `predict` which returns the label probability of each sample. | |
`use_unk` allow usage of `oov` (out-of-vocabulary) label in samples | |
""" | |
dead_k = 0 # samples that reached eos | |
dead_samples = [] | |
dead_scores = [] | |
live_k = 1 # samples that did not yet reached eos | |
live_samples = [[empty]] | |
live_scores = [0] | |
while live_k and dead_k < k: | |
# for every possible live sample calc prob for every possible label | |
probs = predict(live_samples, empty=empty) | |
# total score for every sample is sum of -log of word prb | |
cand_scores = np.array(live_scores)[:,None] - np.log(probs) | |
if not use_unk and oov is not None: | |
cand_scores[:,oov] = 1e20 | |
cand_flat = cand_scores.flatten() | |
# find the best (lowest) scores we have from all possible samples and new words | |
ranks_flat = cand_flat.argsort()[:(k-dead_k)] | |
live_scores = cand_flat[ranks_flat] | |
# append the new words to their appropriate live sample | |
voc_size = probs.shape[1] | |
live_samples = [live_samples[r//voc_size]+[r%voc_size] for r in ranks_flat] | |
# live samples that should be dead are... | |
zombie = [s[-1] == eos or len(s) >= maxsample for s in live_samples] | |
# add zombies to the dead | |
dead_samples += [s for s,z in zip(live_samples,zombie) if z] # remove first label == empty | |
dead_scores += [s for s,z in zip(live_scores,zombie) if z] | |
dead_k = len(dead_samples) | |
# remove zombies from the living | |
live_samples = [s for s,z in zip(live_samples,zombie) if not z] | |
live_scores = [s for s,z in zip(live_scores,zombie) if not z] | |
live_k = len(live_samples) | |
return dead_samples + live_samples, dead_scores + live_scores |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment