Skip to content

Instantly share code, notes, and snippets.

@pszemraj
Last active September 19, 2025 22:07
Show Gist options
  • Select an option

  • Save pszemraj/24e98af11455edc11a0ec02c9699129e to your computer and use it in GitHub Desktop.

Select an option

Save pszemraj/24e98af11455edc11a0ec02c9699129e to your computer and use it in GitHub Desktop.
Prints an accurate summary of a pytorch model
from dataclasses import dataclass
from typing import Dict, List, Optional, Set, Tuple
import torch
import torch.nn as nn
@dataclass
class _LayerSummary:
name: str
param_shape: Optional[torch.Size]
inclusive_total_params: int
inclusive_trainable_params: int
def model_summary(
model: nn.Module, max_depth: int = 4, show_param_shapes: bool = False
) -> None:
"""
Prints a hierarchical summary of a PyTorch model with *inclusive* parameter counts.
Counts are robust to shared/tied parameters (each Parameter is counted once per subtree).
"""
# ---------- formatting helpers ----------
def _format_number(num: int) -> str:
return f"{num:,}" if num > 0 else "--"
def _format_shape(shape: Optional[torch.Size]) -> str:
return "x".join(map(str, shape)) if shape else "N/A"
# ---------- build param info once ----------
# Map: id(param) -> (numel, requires_grad)
param_info: Dict[int, Tuple[int, bool]] = {}
for p in model.parameters(recurse=True):
pid = id(p)
if pid not in param_info:
param_info[pid] = (p.numel(), bool(p.requires_grad))
# Fast path: totals only
if max_depth <= 0:
total_params = sum(n for (n, _) in param_info.values())
trainable_params = sum(n for (n, rg) in param_info.values() if rg)
print("=" * 50)
print("Total params:", _format_number(total_params))
print("Trainable params:", _format_number(trainable_params))
nontrain = total_params - trainable_params
print("Non-trainable params:", _format_number(nontrain))
print("=" * 50)
return
summary_list: List[_LayerSummary] = []
def summarize_recursive(module: nn.Module, depth: int, prefix: str) -> Set[int]:
"""
Return the set of unique Parameter IDs reachable from this module's subtree.
Also appends a _LayerSummary for this module.
"""
# If we're beyond the print depth, just return the deduped set upward
if depth > max_depth:
ids = {id(p) for p in module.parameters(recurse=True)}
return ids
# Direct parameters of *this* module (non-recursive)
direct_ids: Set[int] = {id(p) for p in module.parameters(recurse=False)}
# Recurse into children and union their sets
child_ids: Set[int] = set()
for child in module.children():
child_ids |= summarize_recursive(child, depth + 1, prefix + " ")
all_ids = direct_ids | child_ids
# Inclusive counts from the deduped set
total = sum(param_info[i][0] for i in all_ids)
trainable = sum(param_info[i][0] for i in all_ids if param_info[i][1])
# First direct trainable parameter shape (display purpose only)
param_shape = next(
(p.shape for p in module.parameters(recurse=False) if p.requires_grad),
None,
)
summary_list.append(
_LayerSummary(
name=f"{prefix}{type(module).__name__}",
param_shape=param_shape,
inclusive_total_params=total,
inclusive_trainable_params=trainable,
)
)
return all_ids
# Build the list (pre-order traversal)
summarize_recursive(model, 1, "")
# Totals from the whole model (already deduped)
total_params = sum(n for (n, _) in param_info.values())
trainable_params = sum(n for (n, rg) in param_info.values() if rg)
# ---------- printing ----------
name_col_width = max(len("Layer (type)"), max(len(s.name) for s in summary_list))
shape_col_width = 0
if show_param_shapes:
shape_col_width = max(
len("Param Shape"),
max(len(_format_shape(s.param_shape)) for s in summary_list),
)
params_col_width = 12
trainable_col_width = 10
col_spacing = " "
header_parts = [f"{'Layer (type)':<{name_col_width}}"]
if show_param_shapes:
header_parts.append(f"{'Param Shape':>{shape_col_width}}")
header_parts.append(f"{'Param #':>{params_col_width}}")
header_parts.append(f"{'Trainable':>{trainable_col_width}}")
header = col_spacing.join(header_parts)
sep = "=" * len(header)
print(sep)
print(header)
print(sep)
for e in summary_list:
parts = [f"{e.name:<{name_col_width}}"]
if show_param_shapes:
parts.append(f"{_format_shape(e.param_shape):>{shape_col_width}}")
parts.append(f"{_format_number(e.inclusive_total_params):>{params_col_width}}")
parts.append(f"{str(e.inclusive_trainable_params > 0):>{trainable_col_width}}")
print(col_spacing.join(parts))
print(sep)
print(f"Total params: {_format_number(total_params)}")
print(f"Trainable params: {_format_number(trainable_params)}")
print(f"Non-trainable params: {_format_number(total_params - trainable_params)}")
print(sep)
from transformers import AutoTokenizer, AutoModelForPreTraining
# example: using an obscure transformers auto class
tokenizer = AutoTokenizer.from_pretrained("google/electra-base-discriminator")
model = AutoModelForPreTraining.from_pretrained("google/electra-base-discriminator")
# assuming model_summary() already defined
model_summary(model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment