Created
April 11, 2025 11:21
-
-
Save attentionmech/8d446c462d42a9fbc51c77d95ca6816b to your computer and use it in GitHub Desktop.
logit lens animation gpt2
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from nnsight import LanguageModel | |
import torch | |
import argparse | |
import os | |
import sys | |
import time | |
def clear_terminal(): | |
os.system('cls' if os.name == 'nt' else 'clear') | |
def print_logit_lens_terminal(frames, delay=1.0, max_tokens_per_layer=5,prompt=""): | |
# print(f"Starting animation with {len(frames)} frames...") | |
for frame_idx, frame in enumerate(frames): | |
try: | |
clear_terminal() | |
print(f"{prompt} ") | |
last_token = frame['prompt'][len(frames[0]['prompt']):].strip() or "(none)" | |
# print("\n--- Current Output ---") | |
print(last_token + "\n") | |
for i, tokens_row in enumerate(frame['layer_words']): | |
import re | |
token_str = " | ".join(tokens_row[-max_tokens_per_layer:]).replace("\n"," ") | |
token_str = re.sub(r"[-_]+", " ", token_str) | |
print(f"[LAYER {i}]: {token_str[-50:]}\n") | |
sys.stdout.flush() | |
time.sleep(delay) | |
except Exception as e: | |
print(f"[Error in frame {frame_idx}]: {str(e)}") | |
# print("Animation complete.") | |
def compute_logit_lens(prompt, model): | |
print(f"Computing logit lens for prompt: '{prompt}'") | |
probs_layers = [] | |
layers = model.transformer.h | |
with model.trace() as tracer: | |
with tracer.invoke(prompt) as invoker: | |
for layer in layers: | |
layer_out = model.lm_head(model.transformer.ln_f(layer.output[0])) | |
probs = torch.nn.functional.softmax(layer_out, dim=-1).save() | |
probs_layers.append(probs) | |
input_ids = invoker.inputs[0][0]["input_ids"][0] | |
probs = torch.cat([p.value.unsqueeze(0) for p in probs_layers], dim=0) | |
max_probs, tokens = probs.max(dim=-1) | |
last_n = 10 | |
input_words = [model.tokenizer.decode(t) for t in input_ids[-last_n:]] | |
words = [ | |
[model.tokenizer.decode(t.cpu()) for t in layer_tokens[-last_n:]] | |
for layer_tokens in tokens | |
] | |
max_probs = max_probs[:, -last_n:] | |
return max_probs.detach().cpu().numpy(), words, input_words | |
def autoregressive_logit_lens_animation(prompt, model, temperature=1.0, max_steps=5): | |
frames = [] | |
current_prompt = prompt | |
print(f"Starting generation with initial prompt: '{prompt}'") | |
for step in range(max_steps): | |
print(f"Generating step {step + 1}/{max_steps}...") | |
max_probs, layer_words, input_words = compute_logit_lens(current_prompt, model) | |
frames.append({ | |
'step': step, | |
'prompt': current_prompt, | |
'max_probs': max_probs, | |
'layer_words': layer_words, | |
'input_words': input_words, | |
}) | |
inputs = model.tokenizer(current_prompt, return_tensors="pt").to(model.device) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
last_logits = outputs.logits[0, -1] | |
probs = torch.nn.functional.softmax(last_logits / temperature, dim=-1) | |
next_token_id = torch.multinomial(probs, num_samples=1) | |
next_token = model.tokenizer.decode(next_token_id) | |
if next_token.strip() == "": | |
next_token = " " + next_token | |
print(f"Generated token: '{next_token}'") | |
current_prompt += next_token | |
print(f"Final text: '{current_prompt}'") | |
return frames | |
def main(): | |
parser = argparse.ArgumentParser(description="Plain Logit Lens") | |
parser.add_argument("--prompt", type=str, default="The meaning of life is", help="Initial text prompt") | |
parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature") | |
parser.add_argument("--max_tokens", type=int, default=5, help="Max number of tokens to generate") | |
parser.add_argument("--tokens_per_layer", type=int, default=5, help="Max tokens to display per layer") | |
parser.add_argument("--delay", type=float, default=1.0, help="Delay between frames in seconds") | |
parser.add_argument("--debug", action="store_true", help="Print additional debugging information") | |
args = parser.parse_args() | |
if len(sys.argv) == 1: | |
print("No arguments provided. Using default prompt: 'The meaning of life is'") | |
print("Loading model...") | |
try: | |
model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True) | |
print("Model loaded successfully!") | |
frames = autoregressive_logit_lens_animation( | |
prompt=args.prompt, | |
model=model, | |
temperature=args.temperature, | |
max_steps=args.max_tokens | |
) | |
if len(frames) == 0: | |
print("Error: No frames generated!") | |
return | |
print(f"Generated {len(frames)} frames. Starting animation...") | |
print_logit_lens_terminal( | |
frames, | |
delay=args.delay, | |
max_tokens_per_layer=5, | |
prompt=args.prompt, | |
) | |
input("\n") | |
except Exception as e: | |
print(f"Error: {str(e)}") | |
if args.debug: | |
import traceback | |
print(traceback.format_exc()) | |
if __name__ == "__main__": | |
main() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment