Skip to content

Instantly share code, notes, and snippets.

@wassname
Last active September 11, 2025 06:36
Show Gist options
  • Select an option

  • Save wassname/fffa2aaca88bae2ca29a3b11bd2df9cf to your computer and use it in GitHub Desktop.

Select an option

Save wassname/fffa2aaca88bae2ca29a3b11bd2df9cf to your computer and use it in GitHub Desktop.
how to sample from logits in huggingface transformers
# how to sample from logits in huggingface transformers
from transformers.generation.utils import MinPLogitsWarper, LogitNormalization
logits_processors = [
MinPLogitsWarper(min_p=0.1),
LogitNormalization() # alaways need this last
]
logits = o.logits[:, -1].clone()
# logits[:, banned_token_ids] = -float("inf")
for proc in logits_processors:
logits = proc(input_ids, logits)
logp = logits.log_softmax(dim=-1)
if do_sample:
new_token_id = torch.multinomial(logp.exp(), num_samples=1)
else:
new_token_id = logp.argmax(dim=-1, keepdim=True)#.unsqueeze(1)
@wassname
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment