Skip to content

Instantly share code, notes, and snippets.

@wassname
Last active March 31, 2025 23:08
Show Gist options
  • Save wassname/4af760435447d38a3012c6e39abb58e1 to your computer and use it in GitHub Desktop.
Save wassname/4af760435447d38a3012c6e39abb58e1 to your computer and use it in GitHub Desktop.
simple perplexity for huggingface models similar to llam..cpp
# Directly taken from https://huggingface.co/spaces/evaluate-measurement/perplexity/blob/main/perplexity.py
# TODO replace with a strided version https://github.com/huggingface/transformers/issues/9648#issuecomment-812981524
import numpy as np
import torch
import itertools
from torch.nn import CrossEntropyLoss
from tqdm.auto import tqdm
import torch.nn.functional as F
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
def nll_loss_no_mean(logits, labels):
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1228
logits = logits.float()
# Shift so that tokens < n predict n
vocab_size = logits.shape[-1]
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=-100, reduce=False)
shift_logits = shift_logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
return loss_fct(shift_logits, shift_labels)
def create_batch(input_ids, loss_mask, batch_i, batch_size, stride):
text_len = input_ids.size(1)
# create batch inds
begin_locs, end_locs, trg_lens = [], [], []
for j in range(batch_size):
j = batch_i + j * stride
if j >= text_len:
break
begin_loc = max(j, 0)
end_loc = min(j + stride, text_len)
trg_len = end_loc - j # may be different from stride on last loop
begin_locs.append(begin_loc)
end_locs.append(end_loc)
trg_lens.append(trg_len)
# create batch
b_input_ids = [input_ids[:, b:e] for b, e in zip(begin_locs, end_locs)]
b_input_ids = torch.stack(b_input_ids, dim=1).squeeze(0)
b_loss_mask = [loss_mask[:, b:e] for b, e in zip(begin_locs, end_locs)]
b_loss_mask = torch.stack(b_loss_mask, dim=1).squeeze(0)
# create target
target_ids = torch.ones_like(b_input_ids) * -100 # -100 is the default ingore_index value in torch.nn.CrossEntropyLoss
target_end_locs = [sen.size(-1) for sen in b_input_ids]
for i, (b, e) in enumerate(zip(trg_lens, target_end_locs)):
labels = b_input_ids[i, -b:e].clone()
target_ids[i, -b:e] = labels
target_ids[~b_loss_mask]=-100
return b_input_ids, target_ids
@torch.no_grad()
def batched_perplexity(model: AutoModelForCausalLM, tokenizer: AutoTokenizer, dataset: Dataset = None, batch_size=32, stride=512):
"""
Better perplexity calculation for causal language models.
Args:
model: A pretrained language model
tokenizer: The tokenizer used to preprocess the data
dataset: A dataset to calculate perplexity on. If None, the wikitext-2 test set is used.
batch_size: The batch size to use for perplexity calculation
stride: The stride to use for perplexity calculation - Important, changing this will change your results
Comparison again other implementations:
- https://huggingface.co/docs/transformers/perplexity - takes the mean of means giving it the wrong value
- https://github.com/huggingface/evaluate/blob/main/metrics/perplexity/perplexity.py - compelx and crops sentances so it's not comparable
- https://github.com/ggerganov/llama.cpp/tree/master/examples/perplexity - good but in cpp
- https://github.com/huggingface/transformers/issues/9648#issuecomment-812981524 - doesn't use special tokens
Limitations of this implementation:
- if a token is at the start of a strided window, it has no context, so it's perplexity is higher. TODO: have overlapping windows
- uses special tokens, hard to compare to scores that do not
"""
if dataset is None:
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test[:10%]")["text"]
dataset = dataset.filter(lambda x: len(x) > 0)
device = next(iter(model.parameters())).device
i = tokenizer(dataset, add_special_tokens=True, return_special_tokens_mask=True)
input_ids = torch.tensor(list(itertools.chain(*i.input_ids))).to(torch.long).unsqueeze(0)
# without padding or truncation we don't need attention but we do need special_tokens
attention_mask = torch.tensor(list(itertools.chain(*i.attention_mask))).to(torch.bool).unsqueeze(0)
special_tokens_mask = torch.tensor(list(itertools.chain(*i.special_tokens_mask))).to(torch.bool).unsqueeze(0)
# let's not calc the perplexity on special_tokens
loss_mask = attention_mask & ~special_tokens_mask
text_len = input_ids.size(1)
lls = []
for i in tqdm(range(0, text_len, batch_size * stride)):
b_input_ids, target_ids = create_batch(input_ids, loss_mask, i, batch_size, stride)
b_input_ids = b_input_ids.to(device)
target_ids = target_ids.to(device)
logits = model(b_input_ids).logits
log_likelihood = nll_loss_no_mean(logits, target_ids)
lls.extend(log_likelihood.view(-1).cpu().tolist())
lls = torch.tensor(lls)
ppl = lls.mean().exp()
return ppl.cpu().item()
import torch
import numpy as np
from tqdm.auto import tqdm
from torch.nn import CrossEntropyLoss
from transformers import PreTrainedModel, PreTrainedTokenizerBase
@torch.no_grad()
def compute_perplexity(text: str, model: PreTrainedModel, tokenizer: PreTrainedTokenizerBase, stride=8, max_length=512, batch_size=4):
"""
Efficient corpus perplexity calculation using strided windows.
Args:
model: A pretrained language model
tokenizer: The tokenizer used to preprocess the data
dataset: A dataset to calculate perplexity on. If None, the wikitext-2 test set is used.
batch_size: The batch size to use for perplexity calculation
stride: The stride to use for perplexity calculation - Important, changing this will change your results
Comparison again other implementations:
- https://huggingface.co/docs/transformers/perplexity - takes the mean of means giving it the wrong value
- https://github.com/huggingface/evaluate/blob/main/metrics/perplexity/perplexity.py - compelx and crops sentances so it's not comparable
- https://github.com/ggerganov/llama.cpp/tree/master/examples/perplexity - good but in cpp
- https://github.com/huggingface/transformers/issues/9648#issuecomment-812981524 - doesn't use special tokens
@url: https://gist.github.com/wassname/4af760435447d38a3012c6e39abb58e1
"""
device = model.device
# Tokenize corpus
encodings = tokenizer(text, return_tensors="pt", add_special_tokens=False)
seq_len = encodings.input_ids.size(1)
# Initialize tracking variables
nlls, counts = 0, 0
# Configure loss function
loss_fn = CrossEntropyLoss(reduction="none")
# Process corpus in strided windows
for i in tqdm(range(0, seq_len, stride * batch_size)):
# Prepare batch windows
input_ids_list, target_masks_list = [], []
for j in range(batch_size):
# Window start position
start_idx = i + j * stride
if start_idx >= seq_len:
break
# Extract window with context
end_idx = min(start_idx + max_length, seq_len)
ids = encodings.input_ids[0, start_idx:end_idx].clone()
# Skip windows that are too small
if len(ids) < 2:
continue
# Add BOS token for initial window
if start_idx == 0:
ids = torch.cat([torch.tensor([tokenizer.bos_token_id]), ids])
# Create evaluation mask (1 for tokens to evaluate, 0 otherwise)
# For overlapping windows, only evaluate tokens beyond the overlap point
eval_mask = torch.zeros_like(ids)
eval_offset = 0 if start_idx == 0 else stride
eval_mask[eval_offset:] = 1
input_ids_list.append(ids)
target_masks_list.append(eval_mask)
if not input_ids_list:
continue
# Create padded batch tensors
batch = tokenizer.pad({"input_ids": input_ids_list}, return_tensors="pt")
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
# Create padded target masks
max_len = input_ids.size(1)
padded_masks = []
for mask in target_masks_list:
padding = torch.zeros(max_len - len(mask), dtype=torch.long)
padded_masks.append(torch.cat([mask, padding]))
target_masks = torch.stack(padded_masks).to(device)
# Forward pass
outputs = model(input_ids, attention_mask=attention_mask)
# Compute loss on shifted sequences
shift_logits = outputs.logits[:, :-1].contiguous()
shift_labels = input_ids[:, 1:].contiguous()
shift_masks = target_masks[:, 1:].contiguous() * attention_mask[:, 1:].contiguous()
# Calculate NLL only for targeted tokens
loss = loss_fn(shift_logits.transpose(1, 2), shift_labels)
masked_loss = (loss * shift_masks).sum()
token_count = shift_masks.sum()
# Accumulate results
nlls += masked_loss.item()
counts += token_count.item()
# Return corpus-level perplexity
return np.exp(nlls / counts) if counts > 0 else float('inf')
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test[:10%]")["text"]
dataset = dataset.filter(lambda x: len(x) > 0)
text = "\n\n".join(dataset)
compute_perplexity(text, model, tokenizer)
@wassname
Copy link
Author

Ideally, but some models seem to be different, so I thought it best to ensure it's done explicitly. Tthere are a few ways people seem to calculate perplexity, so it's good to be able to see and modify it easily.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment