Skip to content

Instantly share code, notes, and snippets.

@Stfort52
Last active August 1, 2024 05:53
Show Gist options
  • Save Stfort52/88123d22625e580abb7b9ca80a2fceec to your computer and use it in GitHub Desktop.
Save Stfort52/88123d22625e580abb7b9ca80a2fceec to your computer and use it in GitHub Desktop.

Plot pytorch module parameters

A simple implementation.

plot

Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
from collections import defaultdict
from typing import Literal
from torch.nn import Module
def parse_module_into_plotly_data(
module: Module, max_depth: int = -1
) -> dict[Literal["name", "parent", "params"], list[str | int]]:
def object_basename(obj: object) -> str:
return obj.__class__.__name__.split(".")[-1]
def module_parameters(module: Module) -> int:
return sum(p.numel() for p in module.parameters())
data = dict(
name=[object_basename(module)], parent=[""], params=[module_parameters(module)]
)
names = defaultdict(int)
def _parse_module(module: Module, parent: str, max_depth: int):
if max_depth == 0:
return
for name, child in module.named_children():
params = module_parameters(child)
name = object_basename(child)
names[name] += 1
data["name"].append(f"{name}-{names[name]}")
data["parent"].append(parent)
data["params"].append(params)
_parse_module(child, f"{name}-{names[name]}", max_depth - 1)
_parse_module(module, object_basename(module), max_depth)
return data
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment