Skip to content

Instantly share code, notes, and snippets.

@vndee
Created November 1, 2024 13:45
Show Gist options
  • Save vndee/f8b77056a5867e51387e6a7f0590f91a to your computer and use it in GitHub Desktop.
Save vndee/f8b77056a5867e51387e6a7f0590f91a to your computer and use it in GitHub Desktop.
def sample_next_token(rng, logits, temperature=1.0):
"""Sample next token from language model logits"""
# Apply temperature
scaled_logits = [l/temperature for l in logits]
# Convert to probabilities with softmax
max_logit = max(scaled_logits)
exp_logits = [math.exp(l - max_logit) for l in scaled_logits]
total = sum(exp_logits)
probs = [e/total for e in exp_logits]
# Sample using our weighted random choice
return random_choice(rng, range(len(logits)), probs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment