Skip to content

Instantly share code, notes, and snippets.

@thomwolf
Last active October 25, 2025 20:25
Show Gist options
  • Select an option

  • Save thomwolf/1a5a29f6962089e871b94cbd09daf317 to your computer and use it in GitHub Desktop.

Select an option

Save thomwolf/1a5a29f6962089e871b94cbd09daf317 to your computer and use it in GitHub Desktop.
Sample the next token from a probability distribution using top-k and/or nucleus (top-p) sampling
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (vocabulary size)
top_k >0: keep only top k tokens with highest probability (top-k filtering).
top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
"""
assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
top_k = min(top_k, logits.size(-1)) # Safety check
if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = filter_value
return logits
# Here is how to use this function for top-p sampling
temperature = 1.0
top_k = 0
top_p = 0.9
# Get logits with a forward pass in our model (input is pre-defined)
logits = model(input)
# Keep only the last token predictions of the first batch item (batch size 1), apply a temperature coefficient and filter
logits = logits[0, -1, :] / temperature
filtered_logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
# Sample from the filtered distribution
probabilities = F.softmax(filtered_logits, dim=-1)
next_token = torch.multinomial(probabilities, 1)
@calvinmccarter
Copy link

@ZhangShiyue No, softmax occurs in line 18 within the implementation of top_k_top_p_filtering. Filtering is applied in line 21 after the softmax.

@calvinmccarter
Copy link

Also note that setting specific indices of the logits to -inf is sufficient. You don't need to renormalize logits because softmax is translation invariant.

@ZhangShiyue
Copy link

ZhangShiyue commented Oct 25, 2025

Let me give a concrete example of top-k (k=2) sampling below to clarify what I meant. And I think I mistakenly claimed "obviously leads to a different distribution from the logits level filtering." and "These two are different." After a 2nd thought, I think these two are the same.

if using the code above

>>> logits = torch.tensor([6., 5., 4.]) # assume logits are already sorted)
>>> logits_filtered = torch.tensor([6., 5., float('-inf')])
>>> torch.softmax(logits_filtered, dim=0)
tensor([0.7311, 0.2689, 0.0000])
# now you sample from [0.7311, 0.2689, 0.0000]

what I suggested

>>> logits = torch.tensor([6., 5., 4.]) # assume logits are already sorted
>>> probs = torch.softmax(logits, dim=0)  # tensor([0.6652, 0.2447, 0.0900])
>>> probs = probs[:2]
>>> probs / probs.sum() # renormalize
tensor([0.7311, 0.2689]) # now you sample from [0.7311, 0.2689]

So they are the same. Lol. My bad.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment