Last active
June 15, 2023 21:52
-
-
Save crowsonkb/6856f8bdd0cf713e2a6315cdaa8d2c53 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Stochastic beam search. | |
Implements "Stochastic Beams and Where to Find Them: The Gumbel-Top-k Trick for | |
Sampling Sequences Without Replacement" (https://arxiv.org/abs/1903.06059)""" | |
import math | |
import torch | |
def log1mexp(a): | |
a1 = torch.log(-torch.expm1(a)) | |
a2 = torch.log1p(-torch.exp(a)) | |
return torch.where(a > -math.log(2), a1, a2) | |
def shift_gumbel(g, t): | |
z = torch.max(g, dim=-1, keepdim=True).values | |
v = t[..., None] - g + log1mexp(g - z) | |
return t[..., None] - v.relu() - torch.nn.functional.softplus(-v.abs()) | |
def stochastic_beam_search(model, input_ids, n_tokens, beam_width, temperature=1.0): | |
assert input_ids.shape[0] == 1 | |
device = input_ids.device | |
past_key_values = None | |
# Initialize beam | |
input_ids = input_ids.repeat(beam_width, 1) | |
phi_s = torch.zeros([1], device=device) | |
g_phi_s = torch.zeros([1], device=device) | |
cur_beam_width = 1 | |
for _ in range(n_tokens): | |
input_ids_in = input_ids if past_key_values is None else input_ids[:, -1:] | |
with torch.no_grad(): | |
model_output = model( | |
input_ids_in, | |
use_cache=True, | |
past_key_values=past_key_values, | |
) | |
past_key_values = model_output.past_key_values | |
logits = model_output.logits[:cur_beam_width, -1, :].float() / temperature | |
logprobs = torch.nn.functional.log_softmax(logits, dim=1) | |
phi_s_prime = phi_s[:, None] + logprobs | |
g_phi_s_prime = phi_s_prime - torch.log(-torch.log(torch.rand_like(logits))) | |
g_phi_s_prime = shift_gumbel(g_phi_s_prime, g_phi_s) | |
src = torch.arange(cur_beam_width, device=device).repeat_interleave(logits.shape[1]) | |
y_prime = torch.arange(logits.shape[1], device=device).repeat(cur_beam_width) | |
g_phi_s, indices = torch.topk(g_phi_s_prime.flatten(), k=beam_width) | |
phi_s = phi_s_prime.flatten()[indices] | |
input_ids = torch.cat([input_ids[src[indices]], y_prime[indices, None]], dim=1) | |
cur_beam_width = g_phi_s.shape[0] | |
return input_ids |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment