-
Star
(159)
You must be signed in to star a gist -
Fork
(23)
You must be signed in to fork a gist
-
-
Save thomwolf/1a5a29f6962089e871b94cbd09daf317 to your computer and use it in GitHub Desktop.
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): | |
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering | |
Args: | |
logits: logits distribution shape (vocabulary size) | |
top_k >0: keep only top k tokens with highest probability (top-k filtering). | |
top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). | |
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) | |
""" | |
assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear | |
top_k = min(top_k, logits.size(-1)) # Safety check | |
if top_k > 0: | |
# Remove all tokens with a probability less than the last token of the top-k | |
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] | |
logits[indices_to_remove] = filter_value | |
if top_p > 0.0: | |
sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | |
# Remove tokens with cumulative probability above the threshold | |
sorted_indices_to_remove = cumulative_probs > top_p | |
# Shift the indices to the right to keep also the first token above the threshold | |
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
sorted_indices_to_remove[..., 0] = 0 | |
indices_to_remove = sorted_indices[sorted_indices_to_remove] | |
logits[indices_to_remove] = filter_value | |
return logits | |
# Here is how to use this function for top-p sampling | |
temperature = 1.0 | |
top_k = 0 | |
top_p = 0.9 | |
# Get logits with a forward pass in our model (input is pre-defined) | |
logits = model(input) | |
# Keep only the last token predictions of the first batch item (batch size 1), apply a temperature coefficient and filter | |
logits = logits[0, -1, :] / temperature | |
filtered_logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) | |
# Sample from the filtered distribution | |
probabilities = F.softmax(filtered_logits, dim=-1) | |
next_token = torch.multinomial(probabilities, 1) |
I tried it out here and was unimpressed with sample quality compared to top_k=40, temp=.8
Anyone have better luck?
@mdda, I don't think I understand your issue.
Why would the masking operation on the RHS produce a list of indices and not a LongTensor?
torch.sort
produces a tuple (Tensor, LongTensor), sorted_indices_to_remove
is a ByteTensor. When we index the LongTensor with a ByteTensor we have another LongTensor with only the masked elements kept (so not the same size indeed, which is intended).
We can then set the masked elements to -inf in the last indexing operation.
Your solution is the same using a ByteTensor to mask instead of a LongTensor to index in the last operation so the results are identical (I tested it to compare the outputs).
So I'm curious, what error was raised in your tests?
@8enmann it's been working slightly better than top-40 in my tests (dialog generation) but the variance of my personal evaluation is quite high I must say. I usually use a temperature a bit lower: 0.7
and a top_p
of 0.9
.
@thomwolf the paper suggested temperature 1.0, so that's what I'd been using (and top_p=.9
). Reducing the temperature is giving much better results!
I was getting an error using the original code, @mdda 's edit fixed it for me. Stack trace below. Note: batch size was 1, not sure if that matters.
@yaroslavvb and I were talking about an automated way to tune these hyperparameters and thought about LAMBADA or even the likelihood of selecting "gold" word as they mention in the paper. Haven't tried it.
Traceback (most recent call last):
File "generate.py", line 131, in <module>
main()
File "generate.py", line 64, in main
softmax = hidden_to_softmax(model, pred_hid[-1], top_k=args.top_k, temperature=args.temperature, top_p=args.top_p)
File "generate.py", line 90, in hidden_to_softmax
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
File "generate.py", line 127, in top_k_top_p_filtering
logits[indices_to_remove] = filter_value
IndexError: index 36955 is out of bounds for dimension 0 with size 1
@8enmann Interesting, thanks! I would like to get to the root of this error if you're ok. What is the shape of your input logits tensor?
Mine is torch.Size([50262])
in my current testing setup.
[update] Oh yes, that's the problem, I'm filtering the logits with logits = logits[0, -1, :] / temperature
and not what I showed here.
Let's fix the gist for batch_size 1 for now then (that's the main use-case anyway).
Hi, added a PR about expending to num_samples > 1.
huggingface/transformers#1333
:)
I am getting this error here, plus can you please elaborate about that "input"?
logits = model.forward(input)
Traceback (most recent call last):
File "", line 1, in
logits = model.forward(input)
TypeError: forward() missing 1 required positional argument: 'mc_token_ids'
You can refer in the run_generation.py in the Transformers to a full source code working on a real model, i.e. OpenAI GPT-2.
I am getting this error here, plus can you please elaborate about that "input"?
logits = model.forward(input)
Traceback (most recent call last):File "", line 1, in
logits = model.forward(input)TypeError: forward() missing 1 required positional argument: 'mc_token_ids'
sample from the smallest set whose cumulative probability mass exceeds p for next words
what exactly this means? lets say I put p =0.9 then it will filter only those next tokens which have probability of > 0.9 or what ?
@thomwolf
Hi, I am recently learning the temperature sampling/ Nucleus sampling,
And I read the paper: "The Curious Case of Neural Text Degeneration", they rescaled the original-distribution to a new-distribution,
In the top_k_top_p_filtering function, it set the logit score to zero but doesn't change the probability distribution.
Does "Change the probability distribution" is necessary for top-p sampling?
Thank you ~
@chungyilinxrspace
In the top_k_top_p_filtering function, it set the logit score to zero but doesn't change the probability distribution.
Does "Change the probability distribution" is necessary for top-p sampling?
Hi!
TL;DR: The filtering function provided operates on the logits and not on the probabilities.
After filtering the logits, they are converted to class probabilities via the call to F.softmax
, which ensures both that the filtered classes have zero probability (since they have logit value float("-inf)"
) and that the filtered probabilities define a proper, scaled, proability distribution. Hence the probability distribution is indeed "changed".
Hello all!
First, thank you for a very nice piece of code.
I have a more general question about nucleus sampling itself, maybe someone will be willing to clarify several things for me.
How do we choose k and p? As fas as I understand, every time we generate text, it will be different given that k and p are the same (or different). In other words, one cannot get a stable generate output (unlike when using greedy or beam search).
Is there a good approximation of what values for these parameters could be? Or should it based solely on empirical observations for a particular problem? If the latter is the case, can anyone navigate me towards basic ideas on how changing k and/or p would affect generated output in general ?
@mdda
Line24 exactly produce a list of indices. and your code helps.
How can I return multiple sampling sequences?
My understanding is run nucleus sampling for a whole sequence multiple times.
Thanks!
https://github.com/BenjaminWegener/gpt2_torch_nucleus as working example
Thank you for this code. Is it distributed under some Open Source license or are there otherwise any limitations on its use?
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
What is the significance of these lines? Cannot get my head around them.
Thanks
@thomwolf
Hello! I'm trying to modify the code to support batch size grater than one. I get some problem to make the top_p support 2d input. I didn't find the appropriate pytorch api to index the 2d tensor(line 26~27) in the code and implement it with for loop, which is too slow. Could you provide some suggestions about the implementation?
@umiswing I'm also looking for a batched version of this, did you find anything?
@umiswing I'm also looking for a batched version of this, did you find anything?
@nicofirst1 I modify it to batch version. But I didn't do much test for it. I hope it can help.
# Shift the indices to the right to keep also the first token above the threshold sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0
What is the significance of these lines? Cannot get my head around them. Thanks
can anybody please tell this?
is it to keep at least one value?
@not-hermione
Consider this example:
x = torch.arange(5,0,-1)
cumulative_probs = torch.cumsum(x, dim=0) # tensor([ 5, 9, 12, 14, 15])
sorted_indices_to_remove = cumulative_probs > 13 # tensor([False, False, False, True, True])
We want to create a boolean mask called sorted_indices_to_remove
to identify which indices in cumulative_probs
need to be removed. Specifically, we want to remove indices where the corresponding value in cumulative_probs is greater than 13.
Notice the index corresponding to value 12 is also marked as True in sorted_indices_to_remove
, which we don't want to remove.
To address this issue, we use the following two lines of code:
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
These 2 lines of code shift the values in sorted_indices_to_remove
to the right by 1 along the last dimension and then set the first value along the last dimension to False
.
This ensures the index corresponding to value 12 in cumulative_probs
is not marked as True
in sorted_indices_to_remove
.
@LeeSinLiang
thanks alot for your time and answer.
does anyone have this error? RuntimeError: scatter(): Expected self.dtype to be equal to src.dtype
?
I changed the line to be
indices_to_remove = torch.zeros_like(logits, dtype=sorted_indices_to_remove.dtype).scatter_(
dim=-1, index=sorted_indices, src=sorted_indices_to_remove )
such that it works now.
Note that this implementation does not just take the top_p probability mass. It also includes the probability mass of the token that straddles the top_p boundary. Here is a (numpy, not pytorch) implementation which always samples exactly from the top_p probability mass: https://gist.github.com/calvinmccarter/eaa9ee398606352e6e1df4b50e62881c .
Line 24 :
indices_to_remove = sorted_indices[sorted_indices_to_remove]
does not seem to do what's intended, since the masking operation on the RHS seems to produce a list of indices fromsorted_indices
(but the shape is different from thelogits
that got sorted)I had to go with something like :
This is on PyTorch 1.1 (and 1.0,1 which I was using before I thought I must be going crazy)
HTH (please let me know if the above is also wrong...)