Skip to content

Instantly share code, notes, and snippets.

@pskrunner14
Last active May 30, 2019 14:40
Show Gist options
  • Save pskrunner14/e43aa89adacd5bfde448e1ebb367ea2f to your computer and use it in GitHub Desktop.
Save pskrunner14/e43aa89adacd5bfde448e1ebb367ea2f to your computer and use it in GitHub Desktop.
Beam Search example script from scratch in NumPy
import time
import numpy as np
def topk(a, k):
# idx = a.argsort()[-k:][::-1] # slower (runs in O(n logn) time)
idx = np.argpartition(a, -k)[-k:] # faster (runs in linear O(n) time)
return a[idx], idx
def decode(n):
return np.random.rand(n)
def beam_search(n, k, max_len):
"""Performs beam search decoding on sequences
having an arbitrary probability distribution.
Reference: http://web.stanford.edu/class/cs224n/slides/cs224n-2019-lecture08-nmt.pdf
Args:
-----
n: the total size of the vocabulary of tokens.
k: the beam size.
max_len: maximum allowed length of the sequence.
Returns:
--------
tuple(np.ndarray, float): best pair of (sequence, prob).
"""
# keep track of topk hypoths and their scores at all times
scores = np.array([.0])
hypotheses = np.array([[0]])
for i in range(max_len):
# for keeping track of local paths
local_hypoth = np.empty((0, hypotheses.shape[1] + 1))
local_scores = np.empty((0, ))
# expand each hypoth into candidate hypoths
for j in range(len(hypotheses)):
# cache previous hypoths and score
hypoth = hypotheses[j]
score = scores[j]
# decode the next time step
probs = decode(n)
# take topk indices and values
vals, idx = topk(probs, k)
# append to local hypoths
idx = np.append(np.repeat(np.expand_dims(hypoth, axis=0), k, axis=0),
np.expand_dims(idx, axis=1), axis=1)
local_hypoth = np.concatenate((local_hypoth, idx))
# keep track of local scores
vals = vals + score
local_scores = np.append(local_scores, vals)
# take topk on all candidate hypoths
vals, idx = topk(local_scores, k)
# discard all other candidates and scores not in topk
scores = vals
hypotheses = local_hypoth[idx]
# return sequence with max prob
idx = np.argmax(scores)
return hypotheses[idx], scores[idx] / max_len
def main():
n = int(input('Enter total size of dictionary of tokens: '))
k = int(input('Enter beam size (k): '))
max_len = int(input('Enter the maximum allowed length of a sequence: '))
start_time = time.time()
print('Best pair of (sequence, prob):')
print(beam_search(n, k, max_len))
print(f'ran in {(time.time() - start_time):.4f}s')
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment