Created
May 15, 2025 09:35
-
-
Save attentionmech/fef6e70e0473e0f88d17d30bb7ffcf4e to your computer and use it in GitHub Desktop.
forward pass activation dynamics
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import argparse | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from matplotlib import cm | |
from matplotlib import colormaps | |
from mpl_toolkits.mplot3d import Axes3D | |
from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
from sklearn.decomposition import PCA | |
from sklearn.manifold import TSNE | |
try: | |
import umap | |
has_umap = True | |
except ImportError: | |
has_umap = False | |
# === Collect hidden states === | |
def collect_hidden_vectors(model, tokenizer, prompt, num_generate=10, temperature=1.0, device='cpu'): | |
input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(device) | |
generated_ids = input_ids.clone() | |
hidden_vectors = [] | |
with torch.no_grad(): | |
for _ in range(num_generate): | |
outputs = model(generated_ids, output_hidden_states=True) | |
# Collect final token activations from all transformer layers | |
for layer_hidden in outputs.hidden_states[1:]: # skip embedding | |
last_token_vector = layer_hidden[0, -1, :] | |
hidden_vectors.append(last_token_vector.cpu().numpy()) | |
# Sample next token with temperature | |
logits = outputs.logits[:, -1, :] / temperature | |
probs = torch.softmax(logits, dim=-1) | |
next_token = torch.multinomial(probs, num_samples=1) | |
generated_ids = torch.cat([generated_ids, next_token], dim=1) | |
return np.stack(hidden_vectors) | |
# === Project vectors to 3D === | |
def project_vectors(vectors, method="pca"): | |
if method == "pca": | |
return PCA(n_components=3).fit_transform(vectors) | |
elif method == "tsne": | |
return TSNE(n_components=3, init="random", random_state=42, perplexity=5).fit_transform(vectors) | |
elif method == "umap": | |
if not has_umap: | |
raise ImportError("UMAP is not installed. Install it with `pip install umap-learn`.") | |
return umap.UMAP(n_components=3).fit_transform(vectors) | |
else: | |
raise ValueError(f"Unknown projection method: {method}") | |
# === Plotting === | |
def plot_trajectory(points_3d, alpha=0.6, title="GPT-2 Belief Trajectory", colormap="plasma"): | |
num_points = len(points_3d) | |
cmap = colormaps[colormap] | |
colors = [cmap(i / (num_points - 1)) for i in range(num_points - 1)] | |
fig = plt.figure(figsize=(12, 8)) | |
ax = fig.add_subplot(111, projection='3d') | |
for i in range(num_points - 1): | |
x = [points_3d[i, 0], points_3d[i + 1, 0]] | |
y = [points_3d[i, 1], points_3d[i + 1, 1]] | |
z = [points_3d[i, 2], points_3d[i + 1, 2]] | |
ax.plot(x, y, z, color=colors[i], linewidth=2.5, alpha=alpha) | |
ax.set_facecolor("#000000") | |
fig.patch.set_facecolor("#000000") | |
ax.axis('off') | |
sm = cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(0, num_points)) | |
sm.set_array([]) | |
cbar = plt.colorbar(sm, shrink=0.5, aspect=10, ax=ax) | |
cbar.set_label("Time progression", color='white') | |
cbar.ax.yaxis.set_tick_params(color='white') | |
plt.setp(plt.getp(cbar.ax.axes, 'yticklabels'), color='white') | |
plt.title(title, fontsize=14, color='white', pad=20) | |
plt.tight_layout() | |
plt.show() | |
# === Main CLI === | |
def main(): | |
parser = argparse.ArgumentParser(description="Visualize GPT-2's belief trajectory in latent space.") | |
parser.add_argument('--prompt', type=str, required=True, help="Input prompt.") | |
parser.add_argument('--num-generate', type=int, default=10, help="Number of tokens to generate.") | |
parser.add_argument('--projection', type=str, default="pca", choices=["pca", "tsne", "umap"], | |
help="Dimensionality reduction method.") | |
parser.add_argument('--temperature', type=float, default=1.0, | |
help="Sampling temperature (higher = more random, lower = more greedy).") | |
parser.add_argument('--alpha', type=float, default=0.6, help="Line transparency.") | |
parser.add_argument('--title', type=str, default="GPT-2 Belief Trajectory", help="Plot title.") | |
args = parser.parse_args() | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
model = GPT2LMHeadModel.from_pretrained("gpt2", output_hidden_states=True).to(device).eval() | |
print(f"[INFO] Prompt: \"{args.prompt}\" | Generating {args.num_generate} tokens | Temp={args.temperature}") | |
vectors = collect_hidden_vectors( | |
model, tokenizer, args.prompt, | |
num_generate=args.num_generate, | |
temperature=args.temperature, | |
device=device | |
) | |
print(f"[INFO] Collected {vectors.shape[0]} hidden vectors. Projecting with {args.projection.upper()}...") | |
points_3d = project_vectors(vectors, args.projection) | |
plot_trajectory(points_3d, alpha=args.alpha, title=args.title) | |
if __name__ == "__main__": | |
main() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment