Created
March 13, 2025 18:52
-
-
Save kgourgou/fa0e877b02a40468e739c77612951d99 to your computer and use it in GitHub Desktop.
Token explorer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# very cool idea from Will Kurt: https://x.com/willkurt/status/1900080101431066673<D-s> | |
# pip install transformers torch readchar | |
#!/usr/bin/env python3 | |
import os | |
import torch | |
import torch.nn.functional as F | |
from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
import readchar | |
# Number of candidate tokens to show | |
TOP_K = 10 | |
def clear_screen(): | |
os.system("cls" if os.name == "nt" else "clear") | |
def get_candidate_tokens(model, tokenizer, current_text, top_k=TOP_K): | |
# Encode the current text and get the logits for the next token | |
inputs = tokenizer(current_text, return_tensors="pt") | |
outputs = model(**inputs) | |
logits = outputs.logits | |
next_token_logits = logits[0, -1, :] | |
probs = F.softmax(next_token_logits, dim=-1) | |
topk_probs, topk_indices = torch.topk(probs, top_k) | |
candidate_tokens = [] | |
for prob, token_id in zip(topk_probs, topk_indices): | |
token_str = tokenizer.decode(token_id.item()) | |
candidate_tokens.append((token_str, prob.item())) | |
return candidate_tokens | |
def display_state(current_text, candidate_tokens, selection_index): | |
clear_screen() | |
print("Current Text:\n" + current_text + "\n") | |
print("Candidate Tokens:") | |
for idx, (token, prob) in enumerate(candidate_tokens): | |
pointer = "->" if idx == selection_index else " " | |
print(f"{pointer} {token!r} (prob: {prob:.4f})") | |
print( | |
"\nUse UP/DOWN to navigate, RIGHT to select token, LEFT to remove last token, Q to quit." | |
) | |
def main(): | |
# Load the model and tokenizer (using GPT-2 as an example) | |
model_name = "gpt2" | |
tokenizer = GPT2Tokenizer.from_pretrained(model_name) | |
model = GPT2LMHeadModel.from_pretrained(model_name) | |
model.eval() | |
current_text = "" | |
selection_index = 0 | |
# Optionally prompt for an initial prompt | |
print("Enter your initial prompt (or leave empty):") | |
initial_prompt = input().strip() | |
if initial_prompt: | |
current_text = initial_prompt | |
while True: | |
candidate_tokens = get_candidate_tokens(model, tokenizer, current_text) | |
# Ensure the selection index is valid | |
if selection_index >= len(candidate_tokens): | |
selection_index = 0 | |
display_state(current_text, candidate_tokens, selection_index) | |
key = readchar.readkey() | |
if key == readchar.key.UP: | |
selection_index = (selection_index - 1) % len(candidate_tokens) | |
elif key == readchar.key.DOWN: | |
selection_index = (selection_index + 1) % len(candidate_tokens) | |
elif key == readchar.key.RIGHT: | |
# Append the currently selected token to the prompt | |
token_to_add = candidate_tokens[selection_index][0] | |
current_text += token_to_add | |
selection_index = 0 # Reset selection for the new context | |
elif key == readchar.key.LEFT: | |
# Remove the last token by re-tokenizing the current text and dropping the last token | |
tokens = tokenizer.tokenize(current_text) | |
if tokens: | |
tokens = tokens[:-1] | |
current_text = tokenizer.convert_tokens_to_string(tokens) | |
selection_index = 0 | |
elif key.lower() == "q": | |
print("Exiting Token Explorer.") | |
break | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment