Skip to content

Instantly share code, notes, and snippets.

@attentionmech
Created May 15, 2025 09:35
Show Gist options
  • Save attentionmech/fef6e70e0473e0f88d17d30bb7ffcf4e to your computer and use it in GitHub Desktop.
Save attentionmech/fef6e70e0473e0f88d17d30bb7ffcf4e to your computer and use it in GitHub Desktop.
forward pass activation dynamics
import torch
import argparse
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib import colormaps
from mpl_toolkits.mplot3d import Axes3D
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
try:
import umap
has_umap = True
except ImportError:
has_umap = False
# === Collect hidden states ===
def collect_hidden_vectors(model, tokenizer, prompt, num_generate=10, temperature=1.0, device='cpu'):
input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(device)
generated_ids = input_ids.clone()
hidden_vectors = []
with torch.no_grad():
for _ in range(num_generate):
outputs = model(generated_ids, output_hidden_states=True)
# Collect final token activations from all transformer layers
for layer_hidden in outputs.hidden_states[1:]: # skip embedding
last_token_vector = layer_hidden[0, -1, :]
hidden_vectors.append(last_token_vector.cpu().numpy())
# Sample next token with temperature
logits = outputs.logits[:, -1, :] / temperature
probs = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
generated_ids = torch.cat([generated_ids, next_token], dim=1)
return np.stack(hidden_vectors)
# === Project vectors to 3D ===
def project_vectors(vectors, method="pca"):
if method == "pca":
return PCA(n_components=3).fit_transform(vectors)
elif method == "tsne":
return TSNE(n_components=3, init="random", random_state=42, perplexity=5).fit_transform(vectors)
elif method == "umap":
if not has_umap:
raise ImportError("UMAP is not installed. Install it with `pip install umap-learn`.")
return umap.UMAP(n_components=3).fit_transform(vectors)
else:
raise ValueError(f"Unknown projection method: {method}")
# === Plotting ===
def plot_trajectory(points_3d, alpha=0.6, title="GPT-2 Belief Trajectory", colormap="plasma"):
num_points = len(points_3d)
cmap = colormaps[colormap]
colors = [cmap(i / (num_points - 1)) for i in range(num_points - 1)]
fig = plt.figure(figsize=(12, 8))
ax = fig.add_subplot(111, projection='3d')
for i in range(num_points - 1):
x = [points_3d[i, 0], points_3d[i + 1, 0]]
y = [points_3d[i, 1], points_3d[i + 1, 1]]
z = [points_3d[i, 2], points_3d[i + 1, 2]]
ax.plot(x, y, z, color=colors[i], linewidth=2.5, alpha=alpha)
ax.set_facecolor("#000000")
fig.patch.set_facecolor("#000000")
ax.axis('off')
sm = cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(0, num_points))
sm.set_array([])
cbar = plt.colorbar(sm, shrink=0.5, aspect=10, ax=ax)
cbar.set_label("Time progression", color='white')
cbar.ax.yaxis.set_tick_params(color='white')
plt.setp(plt.getp(cbar.ax.axes, 'yticklabels'), color='white')
plt.title(title, fontsize=14, color='white', pad=20)
plt.tight_layout()
plt.show()
# === Main CLI ===
def main():
parser = argparse.ArgumentParser(description="Visualize GPT-2's belief trajectory in latent space.")
parser.add_argument('--prompt', type=str, required=True, help="Input prompt.")
parser.add_argument('--num-generate', type=int, default=10, help="Number of tokens to generate.")
parser.add_argument('--projection', type=str, default="pca", choices=["pca", "tsne", "umap"],
help="Dimensionality reduction method.")
parser.add_argument('--temperature', type=float, default=1.0,
help="Sampling temperature (higher = more random, lower = more greedy).")
parser.add_argument('--alpha', type=float, default=0.6, help="Line transparency.")
parser.add_argument('--title', type=str, default="GPT-2 Belief Trajectory", help="Plot title.")
args = parser.parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2", output_hidden_states=True).to(device).eval()
print(f"[INFO] Prompt: \"{args.prompt}\" | Generating {args.num_generate} tokens | Temp={args.temperature}")
vectors = collect_hidden_vectors(
model, tokenizer, args.prompt,
num_generate=args.num_generate,
temperature=args.temperature,
device=device
)
print(f"[INFO] Collected {vectors.shape[0]} hidden vectors. Projecting with {args.projection.upper()}...")
points_3d = project_vectors(vectors, args.projection)
plot_trajectory(points_3d, alpha=args.alpha, title=args.title)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment