Skip to content

Instantly share code, notes, and snippets.

@danyaljj
Created January 16, 2025 16:18
Show Gist options
  • Save danyaljj/e352a0d7c550b55b4d7379296f6e8c55 to your computer and use it in GitHub Desktop.
Save danyaljj/e352a0d7c550b55b4d7379296f6e8c55 to your computer and use it in GitHub Desktop.
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
# Load GPT-2 model and tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
model.eval()
def compare_outputs(prompt):
print('Prompt:', prompt)
inputs = tokenizer(prompt, return_attention_mask=True, return_tensors="pt").to(model.device)
input_ids = inputs['input_ids'].clone().detach().to(model.device)
decoded_ids = model.generate(**inputs, return_dict_in_generate=False, output_scores=False,max_length=25)
decoded_text = tokenizer.batch_decode(decoded_ids, skip_special_tokens=True)[0]
print(' --> First few token from model\'s output:', decoded_text)
# Get logits from the model
outputs = model(input_ids=input_ids)
logits = outputs.logits
# Get logits for the last token in the sequence
last_token_logits = logits[0, -1, :]
# Apply softmax to convert logits to probabilities
probabilities = torch.softmax(last_token_logits, dim=-1)
# Get the top 3 most probable tokens
top_k = 3
top_k_probs, top_k_indices = torch.topk(probabilities, top_k)
# Decode the top 3 tokens and print them with their probabilities
for i in range(top_k):
token = tokenizer.decode(top_k_indices[i].item())
prob = top_k_probs[i].item()
print(f"{i}-th top token: {token}, Probability: {prob:.4f}")
compare_outputs("Corruption involving the contractors is the chief culprit for the prison's problems, according to a recent")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment