Skip to content

Instantly share code, notes, and snippets.

@glinscott
Last active April 18, 2023 01:49
Show Gist options
  • Save glinscott/4ca82be3db30b987859c380a0d8304c0 to your computer and use it in GitHub Desktop.
Save glinscott/4ca82be3db30b987859c380a0d8304c0 to your computer and use it in GitHub Desktop.
GPU Perplexity for Llama to match llama.cpp implementation
import sys
import torch
device = "cuda"
"""
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
model_id = "gpt2-large"
model = GPT2LMHeadModel.from_pretrained(model_id).to(device)
tokenizer = GPT2TokenizerFast.from_pretrained(model_id)
"""
from transformers import LlamaForCausalLM, LlamaTokenizer
tokenizer = LlamaTokenizer.from_pretrained(".")
with torch.device(device):
model = LlamaForCausalLM.from_pretrained(".", _fast_init=True)
from datasets import load_dataset
with open('wiki.test.raw', 'r') as f:
test = f.read()
encodings = tokenizer(test, return_tensors="pt")
from tqdm import tqdm
max_length = 2048
stride = 512
seq_len = encodings.input_ids.size(1)
print(seq_len)
nlls = []
prev_end_loc = 0
#for begin_loc in tqdm(range(0, seq_len, stride)):
for idx, begin_loc in enumerate(range(0, seq_len, max_length)):
end_loc = min(begin_loc + max_length, seq_len)
trg_len = end_loc - prev_end_loc # may be different from stride on last loop
input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
target_ids = input_ids.clone()
target_ids[:, :stride] = -100
with torch.no_grad():
outputs = model(input_ids, labels=target_ids)
# loss is calculated using CrossEntropyLoss which averages over input tokens.
# Multiply it with trg_len to get the summation instead of average.
# We will take average over all the tokens to get the true average
# in the last step of this example.
neg_log_likelihood = outputs.loss * trg_len
#print(target_ids)
#print(outputs.logits.size())
nlls.append(neg_log_likelihood)
ppl = torch.exp(torch.stack(nlls).sum() / end_loc)
sys.stdout.write("[%d]%.4f," % (idx + 1, ppl.item()))
sys.stdout.flush()
prev_end_loc = end_loc
if end_loc == seq_len:
break
ppl = torch.exp(torch.stack(nlls).sum() / end_loc)
print(ppl)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment