Skip to content

Instantly share code, notes, and snippets.

@tamuhey
Last active January 21, 2020 10:44
Show Gist options
  • Save tamuhey/f7a93918bd6d226ed0434aaeef3c85d2 to your computer and use it in GitHub Desktop.
Save tamuhey/f7a93918bd6d226ed0434aaeef3c85d2 to your computer and use it in GitHub Desktop.
simple and efficient beamsearch function for pytorch
import torch
def beamsearch(probs: torch.Tensor, k: int) -> torch.Tensor:
"""Beam search for sequential probabilities.
Args:
data: tensor of shape (length, d). requires d > 0. Assumed all items in `probs` in range [0, 1].
k: beam width
Returns: (k, length) tensor
"""
assert len(probs.shape) == 2
if len(probs) == 0:
return torch.zeros(k, 0)
# We calculate top k-th argmax of E = p0p1p2..pn-1.
# To avoid over(under)flow, evaluete log(E) instead of E.
# By doing this, E can be evaluated by addition.
data = probs.to(torch.double).log()
_, m = data.shape
scores, candidates = torch.topk(data[0], k=min(k, m))
candidates = candidates[:, None]
for row in data[1:]:
z = (scores[:, None] + row[None, :]).flatten()
scores, flat_idx = torch.topk(z, k=min(k, len(z)))
i, j = flat_idx // m, flat_idx % m
candidates = torch.cat([candidates[i], j[:, None]], dim=-1)
return candidates
import hypothesis
@given(
st.integers(0, 200),
st.integers(1, 100),
st.integers(1, 10),
st.integers(-1000, 1000),
)
def test_beamsearch(n, m, k, s):
torch.manual_seed(s)
data = torch.randn(n, m).softmax(1)
output = beamsearch(data, k)
if n != 0:
assert output.shape == (min(m ** n, k), n)
assert all(output[0] == data.argmax(1))
assert all(output[0] == beamsearch(data, 1)[0])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment