Skip to content

Instantly share code, notes, and snippets.

@joelburget
Created August 2, 2024 00:24
Show Gist options
  • Save joelburget/bb5118d828713810df27c615ae1724c6 to your computer and use it in GitHub Desktop.
Save joelburget/bb5118d828713810df27c615ae1724c6 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "68a82751-4c11-4889-9415-edb48bee247f",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-08-01 17:18:15.177 Python[64493:7647624] getMetalPluginClassForService: Failed to find bundle for accelerator bundle named: AGXMetalA12 errno: 0\n"
]
}
],
"source": [
"from transformers import AutoConfig\n",
"from Auto_HookPoint import HookedTransformerAdapter \n",
"from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner\n",
"from dataclasses import dataclass"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "7148e9ae-e375-4d4e-8282-2f42731a25cb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Run name: 16384-L1-5-LR-5e-05-Tokens-6.144e+07\n",
"n_tokens_per_buffer (millions): 0.524288\n",
"Lower bound: n_contexts_per_buffer (millions): 0.001024\n",
"Total training steps: 15000\n",
"Total wandb updates: 500\n",
"n_tokens_per_feature_sampling_window (millions): 2097.152\n",
"n_tokens_per_dead_feature_window (millions): 2097.152\n",
"We will reset the sparsity calculation 15 times.\n",
"Number tokens in sparsity calculation window: 4.10e+06\n"
]
}
],
"source": [
"model_name = \"joelb/Mixtral-8x7B-1l\"\n",
"config = AutoConfig.from_pretrained(model_name)\n",
"\n",
"total_training_steps = 15_000 # probably we should do more\n",
"batch_size = 4096\n",
"total_training_tokens = total_training_steps * batch_size\n",
"\n",
"lr_warm_up_steps = 0\n",
"lr_decay_steps = total_training_steps // 5 # 20% of training\n",
"l1_warm_up_steps = total_training_steps // 20 # 5% of training\n",
"\n",
"\n",
"#sae_lens\n",
"cfg = LanguageModelSAERunnerConfig(\n",
" model_name=model_name,\n",
" hook_name=\"model.norm.hook_point\",\n",
" hook_layer=0,\n",
" d_in=config.hidden_size,\n",
" dataset_path=\"monology/pile-uncopyrighted\",\n",
" is_dataset_tokenized=False,\n",
" streaming=True, \n",
" mse_loss_normalization=None, # We won't normalize the mse loss,\n",
" expansion_factor=4,\n",
" b_dec_init_method=\"zeros\", # The geometric median can be used to initialize the decoder weights.\n",
" apply_b_dec_to_input=False, # We won't apply the decoder weights to the input.\n",
" normalize_sae_decoder=False,\n",
" scale_sparsity_penalty_by_decoder_norm=True,\n",
" decoder_heuristic_init=True,\n",
" init_encoder_as_decoder_transpose=True,\n",
" normalize_activations=\"expected_average_only_in\",\n",
" # Training Parameters\n",
" lr=5e-5, # lower the better, we'll go fairly high to speed up the tutorial.\n",
" adam_beta1=0.9, # adam params (default, but once upon a time we experimented with these.)\n",
" adam_beta2=0.999,\n",
" lr_scheduler_name=\"constant\", # constant learning rate with warmup. Could be better schedules out there.\n",
" lr_warm_up_steps=lr_warm_up_steps, # this can help avoid too many dead features initially.\n",
" lr_decay_steps=lr_decay_steps, # this will help us avoid overfitting.\n",
" l1_coefficient=5, # will control how sparse the feature activations are\n",
" l1_warm_up_steps=l1_warm_up_steps, # this can help avoid too many dead features initially.\n",
" lp_norm=1.0, # the L1 penalty (and not a Lp for p < 1)\n",
" train_batch_size_tokens=batch_size,\n",
" context_size=512, # will control the lenght of the prompts we feed to the model. Larger is better but slower. so for the tutorial we'll use a short one.\n",
" # Activation Store Parameters\n",
" n_batches_in_buffer=64, # controls how many activations we store / shuffle.\n",
" training_tokens=total_training_tokens, # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.\n",
" store_batch_size_prompts=16,\n",
" # Resampling protocol\n",
" use_ghost_grads=False, # we don't use ghost grads anymore.\n",
" feature_sampling_window=1000, # this controls our reporting of feature sparsity stats\n",
" dead_feature_window=1000, # would effect resampling or ghost grads if we were using it.\n",
" dead_feature_threshold=1e-4, # would effect resampling or ghost grads if we were using it.\n",
" # WANDB\n",
" log_to_wandb=True, # always use wandb unless you are just testing code.\n",
" wandb_log_frequency=30,\n",
" eval_every_n_wandb_logs=20,\n",
" # Misc\n",
" device=\"cpu\",\n",
" seed=42,\n",
" n_checkpoints=0,\n",
" checkpoint_path=\"checkpoints\",\n",
" dtype=\"float32\"\n",
")\n",
"\n",
"@dataclass\n",
"class Cfg:\n",
" device: str\n",
" n_ctx: int"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "c513635f-bf47-4301-8e20-14c6de2ba7d6",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "09e250b7583045ffacb870b3f72c60a0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of MixtralForCausalLM were not initialized from the model checkpoint at joelb/Mixtral-8x7B-1l and are newly initialized: ['lm_head.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
"WARNING:root:You just passed in a model which will override the one specified in your configuration: joelb/Mixtral-8x7B-1l. As a consequence this run will not be reproducible via configuration alone.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Found embedding matrix: _module.model._module.embed_tokens._module\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6b3187feec5c4ad28301c9047755c5a7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Resolving data files: 0%| | 0/30 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Warning: Dataset is not tokenized. Pre-tokenizing will improve performance and allows for more control over special tokens. See https://jbloomaus.github.io/SAELens/training_saes/#pretokenizing-datasets for more info.\n"
]
},
{
"data": {
"text/html": [
"Finishing last run (ID:b7mvk41i) before initializing another..."
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\\r'), FloatProgress(value=1.0, max=1.0)))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" View run <strong style=\"color:#cdcd00\">16384-L1-5-LR-5e-05-Tokens-6.144e+07</strong> at: <a href='https://wandb.ai/joelb/mats_sae_training_language_model/runs/b7mvk41i' target=\"_blank\">https://wandb.ai/joelb/mats_sae_training_language_model/runs/b7mvk41i</a><br/> View project at: <a href='https://wandb.ai/joelb/mats_sae_training_language_model' target=\"_blank\">https://wandb.ai/joelb/mats_sae_training_language_model</a><br/>Synced 5 W&B file(s), 0 media file(s), 2 artifact file(s) and 0 other file(s)"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Find logs at: <code>./wandb/run-20240801_171909-b7mvk41i/logs</code>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Successfully finished last run (ID:b7mvk41i). Initializing new run:<br/>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
]
},
{
"data": {
"text/html": [
"wandb version 0.17.5 is available! To upgrade, please run:\n",
" $ pip install wandb --upgrade"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Tracking run with wandb version 0.16.6"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Run data is saved locally in <code>/Users/joel/code/github/TransformerLens/debugging/wandb/run-20240801_171952-cv07gebr</code>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Syncing run <strong><a href='https://wandb.ai/joelb/mats_sae_training_language_model/runs/cv07gebr' target=\"_blank\">16384-L1-5-LR-5e-05-Tokens-6.144e+07</a></strong> to <a href='https://wandb.ai/joelb/mats_sae_training_language_model' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" View project at <a href='https://wandb.ai/joelb/mats_sae_training_language_model' target=\"_blank\">https://wandb.ai/joelb/mats_sae_training_language_model</a>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" View run at <a href='https://wandb.ai/joelb/mats_sae_training_language_model/runs/cv07gebr' target=\"_blank\">https://wandb.ai/joelb/mats_sae_training_language_model/runs/cv07gebr</a>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"Training SAE: 0%| | 0/61440000 [00:00<?, ?it/s]\u001b[A\n",
"\n",
"Estimating norm scaling factor: 0%| | 0/1000 [00:00<?, ?it/s]\u001b[A\u001b[A"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"MixralSdpaAttention.forward\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Estimating norm scaling factor: 0%| | 0/1000 [00:19<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"hf\n",
"router_logits 0.376691073179245\n",
"routing_weights 0.6589474678039551\n",
"expert_idx=0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"ename": "AttributeError",
"evalue": "'HookedModule' object has no attribute 'w1'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[4], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m hooked_model \u001b[38;5;241m=\u001b[39m HookedTransformerAdapter(model_name, Cfg(device\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcpu\u001b[39m\u001b[38;5;124m\"\u001b[39m, n_ctx\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m512\u001b[39m))\n\u001b[0;32m----> 2\u001b[0m sparse_autoencoder \u001b[38;5;241m=\u001b[39m \u001b[43mSAETrainingRunner\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcfg\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moverride_model\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhooked_model\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/sae_lens/sae_training_runner.py:106\u001b[0m, in \u001b[0;36mSAETrainingRunner.run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 97\u001b[0m trainer \u001b[38;5;241m=\u001b[39m SAETrainer(\n\u001b[1;32m 98\u001b[0m model\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel,\n\u001b[1;32m 99\u001b[0m sae\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msae,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 102\u001b[0m cfg\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcfg,\n\u001b[1;32m 103\u001b[0m )\n\u001b[1;32m 105\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compile_if_needed()\n\u001b[0;32m--> 106\u001b[0m sae \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_trainer_with_interruption_handling\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrainer\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 108\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcfg\u001b[38;5;241m.\u001b[39mlog_to_wandb:\n\u001b[1;32m 109\u001b[0m wandb\u001b[38;5;241m.\u001b[39mfinish()\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/sae_lens/sae_training_runner.py:149\u001b[0m, in \u001b[0;36mSAETrainingRunner.run_trainer_with_interruption_handling\u001b[0;34m(self, trainer)\u001b[0m\n\u001b[1;32m 146\u001b[0m signal\u001b[38;5;241m.\u001b[39msignal(signal\u001b[38;5;241m.\u001b[39mSIGTERM, interrupt_callback)\n\u001b[1;32m 148\u001b[0m \u001b[38;5;66;03m# train SAE\u001b[39;00m\n\u001b[0;32m--> 149\u001b[0m sae \u001b[38;5;241m=\u001b[39m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 151\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (\u001b[38;5;167;01mKeyboardInterrupt\u001b[39;00m, InterruptedException):\n\u001b[1;32m 152\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minterrupted, saving progress\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/sae_lens/training/sae_trainer.py:164\u001b[0m, in \u001b[0;36mSAETrainer.fit\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 160\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfit\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m TrainingSAE:\n\u001b[1;32m 162\u001b[0m pbar \u001b[38;5;241m=\u001b[39m tqdm(total\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcfg\u001b[38;5;241m.\u001b[39mtotal_training_tokens, desc\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTraining SAE\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 164\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_estimate_norm_scaling_factor_if_needed\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 166\u001b[0m \u001b[38;5;66;03m# Train loop\u001b[39;00m\n\u001b[1;32m 167\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_training_tokens \u001b[38;5;241m<\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcfg\u001b[38;5;241m.\u001b[39mtotal_training_tokens:\n\u001b[1;32m 168\u001b[0m \u001b[38;5;66;03m# Do a training step.\u001b[39;00m\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/sae_lens/training/sae_trainer.py:205\u001b[0m, in \u001b[0;36mSAETrainer._estimate_norm_scaling_factor_if_needed\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 201\u001b[0m \u001b[38;5;129m@torch\u001b[39m\u001b[38;5;241m.\u001b[39mno_grad()\n\u001b[1;32m 202\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_estimate_norm_scaling_factor_if_needed\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 203\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcfg\u001b[38;5;241m.\u001b[39mnormalize_activations \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mexpected_average_only_in\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 204\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mactivation_store\u001b[38;5;241m.\u001b[39mestimated_norm_scaling_factor \u001b[38;5;241m=\u001b[39m (\n\u001b[0;32m--> 205\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mactivation_store\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mestimate_norm_scaling_factor\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 206\u001b[0m )\n\u001b[1;32m 207\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 208\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mactivation_store\u001b[38;5;241m.\u001b[39mestimated_norm_scaling_factor \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1.0\u001b[39m\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/sae_lens/training/activations_store.py:346\u001b[0m, in \u001b[0;36mActivationsStore.estimate_norm_scaling_factor\u001b[0;34m(self, n_batches_for_norm_estimate)\u001b[0m\n\u001b[1;32m 342\u001b[0m norms_per_batch \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m 343\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m tqdm(\n\u001b[1;32m 344\u001b[0m \u001b[38;5;28mrange\u001b[39m(n_batches_for_norm_estimate), desc\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEstimating norm scaling factor\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 345\u001b[0m ):\n\u001b[0;32m--> 346\u001b[0m acts \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnext_batch\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 347\u001b[0m norms_per_batch\u001b[38;5;241m.\u001b[39mappend(acts\u001b[38;5;241m.\u001b[39mnorm(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mmean()\u001b[38;5;241m.\u001b[39mitem())\n\u001b[1;32m 348\u001b[0m mean_norm \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mmean(norms_per_batch)\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/sae_lens/training/activations_store.py:640\u001b[0m, in \u001b[0;36mActivationsStore.next_batch\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 634\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 635\u001b[0m \u001b[38;5;124;03mGet the next batch from the current DataLoader.\u001b[39;00m\n\u001b[1;32m 636\u001b[0m \u001b[38;5;124;03mIf the DataLoader is exhausted, refill the buffer and create a new DataLoader.\u001b[39;00m\n\u001b[1;32m 637\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 638\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 639\u001b[0m \u001b[38;5;66;03m# Try to get the next batch\u001b[39;00m\n\u001b[0;32m--> 640\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mnext\u001b[39m(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdataloader\u001b[49m)\n\u001b[1;32m 641\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m:\n\u001b[1;32m 642\u001b[0m \u001b[38;5;66;03m# If the DataLoader is exhausted, create a new one\u001b[39;00m\n\u001b[1;32m 643\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataloader \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_data_loader()\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/sae_lens/training/activations_store.py:383\u001b[0m, in \u001b[0;36mActivationsStore.dataloader\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 380\u001b[0m \u001b[38;5;129m@property\u001b[39m\n\u001b[1;32m 381\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdataloader\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Iterator[Any]:\n\u001b[1;32m 382\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataloader \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 383\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataloader \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_data_loader\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 384\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataloader\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/sae_lens/training/activations_store.py:593\u001b[0m, in \u001b[0;36mActivationsStore.get_data_loader\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 590\u001b[0m batch_size \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrain_batch_size_tokens\n\u001b[1;32m 592\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 593\u001b[0m new_samples \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_buffer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 594\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mhalf_buffer_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mraise_on_epoch_end\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\n\u001b[1;32m 595\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 596\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m:\n\u001b[1;32m 597\u001b[0m \u001b[38;5;28mprint\u001b[39m(\n\u001b[1;32m 598\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mWarning: All samples in the training dataset have been exhausted, we are now beginning a new epoch with the same samples.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 599\u001b[0m )\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/sae_lens/training/activations_store.py:548\u001b[0m, in \u001b[0;36mActivationsStore.get_buffer\u001b[0;34m(self, n_batches_in_buffer, raise_on_epoch_end)\u001b[0m\n\u001b[1;32m 543\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m refill_batch_idx_start \u001b[38;5;129;01min\u001b[39;00m refill_iterator:\n\u001b[1;32m 544\u001b[0m \u001b[38;5;66;03m# move batch toks to gpu for model\u001b[39;00m\n\u001b[1;32m 545\u001b[0m refill_batch_tokens \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_batch_tokens(\n\u001b[1;32m 546\u001b[0m raise_at_epoch_end\u001b[38;5;241m=\u001b[39mraise_on_epoch_end\n\u001b[1;32m 547\u001b[0m )\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39mcfg\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[0;32m--> 548\u001b[0m refill_activations \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_activations\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrefill_batch_tokens\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 549\u001b[0m \u001b[38;5;66;03m# move acts back to cpu\u001b[39;00m\n\u001b[1;32m 550\u001b[0m refill_activations\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice)\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/sae_lens/training/activations_store.py:431\u001b[0m, in \u001b[0;36mActivationsStore.get_activations\u001b[0;34m(self, batch_tokens)\u001b[0m\n\u001b[1;32m 428\u001b[0m autocast_if_enabled \u001b[38;5;241m=\u001b[39m contextlib\u001b[38;5;241m.\u001b[39mnullcontext()\n\u001b[1;32m 430\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m autocast_if_enabled:\n\u001b[0;32m--> 431\u001b[0m layerwise_activations \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_with_cache\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 432\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_tokens\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 433\u001b[0m \u001b[43m \u001b[49m\u001b[43mnames_filter\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mhook_name\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 434\u001b[0m \u001b[43m \u001b[49m\u001b[43mstop_at_layer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mhook_layer\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 435\u001b[0m \u001b[43m \u001b[49m\u001b[43mprepend_bos\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 436\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 437\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;241m1\u001b[39m]\n\u001b[1;32m 439\u001b[0m n_batches, n_context \u001b[38;5;241m=\u001b[39m batch_tokens\u001b[38;5;241m.\u001b[39mshape\n\u001b[1;32m 441\u001b[0m stacked_activations \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mzeros((n_batches, n_context, \u001b[38;5;241m1\u001b[39m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39md_in))\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/Auto_HookPoint/HookedTransformerAdapter.py:116\u001b[0m, in \u001b[0;36mHookedTransformerAdapter.run_with_cache\u001b[0;34m(self, return_cache_object, remove_batch_dim, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mrun_with_cache\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39mmodel_args, return_cache_object\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, remove_batch_dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 115\u001b[0m names_filter \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnames_filter\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[0;32m--> 116\u001b[0m out, cache_dict \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_with_cache\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 117\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mremove_batch_dim\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mremove_batch_dim\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnames_filter\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnames_filter\u001b[49m\n\u001b[1;32m 118\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 119\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m return_cache_object:\n\u001b[1;32m 120\u001b[0m cache \u001b[38;5;241m=\u001b[39m ActivationCache(cache_dict, \u001b[38;5;28mself\u001b[39m, has_batch_dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;129;01mnot\u001b[39;00m remove_batch_dim)\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/transformer_lens/hook_points.py:566\u001b[0m, in \u001b[0;36mHookedRootModule.run_with_cache\u001b[0;34m(self, names_filter, device, remove_batch_dim, incl_bwd, reset_hooks_end, clear_contexts, pos_slice, *model_args, **model_kwargs)\u001b[0m\n\u001b[1;32m 552\u001b[0m cache_dict, fwd, bwd \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_caching_hooks(\n\u001b[1;32m 553\u001b[0m names_filter,\n\u001b[1;32m 554\u001b[0m incl_bwd,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 557\u001b[0m pos_slice\u001b[38;5;241m=\u001b[39mpos_slice,\n\u001b[1;32m 558\u001b[0m )\n\u001b[1;32m 560\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhooks(\n\u001b[1;32m 561\u001b[0m fwd_hooks\u001b[38;5;241m=\u001b[39mfwd,\n\u001b[1;32m 562\u001b[0m bwd_hooks\u001b[38;5;241m=\u001b[39mbwd,\n\u001b[1;32m 563\u001b[0m reset_hooks_end\u001b[38;5;241m=\u001b[39mreset_hooks_end,\n\u001b[1;32m 564\u001b[0m clear_contexts\u001b[38;5;241m=\u001b[39mclear_contexts,\n\u001b[1;32m 565\u001b[0m ):\n\u001b[0;32m--> 566\u001b[0m model_out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 567\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m incl_bwd:\n\u001b[1;32m 568\u001b[0m model_out\u001b[38;5;241m.\u001b[39mbackward()\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/Auto_HookPoint/HookedTransformerAdapter.py:126\u001b[0m, in \u001b[0;36mHookedTransformerAdapter.forward\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 125\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m--> 126\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 127\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m result\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/Auto_HookPoint/hook.py:317\u001b[0m, in \u001b[0;36mHookedModule._create_forward.<locals>.new_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 315\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(original_forward)\n\u001b[1;32m 316\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mnew_forward\u001b[39m(\u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[0;32m--> 317\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhook_point(\u001b[43moriginal_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m)\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/transformers/models/mixtral/modeling_mixtral.py:1407\u001b[0m, in \u001b[0;36mMixtralForCausalLM.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, output_router_logits, return_dict)\u001b[0m\n\u001b[1;32m 1404\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39muse_return_dict\n\u001b[1;32m 1406\u001b[0m \u001b[38;5;66;03m# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\u001b[39;00m\n\u001b[0;32m-> 1407\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1408\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1409\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1410\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1411\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1412\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1413\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1414\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1415\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1416\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_router_logits\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_router_logits\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1417\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1418\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1420\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 1421\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlm_head(hidden_states)\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/Auto_HookPoint/hook.py:317\u001b[0m, in \u001b[0;36mHookedModule._create_forward.<locals>.new_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 315\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(original_forward)\n\u001b[1;32m 316\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mnew_forward\u001b[39m(\u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[0;32m--> 317\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhook_point(\u001b[43moriginal_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m)\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/transformers/models/mixtral/modeling_mixtral.py:1275\u001b[0m, in \u001b[0;36mMixtralModel.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, output_router_logits, return_dict)\u001b[0m\n\u001b[1;32m 1264\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_gradient_checkpointing_func(\n\u001b[1;32m 1265\u001b[0m decoder_layer\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__call__\u001b[39m,\n\u001b[1;32m 1266\u001b[0m hidden_states,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1272\u001b[0m use_cache,\n\u001b[1;32m 1273\u001b[0m )\n\u001b[1;32m 1274\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1275\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[43mdecoder_layer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1276\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1277\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1278\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1279\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1280\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1281\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_router_logits\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_router_logits\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1282\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1283\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1285\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m layer_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 1287\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m use_cache:\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/Auto_HookPoint/hook.py:317\u001b[0m, in \u001b[0;36mHookedModule._create_forward.<locals>.new_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 315\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(original_forward)\n\u001b[1;32m 316\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mnew_forward\u001b[39m(\u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[0;32m--> 317\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhook_point(\u001b[43moriginal_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m)\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/transformers/models/mixtral/modeling_mixtral.py:994\u001b[0m, in \u001b[0;36mMixtralDecoderLayer.forward\u001b[0;34m(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, output_router_logits, use_cache, **kwargs)\u001b[0m\n\u001b[1;32m 992\u001b[0m residual \u001b[38;5;241m=\u001b[39m hidden_states\n\u001b[1;32m 993\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpost_attention_layernorm(hidden_states)\n\u001b[0;32m--> 994\u001b[0m hidden_states, router_logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mblock_sparse_moe\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 995\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m residual \u001b[38;5;241m+\u001b[39m hidden_states\n\u001b[1;32m 997\u001b[0m outputs \u001b[38;5;241m=\u001b[39m (hidden_states,)\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/Auto_HookPoint/hook.py:317\u001b[0m, in \u001b[0;36mHookedModule._create_forward.<locals>.new_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 315\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(original_forward)\n\u001b[1;32m 316\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mnew_forward\u001b[39m(\u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[0;32m--> 317\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhook_point(\u001b[43moriginal_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m)\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/transformers/models/mixtral/modeling_mixtral.py:901\u001b[0m, in \u001b[0;36mMixtralSparseMoeBlock.forward\u001b[0;34m(self, hidden_states)\u001b[0m\n\u001b[1;32m 898\u001b[0m idx, top_x \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mwhere(expert_mask[expert_idx])\n\u001b[1;32m 900\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m expert_idx \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m--> 901\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mW_gate\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;28mstr\u001b[39m(\u001b[43mexpert_layer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mw1\u001b[49m\u001b[38;5;241m.\u001b[39mweight[\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mitem()), expert_layer\u001b[38;5;241m.\u001b[39mw1\u001b[38;5;241m.\u001b[39mweight\u001b[38;5;241m.\u001b[39mdtype)\n\u001b[1;32m 902\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mW_in\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;28mstr\u001b[39m(expert_layer\u001b[38;5;241m.\u001b[39mw3\u001b[38;5;241m.\u001b[39mweight[\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mitem()), expert_layer\u001b[38;5;241m.\u001b[39mw3\u001b[38;5;241m.\u001b[39mweight\u001b[38;5;241m.\u001b[39mdtype)\n\u001b[1;32m 903\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mW_out\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;28mstr\u001b[39m(expert_layer\u001b[38;5;241m.\u001b[39mw2\u001b[38;5;241m.\u001b[39mweight[\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mitem()), expert_layer\u001b[38;5;241m.\u001b[39mw2\u001b[38;5;241m.\u001b[39mweight\u001b[38;5;241m.\u001b[39mdtype)\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1695\u001b[0m, in \u001b[0;36mModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 1693\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m modules:\n\u001b[1;32m 1694\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m modules[name]\n\u001b[0;32m-> 1695\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mself\u001b[39m)\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m object has no attribute \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
"\u001b[0;31mAttributeError\u001b[0m: 'HookedModule' object has no attribute 'w1'"
]
}
],
"source": [
"hooked_model = HookedTransformerAdapter(model_name, Cfg(device=\"cpu\", n_ctx=512))\n",
"sparse_autoencoder = SAETrainingRunner(cfg, override_model=hooked_model).run()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment