Created
August 2, 2024 00:24
-
-
Save joelburget/bb5118d828713810df27c615ae1724c6 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "68a82751-4c11-4889-9415-edb48bee247f", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"2024-08-01 17:18:15.177 Python[64493:7647624] getMetalPluginClassForService: Failed to find bundle for accelerator bundle named: AGXMetalA12 errno: 0\n" | |
] | |
} | |
], | |
"source": [ | |
"from transformers import AutoConfig\n", | |
"from Auto_HookPoint import HookedTransformerAdapter \n", | |
"from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner\n", | |
"from dataclasses import dataclass" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "7148e9ae-e375-4d4e-8282-2f42731a25cb", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Run name: 16384-L1-5-LR-5e-05-Tokens-6.144e+07\n", | |
"n_tokens_per_buffer (millions): 0.524288\n", | |
"Lower bound: n_contexts_per_buffer (millions): 0.001024\n", | |
"Total training steps: 15000\n", | |
"Total wandb updates: 500\n", | |
"n_tokens_per_feature_sampling_window (millions): 2097.152\n", | |
"n_tokens_per_dead_feature_window (millions): 2097.152\n", | |
"We will reset the sparsity calculation 15 times.\n", | |
"Number tokens in sparsity calculation window: 4.10e+06\n" | |
] | |
} | |
], | |
"source": [ | |
"model_name = \"joelb/Mixtral-8x7B-1l\"\n", | |
"config = AutoConfig.from_pretrained(model_name)\n", | |
"\n", | |
"total_training_steps = 15_000 # probably we should do more\n", | |
"batch_size = 4096\n", | |
"total_training_tokens = total_training_steps * batch_size\n", | |
"\n", | |
"lr_warm_up_steps = 0\n", | |
"lr_decay_steps = total_training_steps // 5 # 20% of training\n", | |
"l1_warm_up_steps = total_training_steps // 20 # 5% of training\n", | |
"\n", | |
"\n", | |
"#sae_lens\n", | |
"cfg = LanguageModelSAERunnerConfig(\n", | |
" model_name=model_name,\n", | |
" hook_name=\"model.norm.hook_point\",\n", | |
" hook_layer=0,\n", | |
" d_in=config.hidden_size,\n", | |
" dataset_path=\"monology/pile-uncopyrighted\",\n", | |
" is_dataset_tokenized=False,\n", | |
" streaming=True, \n", | |
" mse_loss_normalization=None, # We won't normalize the mse loss,\n", | |
" expansion_factor=4,\n", | |
" b_dec_init_method=\"zeros\", # The geometric median can be used to initialize the decoder weights.\n", | |
" apply_b_dec_to_input=False, # We won't apply the decoder weights to the input.\n", | |
" normalize_sae_decoder=False,\n", | |
" scale_sparsity_penalty_by_decoder_norm=True,\n", | |
" decoder_heuristic_init=True,\n", | |
" init_encoder_as_decoder_transpose=True,\n", | |
" normalize_activations=\"expected_average_only_in\",\n", | |
" # Training Parameters\n", | |
" lr=5e-5, # lower the better, we'll go fairly high to speed up the tutorial.\n", | |
" adam_beta1=0.9, # adam params (default, but once upon a time we experimented with these.)\n", | |
" adam_beta2=0.999,\n", | |
" lr_scheduler_name=\"constant\", # constant learning rate with warmup. Could be better schedules out there.\n", | |
" lr_warm_up_steps=lr_warm_up_steps, # this can help avoid too many dead features initially.\n", | |
" lr_decay_steps=lr_decay_steps, # this will help us avoid overfitting.\n", | |
" l1_coefficient=5, # will control how sparse the feature activations are\n", | |
" l1_warm_up_steps=l1_warm_up_steps, # this can help avoid too many dead features initially.\n", | |
" lp_norm=1.0, # the L1 penalty (and not a Lp for p < 1)\n", | |
" train_batch_size_tokens=batch_size,\n", | |
" context_size=512, # will control the lenght of the prompts we feed to the model. Larger is better but slower. so for the tutorial we'll use a short one.\n", | |
" # Activation Store Parameters\n", | |
" n_batches_in_buffer=64, # controls how many activations we store / shuffle.\n", | |
" training_tokens=total_training_tokens, # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.\n", | |
" store_batch_size_prompts=16,\n", | |
" # Resampling protocol\n", | |
" use_ghost_grads=False, # we don't use ghost grads anymore.\n", | |
" feature_sampling_window=1000, # this controls our reporting of feature sparsity stats\n", | |
" dead_feature_window=1000, # would effect resampling or ghost grads if we were using it.\n", | |
" dead_feature_threshold=1e-4, # would effect resampling or ghost grads if we were using it.\n", | |
" # WANDB\n", | |
" log_to_wandb=True, # always use wandb unless you are just testing code.\n", | |
" wandb_log_frequency=30,\n", | |
" eval_every_n_wandb_logs=20,\n", | |
" # Misc\n", | |
" device=\"cpu\",\n", | |
" seed=42,\n", | |
" n_checkpoints=0,\n", | |
" checkpoint_path=\"checkpoints\",\n", | |
" dtype=\"float32\"\n", | |
")\n", | |
"\n", | |
"@dataclass\n", | |
"class Cfg:\n", | |
" device: str\n", | |
" n_ctx: int" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "c513635f-bf47-4301-8e20-14c6de2ba7d6", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "09e250b7583045ffacb870b3f72c60a0", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Some weights of MixtralForCausalLM were not initialized from the model checkpoint at joelb/Mixtral-8x7B-1l and are newly initialized: ['lm_head.weight']\n", | |
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", | |
"WARNING:root:You just passed in a model which will override the one specified in your configuration: joelb/Mixtral-8x7B-1l. As a consequence this run will not be reproducible via configuration alone.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Found embedding matrix: _module.model._module.embed_tokens._module\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "6b3187feec5c4ad28301c9047755c5a7", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"Resolving data files: 0%| | 0/30 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Warning: Dataset is not tokenized. Pre-tokenizing will improve performance and allows for more control over special tokens. See https://jbloomaus.github.io/SAELens/training_saes/#pretokenizing-datasets for more info.\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
"Finishing last run (ID:b7mvk41i) before initializing another..." | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\\r'), FloatProgress(value=1.0, max=1.0)))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
" View run <strong style=\"color:#cdcd00\">16384-L1-5-LR-5e-05-Tokens-6.144e+07</strong> at: <a href='https://wandb.ai/joelb/mats_sae_training_language_model/runs/b7mvk41i' target=\"_blank\">https://wandb.ai/joelb/mats_sae_training_language_model/runs/b7mvk41i</a><br/> View project at: <a href='https://wandb.ai/joelb/mats_sae_training_language_model' target=\"_blank\">https://wandb.ai/joelb/mats_sae_training_language_model</a><br/>Synced 5 W&B file(s), 0 media file(s), 2 artifact file(s) and 0 other file(s)" | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
"Find logs at: <code>./wandb/run-20240801_171909-b7mvk41i/logs</code>" | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
"Successfully finished last run (ID:b7mvk41i). Initializing new run:<br/>" | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", | |
"To disable this warning, you can either:\n", | |
"\t- Avoid using `tokenizers` before the fork if possible\n", | |
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
"wandb version 0.17.5 is available! To upgrade, please run:\n", | |
" $ pip install wandb --upgrade" | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
"Tracking run with wandb version 0.16.6" | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
"Run data is saved locally in <code>/Users/joel/code/github/TransformerLens/debugging/wandb/run-20240801_171952-cv07gebr</code>" | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
"Syncing run <strong><a href='https://wandb.ai/joelb/mats_sae_training_language_model/runs/cv07gebr' target=\"_blank\">16384-L1-5-LR-5e-05-Tokens-6.144e+07</a></strong> to <a href='https://wandb.ai/joelb/mats_sae_training_language_model' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>" | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
" View project at <a href='https://wandb.ai/joelb/mats_sae_training_language_model' target=\"_blank\">https://wandb.ai/joelb/mats_sae_training_language_model</a>" | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
" View run at <a href='https://wandb.ai/joelb/mats_sae_training_language_model/runs/cv07gebr' target=\"_blank\">https://wandb.ai/joelb/mats_sae_training_language_model/runs/cv07gebr</a>" | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\n", | |
"Training SAE: 0%| | 0/61440000 [00:00<?, ?it/s]\u001b[A\n", | |
"\n", | |
"Estimating norm scaling factor: 0%| | 0/1000 [00:00<?, ?it/s]\u001b[A\u001b[A" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"MixralSdpaAttention.forward\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Estimating norm scaling factor: 0%| | 0/1000 [00:19<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"hf\n", | |
"router_logits 0.376691073179245\n", | |
"routing_weights 0.6589474678039551\n", | |
"expert_idx=0\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
}, | |
{ | |
"ename": "AttributeError", | |
"evalue": "'HookedModule' object has no attribute 'w1'", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", | |
"Cell \u001b[0;32mIn[4], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m hooked_model \u001b[38;5;241m=\u001b[39m HookedTransformerAdapter(model_name, Cfg(device\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcpu\u001b[39m\u001b[38;5;124m\"\u001b[39m, n_ctx\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m512\u001b[39m))\n\u001b[0;32m----> 2\u001b[0m sparse_autoencoder \u001b[38;5;241m=\u001b[39m \u001b[43mSAETrainingRunner\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcfg\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moverride_model\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhooked_model\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", | |
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/sae_lens/sae_training_runner.py:106\u001b[0m, in \u001b[0;36mSAETrainingRunner.run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 97\u001b[0m trainer \u001b[38;5;241m=\u001b[39m SAETrainer(\n\u001b[1;32m 98\u001b[0m model\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel,\n\u001b[1;32m 99\u001b[0m sae\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msae,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 102\u001b[0m cfg\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcfg,\n\u001b[1;32m 103\u001b[0m )\n\u001b[1;32m 105\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compile_if_needed()\n\u001b[0;32m--> 106\u001b[0m sae \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_trainer_with_interruption_handling\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrainer\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 108\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcfg\u001b[38;5;241m.\u001b[39mlog_to_wandb:\n\u001b[1;32m 109\u001b[0m wandb\u001b[38;5;241m.\u001b[39mfinish()\n", | |
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/sae_lens/sae_training_runner.py:149\u001b[0m, in \u001b[0;36mSAETrainingRunner.run_trainer_with_interruption_handling\u001b[0;34m(self, trainer)\u001b[0m\n\u001b[1;32m 146\u001b[0m signal\u001b[38;5;241m.\u001b[39msignal(signal\u001b[38;5;241m.\u001b[39mSIGTERM, interrupt_callback)\n\u001b[1;32m 148\u001b[0m \u001b[38;5;66;03m# train SAE\u001b[39;00m\n\u001b[0;32m--> 149\u001b[0m sae \u001b[38;5;241m=\u001b[39m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 151\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (\u001b[38;5;167;01mKeyboardInterrupt\u001b[39;00m, InterruptedException):\n\u001b[1;32m 152\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minterrupted, saving progress\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", | |
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/sae_lens/training/sae_trainer.py:164\u001b[0m, in \u001b[0;36mSAETrainer.fit\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 160\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfit\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m TrainingSAE:\n\u001b[1;32m 162\u001b[0m pbar \u001b[38;5;241m=\u001b[39m tqdm(total\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcfg\u001b[38;5;241m.\u001b[39mtotal_training_tokens, desc\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTraining SAE\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 164\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_estimate_norm_scaling_factor_if_needed\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 166\u001b[0m \u001b[38;5;66;03m# Train loop\u001b[39;00m\n\u001b[1;32m 167\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_training_tokens \u001b[38;5;241m<\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcfg\u001b[38;5;241m.\u001b[39mtotal_training_tokens:\n\u001b[1;32m 168\u001b[0m \u001b[38;5;66;03m# Do a training step.\u001b[39;00m\n", | |
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", | |
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/sae_lens/training/sae_trainer.py:205\u001b[0m, in \u001b[0;36mSAETrainer._estimate_norm_scaling_factor_if_needed\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 201\u001b[0m \u001b[38;5;129m@torch\u001b[39m\u001b[38;5;241m.\u001b[39mno_grad()\n\u001b[1;32m 202\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_estimate_norm_scaling_factor_if_needed\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 203\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcfg\u001b[38;5;241m.\u001b[39mnormalize_activations \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mexpected_average_only_in\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 204\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mactivation_store\u001b[38;5;241m.\u001b[39mestimated_norm_scaling_factor \u001b[38;5;241m=\u001b[39m (\n\u001b[0;32m--> 205\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mactivation_store\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mestimate_norm_scaling_factor\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 206\u001b[0m )\n\u001b[1;32m 207\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 208\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mactivation_store\u001b[38;5;241m.\u001b[39mestimated_norm_scaling_factor \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1.0\u001b[39m\n", | |
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", | |
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/sae_lens/training/activations_store.py:346\u001b[0m, in \u001b[0;36mActivationsStore.estimate_norm_scaling_factor\u001b[0;34m(self, n_batches_for_norm_estimate)\u001b[0m\n\u001b[1;32m 342\u001b[0m norms_per_batch \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m 343\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m tqdm(\n\u001b[1;32m 344\u001b[0m \u001b[38;5;28mrange\u001b[39m(n_batches_for_norm_estimate), desc\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEstimating norm scaling factor\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 345\u001b[0m ):\n\u001b[0;32m--> 346\u001b[0m acts \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnext_batch\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 347\u001b[0m norms_per_batch\u001b[38;5;241m.\u001b[39mappend(acts\u001b[38;5;241m.\u001b[39mnorm(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mmean()\u001b[38;5;241m.\u001b[39mitem())\n\u001b[1;32m 348\u001b[0m mean_norm \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mmean(norms_per_batch)\n", | |
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/sae_lens/training/activations_store.py:640\u001b[0m, in \u001b[0;36mActivationsStore.next_batch\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 634\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 635\u001b[0m \u001b[38;5;124;03mGet the next batch from the current DataLoader.\u001b[39;00m\n\u001b[1;32m 636\u001b[0m \u001b[38;5;124;03mIf the DataLoader is exhausted, refill the buffer and create a new DataLoader.\u001b[39;00m\n\u001b[1;32m 637\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 638\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 639\u001b[0m \u001b[38;5;66;03m# Try to get the next batch\u001b[39;00m\n\u001b[0;32m--> 640\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mnext\u001b[39m(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdataloader\u001b[49m)\n\u001b[1;32m 641\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m:\n\u001b[1;32m 642\u001b[0m \u001b[38;5;66;03m# If the DataLoader is exhausted, create a new one\u001b[39;00m\n\u001b[1;32m 643\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataloader \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_data_loader()\n", | |
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/sae_lens/training/activations_store.py:383\u001b[0m, in \u001b[0;36mActivationsStore.dataloader\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 380\u001b[0m \u001b[38;5;129m@property\u001b[39m\n\u001b[1;32m 381\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdataloader\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Iterator[Any]:\n\u001b[1;32m 382\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataloader \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 383\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataloader \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_data_loader\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 384\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataloader\n", | |
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/sae_lens/training/activations_store.py:593\u001b[0m, in \u001b[0;36mActivationsStore.get_data_loader\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 590\u001b[0m batch_size \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrain_batch_size_tokens\n\u001b[1;32m 592\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 593\u001b[0m new_samples \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_buffer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 594\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mhalf_buffer_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mraise_on_epoch_end\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\n\u001b[1;32m 595\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 596\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m:\n\u001b[1;32m 597\u001b[0m \u001b[38;5;28mprint\u001b[39m(\n\u001b[1;32m 598\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mWarning: All samples in the training dataset have been exhausted, we are now beginning a new epoch with the same samples.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 599\u001b[0m )\n", | |
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", | |
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/sae_lens/training/activations_store.py:548\u001b[0m, in \u001b[0;36mActivationsStore.get_buffer\u001b[0;34m(self, n_batches_in_buffer, raise_on_epoch_end)\u001b[0m\n\u001b[1;32m 543\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m refill_batch_idx_start \u001b[38;5;129;01min\u001b[39;00m refill_iterator:\n\u001b[1;32m 544\u001b[0m \u001b[38;5;66;03m# move batch toks to gpu for model\u001b[39;00m\n\u001b[1;32m 545\u001b[0m refill_batch_tokens \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_batch_tokens(\n\u001b[1;32m 546\u001b[0m raise_at_epoch_end\u001b[38;5;241m=\u001b[39mraise_on_epoch_end\n\u001b[1;32m 547\u001b[0m )\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39mcfg\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[0;32m--> 548\u001b[0m refill_activations \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_activations\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrefill_batch_tokens\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 549\u001b[0m \u001b[38;5;66;03m# move acts back to cpu\u001b[39;00m\n\u001b[1;32m 550\u001b[0m refill_activations\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice)\n", | |
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", | |
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/sae_lens/training/activations_store.py:431\u001b[0m, in \u001b[0;36mActivationsStore.get_activations\u001b[0;34m(self, batch_tokens)\u001b[0m\n\u001b[1;32m 428\u001b[0m autocast_if_enabled \u001b[38;5;241m=\u001b[39m contextlib\u001b[38;5;241m.\u001b[39mnullcontext()\n\u001b[1;32m 430\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m autocast_if_enabled:\n\u001b[0;32m--> 431\u001b[0m layerwise_activations \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_with_cache\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 432\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_tokens\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 433\u001b[0m \u001b[43m \u001b[49m\u001b[43mnames_filter\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mhook_name\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 434\u001b[0m \u001b[43m \u001b[49m\u001b[43mstop_at_layer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mhook_layer\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 435\u001b[0m \u001b[43m \u001b[49m\u001b[43mprepend_bos\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 436\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 437\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;241m1\u001b[39m]\n\u001b[1;32m 439\u001b[0m n_batches, n_context \u001b[38;5;241m=\u001b[39m batch_tokens\u001b[38;5;241m.\u001b[39mshape\n\u001b[1;32m 441\u001b[0m stacked_activations \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mzeros((n_batches, n_context, \u001b[38;5;241m1\u001b[39m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39md_in))\n", | |
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/Auto_HookPoint/HookedTransformerAdapter.py:116\u001b[0m, in \u001b[0;36mHookedTransformerAdapter.run_with_cache\u001b[0;34m(self, return_cache_object, remove_batch_dim, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mrun_with_cache\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39mmodel_args, return_cache_object\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, remove_batch_dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 115\u001b[0m names_filter \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnames_filter\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[0;32m--> 116\u001b[0m out, cache_dict \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_with_cache\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 117\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mremove_batch_dim\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mremove_batch_dim\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnames_filter\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnames_filter\u001b[49m\n\u001b[1;32m 118\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 119\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m return_cache_object:\n\u001b[1;32m 120\u001b[0m cache \u001b[38;5;241m=\u001b[39m ActivationCache(cache_dict, \u001b[38;5;28mself\u001b[39m, has_batch_dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;129;01mnot\u001b[39;00m remove_batch_dim)\n", | |
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/transformer_lens/hook_points.py:566\u001b[0m, in \u001b[0;36mHookedRootModule.run_with_cache\u001b[0;34m(self, names_filter, device, remove_batch_dim, incl_bwd, reset_hooks_end, clear_contexts, pos_slice, *model_args, **model_kwargs)\u001b[0m\n\u001b[1;32m 552\u001b[0m cache_dict, fwd, bwd \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_caching_hooks(\n\u001b[1;32m 553\u001b[0m names_filter,\n\u001b[1;32m 554\u001b[0m incl_bwd,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 557\u001b[0m pos_slice\u001b[38;5;241m=\u001b[39mpos_slice,\n\u001b[1;32m 558\u001b[0m )\n\u001b[1;32m 560\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhooks(\n\u001b[1;32m 561\u001b[0m fwd_hooks\u001b[38;5;241m=\u001b[39mfwd,\n\u001b[1;32m 562\u001b[0m bwd_hooks\u001b[38;5;241m=\u001b[39mbwd,\n\u001b[1;32m 563\u001b[0m reset_hooks_end\u001b[38;5;241m=\u001b[39mreset_hooks_end,\n\u001b[1;32m 564\u001b[0m clear_contexts\u001b[38;5;241m=\u001b[39mclear_contexts,\n\u001b[1;32m 565\u001b[0m ):\n\u001b[0;32m--> 566\u001b[0m model_out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 567\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m incl_bwd:\n\u001b[1;32m 568\u001b[0m model_out\u001b[38;5;241m.\u001b[39mbackward()\n", | |
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", | |
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", | |
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/Auto_HookPoint/HookedTransformerAdapter.py:126\u001b[0m, in \u001b[0;36mHookedTransformerAdapter.forward\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 125\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m--> 126\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 127\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m result\n", | |
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/Auto_HookPoint/hook.py:317\u001b[0m, in \u001b[0;36mHookedModule._create_forward.<locals>.new_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 315\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(original_forward)\n\u001b[1;32m 316\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mnew_forward\u001b[39m(\u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[0;32m--> 317\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhook_point(\u001b[43moriginal_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m)\n", | |
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/transformers/models/mixtral/modeling_mixtral.py:1407\u001b[0m, in \u001b[0;36mMixtralForCausalLM.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, output_router_logits, return_dict)\u001b[0m\n\u001b[1;32m 1404\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39muse_return_dict\n\u001b[1;32m 1406\u001b[0m \u001b[38;5;66;03m# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\u001b[39;00m\n\u001b[0;32m-> 1407\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1408\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1409\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1410\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1411\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1412\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1413\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1414\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1415\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1416\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_router_logits\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_router_logits\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1417\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1418\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1420\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 1421\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlm_head(hidden_states)\n", | |
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", | |
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", | |
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/Auto_HookPoint/hook.py:317\u001b[0m, in \u001b[0;36mHookedModule._create_forward.<locals>.new_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 315\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(original_forward)\n\u001b[1;32m 316\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mnew_forward\u001b[39m(\u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[0;32m--> 317\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhook_point(\u001b[43moriginal_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m)\n", | |
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/transformers/models/mixtral/modeling_mixtral.py:1275\u001b[0m, in \u001b[0;36mMixtralModel.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, output_router_logits, return_dict)\u001b[0m\n\u001b[1;32m 1264\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_gradient_checkpointing_func(\n\u001b[1;32m 1265\u001b[0m decoder_layer\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__call__\u001b[39m,\n\u001b[1;32m 1266\u001b[0m hidden_states,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1272\u001b[0m use_cache,\n\u001b[1;32m 1273\u001b[0m )\n\u001b[1;32m 1274\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1275\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[43mdecoder_layer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1276\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1277\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1278\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1279\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1280\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1281\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_router_logits\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_router_logits\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1282\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1283\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1285\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m layer_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 1287\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m use_cache:\n", | |
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", | |
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", | |
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/Auto_HookPoint/hook.py:317\u001b[0m, in \u001b[0;36mHookedModule._create_forward.<locals>.new_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 315\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(original_forward)\n\u001b[1;32m 316\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mnew_forward\u001b[39m(\u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[0;32m--> 317\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhook_point(\u001b[43moriginal_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m)\n", | |
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/transformers/models/mixtral/modeling_mixtral.py:994\u001b[0m, in \u001b[0;36mMixtralDecoderLayer.forward\u001b[0;34m(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, output_router_logits, use_cache, **kwargs)\u001b[0m\n\u001b[1;32m 992\u001b[0m residual \u001b[38;5;241m=\u001b[39m hidden_states\n\u001b[1;32m 993\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpost_attention_layernorm(hidden_states)\n\u001b[0;32m--> 994\u001b[0m hidden_states, router_logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mblock_sparse_moe\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 995\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m residual \u001b[38;5;241m+\u001b[39m hidden_states\n\u001b[1;32m 997\u001b[0m outputs \u001b[38;5;241m=\u001b[39m (hidden_states,)\n", | |
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", | |
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", | |
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/Auto_HookPoint/hook.py:317\u001b[0m, in \u001b[0;36mHookedModule._create_forward.<locals>.new_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 315\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(original_forward)\n\u001b[1;32m 316\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mnew_forward\u001b[39m(\u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[0;32m--> 317\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhook_point(\u001b[43moriginal_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m)\n", | |
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/transformers/models/mixtral/modeling_mixtral.py:901\u001b[0m, in \u001b[0;36mMixtralSparseMoeBlock.forward\u001b[0;34m(self, hidden_states)\u001b[0m\n\u001b[1;32m 898\u001b[0m idx, top_x \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mwhere(expert_mask[expert_idx])\n\u001b[1;32m 900\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m expert_idx \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m--> 901\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mW_gate\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;28mstr\u001b[39m(\u001b[43mexpert_layer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mw1\u001b[49m\u001b[38;5;241m.\u001b[39mweight[\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mitem()), expert_layer\u001b[38;5;241m.\u001b[39mw1\u001b[38;5;241m.\u001b[39mweight\u001b[38;5;241m.\u001b[39mdtype)\n\u001b[1;32m 902\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mW_in\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;28mstr\u001b[39m(expert_layer\u001b[38;5;241m.\u001b[39mw3\u001b[38;5;241m.\u001b[39mweight[\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mitem()), expert_layer\u001b[38;5;241m.\u001b[39mw3\u001b[38;5;241m.\u001b[39mweight\u001b[38;5;241m.\u001b[39mdtype)\n\u001b[1;32m 903\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mW_out\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;28mstr\u001b[39m(expert_layer\u001b[38;5;241m.\u001b[39mw2\u001b[38;5;241m.\u001b[39mweight[\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mitem()), expert_layer\u001b[38;5;241m.\u001b[39mw2\u001b[38;5;241m.\u001b[39mweight\u001b[38;5;241m.\u001b[39mdtype)\n", | |
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1695\u001b[0m, in \u001b[0;36mModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 1693\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m modules:\n\u001b[1;32m 1694\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m modules[name]\n\u001b[0;32m-> 1695\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mself\u001b[39m)\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m object has no attribute \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", | |
"\u001b[0;31mAttributeError\u001b[0m: 'HookedModule' object has no attribute 'w1'" | |
] | |
} | |
], | |
"source": [ | |
"hooked_model = HookedTransformerAdapter(model_name, Cfg(device=\"cpu\", n_ctx=512))\n", | |
"sparse_autoencoder = SAETrainingRunner(cfg, override_model=hooked_model).run()" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.11.9" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment