Last active
May 21, 2024 14:36
-
-
Save Birch-san/b23c14e96ae27cc3570c69cf49bbf8e1 to your computer and use it in GitHub Desktop.
Watch your activation norms fly into the sunset
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Contains MIT-licensed code from wandb | |
# https://github.com/wandb/wandb/blob/main/LICENSE | |
# This gist is MIT-licensed (Copyright Alex Birch) | |
from torch import Tensor, FloatTensor | |
from torch.nn import Module | |
from torch.utils.hooks import RemovableHandle | |
import torch | |
from typing import List, Callable, Dict, Sequence, Optional, Tuple, Any | |
from wandb.wandb_torch import log_track_init, log_track_update | |
import wandb | |
def hook_activ_stats( | |
mod: Module, | |
mod_name: str, | |
metric_name: str, | |
log_track: List[int], | |
) -> RemovableHandle: | |
def fwd_hook(mod_: Module, args: Tuple[Any, ...], out: Any) -> None: | |
if not log_track_update(log_track): | |
return | |
assert torch.is_tensor(out) | |
log_dict: Dict[str, float] = {} | |
o: Tensor = out.detach() | |
if mod_name.endswith("qkv_proj"): | |
# yes I hardcoded it to a head-dim of 64, deal with it | |
per_head: FloatTensor = o.unflatten(-1, (3, -1, 64)) | |
norms: FloatTensor = torch.linalg.vector_norm(per_head, dim=[*range(per_head.ndim-3), -1]) | |
for proj, proj_name in zip(norms.unbind(), 'qkv'): | |
for head_ix, head in enumerate(proj.unbind()): | |
log_dict[f"{metric_name}/{proj_name}/{head_ix}"] = head.item() | |
elif mod_name.endswith("up_proj"): | |
per_proj: FloatTensor = o.unflatten(-1, (2, -1)) | |
norms: FloatTensor = torch.linalg.vector_norm(per_proj, dim=[*range(per_proj.ndim-2), -1]) | |
for proj, proj_name in zip(norms.unbind(), ['x', 'gate']): | |
log_dict[f"{metric_name}/{proj_name}"] = proj.item() | |
else: | |
log_dict[metric_name] = torch.linalg.vector_norm(o).item() | |
wandb.log(log_dict, commit=False) | |
mod.register_forward_hook(fwd_hook) | |
def add_log_module_outputs_hook( | |
mod: Module, | |
model_name: str = "", | |
prefix: str = "", | |
log_freq: int = 0, | |
): | |
handles: List[RemovableHandle] = [] | |
prefix = f'{prefix}{model_name}' | |
if not hasattr(mod, "_kdiff_hook_names"): | |
kdiff_hook_names: List[str] = [] | |
mod._kdiff_hook_names = kdiff_hook_names | |
for name, module in mod.named_modules(): | |
log_track_activ = log_track_init(log_freq) | |
metric_name = f"activ/{prefix}{name}" | |
kdiff_hook_names.append(metric_name) | |
handle: RemovableHandle = hook_activ_stats( | |
module, | |
mod_name=name, | |
metric_name=metric_name, | |
log_track=log_track_activ, | |
) | |
handles.append(handle) | |
def remove_hooks(): | |
for h in handles: | |
h.remove() | |
return remove_hooks | |
_global_watch_idx = 0 | |
def watch( | |
models: Sequence[Module], | |
log_freq: int = 1000, | |
idx: Optional[int] = None, | |
) -> Callable[[], None]: | |
global _global_watch_idx | |
prefix = "" | |
if idx is None: | |
idx = _global_watch_idx | |
for local_ix, model in enumerate(models): | |
global_ix: int = idx + local_ix | |
_global_watch_idx += 1 | |
if global_ix > 0: | |
# TODO: this makes ugly chart names like gradients/graph_1conv1d.bias | |
prefix = "graph_%i" % global_ix | |
add_log_module_outputs_hook( | |
model, | |
prefix=prefix, | |
log_freq=log_freq, | |
) |
Usage should be watch([score_model], log_freq=10)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Usage:
and watch those activation norms fly