Skip to content

Instantly share code, notes, and snippets.

@vndee
Last active November 2, 2024 10:03
Show Gist options
  • Save vndee/b2ab7a1b40e43c4b6fef4d04f7488dcd to your computer and use it in GitHub Desktop.
Save vndee/b2ab7a1b40e43c4b6fef4d04f7488dcd to your computer and use it in GitHub Desktop.
def sample_next_token(logits, temperature=1.0):
"""Sample next token from language model logits"""
# Apply temperature
scaled_logits = [l/temperature for l in logits]
# Convert to probabilities with softmax
max_logit = max(scaled_logits)
exp_logits = [math.exp(l - max_logit) for l in scaled_logits]
total = sum(exp_logits)
probs = [e/total for e in exp_logits]
# Sample using weighted random choice
return random_choice(range(len(logits)), probs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment