Skip to content

Instantly share code, notes, and snippets.

@kgourgou
Created March 13, 2025 18:52
Show Gist options
  • Save kgourgou/fa0e877b02a40468e739c77612951d99 to your computer and use it in GitHub Desktop.
Save kgourgou/fa0e877b02a40468e739c77612951d99 to your computer and use it in GitHub Desktop.
Token explorer
# very cool idea from Will Kurt: https://x.com/willkurt/status/1900080101431066673<D-s>
# pip install transformers torch readchar
#!/usr/bin/env python3
import os
import torch
import torch.nn.functional as F
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import readchar
# Number of candidate tokens to show
TOP_K = 10
def clear_screen():
os.system("cls" if os.name == "nt" else "clear")
def get_candidate_tokens(model, tokenizer, current_text, top_k=TOP_K):
# Encode the current text and get the logits for the next token
inputs = tokenizer(current_text, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
next_token_logits = logits[0, -1, :]
probs = F.softmax(next_token_logits, dim=-1)
topk_probs, topk_indices = torch.topk(probs, top_k)
candidate_tokens = []
for prob, token_id in zip(topk_probs, topk_indices):
token_str = tokenizer.decode(token_id.item())
candidate_tokens.append((token_str, prob.item()))
return candidate_tokens
def display_state(current_text, candidate_tokens, selection_index):
clear_screen()
print("Current Text:\n" + current_text + "\n")
print("Candidate Tokens:")
for idx, (token, prob) in enumerate(candidate_tokens):
pointer = "->" if idx == selection_index else " "
print(f"{pointer} {token!r} (prob: {prob:.4f})")
print(
"\nUse UP/DOWN to navigate, RIGHT to select token, LEFT to remove last token, Q to quit."
)
def main():
# Load the model and tokenizer (using GPT-2 as an example)
model_name = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
model.eval()
current_text = ""
selection_index = 0
# Optionally prompt for an initial prompt
print("Enter your initial prompt (or leave empty):")
initial_prompt = input().strip()
if initial_prompt:
current_text = initial_prompt
while True:
candidate_tokens = get_candidate_tokens(model, tokenizer, current_text)
# Ensure the selection index is valid
if selection_index >= len(candidate_tokens):
selection_index = 0
display_state(current_text, candidate_tokens, selection_index)
key = readchar.readkey()
if key == readchar.key.UP:
selection_index = (selection_index - 1) % len(candidate_tokens)
elif key == readchar.key.DOWN:
selection_index = (selection_index + 1) % len(candidate_tokens)
elif key == readchar.key.RIGHT:
# Append the currently selected token to the prompt
token_to_add = candidate_tokens[selection_index][0]
current_text += token_to_add
selection_index = 0 # Reset selection for the new context
elif key == readchar.key.LEFT:
# Remove the last token by re-tokenizing the current text and dropping the last token
tokens = tokenizer.tokenize(current_text)
if tokens:
tokens = tokens[:-1]
current_text = tokenizer.convert_tokens_to_string(tokens)
selection_index = 0
elif key.lower() == "q":
print("Exiting Token Explorer.")
break
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment