Skip to content

Instantly share code, notes, and snippets.

@hartmannr76
Created August 22, 2024 10:44
Show Gist options
  • Save hartmannr76/20cfad88e070068b8cc2a87e30d5335a to your computer and use it in GitHub Desktop.
Save hartmannr76/20cfad88e070068b8cc2a87e30d5335a to your computer and use it in GitHub Desktop.
True/False question answering from Gemma 2 with probabilities
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import numpy as np
torch.set_grad_enabled(False)
model_path = "google/gemma-2-2b-it"
access_token = '<your_token>'
tokenizer = AutoTokenizer.from_pretrained(model_path, use_safetensors=True, token=access_token)
model = AutoModelForCausalLM.from_pretrained(model_path, use_safetensors=True, token=access_token, device_map="auto")
# Should match for your model
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn><eos>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn><eos>\n"
inputs = tokenizer.encode((
USER_CHAT_TEMPLATE.format(prompt='Answer this question with either True or False. Josh is a human name.')
+ MODEL_CHAT_TEMPLATE.format(prompt='True')
+ USER_CHAT_TEMPLATE.format(prompt='Answer this question with either True or False. A hairbrush is an animal.')
+ "<start_of_turn>model\n"
), return_tensors="pt").to("cuda")
alt_sequences = tokenizer.batch_encode_plus(['True', 'False'], add_special_tokens=False, return_tensors="pt").to("cuda")
batch_token_ids = alt_sequences["input_ids"]
outputs = model.generate(inputs, max_new_tokens=256, num_beams=len(batch_token_ids), return_dict_in_generate=True, output_scores=True)
transition_scores = model.compute_transition_scores(
batch_token_ids, outputs.scores
)
for tok, score in zip(batch_token_ids, transition_scores):
# | token | token string | logits | probability
print(f"| {tok[0]:5d} | {tokenizer.decode(tok[0]):8s} | {score[0].cpu():.4f} | {np.exp(score[0].cpu()):.2%}")
# Example output
# | 5036 | True | -6.8411 | 0.11%
# | 8393 | False | -0.0013 | 99.87%
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment