Created
July 29, 2024 23:05
-
-
Save joelburget/3457f8dba59c46f792720c28689b8e2a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "b0e5a8cc-b744-48d1-b0e6-5c868cea0f5f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"import torch.nn.functional as F\n", | |
"from torch.utils.data import Dataset, DataLoader\n", | |
"from torch.distributions.normal import Normal\n", | |
"from transformers import AutoTokenizer\n", | |
"from transformer_lens import HookedTransformer\n", | |
"from datasets import load_dataset" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "7097b114-1143-4467-a235-a93107a0e2a0", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model_name = \"roneneldan/TinyStories-1M\"\n", | |
"ds_name = \"roneneldan/TinyStories\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "dcd29074-622a-4a1d-8da7-3a00ebacaaff", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/joel/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n", | |
" return self.fget.__get__(instance, owner)()\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Loaded pretrained model roneneldan/TinyStories-1M into HookedTransformer\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Repo card metadata block was not found. Setting CardData to empty.\n", | |
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", | |
"To disable this warning, you can either:\n", | |
"\t- Avoid using `tokenizers` before the fork if possible\n", | |
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" | |
] | |
} | |
], | |
"source": [ | |
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n", | |
"model = HookedTransformer.from_pretrained(model_name, device=\"cpu\")\n", | |
"ds = load_dataset(ds_name)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "56e642b2-fe90-4f59-9cea-68d9219c1575", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"input_dim = model.cfg.d_model\n", | |
"expansion_factor = 8\n", | |
"hidden_dim = input_dim * expansion_factor\n", | |
"sigma = 1.0\n", | |
"learning_rate = 1e-3\n", | |
"beta = 1e-3\n", | |
"hook_point = 'blocks.4.hook_resid_post'" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "d65908d7-e2d2-49fd-a9cc-bd0fa1ec5b79", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class SparseAutoencoder(nn.Module):\n", | |
" def __init__(self, input_dim, hidden_dim, sigma):\n", | |
" super(SparseAutoencoder, self).__init__()\n", | |
" self.encoder = nn.Linear(input_dim, hidden_dim)\n", | |
" self.decoder = nn.Linear(hidden_dim, input_dim)\n", | |
" self.sigma = sigma\n", | |
" \n", | |
" def forward(self, x):\n", | |
" h = self.encoder(x)\n", | |
" a = F.relu(h)\n", | |
" x_hat = self.decoder(a)\n", | |
" return x_hat, h\n", | |
" \n", | |
" def expected_l0_loss(self, h):\n", | |
" W1, b1 = self.encoder.weight, self.encoder.bias\n", | |
" mu = h # (mu is just the pre-activations)\n", | |
" sigma = self.sigma * torch.sqrt((W1**2).sum(dim=1))\n", | |
" normal = Normal(0, 1)\n", | |
" prob_non_zero = 1 - normal.cdf(-mu / sigma)\n", | |
" return prob_non_zero.sum()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "6e945fa7-1a13-4431-8a99-2385a0e0fe58", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"sae = SparseAutoencoder(input_dim, hidden_dim, sigma)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "36d7bc62-b22d-420a-9d27-5c05b6044590", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def train(model, sae, ds, learning_rate=learning_rate, beta=beta):\n", | |
" optimizer = torch.optim.Adam(sae.parameters(), lr=learning_rate)\n", | |
" criterion = nn.MSELoss()\n", | |
"\n", | |
" i = 0\n", | |
" for input in ds['train']:\n", | |
" input = input['text']\n", | |
" toks = tokenizer(input)\n", | |
" tokens = tokenizer(input)['input_ids']\n", | |
" _, cache = model.run_with_cache(torch.tensor(tokens), remove_batch_dim=True)\n", | |
" x = cache[hook_point]\n", | |
" \n", | |
" x_hat, h = sae(x)\n", | |
" \n", | |
" reconstruction_loss = criterion(x_hat, x)\n", | |
" l0_loss = sae.expected_l0_loss(h)\n", | |
" loss = reconstruction_loss + beta * l0_loss\n", | |
" \n", | |
" optimizer.zero_grad()\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
"\n", | |
" i += 1\n", | |
" if i % 100 == 0:\n", | |
" print(f'Step [{i}/10000], Loss: {loss.item():.4f}, Reconstruction Loss: {reconstruction_loss.item():.4f}, L0 Loss: {l0_loss.item():.6f}')\n", | |
" if i % 1_000 == 0:\n", | |
" break" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "76f9997e-a548-4966-8070-39b097207719", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"2024-07-29 16:00:54.595 Python[65390:4951453] getMetalPluginClassForService: Failed to find bundle for accelerator bundle named: AGXMetalA12 errno: 0\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Step [100/10000], Loss: 25.8655, Reconstruction Loss: 0.0004, L0 Loss: 25865.042969\n", | |
"Step [200/10000], Loss: 8.5769, Reconstruction Loss: 0.0004, L0 Loss: 8576.440430\n", | |
"Step [300/10000], Loss: 0.0009, Reconstruction Loss: 0.0004, L0 Loss: 0.454972\n", | |
"Step [400/10000], Loss: 0.0005, Reconstruction Loss: 0.0004, L0 Loss: 0.059135\n", | |
"Step [500/10000], Loss: 0.0005, Reconstruction Loss: 0.0004, L0 Loss: 0.042081\n", | |
"Step [600/10000], Loss: 0.0004, Reconstruction Loss: 0.0004, L0 Loss: 0.017423\n", | |
"Step [700/10000], Loss: 0.0004, Reconstruction Loss: 0.0004, L0 Loss: 0.015343\n", | |
"Step [800/10000], Loss: 0.0005, Reconstruction Loss: 0.0005, L0 Loss: 0.027245\n", | |
"Step [900/10000], Loss: 0.0004, Reconstruction Loss: 0.0004, L0 Loss: 0.009188\n", | |
"Step [1000/10000], Loss: 0.0004, Reconstruction Loss: 0.0004, L0 Loss: 0.008290\n" | |
] | |
} | |
], | |
"source": [ | |
"train(model, sae, ds)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.11.9" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment