Skip to content

Instantly share code, notes, and snippets.

@thomwolf
Last active October 25, 2025 20:25
Show Gist options
  • Select an option

  • Save thomwolf/1a5a29f6962089e871b94cbd09daf317 to your computer and use it in GitHub Desktop.

Select an option

Save thomwolf/1a5a29f6962089e871b94cbd09daf317 to your computer and use it in GitHub Desktop.
Sample the next token from a probability distribution using top-k and/or nucleus (top-p) sampling
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (vocabulary size)
top_k >0: keep only top k tokens with highest probability (top-k filtering).
top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
"""
assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
top_k = min(top_k, logits.size(-1)) # Safety check
if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = filter_value
return logits
# Here is how to use this function for top-p sampling
temperature = 1.0
top_k = 0
top_p = 0.9
# Get logits with a forward pass in our model (input is pre-defined)
logits = model(input)
# Keep only the last token predictions of the first batch item (batch size 1), apply a temperature coefficient and filter
logits = logits[0, -1, :] / temperature
filtered_logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
# Sample from the filtered distribution
probabilities = F.softmax(filtered_logits, dim=-1)
next_token = torch.multinomial(probabilities, 1)
@8enmann
Copy link

8enmann commented May 16, 2019 via email

@8enmann
Copy link

8enmann commented May 16, 2019 via email

@mataney
Copy link

mataney commented Sep 25, 2019

Hi, added a PR about expending to num_samples > 1.
huggingface/transformers#1333
:)

@Ace-ezer
Copy link

Ace-ezer commented Nov 20, 2019

I am getting this error here, plus can you please elaborate about that "input"?

logits = model.forward(input)
Traceback (most recent call last):

File "", line 1, in
logits = model.forward(input)

TypeError: forward() missing 1 required positional argument: 'mc_token_ids'

@TheEdoardo93
Copy link

You can refer in the run_generation.py in the Transformers to a full source code working on a real model, i.e. OpenAI GPT-2.

I am getting this error here, plus can you please elaborate about that "input"?

logits = model.forward(input)
Traceback (most recent call last):

File "", line 1, in
logits = model.forward(input)

TypeError: forward() missing 1 required positional argument: 'mc_token_ids'

@aj7tesh
Copy link

aj7tesh commented Feb 14, 2020

sample from the smallest set whose cumulative probability mass exceeds p for next words
what exactly this means? lets say I put p =0.9 then it will filter only those next tokens which have probability of > 0.9 or what ?

@chungyilinxrspace
Copy link

@thomwolf
Hi, I am recently learning the temperature sampling/ Nucleus sampling,
And I read the paper: "The Curious Case of Neural Text Degeneration", they rescaled the original-distribution to a new-distribution,

In the top_k_top_p_filtering function, it set the logit score to zero but doesn't change the probability distribution.
Does "Change the probability distribution" is necessary for top-p sampling?
Thank you ~

@tbazin
Copy link

tbazin commented Mar 3, 2020

@chungyilinxrspace
In the top_k_top_p_filtering function, it set the logit score to zero but doesn't change the probability distribution.
Does "Change the probability distribution" is necessary for top-p sampling?

Hi!
TL;DR: The filtering function provided operates on the logits and not on the probabilities.

After filtering the logits, they are converted to class probabilities via the call to F.softmax, which ensures both that the filtered classes have zero probability (since they have logit value float("-inf)") and that the filtered probabilities define a proper, scaled, proability distribution. Hence the probability distribution is indeed "changed".

@nilinykh
Copy link

nilinykh commented Jun 3, 2020

Hello all!
First, thank you for a very nice piece of code.

I have a more general question about nucleus sampling itself, maybe someone will be willing to clarify several things for me.
How do we choose k and p? As fas as I understand, every time we generate text, it will be different given that k and p are the same (or different). In other words, one cannot get a stable generate output (unlike when using greedy or beam search).
Is there a good approximation of what values for these parameters could be? Or should it based solely on empirical observations for a particular problem? If the latter is the case, can anyone navigate me towards basic ideas on how changing k and/or p would affect generated output in general ?

@Hyman25
Copy link

Hyman25 commented Oct 27, 2020

@mdda
Line24 exactly produce a list of indices. and your code helps.

@JiyangZhang
Copy link

How can I return multiple sampling sequences?
My understanding is run nucleus sampling for a whole sequence multiple times.

Thanks!

@BenjaminWegener
Copy link

@tapdiego-amzn
Copy link

Thank you for this code. Is it distributed under some Open Source license or are there otherwise any limitations on its use?

@kushalj001
Copy link

 # Shift the indices to the right to keep also the first token above the threshold
  sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
  sorted_indices_to_remove[..., 0] = 0

What is the significance of these lines? Cannot get my head around them.
Thanks

@umiswing
Copy link

@thomwolf
Hello! I'm trying to modify the code to support batch size grater than one. I get some problem to make the top_p support 2d input. I didn't find the appropriate pytorch api to index the 2d tensor(line 26~27) in the code and implement it with for loop, which is too slow. Could you provide some suggestions about the implementation?

@nicofirst1
Copy link

@umiswing I'm also looking for a batched version of this, did you find anything?

@umiswing
Copy link

umiswing commented Aug 8, 2022

@umiswing I'm also looking for a batched version of this, did you find anything?

@nicofirst1 I modify it to batch version. But I didn't do much test for it. I hope it can help.

@Debolena7
Copy link

Debolena7 commented Feb 16, 2023

 # Shift the indices to the right to keep also the first token above the threshold
  sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
  sorted_indices_to_remove[..., 0] = 0

What is the significance of these lines? Cannot get my head around them. Thanks

can anybody please tell this?
is it to keep at least one value?

@LeeSinLiang
Copy link

LeeSinLiang commented Apr 24, 2023

@not-hermione
Consider this example:

x = torch.arange(5,0,-1) 
cumulative_probs = torch.cumsum(x, dim=0) # tensor([ 5,  9, 12, 14, 15])
sorted_indices_to_remove = cumulative_probs > 13 # tensor([False, False, False,  True,  True])

We want to create a boolean mask called sorted_indices_to_remove to identify which indices in cumulative_probs need to be removed. Specifically, we want to remove indices where the corresponding value in cumulative_probs is greater than 13.

Notice the index corresponding to value 12 is also marked as True in sorted_indices_to_remove, which we don't want to remove.

To address this issue, we use the following two lines of code:

sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0 

These 2 lines of code shift the values in sorted_indices_to_remove to the right by 1 along the last dimension and then set the first value along the last dimension to False.
This ensures the index corresponding to value 12 in cumulative_probs is not marked as True in sorted_indices_to_remove.

@Debolena7
Copy link

Debolena7 commented Apr 24, 2023

@LeeSinLiang
thanks alot for your time and answer.

@yuchenlin
Copy link

yuchenlin commented Jul 5, 2023

does anyone have this error? RuntimeError: scatter(): Expected self.dtype to be equal to src.dtype?

I changed the line to be

indices_to_remove = torch.zeros_like(logits, dtype=sorted_indices_to_remove.dtype).scatter_(
            dim=-1, index=sorted_indices, src=sorted_indices_to_remove )

such that it works now.

@calvinmccarter
Copy link

Note that this implementation does not just take the top_p probability mass. It also includes the probability mass of the token that straddles the top_p boundary. Here is a (numpy, not pytorch) implementation which always samples exactly from the top_p probability mass: https://gist.github.com/calvinmccarter/eaa9ee398606352e6e1df4b50e62881c .

@ZhangShiyue
Copy link

Hi, thanks for this code snippet!

I recently realized that I had a fundamental question/confusion about top-k and top-p sampling. My mental implementation is different from the code above and what people usually use. Although we often say that we take top k/p probabilities, in fact, the filtering occurs at the logits level.

But in my mind it would be: softmax -> take topk/p probs -> renormalize by sum -> then sample. Especially the original top-p paper (https://arxiv.org/pdf/1904.09751) suggested this type of implementation, which obviously leads to a different distribution from the logits level filtering.

Can someone help me understand this? What's the "right" way to implement top-k/p? Thanks a lot!

@calvinmccarter
Copy link

@ZhangShiyue Filtering takes place on the probabilities because of the softmax on line 18.

@ZhangShiyue
Copy link

@ZhangShiyue Filtering takes place on the probabilities because of the softmax on line 18.

No, filtering takes place on the logits level because of filtered_logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
By filtering, I mean you make certain logits to -inf.

In the code above, softmax happens after logits masking. But my mental model is first softmax and then filter probs (setting certain probs to be 0) and then sample. These two are different.

@calvinmccarter
Copy link

@ZhangShiyue No, softmax occurs in line 18 within the implementation of top_k_top_p_filtering. Filtering is applied in line 21 after the softmax.

@calvinmccarter
Copy link

Also note that setting specific indices of the logits to -inf is sufficient. You don't need to renormalize logits because softmax is translation invariant.

@ZhangShiyue
Copy link

ZhangShiyue commented Oct 25, 2025

Let me give a concrete example of top-k (k=2) sampling below to clarify what I meant. And I think I mistakenly claimed "obviously leads to a different distribution from the logits level filtering." and "These two are different." After a 2nd thought, I think these two are the same.

if using the code above

>>> logits = torch.tensor([6., 5., 4.]) # assume logits are already sorted)
>>> logits_filtered = torch.tensor([6., 5., float('-inf')])
>>> torch.softmax(logits_filtered, dim=0)
tensor([0.7311, 0.2689, 0.0000])
# now you sample from [0.7311, 0.2689, 0.0000]

what I suggested

>>> logits = torch.tensor([6., 5., 4.]) # assume logits are already sorted
>>> probs = torch.softmax(logits, dim=0)  # tensor([0.6652, 0.2447, 0.0900])
>>> probs = probs[:2]
>>> probs / probs.sum() # renormalize
tensor([0.7311, 0.2689]) # now you sample from [0.7311, 0.2689]

So they are the same. Lol. My bad.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment