Created
January 16, 2025 16:18
-
-
Save danyaljj/e352a0d7c550b55b4d7379296f6e8c55 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
# Load GPT-2 model and tokenizer | |
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
model = GPT2LMHeadModel.from_pretrained("gpt2") | |
model.eval() | |
def compare_outputs(prompt): | |
print('Prompt:', prompt) | |
inputs = tokenizer(prompt, return_attention_mask=True, return_tensors="pt").to(model.device) | |
input_ids = inputs['input_ids'].clone().detach().to(model.device) | |
decoded_ids = model.generate(**inputs, return_dict_in_generate=False, output_scores=False,max_length=25) | |
decoded_text = tokenizer.batch_decode(decoded_ids, skip_special_tokens=True)[0] | |
print(' --> First few token from model\'s output:', decoded_text) | |
# Get logits from the model | |
outputs = model(input_ids=input_ids) | |
logits = outputs.logits | |
# Get logits for the last token in the sequence | |
last_token_logits = logits[0, -1, :] | |
# Apply softmax to convert logits to probabilities | |
probabilities = torch.softmax(last_token_logits, dim=-1) | |
# Get the top 3 most probable tokens | |
top_k = 3 | |
top_k_probs, top_k_indices = torch.topk(probabilities, top_k) | |
# Decode the top 3 tokens and print them with their probabilities | |
for i in range(top_k): | |
token = tokenizer.decode(top_k_indices[i].item()) | |
prob = top_k_probs[i].item() | |
print(f"{i}-th top token: {token}, Probability: {prob:.4f}") | |
compare_outputs("Corruption involving the contractors is the chief culprit for the prison's problems, according to a recent") | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment