Skip to content

Instantly share code, notes, and snippets.

@keyboardAnt
Last active November 11, 2023 11:48
Show Gist options
  • Save keyboardAnt/2f3db565bebabc57e97ef792d6f2a6ab to your computer and use it in GitHub Desktop.
Save keyboardAnt/2f3db565bebabc57e97ef792d6f2a6ab to your computer and use it in GitHub Desktop.
# Refernce: https://static.sched.com/hosted_files/pytorch2023/c0/Accelerating%20Generative%20AI%20PTC%20%282%29.pdf?page=41
import torch
def speculative_decode(
model: LLaMA,
draft_model: LLaMA,
cur_token: torch.Tensor,
input_pos: int,
speculate_k: int,
**sampling_kwargs
) -> torch.Tensor:
# draft model inference sequentially
device = cur_token.device
orig_input_pos = torch.tensor(
[input_pos], dtype=torch.int64, device=cur_token.device
)
draft_tokens, draft_probs = decode_n_tokens(
draft_model,
cur_token.view(1, -1),
orig_input_pos.clone(),
speculate_k,
**sampling_kwargs
)
draft_tokens = torch.cat(draft_tokens)
# parallel inference on target model using draft tokens
target_logits = model_forward(
model,
torch.cat([cur_token.view(1), draft_tokens]).view(1, -1),
torch.arange(input_pos, input_pos + speculate_k + 1, device=cur_token.device),
)
target_probs = logits_to_probs(target_logits[0], **sampling_kwargs)
draft_probs = torch.stack(draft_probs)
# q: target prob, p: draft prob
# q >= p: always accept draft token
# q < p: q/p prob to accept draft token
p = draft_probs[torch.arange(0, speculate_k, device=device), draft_tokens]
q = target_probs[torch.arange(0, speculate_k, device=device), draft_tokens]
accept_draft_prob = torch.minimum(torch.ones(()), q[:speculate_k] / p)
rejected_locations = (
torch.rand_like(accept_draft_prob) > accept_draft_prob
).nonzero()
if rejected_locations.shape[0] == 0: # All draft tokens have been accepted
accept_length = speculate_k + 1
last_token = multinomial_sample_one_no_sync(target_probs[-1])
# fill last token into draft model
model_forward(
draft_model,
draft_tokens[-1].view(1, -1),
orig_input_pos + speculate_k,
)
return torch.cat([draft_tokens, last_token])
accept_length = rejected_locations[0].item()
p = draft_probs[accept_length]
q = target_probs[accept_length]
new = q - p
new = torch.where(new > 0, new, 0.0)
new = new / new.sum()
next_token = multinomial_sample_one_no_sync(new)
return torch.cat([draft_tokens[:accept_length], next_token])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment