Skip to content

Instantly share code, notes, and snippets.

@Stfort52
Last active August 1, 2024 05:53
Show Gist options
  • Save Stfort52/88123d22625e580abb7b9ca80a2fceec to your computer and use it in GitHub Desktop.
Save Stfort52/88123d22625e580abb7b9ca80a2fceec to your computer and use it in GitHub Desktop.

Plot pytorch module parameters

A simple implementation.

plot

Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import plotly.express as px\n",
"import transformers\n",
"\n",
"from module_to_plotly import parse_module_into_plotly_data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = transformers.BertModel.from_pretrained(\"ctheodoris/Geneformer\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig = px.treemap(\n",
" parse_module_into_plotly_data(model),\n",
" names=\"name\",\n",
" parents=\"parent\",\n",
" values=\"params\",\n",
")\n",
"fig.update_traces(root_color=\"lightgrey\")\n",
"fig.write_html(\"test.html\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "adiformer",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
from collections import defaultdict
from typing import Literal
from torch.nn import Module
def parse_module_into_plotly_data(
module: Module, max_depth: int = -1
) -> dict[Literal["name", "parent", "params"], list[str | int]]:
def object_basename(obj: object) -> str:
return obj.__class__.__name__.split(".")[-1]
def module_parameters(module: Module) -> int:
return sum(p.numel() for p in module.parameters())
data = dict(
name=[object_basename(module)], parent=[""], params=[module_parameters(module)]
)
names = defaultdict(int)
def _parse_module(module: Module, parent: str, max_depth: int):
if max_depth == 0:
return
for name, child in module.named_children():
params = module_parameters(child)
name = object_basename(child)
names[name] += 1
data["name"].append(f"{name}-{names[name]}")
data["parent"].append(parent)
data["params"].append(params)
_parse_module(child, f"{name}-{names[name]}", max_depth - 1)
_parse_module(module, object_basename(module), max_depth)
return data
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment