Last active
March 31, 2025 23:08
-
-
Save wassname/4af760435447d38a3012c6e39abb58e1 to your computer and use it in GitHub Desktop.
simple perplexity for huggingface models similar to llam..cpp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Directly taken from https://huggingface.co/spaces/evaluate-measurement/perplexity/blob/main/perplexity.py | |
# TODO replace with a strided version https://github.com/huggingface/transformers/issues/9648#issuecomment-812981524 | |
import numpy as np | |
import torch | |
import itertools | |
from torch.nn import CrossEntropyLoss | |
from tqdm.auto import tqdm | |
import torch.nn.functional as F | |
from datasets import load_dataset, Dataset | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
def nll_loss_no_mean(logits, labels): | |
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1228 | |
logits = logits.float() | |
# Shift so that tokens < n predict n | |
vocab_size = logits.shape[-1] | |
shift_logits = logits[..., :-1, :].contiguous() | |
shift_labels = labels[..., 1:].contiguous() | |
# Flatten the tokens | |
loss_fct = CrossEntropyLoss(ignore_index=-100, reduce=False) | |
shift_logits = shift_logits.view(-1, vocab_size) | |
shift_labels = shift_labels.view(-1) | |
# Enable model parallelism | |
shift_labels = shift_labels.to(shift_logits.device) | |
return loss_fct(shift_logits, shift_labels) | |
def create_batch(input_ids, loss_mask, batch_i, batch_size, stride): | |
text_len = input_ids.size(1) | |
# create batch inds | |
begin_locs, end_locs, trg_lens = [], [], [] | |
for j in range(batch_size): | |
j = batch_i + j * stride | |
if j >= text_len: | |
break | |
begin_loc = max(j, 0) | |
end_loc = min(j + stride, text_len) | |
trg_len = end_loc - j # may be different from stride on last loop | |
begin_locs.append(begin_loc) | |
end_locs.append(end_loc) | |
trg_lens.append(trg_len) | |
# create batch | |
b_input_ids = [input_ids[:, b:e] for b, e in zip(begin_locs, end_locs)] | |
b_input_ids = torch.stack(b_input_ids, dim=1).squeeze(0) | |
b_loss_mask = [loss_mask[:, b:e] for b, e in zip(begin_locs, end_locs)] | |
b_loss_mask = torch.stack(b_loss_mask, dim=1).squeeze(0) | |
# create target | |
target_ids = torch.ones_like(b_input_ids) * -100 # -100 is the default ingore_index value in torch.nn.CrossEntropyLoss | |
target_end_locs = [sen.size(-1) for sen in b_input_ids] | |
for i, (b, e) in enumerate(zip(trg_lens, target_end_locs)): | |
labels = b_input_ids[i, -b:e].clone() | |
target_ids[i, -b:e] = labels | |
target_ids[~b_loss_mask]=-100 | |
return b_input_ids, target_ids | |
@torch.no_grad() | |
def batched_perplexity(model: AutoModelForCausalLM, tokenizer: AutoTokenizer, dataset: Dataset = None, batch_size=32, stride=512): | |
""" | |
Better perplexity calculation for causal language models. | |
Args: | |
model: A pretrained language model | |
tokenizer: The tokenizer used to preprocess the data | |
dataset: A dataset to calculate perplexity on. If None, the wikitext-2 test set is used. | |
batch_size: The batch size to use for perplexity calculation | |
stride: The stride to use for perplexity calculation - Important, changing this will change your results | |
Comparison again other implementations: | |
- https://huggingface.co/docs/transformers/perplexity - takes the mean of means giving it the wrong value | |
- https://github.com/huggingface/evaluate/blob/main/metrics/perplexity/perplexity.py - compelx and crops sentances so it's not comparable | |
- https://github.com/ggerganov/llama.cpp/tree/master/examples/perplexity - good but in cpp | |
- https://github.com/huggingface/transformers/issues/9648#issuecomment-812981524 - doesn't use special tokens | |
Limitations of this implementation: | |
- if a token is at the start of a strided window, it has no context, so it's perplexity is higher. TODO: have overlapping windows | |
- uses special tokens, hard to compare to scores that do not | |
""" | |
if dataset is None: | |
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test[:10%]")["text"] | |
dataset = dataset.filter(lambda x: len(x) > 0) | |
device = next(iter(model.parameters())).device | |
i = tokenizer(dataset, add_special_tokens=True, return_special_tokens_mask=True) | |
input_ids = torch.tensor(list(itertools.chain(*i.input_ids))).to(torch.long).unsqueeze(0) | |
# without padding or truncation we don't need attention but we do need special_tokens | |
attention_mask = torch.tensor(list(itertools.chain(*i.attention_mask))).to(torch.bool).unsqueeze(0) | |
special_tokens_mask = torch.tensor(list(itertools.chain(*i.special_tokens_mask))).to(torch.bool).unsqueeze(0) | |
# let's not calc the perplexity on special_tokens | |
loss_mask = attention_mask & ~special_tokens_mask | |
text_len = input_ids.size(1) | |
lls = [] | |
for i in tqdm(range(0, text_len, batch_size * stride)): | |
b_input_ids, target_ids = create_batch(input_ids, loss_mask, i, batch_size, stride) | |
b_input_ids = b_input_ids.to(device) | |
target_ids = target_ids.to(device) | |
logits = model(b_input_ids).logits | |
log_likelihood = nll_loss_no_mean(logits, target_ids) | |
lls.extend(log_likelihood.view(-1).cpu().tolist()) | |
lls = torch.tensor(lls) | |
ppl = lls.mean().exp() | |
return ppl.cpu().item() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import numpy as np | |
from tqdm.auto import tqdm | |
from torch.nn import CrossEntropyLoss | |
from transformers import PreTrainedModel, PreTrainedTokenizerBase | |
@torch.no_grad() | |
def compute_perplexity(text: str, model: PreTrainedModel, tokenizer: PreTrainedTokenizerBase, stride=8, max_length=512, batch_size=4): | |
""" | |
Efficient corpus perplexity calculation using strided windows. | |
Args: | |
model: A pretrained language model | |
tokenizer: The tokenizer used to preprocess the data | |
dataset: A dataset to calculate perplexity on. If None, the wikitext-2 test set is used. | |
batch_size: The batch size to use for perplexity calculation | |
stride: The stride to use for perplexity calculation - Important, changing this will change your results | |
Comparison again other implementations: | |
- https://huggingface.co/docs/transformers/perplexity - takes the mean of means giving it the wrong value | |
- https://github.com/huggingface/evaluate/blob/main/metrics/perplexity/perplexity.py - compelx and crops sentances so it's not comparable | |
- https://github.com/ggerganov/llama.cpp/tree/master/examples/perplexity - good but in cpp | |
- https://github.com/huggingface/transformers/issues/9648#issuecomment-812981524 - doesn't use special tokens | |
@url: https://gist.github.com/wassname/4af760435447d38a3012c6e39abb58e1 | |
""" | |
device = model.device | |
# Tokenize corpus | |
encodings = tokenizer(text, return_tensors="pt", add_special_tokens=False) | |
seq_len = encodings.input_ids.size(1) | |
# Initialize tracking variables | |
nlls, counts = 0, 0 | |
# Configure loss function | |
loss_fn = CrossEntropyLoss(reduction="none") | |
# Process corpus in strided windows | |
for i in tqdm(range(0, seq_len, stride * batch_size)): | |
# Prepare batch windows | |
input_ids_list, target_masks_list = [], [] | |
for j in range(batch_size): | |
# Window start position | |
start_idx = i + j * stride | |
if start_idx >= seq_len: | |
break | |
# Extract window with context | |
end_idx = min(start_idx + max_length, seq_len) | |
ids = encodings.input_ids[0, start_idx:end_idx].clone() | |
# Skip windows that are too small | |
if len(ids) < 2: | |
continue | |
# Add BOS token for initial window | |
if start_idx == 0: | |
ids = torch.cat([torch.tensor([tokenizer.bos_token_id]), ids]) | |
# Create evaluation mask (1 for tokens to evaluate, 0 otherwise) | |
# For overlapping windows, only evaluate tokens beyond the overlap point | |
eval_mask = torch.zeros_like(ids) | |
eval_offset = 0 if start_idx == 0 else stride | |
eval_mask[eval_offset:] = 1 | |
input_ids_list.append(ids) | |
target_masks_list.append(eval_mask) | |
if not input_ids_list: | |
continue | |
# Create padded batch tensors | |
batch = tokenizer.pad({"input_ids": input_ids_list}, return_tensors="pt") | |
input_ids = batch["input_ids"].to(device) | |
attention_mask = batch["attention_mask"].to(device) | |
# Create padded target masks | |
max_len = input_ids.size(1) | |
padded_masks = [] | |
for mask in target_masks_list: | |
padding = torch.zeros(max_len - len(mask), dtype=torch.long) | |
padded_masks.append(torch.cat([mask, padding])) | |
target_masks = torch.stack(padded_masks).to(device) | |
# Forward pass | |
outputs = model(input_ids, attention_mask=attention_mask) | |
# Compute loss on shifted sequences | |
shift_logits = outputs.logits[:, :-1].contiguous() | |
shift_labels = input_ids[:, 1:].contiguous() | |
shift_masks = target_masks[:, 1:].contiguous() * attention_mask[:, 1:].contiguous() | |
# Calculate NLL only for targeted tokens | |
loss = loss_fn(shift_logits.transpose(1, 2), shift_labels) | |
masked_loss = (loss * shift_masks).sum() | |
token_count = shift_masks.sum() | |
# Accumulate results | |
nlls += masked_loss.item() | |
counts += token_count.item() | |
# Return corpus-level perplexity | |
return np.exp(nlls / counts) if counts > 0 else float('inf') | |
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test[:10%]")["text"] | |
dataset = dataset.filter(lambda x: len(x) > 0) | |
text = "\n\n".join(dataset) | |
compute_perplexity(text, model, tokenizer) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Ideally, but some models seem to be different, so I thought it best to ensure it's done explicitly. Tthere are a few ways people seem to calculate perplexity, so it's good to be able to see and modify it easily.