Last active
April 18, 2023 01:49
-
-
Save glinscott/4ca82be3db30b987859c380a0d8304c0 to your computer and use it in GitHub Desktop.
GPU Perplexity for Llama to match llama.cpp implementation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import torch | |
device = "cuda" | |
""" | |
from transformers import GPT2LMHeadModel, GPT2TokenizerFast | |
model_id = "gpt2-large" | |
model = GPT2LMHeadModel.from_pretrained(model_id).to(device) | |
tokenizer = GPT2TokenizerFast.from_pretrained(model_id) | |
""" | |
from transformers import LlamaForCausalLM, LlamaTokenizer | |
tokenizer = LlamaTokenizer.from_pretrained(".") | |
with torch.device(device): | |
model = LlamaForCausalLM.from_pretrained(".", _fast_init=True) | |
from datasets import load_dataset | |
with open('wiki.test.raw', 'r') as f: | |
test = f.read() | |
encodings = tokenizer(test, return_tensors="pt") | |
from tqdm import tqdm | |
max_length = 2048 | |
stride = 512 | |
seq_len = encodings.input_ids.size(1) | |
print(seq_len) | |
nlls = [] | |
prev_end_loc = 0 | |
#for begin_loc in tqdm(range(0, seq_len, stride)): | |
for idx, begin_loc in enumerate(range(0, seq_len, max_length)): | |
end_loc = min(begin_loc + max_length, seq_len) | |
trg_len = end_loc - prev_end_loc # may be different from stride on last loop | |
input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device) | |
target_ids = input_ids.clone() | |
target_ids[:, :stride] = -100 | |
with torch.no_grad(): | |
outputs = model(input_ids, labels=target_ids) | |
# loss is calculated using CrossEntropyLoss which averages over input tokens. | |
# Multiply it with trg_len to get the summation instead of average. | |
# We will take average over all the tokens to get the true average | |
# in the last step of this example. | |
neg_log_likelihood = outputs.loss * trg_len | |
#print(target_ids) | |
#print(outputs.logits.size()) | |
nlls.append(neg_log_likelihood) | |
ppl = torch.exp(torch.stack(nlls).sum() / end_loc) | |
sys.stdout.write("[%d]%.4f," % (idx + 1, ppl.item())) | |
sys.stdout.flush() | |
prev_end_loc = end_loc | |
if end_loc == seq_len: | |
break | |
ppl = torch.exp(torch.stack(nlls).sum() / end_loc) | |
print(ppl) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment