Skip to content

Instantly share code, notes, and snippets.

@joelburget
Created July 30, 2024 00:59
Show Gist options
  • Save joelburget/967a38065e1c519f782c95e108a26f92 to your computer and use it in GitHub Desktop.
Save joelburget/967a38065e1c519f782c95e108a26f92 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "e74f0b8e-6cc8-42ef-a27d-f0ebb00479a7",
"metadata": {},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"test_auto_hook imported\n"
]
}
],
"source": [
"import torch\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig\n",
"from Auto_HookPoint import auto_hook \n",
"from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner\n",
"from transformer_lens.hook_points import HookedRootModule\n",
"from transformer_lens import ActivationCache\n",
"import transformer_lens.utils as tl_utils\n",
"from typing import Union, List, Optional, Literal\n",
"from dataclasses import dataclass"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a2d8bd84-8d38-4d98-a994-72a8276142c3",
"metadata": {},
"outputs": [],
"source": [
"model_name = \"joelb/Mixtral-8x7B-1l\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "40187b91-00f8-41ee-aa9b-2dec2edffb83",
"metadata": {},
"outputs": [],
"source": [
"# hf_model = AutoModelForCausalLM.from_pretrained(\"joelb/Mixtral-8x7B-1l\")\n",
"# hooked_model = auto_hook(hf_model)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "be805cc0-4932-4637-a168-06533d10e1ba",
"metadata": {},
"outputs": [],
"source": [
"config = AutoConfig.from_pretrained(model_name)"
]
},
{
"cell_type": "markdown",
"id": "4410911e-d61f-4650-8266-9d032e88d8e5",
"metadata": {},
"source": [
"## SAE Training"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "f05137bd-2cc2-4b18-93bb-553b9f689f2e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Run name: 16384-L1-5-LR-5e-05-Tokens-6.144e+07\n",
"n_tokens_per_buffer (millions): 0.524288\n",
"Lower bound: n_contexts_per_buffer (millions): 0.001024\n",
"Total training steps: 15000\n",
"Total wandb updates: 500\n",
"n_tokens_per_feature_sampling_window (millions): 2097.152\n",
"n_tokens_per_dead_feature_window (millions): 2097.152\n",
"We will reset the sparsity calculation 15 times.\n",
"Number tokens in sparsity calculation window: 4.10e+06\n"
]
}
],
"source": [
"total_training_steps = 15_000 # probably we should do more\n",
"batch_size = 4096\n",
"total_training_tokens = total_training_steps * batch_size\n",
"\n",
"lr_warm_up_steps = 0\n",
"lr_decay_steps = total_training_steps // 5 # 20% of training\n",
"l1_warm_up_steps = total_training_steps // 20 # 5% of training\n",
"\n",
"cfg = LanguageModelSAERunnerConfig(\n",
" model_name=model_name,\n",
" hook_name=\"model.layers.0.hook_point\",\n",
" hook_layer=0,\n",
" d_in=config.hidden_size,\n",
" dataset_path=\"monology/pile-uncopyrighted\",\n",
" is_dataset_tokenized=False,\n",
" streaming=True, # we could pre-download the token dataset if it was small.\n",
" # SAE Parameters\n",
" mse_loss_normalization=None, # We won't normalize the mse loss,\n",
" expansion_factor=4,\n",
" b_dec_init_method=\"zeros\", # The geometric median can be used to initialize the decoder weights.\n",
" apply_b_dec_to_input=False, # We won't apply the decoder weights to the input.\n",
" normalize_sae_decoder=False,\n",
" scale_sparsity_penalty_by_decoder_norm=True,\n",
" decoder_heuristic_init=True,\n",
" init_encoder_as_decoder_transpose=True,\n",
" normalize_activations=\"expected_average_only_in\",\n",
" # Training Parameters\n",
" lr=5e-5, # lower the better, we'll go fairly high to speed up the tutorial.\n",
" adam_beta1=0.9, # adam params (default, but once upon a time we experimented with these.)\n",
" adam_beta2=0.999,\n",
" lr_scheduler_name=\"constant\", # constant learning rate with warmup. Could be better schedules out there.\n",
" lr_warm_up_steps=lr_warm_up_steps, # this can help avoid too many dead features initially.\n",
" lr_decay_steps=lr_decay_steps, # this will help us avoid overfitting.\n",
" l1_coefficient=5, # will control how sparse the feature activations are\n",
" l1_warm_up_steps=l1_warm_up_steps, # this can help avoid too many dead features initially.\n",
" lp_norm=1.0, # the L1 penalty (and not a Lp for p < 1)\n",
" train_batch_size_tokens=batch_size,\n",
" context_size=512, # will control the lenght of the prompts we feed to the model. Larger is better but slower. so for the tutorial we'll use a short one.\n",
" # Activation Store Parameters\n",
" n_batches_in_buffer=64, # controls how many activations we store / shuffle.\n",
" training_tokens=total_training_tokens, # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.\n",
" store_batch_size_prompts=16,\n",
" # Resampling protocol\n",
" use_ghost_grads=False, # we don't use ghost grads anymore.\n",
" feature_sampling_window=1000, # this controls our reporting of feature sparsity stats\n",
" dead_feature_window=1000, # would effect resampling or ghost grads if we were using it.\n",
" dead_feature_threshold=1e-4, # would effect resampling or ghost grads if we were using it.\n",
" # WANDB\n",
" log_to_wandb=True, # always use wandb unless you are just testing code.\n",
" wandb_project=\"mixtral-1l-test\",\n",
" wandb_log_frequency=30,\n",
" eval_every_n_wandb_logs=20,\n",
" # Misc\n",
" device=\"cpu\",\n",
" seed=42,\n",
" n_checkpoints=0,\n",
" checkpoint_path=\"checkpoints\",\n",
" dtype=\"float32\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "a377e41a-76e9-4c0b-ac7e-312e25b3962e",
"metadata": {},
"outputs": [],
"source": [
"@dataclass\n",
"class Cfg:\n",
" device: str\n",
"\n",
"class HookedTransformerAdapter(HookedRootModule):\n",
" def __init__(self, model_name, n_ctx=8192):\n",
" super().__init__()\n",
" self.cfg = Cfg(device='cpu')\n",
" self.model = auto_hook(AutoModelForCausalLM.from_pretrained(model_name))\n",
" self.tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
" self.tokenizer.pad_token = self.tokenizer.eos_token\n",
" self.n_ctx = n_ctx\n",
" # self.model._module: MixtralForCausalLM\n",
" # self.model._module.model._module: MixtralModel\n",
" self.W_E = self.model._module.model._module.embed_tokens._module.weight\n",
" self.setup()\n",
"\n",
" def to_tokens(\n",
" self,\n",
" input: Union[str, List[str]],\n",
" prepend_bos: Optional[Union[bool, None]] = True,\n",
" padding_side: Optional[Union[Literal[\"left\", \"right\"], None]] = None,\n",
" move_to_device: bool = True,\n",
" truncate: bool = True,\n",
" ):\n",
" return self.tokenizer(\n",
" input,\n",
" return_tensors=\"pt\",\n",
" padding=True,\n",
" truncation=truncate,\n",
" max_length=self.n_ctx if truncate else None,\n",
" )[\"input_ids\"]\n",
" \n",
" def run_with_cache(self, *model_args, return_cache_object=True, remove_batch_dim=False, **kwargs):\n",
" out, cache_dict = super().run_with_cache(\n",
" *model_args, remove_batch_dim=remove_batch_dim # , **kwargs\n",
" )\n",
" if return_cache_object:\n",
" cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim)\n",
" return out, cache\n",
" else:\n",
" return out, cache_dict\n",
"\n",
" def forward(self, *args, **kwargs):\n",
" print(f\"forward {args=}\\n{kwargs=}\")\n",
" result = self.model.forward(*args, **kwargs)\n",
" print(f\"{result=}\")\n",
" return result"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "ab0a39e3-0e67-4383-825e-d4be7434a8ef",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b5da513019594a0c89fc9479ee1f6a6b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of MixtralForCausalLM were not initialized from the model checkpoint at joelb/Mixtral-8x7B-1l and are newly initialized: ['lm_head.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
],
"source": [
"hooked_model = HookedTransformerAdapter(model_name)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "1f61f546-1bf5-4924-b901-ddf546eb0882",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:root:You just passed in a model which will override the one specified in your configuration: joelb/Mixtral-8x7B-1l. As a consequence this run will not be reproducible via configuration alone.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "284a19561f714eb48d677acebd401478",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Resolving data files: 0%| | 0/30 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Warning: Dataset is not tokenized. Pre-tokenizing will improve performance and allows for more control over special tokens. See https://jbloomaus.github.io/SAELens/training_saes/#pretokenizing-datasets for more info.\n"
]
},
{
"data": {
"text/html": [
"Finishing last run (ID:ij58bbt9) before initializing another..."
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\\r'), FloatProgress(value=1.0, max=1.0)))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" View run <strong style=\"color:#cdcd00\">16384-L1-5-LR-5e-05-Tokens-6.144e+07</strong> at: <a href='https://wandb.ai/joelb/mixtral-1l-test/runs/ij58bbt9' target=\"_blank\">https://wandb.ai/joelb/mixtral-1l-test/runs/ij58bbt9</a><br/> View project at: <a href='https://wandb.ai/joelb/mixtral-1l-test' target=\"_blank\">https://wandb.ai/joelb/mixtral-1l-test</a><br/>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Find logs at: <code>./wandb/run-20240729_175208-ij58bbt9/logs</code>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Successfully finished last run (ID:ij58bbt9). Initializing new run:<br/>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
]
},
{
"data": {
"text/html": [
"wandb version 0.17.5 is available! To upgrade, please run:\n",
" $ pip install wandb --upgrade"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Tracking run with wandb version 0.16.6"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Run data is saved locally in <code>/Users/joel/code/github/TransformerLens/debugging/wandb/run-20240729_175356-l7hmitr9</code>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Syncing run <strong><a href='https://wandb.ai/joelb/mixtral-1l-test/runs/l7hmitr9' target=\"_blank\">16384-L1-5-LR-5e-05-Tokens-6.144e+07</a></strong> to <a href='https://wandb.ai/joelb/mixtral-1l-test' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" View project at <a href='https://wandb.ai/joelb/mixtral-1l-test' target=\"_blank\">https://wandb.ai/joelb/mixtral-1l-test</a>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" View run at <a href='https://wandb.ai/joelb/mixtral-1l-test/runs/l7hmitr9' target=\"_blank\">https://wandb.ai/joelb/mixtral-1l-test/runs/l7hmitr9</a>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"Training SAE: 0%| | 0/61440000 [00:00<?, ?it/s]\u001b[A\n",
"\n",
"Estimating norm scaling factor: 0%| | 0/1000 [00:00<?, ?it/s]\u001b[A\u001b[A"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"forward args=(tensor([[ 1, 661, 349, ..., 575, 28723, 13],\n",
" [ 1, 13, 5183, ..., 625, 21762, 354],\n",
" [ 1, 799, 2609, ..., 2590, 264, 2286],\n",
" ...,\n",
" [ 1, 658, 3266, ..., 873, 910, 298],\n",
" [ 1, 938, 396, ..., 23454, 297, 272],\n",
" [ 1, 1682, 28733, ..., 28781, 28734, 28823]]),)\n",
"kwargs={}\n",
"MixralSdpaAttention.forward\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Estimating norm scaling factor: 0%| | 0/1000 [00:12<?, ?it/s]\n"
]
},
{
"ename": "AttributeError",
"evalue": "'tuple' object has no attribute 'detach'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[17], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m sparse_autoencoder \u001b[38;5;241m=\u001b[39m \u001b[43mSAETrainingRunner\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcfg\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moverride_model\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhooked_model\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/sae_lens/sae_training_runner.py:106\u001b[0m, in \u001b[0;36mSAETrainingRunner.run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 97\u001b[0m trainer \u001b[38;5;241m=\u001b[39m SAETrainer(\n\u001b[1;32m 98\u001b[0m model\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel,\n\u001b[1;32m 99\u001b[0m sae\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msae,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 102\u001b[0m cfg\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcfg,\n\u001b[1;32m 103\u001b[0m )\n\u001b[1;32m 105\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compile_if_needed()\n\u001b[0;32m--> 106\u001b[0m sae \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_trainer_with_interruption_handling\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrainer\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 108\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcfg\u001b[38;5;241m.\u001b[39mlog_to_wandb:\n\u001b[1;32m 109\u001b[0m wandb\u001b[38;5;241m.\u001b[39mfinish()\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/sae_lens/sae_training_runner.py:149\u001b[0m, in \u001b[0;36mSAETrainingRunner.run_trainer_with_interruption_handling\u001b[0;34m(self, trainer)\u001b[0m\n\u001b[1;32m 146\u001b[0m signal\u001b[38;5;241m.\u001b[39msignal(signal\u001b[38;5;241m.\u001b[39mSIGTERM, interrupt_callback)\n\u001b[1;32m 148\u001b[0m \u001b[38;5;66;03m# train SAE\u001b[39;00m\n\u001b[0;32m--> 149\u001b[0m sae \u001b[38;5;241m=\u001b[39m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 151\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (\u001b[38;5;167;01mKeyboardInterrupt\u001b[39;00m, InterruptedException):\n\u001b[1;32m 152\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minterrupted, saving progress\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/sae_lens/training/sae_trainer.py:164\u001b[0m, in \u001b[0;36mSAETrainer.fit\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 160\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfit\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m TrainingSAE:\n\u001b[1;32m 162\u001b[0m pbar \u001b[38;5;241m=\u001b[39m tqdm(total\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcfg\u001b[38;5;241m.\u001b[39mtotal_training_tokens, desc\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTraining SAE\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 164\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_estimate_norm_scaling_factor_if_needed\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 166\u001b[0m \u001b[38;5;66;03m# Train loop\u001b[39;00m\n\u001b[1;32m 167\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_training_tokens \u001b[38;5;241m<\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcfg\u001b[38;5;241m.\u001b[39mtotal_training_tokens:\n\u001b[1;32m 168\u001b[0m \u001b[38;5;66;03m# Do a training step.\u001b[39;00m\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/sae_lens/training/sae_trainer.py:205\u001b[0m, in \u001b[0;36mSAETrainer._estimate_norm_scaling_factor_if_needed\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 201\u001b[0m \u001b[38;5;129m@torch\u001b[39m\u001b[38;5;241m.\u001b[39mno_grad()\n\u001b[1;32m 202\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_estimate_norm_scaling_factor_if_needed\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 203\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcfg\u001b[38;5;241m.\u001b[39mnormalize_activations \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mexpected_average_only_in\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 204\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mactivation_store\u001b[38;5;241m.\u001b[39mestimated_norm_scaling_factor \u001b[38;5;241m=\u001b[39m (\n\u001b[0;32m--> 205\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mactivation_store\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mestimate_norm_scaling_factor\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 206\u001b[0m )\n\u001b[1;32m 207\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 208\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mactivation_store\u001b[38;5;241m.\u001b[39mestimated_norm_scaling_factor \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1.0\u001b[39m\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/sae_lens/training/activations_store.py:346\u001b[0m, in \u001b[0;36mActivationsStore.estimate_norm_scaling_factor\u001b[0;34m(self, n_batches_for_norm_estimate)\u001b[0m\n\u001b[1;32m 342\u001b[0m norms_per_batch \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m 343\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m tqdm(\n\u001b[1;32m 344\u001b[0m \u001b[38;5;28mrange\u001b[39m(n_batches_for_norm_estimate), desc\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEstimating norm scaling factor\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 345\u001b[0m ):\n\u001b[0;32m--> 346\u001b[0m acts \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnext_batch\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 347\u001b[0m norms_per_batch\u001b[38;5;241m.\u001b[39mappend(acts\u001b[38;5;241m.\u001b[39mnorm(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mmean()\u001b[38;5;241m.\u001b[39mitem())\n\u001b[1;32m 348\u001b[0m mean_norm \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mmean(norms_per_batch)\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/sae_lens/training/activations_store.py:640\u001b[0m, in \u001b[0;36mActivationsStore.next_batch\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 634\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 635\u001b[0m \u001b[38;5;124;03mGet the next batch from the current DataLoader.\u001b[39;00m\n\u001b[1;32m 636\u001b[0m \u001b[38;5;124;03mIf the DataLoader is exhausted, refill the buffer and create a new DataLoader.\u001b[39;00m\n\u001b[1;32m 637\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 638\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 639\u001b[0m \u001b[38;5;66;03m# Try to get the next batch\u001b[39;00m\n\u001b[0;32m--> 640\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mnext\u001b[39m(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdataloader\u001b[49m)\n\u001b[1;32m 641\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m:\n\u001b[1;32m 642\u001b[0m \u001b[38;5;66;03m# If the DataLoader is exhausted, create a new one\u001b[39;00m\n\u001b[1;32m 643\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataloader \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_data_loader()\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/sae_lens/training/activations_store.py:383\u001b[0m, in \u001b[0;36mActivationsStore.dataloader\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 380\u001b[0m \u001b[38;5;129m@property\u001b[39m\n\u001b[1;32m 381\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdataloader\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Iterator[Any]:\n\u001b[1;32m 382\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataloader \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 383\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataloader \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_data_loader\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 384\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataloader\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/sae_lens/training/activations_store.py:593\u001b[0m, in \u001b[0;36mActivationsStore.get_data_loader\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 590\u001b[0m batch_size \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrain_batch_size_tokens\n\u001b[1;32m 592\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 593\u001b[0m new_samples \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_buffer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 594\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mhalf_buffer_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mraise_on_epoch_end\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\n\u001b[1;32m 595\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 596\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m:\n\u001b[1;32m 597\u001b[0m \u001b[38;5;28mprint\u001b[39m(\n\u001b[1;32m 598\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mWarning: All samples in the training dataset have been exhausted, we are now beginning a new epoch with the same samples.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 599\u001b[0m )\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/sae_lens/training/activations_store.py:548\u001b[0m, in \u001b[0;36mActivationsStore.get_buffer\u001b[0;34m(self, n_batches_in_buffer, raise_on_epoch_end)\u001b[0m\n\u001b[1;32m 543\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m refill_batch_idx_start \u001b[38;5;129;01min\u001b[39;00m refill_iterator:\n\u001b[1;32m 544\u001b[0m \u001b[38;5;66;03m# move batch toks to gpu for model\u001b[39;00m\n\u001b[1;32m 545\u001b[0m refill_batch_tokens \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_batch_tokens(\n\u001b[1;32m 546\u001b[0m raise_at_epoch_end\u001b[38;5;241m=\u001b[39mraise_on_epoch_end\n\u001b[1;32m 547\u001b[0m )\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39mcfg\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[0;32m--> 548\u001b[0m refill_activations \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_activations\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrefill_batch_tokens\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 549\u001b[0m \u001b[38;5;66;03m# move acts back to cpu\u001b[39;00m\n\u001b[1;32m 550\u001b[0m refill_activations\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice)\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/sae_lens/training/activations_store.py:431\u001b[0m, in \u001b[0;36mActivationsStore.get_activations\u001b[0;34m(self, batch_tokens)\u001b[0m\n\u001b[1;32m 428\u001b[0m autocast_if_enabled \u001b[38;5;241m=\u001b[39m contextlib\u001b[38;5;241m.\u001b[39mnullcontext()\n\u001b[1;32m 430\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m autocast_if_enabled:\n\u001b[0;32m--> 431\u001b[0m layerwise_activations \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_with_cache\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 432\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_tokens\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 433\u001b[0m \u001b[43m \u001b[49m\u001b[43mnames_filter\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mhook_name\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 434\u001b[0m \u001b[43m \u001b[49m\u001b[43mstop_at_layer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mhook_layer\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 435\u001b[0m \u001b[43m \u001b[49m\u001b[43mprepend_bos\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 436\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 437\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;241m1\u001b[39m]\n\u001b[1;32m 439\u001b[0m n_batches, n_context \u001b[38;5;241m=\u001b[39m batch_tokens\u001b[38;5;241m.\u001b[39mshape\n\u001b[1;32m 441\u001b[0m stacked_activations \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mzeros((n_batches, n_context, \u001b[38;5;241m1\u001b[39m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39md_in))\n",
"Cell \u001b[0;32mIn[15], line 35\u001b[0m, in \u001b[0;36mHookedTransformerAdapter.run_with_cache\u001b[0;34m(self, return_cache_object, remove_batch_dim, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mrun_with_cache\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39mmodel_args, return_cache_object\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, remove_batch_dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m---> 35\u001b[0m out, cache_dict \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_with_cache\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 36\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mremove_batch_dim\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mremove_batch_dim\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# , **kwargs\u001b[39;49;00m\n\u001b[1;32m 37\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 38\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m return_cache_object:\n\u001b[1;32m 39\u001b[0m cache \u001b[38;5;241m=\u001b[39m ActivationCache(cache_dict, \u001b[38;5;28mself\u001b[39m, has_batch_dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;129;01mnot\u001b[39;00m remove_batch_dim)\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/transformer_lens/hook_points.py:566\u001b[0m, in \u001b[0;36mHookedRootModule.run_with_cache\u001b[0;34m(self, names_filter, device, remove_batch_dim, incl_bwd, reset_hooks_end, clear_contexts, pos_slice, *model_args, **model_kwargs)\u001b[0m\n\u001b[1;32m 552\u001b[0m cache_dict, fwd, bwd \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_caching_hooks(\n\u001b[1;32m 553\u001b[0m names_filter,\n\u001b[1;32m 554\u001b[0m incl_bwd,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 557\u001b[0m pos_slice\u001b[38;5;241m=\u001b[39mpos_slice,\n\u001b[1;32m 558\u001b[0m )\n\u001b[1;32m 560\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhooks(\n\u001b[1;32m 561\u001b[0m fwd_hooks\u001b[38;5;241m=\u001b[39mfwd,\n\u001b[1;32m 562\u001b[0m bwd_hooks\u001b[38;5;241m=\u001b[39mbwd,\n\u001b[1;32m 563\u001b[0m reset_hooks_end\u001b[38;5;241m=\u001b[39mreset_hooks_end,\n\u001b[1;32m 564\u001b[0m clear_contexts\u001b[38;5;241m=\u001b[39mclear_contexts,\n\u001b[1;32m 565\u001b[0m ):\n\u001b[0;32m--> 566\u001b[0m model_out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 567\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m incl_bwd:\n\u001b[1;32m 568\u001b[0m model_out\u001b[38;5;241m.\u001b[39mbackward()\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
"Cell \u001b[0;32mIn[15], line 46\u001b[0m, in \u001b[0;36mHookedTransformerAdapter.forward\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 44\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 45\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mforward \u001b[39m\u001b[38;5;132;01m{\u001b[39;00margs\u001b[38;5;132;01m=}\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00mkwargs\u001b[38;5;132;01m=}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m---> 46\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 47\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresult\u001b[38;5;132;01m=}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 48\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m result\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/Auto_HookPoint/hook.py:311\u001b[0m, in \u001b[0;36mHookedModule._create_forward.<locals>.new_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 309\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(original_forward)\n\u001b[1;32m 310\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mnew_forward\u001b[39m(\u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[0;32m--> 311\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhook_point(\u001b[43moriginal_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m)\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/transformers/models/mixtral/modeling_mixtral.py:1407\u001b[0m, in \u001b[0;36mMixtralForCausalLM.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, output_router_logits, return_dict)\u001b[0m\n\u001b[1;32m 1404\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39muse_return_dict\n\u001b[1;32m 1406\u001b[0m \u001b[38;5;66;03m# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\u001b[39;00m\n\u001b[0;32m-> 1407\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1408\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1409\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1410\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1411\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1412\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1413\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1414\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1415\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1416\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_router_logits\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_router_logits\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1417\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1418\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1420\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 1421\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlm_head(hidden_states)\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/Auto_HookPoint/hook.py:311\u001b[0m, in \u001b[0;36mHookedModule._create_forward.<locals>.new_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 309\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(original_forward)\n\u001b[1;32m 310\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mnew_forward\u001b[39m(\u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[0;32m--> 311\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhook_point(\u001b[43moriginal_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m)\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/transformers/models/mixtral/modeling_mixtral.py:1275\u001b[0m, in \u001b[0;36mMixtralModel.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, output_router_logits, return_dict)\u001b[0m\n\u001b[1;32m 1264\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_gradient_checkpointing_func(\n\u001b[1;32m 1265\u001b[0m decoder_layer\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__call__\u001b[39m,\n\u001b[1;32m 1266\u001b[0m hidden_states,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1272\u001b[0m use_cache,\n\u001b[1;32m 1273\u001b[0m )\n\u001b[1;32m 1274\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1275\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[43mdecoder_layer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1276\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1277\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1278\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1279\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1280\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1281\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_router_logits\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_router_logits\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1282\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1283\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1285\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m layer_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 1287\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m use_cache:\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/Auto_HookPoint/hook.py:311\u001b[0m, in \u001b[0;36mHookedModule._create_forward.<locals>.new_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 309\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(original_forward)\n\u001b[1;32m 310\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mnew_forward\u001b[39m(\u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[0;32m--> 311\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhook_point(\u001b[43moriginal_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m)\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/transformers/models/mixtral/modeling_mixtral.py:981\u001b[0m, in \u001b[0;36mMixtralDecoderLayer.forward\u001b[0;34m(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, output_router_logits, use_cache, **kwargs)\u001b[0m\n\u001b[1;32m 978\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minput_layernorm(hidden_states)\n\u001b[1;32m 980\u001b[0m \u001b[38;5;66;03m# Self Attention\u001b[39;00m\n\u001b[0;32m--> 981\u001b[0m hidden_states, self_attn_weights, present_key_value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mself_attn\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 982\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 983\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 984\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 985\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_value\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 986\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 987\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 988\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 989\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m residual \u001b[38;5;241m+\u001b[39m hidden_states\n\u001b[1;32m 991\u001b[0m \u001b[38;5;66;03m# Fully Connected\u001b[39;00m\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/Auto_HookPoint/hook.py:311\u001b[0m, in \u001b[0;36mHookedModule._create_forward.<locals>.new_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 309\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(original_forward)\n\u001b[1;32m 310\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mnew_forward\u001b[39m(\u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[0;32m--> 311\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhook_point(\u001b[43moriginal_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m)\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/transformers/models/mixtral/modeling_mixtral.py:750\u001b[0m, in \u001b[0;36mMixtralSdpaAttention.forward\u001b[0;34m(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache)\u001b[0m\n\u001b[1;32m 748\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m past_key_value \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 749\u001b[0m kv_seq_len \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m past_key_value\u001b[38;5;241m.\u001b[39mget_usable_length(kv_seq_len, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayer_idx)\n\u001b[0;32m--> 750\u001b[0m cos, sin \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrotary_emb\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvalue_states\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mseq_len\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkv_seq_len\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 751\u001b[0m global_cache[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcos\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m cos\n\u001b[1;32m 752\u001b[0m global_cache[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124msin\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m sin\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/Auto_HookPoint/hook.py:311\u001b[0m, in \u001b[0;36mHookedModule._create_forward.<locals>.new_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 309\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(original_forward)\n\u001b[1;32m 310\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mnew_forward\u001b[39m(\u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[0;32m--> 311\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mhook_point\u001b[49m\u001b[43m(\u001b[49m\u001b[43moriginal_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1581\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1579\u001b[0m hook_result \u001b[38;5;241m=\u001b[39m hook(\u001b[38;5;28mself\u001b[39m, args, kwargs, result)\n\u001b[1;32m 1580\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1581\u001b[0m hook_result \u001b[38;5;241m=\u001b[39m \u001b[43mhook\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mresult\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1583\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m hook_result \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1584\u001b[0m result \u001b[38;5;241m=\u001b[39m hook_result\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/transformer_lens/hook_points.py:109\u001b[0m, in \u001b[0;36mHookPoint.add_hook.<locals>.full_hook\u001b[0;34m(module, module_input, module_output)\u001b[0m\n\u001b[1;32m 105\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m 106\u001b[0m \u001b[38;5;28mdir\u001b[39m \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbwd\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 107\u001b[0m ): \u001b[38;5;66;03m# For a backwards hook, module_output is a tuple of (grad,) - I don't know why.\u001b[39;00m\n\u001b[1;32m 108\u001b[0m module_output \u001b[38;5;241m=\u001b[39m module_output[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m--> 109\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mhook\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodule_output\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhook\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/code/github/TransformerLens/.venv/lib/python3.11/site-packages/transformer_lens/hook_points.py:624\u001b[0m, in \u001b[0;36mHookedRootModule.get_caching_hooks.<locals>.save_hook\u001b[0;34m(tensor, hook, is_backward)\u001b[0m\n\u001b[1;32m 622\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_backward:\n\u001b[1;32m 623\u001b[0m hook_name \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_grad\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m--> 624\u001b[0m resid_stream \u001b[38;5;241m=\u001b[39m \u001b[43mtensor\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdetach\u001b[49m()\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 625\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m remove_batch_dim:\n\u001b[1;32m 626\u001b[0m resid_stream \u001b[38;5;241m=\u001b[39m resid_stream[\u001b[38;5;241m0\u001b[39m]\n",
"\u001b[0;31mAttributeError\u001b[0m: 'tuple' object has no attribute 'detach'"
]
}
],
"source": [
"sparse_autoencoder = SAETrainingRunner(cfg, override_model=hooked_model).run()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "TransformerLens",
"language": "python",
"name": "transformerlens"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment