Last active
May 30, 2019 14:40
-
-
Save pskrunner14/e43aa89adacd5bfde448e1ebb367ea2f to your computer and use it in GitHub Desktop.
Beam Search example script from scratch in NumPy
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import numpy as np | |
def topk(a, k): | |
# idx = a.argsort()[-k:][::-1] # slower (runs in O(n logn) time) | |
idx = np.argpartition(a, -k)[-k:] # faster (runs in linear O(n) time) | |
return a[idx], idx | |
def decode(n): | |
return np.random.rand(n) | |
def beam_search(n, k, max_len): | |
"""Performs beam search decoding on sequences | |
having an arbitrary probability distribution. | |
Reference: http://web.stanford.edu/class/cs224n/slides/cs224n-2019-lecture08-nmt.pdf | |
Args: | |
----- | |
n: the total size of the vocabulary of tokens. | |
k: the beam size. | |
max_len: maximum allowed length of the sequence. | |
Returns: | |
-------- | |
tuple(np.ndarray, float): best pair of (sequence, prob). | |
""" | |
# keep track of topk hypoths and their scores at all times | |
scores = np.array([.0]) | |
hypotheses = np.array([[0]]) | |
for i in range(max_len): | |
# for keeping track of local paths | |
local_hypoth = np.empty((0, hypotheses.shape[1] + 1)) | |
local_scores = np.empty((0, )) | |
# expand each hypoth into candidate hypoths | |
for j in range(len(hypotheses)): | |
# cache previous hypoths and score | |
hypoth = hypotheses[j] | |
score = scores[j] | |
# decode the next time step | |
probs = decode(n) | |
# take topk indices and values | |
vals, idx = topk(probs, k) | |
# append to local hypoths | |
idx = np.append(np.repeat(np.expand_dims(hypoth, axis=0), k, axis=0), | |
np.expand_dims(idx, axis=1), axis=1) | |
local_hypoth = np.concatenate((local_hypoth, idx)) | |
# keep track of local scores | |
vals = vals + score | |
local_scores = np.append(local_scores, vals) | |
# take topk on all candidate hypoths | |
vals, idx = topk(local_scores, k) | |
# discard all other candidates and scores not in topk | |
scores = vals | |
hypotheses = local_hypoth[idx] | |
# return sequence with max prob | |
idx = np.argmax(scores) | |
return hypotheses[idx], scores[idx] / max_len | |
def main(): | |
n = int(input('Enter total size of dictionary of tokens: ')) | |
k = int(input('Enter beam size (k): ')) | |
max_len = int(input('Enter the maximum allowed length of a sequence: ')) | |
start_time = time.time() | |
print('Best pair of (sequence, prob):') | |
print(beam_search(n, k, max_len)) | |
print(f'ran in {(time.time() - start_time):.4f}s') | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment