Skip to content

Instantly share code, notes, and snippets.

@calvinmccarter
Created August 31, 2024 20:43
Show Gist options
  • Save calvinmccarter/eaa9ee398606352e6e1df4b50e62881c to your computer and use it in GitHub Desktop.
Save calvinmccarter/eaa9ee398606352e6e1df4b50e62881c to your computer and use it in GitHub Desktop.
Cleaner version of top-p (nucleus) sampling
import numpy as np
def top_p_sampling(n_bins, probs, rng, top_p):
"""A modified implementation of nucleus sampling.
For the class straddling the top_p boundary, the probability mass beyond top_p is discarded.
But this class does not receive zero probability mass, so it differs
from https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 .
This is more mathematically elegant, in my humble opinion.
Parameters
----------
n_bins : int
The number of classes (ie vocab size) to sample from.
probs : np.ndarray with n_bins elements
The sampling probabilities (not logits) for each bin.
rng : np.random.RandomState
An instantiated random number generator.
top_p : float in (0, 1]
Top-p probability filter.
Returns
-------
chosen : np.ndarray of size (1,)
The sampled integer in [0, n_bins).
"""
probs = probs.ravel() # currently assumes only one sample
sort_indices = np.argsort(probs)[::-1]
sort_probs = probs[sort_indices]
cumsum_probs = np.cumsum(sort_probs)
unnorm_probs = np.diff(np.minimum(cumsum_probs, top_p), prepend=0.)
unnorm_probs = unnorm_probs[np.argsort(sort_indices)] # undo the sort
norm_probs = unnorm_probs / np.sum(unnorm_probs)
chosen = np.array(rng.choice(n_bins, p=norm_probs))
return chosen
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment