-
Star
(160)
You must be signed in to star a gist -
Fork
(23)
You must be signed in to fork a gist
-
-
Save thomwolf/1a5a29f6962089e871b94cbd09daf317 to your computer and use it in GitHub Desktop.
| def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): | |
| """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering | |
| Args: | |
| logits: logits distribution shape (vocabulary size) | |
| top_k >0: keep only top k tokens with highest probability (top-k filtering). | |
| top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). | |
| Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) | |
| """ | |
| assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear | |
| top_k = min(top_k, logits.size(-1)) # Safety check | |
| if top_k > 0: | |
| # Remove all tokens with a probability less than the last token of the top-k | |
| indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] | |
| logits[indices_to_remove] = filter_value | |
| if top_p > 0.0: | |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | |
| # Remove tokens with cumulative probability above the threshold | |
| sorted_indices_to_remove = cumulative_probs > top_p | |
| # Shift the indices to the right to keep also the first token above the threshold | |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
| sorted_indices_to_remove[..., 0] = 0 | |
| indices_to_remove = sorted_indices[sorted_indices_to_remove] | |
| logits[indices_to_remove] = filter_value | |
| return logits | |
| # Here is how to use this function for top-p sampling | |
| temperature = 1.0 | |
| top_k = 0 | |
| top_p = 0.9 | |
| # Get logits with a forward pass in our model (input is pre-defined) | |
| logits = model(input) | |
| # Keep only the last token predictions of the first batch item (batch size 1), apply a temperature coefficient and filter | |
| logits = logits[0, -1, :] / temperature | |
| filtered_logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) | |
| # Sample from the filtered distribution | |
| probabilities = F.softmax(filtered_logits, dim=-1) | |
| next_token = torch.multinomial(probabilities, 1) |
@ZhangShiyue Filtering takes place on the probabilities because of the softmax on line 18.
@ZhangShiyue Filtering takes place on the probabilities because of the softmax on line 18.
No, filtering takes place on the logits level because of filtered_logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
By filtering, I mean you make certain logits to -inf.
In the code above, softmax happens after logits masking. But my mental model is first softmax and then filter probs (setting certain probs to be 0) and then sample. These two are different.
@ZhangShiyue No, softmax occurs in line 18 within the implementation of top_k_top_p_filtering. Filtering is applied in line 21 after the softmax.
Also note that setting specific indices of the logits to -inf is sufficient. You don't need to renormalize logits because softmax is translation invariant.
Let me give a concrete example of top-k (k=2) sampling below to clarify what I meant. And I think I mistakenly claimed "obviously leads to a different distribution from the logits level filtering." and "These two are different." After a 2nd thought, I think these two are the same.
if using the code above
>>> logits = torch.tensor([6., 5., 4.]) # assume logits are already sorted)
>>> logits_filtered = torch.tensor([6., 5., float('-inf')])
>>> torch.softmax(logits_filtered, dim=0)
tensor([0.7311, 0.2689, 0.0000])
# now you sample from [0.7311, 0.2689, 0.0000]
what I suggested
>>> logits = torch.tensor([6., 5., 4.]) # assume logits are already sorted
>>> probs = torch.softmax(logits, dim=0) # tensor([0.6652, 0.2447, 0.0900])
>>> probs = probs[:2]
>>> probs / probs.sum() # renormalize
tensor([0.7311, 0.2689]) # now you sample from [0.7311, 0.2689]
So they are the same. Lol. My bad.
Hi, thanks for this code snippet!
I recently realized that I had a fundamental question/confusion about top-k and top-p sampling. My mental implementation is different from the code above and what people usually use. Although we often say that we take top k/p
probabilities, in fact, the filtering occurs at thelogitslevel.But in my mind it would be: softmax -> take topk/p probs -> renormalize by sum -> then sample. Especially the original top-p paper (https://arxiv.org/pdf/1904.09751) suggested this type of implementation, which obviously leads to a different distribution from the logits level filtering.
Can someone help me understand this? What's the "right" way to implement top-k/p? Thanks a lot!