Skip to content

Instantly share code, notes, and snippets.

@attentionmech
Created April 11, 2025 11:21
Show Gist options
  • Save attentionmech/8d446c462d42a9fbc51c77d95ca6816b to your computer and use it in GitHub Desktop.
Save attentionmech/8d446c462d42a9fbc51c77d95ca6816b to your computer and use it in GitHub Desktop.
logit lens animation gpt2
from nnsight import LanguageModel
import torch
import argparse
import os
import sys
import time
def clear_terminal():
os.system('cls' if os.name == 'nt' else 'clear')
def print_logit_lens_terminal(frames, delay=1.0, max_tokens_per_layer=5,prompt=""):
# print(f"Starting animation with {len(frames)} frames...")
for frame_idx, frame in enumerate(frames):
try:
clear_terminal()
print(f"{prompt} ")
last_token = frame['prompt'][len(frames[0]['prompt']):].strip() or "(none)"
# print("\n--- Current Output ---")
print(last_token + "\n")
for i, tokens_row in enumerate(frame['layer_words']):
import re
token_str = " | ".join(tokens_row[-max_tokens_per_layer:]).replace("\n"," ")
token_str = re.sub(r"[-_]+", " ", token_str)
print(f"[LAYER {i}]: {token_str[-50:]}\n")
sys.stdout.flush()
time.sleep(delay)
except Exception as e:
print(f"[Error in frame {frame_idx}]: {str(e)}")
# print("Animation complete.")
def compute_logit_lens(prompt, model):
print(f"Computing logit lens for prompt: '{prompt}'")
probs_layers = []
layers = model.transformer.h
with model.trace() as tracer:
with tracer.invoke(prompt) as invoker:
for layer in layers:
layer_out = model.lm_head(model.transformer.ln_f(layer.output[0]))
probs = torch.nn.functional.softmax(layer_out, dim=-1).save()
probs_layers.append(probs)
input_ids = invoker.inputs[0][0]["input_ids"][0]
probs = torch.cat([p.value.unsqueeze(0) for p in probs_layers], dim=0)
max_probs, tokens = probs.max(dim=-1)
last_n = 10
input_words = [model.tokenizer.decode(t) for t in input_ids[-last_n:]]
words = [
[model.tokenizer.decode(t.cpu()) for t in layer_tokens[-last_n:]]
for layer_tokens in tokens
]
max_probs = max_probs[:, -last_n:]
return max_probs.detach().cpu().numpy(), words, input_words
def autoregressive_logit_lens_animation(prompt, model, temperature=1.0, max_steps=5):
frames = []
current_prompt = prompt
print(f"Starting generation with initial prompt: '{prompt}'")
for step in range(max_steps):
print(f"Generating step {step + 1}/{max_steps}...")
max_probs, layer_words, input_words = compute_logit_lens(current_prompt, model)
frames.append({
'step': step,
'prompt': current_prompt,
'max_probs': max_probs,
'layer_words': layer_words,
'input_words': input_words,
})
inputs = model.tokenizer(current_prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model(**inputs)
last_logits = outputs.logits[0, -1]
probs = torch.nn.functional.softmax(last_logits / temperature, dim=-1)
next_token_id = torch.multinomial(probs, num_samples=1)
next_token = model.tokenizer.decode(next_token_id)
if next_token.strip() == "":
next_token = " " + next_token
print(f"Generated token: '{next_token}'")
current_prompt += next_token
print(f"Final text: '{current_prompt}'")
return frames
def main():
parser = argparse.ArgumentParser(description="Plain Logit Lens")
parser.add_argument("--prompt", type=str, default="The meaning of life is", help="Initial text prompt")
parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature")
parser.add_argument("--max_tokens", type=int, default=5, help="Max number of tokens to generate")
parser.add_argument("--tokens_per_layer", type=int, default=5, help="Max tokens to display per layer")
parser.add_argument("--delay", type=float, default=1.0, help="Delay between frames in seconds")
parser.add_argument("--debug", action="store_true", help="Print additional debugging information")
args = parser.parse_args()
if len(sys.argv) == 1:
print("No arguments provided. Using default prompt: 'The meaning of life is'")
print("Loading model...")
try:
model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True)
print("Model loaded successfully!")
frames = autoregressive_logit_lens_animation(
prompt=args.prompt,
model=model,
temperature=args.temperature,
max_steps=args.max_tokens
)
if len(frames) == 0:
print("Error: No frames generated!")
return
print(f"Generated {len(frames)} frames. Starting animation...")
print_logit_lens_terminal(
frames,
delay=args.delay,
max_tokens_per_layer=5,
prompt=args.prompt,
)
input("\n")
except Exception as e:
print(f"Error: {str(e)}")
if args.debug:
import traceback
print(traceback.format_exc())
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment