Skip to content

Instantly share code, notes, and snippets.

@thomwolf
Last active October 25, 2025 20:25
Show Gist options
  • Select an option

  • Save thomwolf/1a5a29f6962089e871b94cbd09daf317 to your computer and use it in GitHub Desktop.

Select an option

Save thomwolf/1a5a29f6962089e871b94cbd09daf317 to your computer and use it in GitHub Desktop.
Sample the next token from a probability distribution using top-k and/or nucleus (top-p) sampling
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (vocabulary size)
top_k >0: keep only top k tokens with highest probability (top-k filtering).
top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
"""
assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
top_k = min(top_k, logits.size(-1)) # Safety check
if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = filter_value
return logits
# Here is how to use this function for top-p sampling
temperature = 1.0
top_k = 0
top_p = 0.9
# Get logits with a forward pass in our model (input is pre-defined)
logits = model(input)
# Keep only the last token predictions of the first batch item (batch size 1), apply a temperature coefficient and filter
logits = logits[0, -1, :] / temperature
filtered_logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
# Sample from the filtered distribution
probabilities = F.softmax(filtered_logits, dim=-1)
next_token = torch.multinomial(probabilities, 1)
@calvinmccarter
Copy link

@ZhangShiyue Filtering takes place on the probabilities because of the softmax on line 18.

@ZhangShiyue
Copy link

@ZhangShiyue Filtering takes place on the probabilities because of the softmax on line 18.

No, filtering takes place on the logits level because of filtered_logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
By filtering, I mean you make certain logits to -inf.

In the code above, softmax happens after logits masking. But my mental model is first softmax and then filter probs (setting certain probs to be 0) and then sample. These two are different.

@calvinmccarter
Copy link

@ZhangShiyue No, softmax occurs in line 18 within the implementation of top_k_top_p_filtering. Filtering is applied in line 21 after the softmax.

@calvinmccarter
Copy link

Also note that setting specific indices of the logits to -inf is sufficient. You don't need to renormalize logits because softmax is translation invariant.

@ZhangShiyue
Copy link

ZhangShiyue commented Oct 25, 2025

Let me give a concrete example of top-k (k=2) sampling below to clarify what I meant. And I think I mistakenly claimed "obviously leads to a different distribution from the logits level filtering." and "These two are different." After a 2nd thought, I think these two are the same.

if using the code above

>>> logits = torch.tensor([6., 5., 4.]) # assume logits are already sorted)
>>> logits_filtered = torch.tensor([6., 5., float('-inf')])
>>> torch.softmax(logits_filtered, dim=0)
tensor([0.7311, 0.2689, 0.0000])
# now you sample from [0.7311, 0.2689, 0.0000]

what I suggested

>>> logits = torch.tensor([6., 5., 4.]) # assume logits are already sorted
>>> probs = torch.softmax(logits, dim=0)  # tensor([0.6652, 0.2447, 0.0900])
>>> probs = probs[:2]
>>> probs / probs.sum() # renormalize
tensor([0.7311, 0.2689]) # now you sample from [0.7311, 0.2689]

So they are the same. Lol. My bad.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment