Skip to content

Instantly share code, notes, and snippets.

@Birch-san
Last active May 21, 2024 14:36
Show Gist options
  • Save Birch-san/b23c14e96ae27cc3570c69cf49bbf8e1 to your computer and use it in GitHub Desktop.
Save Birch-san/b23c14e96ae27cc3570c69cf49bbf8e1 to your computer and use it in GitHub Desktop.
Watch your activation norms fly into the sunset
# Contains MIT-licensed code from wandb
# https://github.com/wandb/wandb/blob/main/LICENSE
# This gist is MIT-licensed (Copyright Alex Birch)
from torch import Tensor, FloatTensor
from torch.nn import Module
from torch.utils.hooks import RemovableHandle
import torch
from typing import List, Callable, Dict, Sequence, Optional, Tuple, Any
from wandb.wandb_torch import log_track_init, log_track_update
import wandb
def hook_activ_stats(
mod: Module,
mod_name: str,
metric_name: str,
log_track: List[int],
) -> RemovableHandle:
def fwd_hook(mod_: Module, args: Tuple[Any, ...], out: Any) -> None:
if not log_track_update(log_track):
return
assert torch.is_tensor(out)
log_dict: Dict[str, float] = {}
o: Tensor = out.detach()
if mod_name.endswith("qkv_proj"):
# yes I hardcoded it to a head-dim of 64, deal with it
per_head: FloatTensor = o.unflatten(-1, (3, -1, 64))
norms: FloatTensor = torch.linalg.vector_norm(per_head, dim=[*range(per_head.ndim-3), -1])
for proj, proj_name in zip(norms.unbind(), 'qkv'):
for head_ix, head in enumerate(proj.unbind()):
log_dict[f"{metric_name}/{proj_name}/{head_ix}"] = head.item()
elif mod_name.endswith("up_proj"):
per_proj: FloatTensor = o.unflatten(-1, (2, -1))
norms: FloatTensor = torch.linalg.vector_norm(per_proj, dim=[*range(per_proj.ndim-2), -1])
for proj, proj_name in zip(norms.unbind(), ['x', 'gate']):
log_dict[f"{metric_name}/{proj_name}"] = proj.item()
else:
log_dict[metric_name] = torch.linalg.vector_norm(o).item()
wandb.log(log_dict, commit=False)
mod.register_forward_hook(fwd_hook)
def add_log_module_outputs_hook(
mod: Module,
model_name: str = "",
prefix: str = "",
log_freq: int = 0,
):
handles: List[RemovableHandle] = []
prefix = f'{prefix}{model_name}'
if not hasattr(mod, "_kdiff_hook_names"):
kdiff_hook_names: List[str] = []
mod._kdiff_hook_names = kdiff_hook_names
for name, module in mod.named_modules():
log_track_activ = log_track_init(log_freq)
metric_name = f"activ/{prefix}{name}"
kdiff_hook_names.append(metric_name)
handle: RemovableHandle = hook_activ_stats(
module,
mod_name=name,
metric_name=metric_name,
log_track=log_track_activ,
)
handles.append(handle)
def remove_hooks():
for h in handles:
h.remove()
return remove_hooks
_global_watch_idx = 0
def watch(
models: Sequence[Module],
log_freq: int = 1000,
idx: Optional[int] = None,
) -> Callable[[], None]:
global _global_watch_idx
prefix = ""
if idx is None:
idx = _global_watch_idx
for local_ix, model in enumerate(models):
global_ix: int = idx + local_ix
_global_watch_idx += 1
if global_ix > 0:
# TODO: this makes ugly chart names like gradients/graph_1conv1d.bias
prefix = "graph_%i" % global_ix
add_log_module_outputs_hook(
model,
prefix=prefix,
log_freq=log_freq,
)
@Birch-san
Copy link
Author

Usage:

# log activation norms to wandb every 10 steps (yes 10 is a bit often)
watch([model], freq=10)

and watch those activation norms fly

image

@yuanzhi-zhu
Copy link

Usage should be watch([score_model], log_freq=10)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment