Skip to content

Instantly share code, notes, and snippets.

@tomasruizt
Last active August 20, 2024 16:13
Show Gist options
  • Save tomasruizt/a7b7b6a3def8055802b4893351ba6ff1 to your computer and use it in GitHub Desktop.
Save tomasruizt/a7b7b6a3def8055802b4893351ba6ff1 to your computer and use it in GitHub Desktop.
An example with Llama3, where batching the input prompts modifies the logits that come out.
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from transformers import set_seed
checkpoint = "meta-llama/Meta-Llama-3-8B-Instruct"
device = "cuda"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, device_map="auto").to(device)
model.eval()
def get_logits(prompts: list[str]) -> torch.Tensor:
set_seed(0)
messages = [[{"role": "user", "content": p}] for p in prompts]
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", return_dict=True, add_generation_prompt=True).to(device)
outputs = model.generate(**inputs, max_new_tokens=10, do_sample=False, top_p=None, temperature=None, output_logits=True, return_dict_in_generate=True, pad_token_id=tokenizer.eos_token_id)
return outputs.logits
prompt = "Is Berlin the capital of France? Answer only with 'Yes' or 'No'."
for i in [1, 2, 3]:
l1, l2 = get_logits([prompt] * i)
print(l1)
print()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment