Skip to content

Instantly share code, notes, and snippets.

@joelburget
Created July 29, 2024 23:05
Show Gist options
  • Save joelburget/3457f8dba59c46f792720c28689b8e2a to your computer and use it in GitHub Desktop.
Save joelburget/3457f8dba59c46f792720c28689b8e2a to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "b0e5a8cc-b744-48d1-b0e6-5c868cea0f5f",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from torch.utils.data import Dataset, DataLoader\n",
"from torch.distributions.normal import Normal\n",
"from transformers import AutoTokenizer\n",
"from transformer_lens import HookedTransformer\n",
"from datasets import load_dataset"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "7097b114-1143-4467-a235-a93107a0e2a0",
"metadata": {},
"outputs": [],
"source": [
"model_name = \"roneneldan/TinyStories-1M\"\n",
"ds_name = \"roneneldan/TinyStories\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "dcd29074-622a-4a1d-8da7-3a00ebacaaff",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/joel/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n",
" return self.fget.__get__(instance, owner)()\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loaded pretrained model roneneldan/TinyStories-1M into HookedTransformer\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Repo card metadata block was not found. Setting CardData to empty.\n",
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
]
}
],
"source": [
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
"model = HookedTransformer.from_pretrained(model_name, device=\"cpu\")\n",
"ds = load_dataset(ds_name)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "56e642b2-fe90-4f59-9cea-68d9219c1575",
"metadata": {},
"outputs": [],
"source": [
"input_dim = model.cfg.d_model\n",
"expansion_factor = 8\n",
"hidden_dim = input_dim * expansion_factor\n",
"sigma = 1.0\n",
"learning_rate = 1e-3\n",
"beta = 1e-3\n",
"hook_point = 'blocks.4.hook_resid_post'"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "d65908d7-e2d2-49fd-a9cc-bd0fa1ec5b79",
"metadata": {},
"outputs": [],
"source": [
"class SparseAutoencoder(nn.Module):\n",
" def __init__(self, input_dim, hidden_dim, sigma):\n",
" super(SparseAutoencoder, self).__init__()\n",
" self.encoder = nn.Linear(input_dim, hidden_dim)\n",
" self.decoder = nn.Linear(hidden_dim, input_dim)\n",
" self.sigma = sigma\n",
" \n",
" def forward(self, x):\n",
" h = self.encoder(x)\n",
" a = F.relu(h)\n",
" x_hat = self.decoder(a)\n",
" return x_hat, h\n",
" \n",
" def expected_l0_loss(self, h):\n",
" W1, b1 = self.encoder.weight, self.encoder.bias\n",
" mu = h # (mu is just the pre-activations)\n",
" sigma = self.sigma * torch.sqrt((W1**2).sum(dim=1))\n",
" normal = Normal(0, 1)\n",
" prob_non_zero = 1 - normal.cdf(-mu / sigma)\n",
" return prob_non_zero.sum()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "6e945fa7-1a13-4431-8a99-2385a0e0fe58",
"metadata": {},
"outputs": [],
"source": [
"sae = SparseAutoencoder(input_dim, hidden_dim, sigma)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "36d7bc62-b22d-420a-9d27-5c05b6044590",
"metadata": {},
"outputs": [],
"source": [
"def train(model, sae, ds, learning_rate=learning_rate, beta=beta):\n",
" optimizer = torch.optim.Adam(sae.parameters(), lr=learning_rate)\n",
" criterion = nn.MSELoss()\n",
"\n",
" i = 0\n",
" for input in ds['train']:\n",
" input = input['text']\n",
" toks = tokenizer(input)\n",
" tokens = tokenizer(input)['input_ids']\n",
" _, cache = model.run_with_cache(torch.tensor(tokens), remove_batch_dim=True)\n",
" x = cache[hook_point]\n",
" \n",
" x_hat, h = sae(x)\n",
" \n",
" reconstruction_loss = criterion(x_hat, x)\n",
" l0_loss = sae.expected_l0_loss(h)\n",
" loss = reconstruction_loss + beta * l0_loss\n",
" \n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" i += 1\n",
" if i % 100 == 0:\n",
" print(f'Step [{i}/10000], Loss: {loss.item():.4f}, Reconstruction Loss: {reconstruction_loss.item():.4f}, L0 Loss: {l0_loss.item():.6f}')\n",
" if i % 1_000 == 0:\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "76f9997e-a548-4966-8070-39b097207719",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-07-29 16:00:54.595 Python[65390:4951453] getMetalPluginClassForService: Failed to find bundle for accelerator bundle named: AGXMetalA12 errno: 0\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Step [100/10000], Loss: 25.8655, Reconstruction Loss: 0.0004, L0 Loss: 25865.042969\n",
"Step [200/10000], Loss: 8.5769, Reconstruction Loss: 0.0004, L0 Loss: 8576.440430\n",
"Step [300/10000], Loss: 0.0009, Reconstruction Loss: 0.0004, L0 Loss: 0.454972\n",
"Step [400/10000], Loss: 0.0005, Reconstruction Loss: 0.0004, L0 Loss: 0.059135\n",
"Step [500/10000], Loss: 0.0005, Reconstruction Loss: 0.0004, L0 Loss: 0.042081\n",
"Step [600/10000], Loss: 0.0004, Reconstruction Loss: 0.0004, L0 Loss: 0.017423\n",
"Step [700/10000], Loss: 0.0004, Reconstruction Loss: 0.0004, L0 Loss: 0.015343\n",
"Step [800/10000], Loss: 0.0005, Reconstruction Loss: 0.0005, L0 Loss: 0.027245\n",
"Step [900/10000], Loss: 0.0004, Reconstruction Loss: 0.0004, L0 Loss: 0.009188\n",
"Step [1000/10000], Loss: 0.0004, Reconstruction Loss: 0.0004, L0 Loss: 0.008290\n"
]
}
],
"source": [
"train(model, sae, ds)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment