Skip to content

Instantly share code, notes, and snippets.

@kiyoon
Last active October 30, 2024 01:55
Show Gist options
  • Save kiyoon/d7dffd0c20f5e4c0123e5b18defe1b77 to your computer and use it in GitHub Desktop.
Save kiyoon/d7dffd0c20f5e4c0123e5b18defe1b77 to your computer and use it in GitHub Desktop.
Print pytorch model with #params and requires_grad information.
# NOTE: __future__ annotations are needed because we want to lazily evaluate torch type hints during runtime.
# Or, you need to wrap the types in quotes, e.g. "torch.nn.Module" instead of torch.nn.Module.
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import torch
def _addindent(s_: str, num_spaces: int):
s = s_.split("\n")
# don't do anything for single-line stuff
if len(s) == 1:
return s_
first = s.pop(0)
s = [(num_spaces * " ") + line for line in s]
s = "\n".join(s)
s = first + "\n" + s
return s
def _count_module_requires_grad(module: torch.nn.Module) -> str:
"""Check if all of the parameters in the module requires grad."""
if len(list(module.parameters())) == 0:
return ""
num_params = 0
num_requires_grad = 0
msg = ""
for param in module.parameters():
nn = 1
for s in param.size():
nn = nn * s
num_params += nn
if param.requires_grad:
num_requires_grad += nn
if num_requires_grad == num_params:
msg = f" - num_params={num_params:,} - requires_grad=True"
elif num_requires_grad == 0:
msg = f" - num_params={num_params:,} - requires_grad=False"
else:
msg = f" - requires_grad={num_requires_grad:,}/{num_params:,} ({num_requires_grad/num_params*100:.2f}%)"
return msg
def repr_model_with_requires_grad(model: torch.nn.Module, *, recursed: bool = False):
"""
Similar to repr(model), but also includes whether or not each module requires grad (True/False/partial).
Args:
model: nn.Module
recursed: If True, then we are calling this function recursively on a submodule
and we should not include requires_grad information when closing the
parenthesis. Must be False when calling this function.
"""
# We treat the extra repr like the sub-module, one item per line
extra_lines = []
extra_repr = model.extra_repr()
# empty string will be split into list ['']
if extra_repr:
extra_lines = extra_repr.split("\n")
child_lines = []
for key, module in model._modules.items():
assert module is not None
mod_str = repr_model_with_requires_grad(module, recursed=True)
mod_str = _addindent(mod_str, 2)
requires_grad = _count_module_requires_grad(module)
child_lines.append("(" + key + "): " + mod_str + requires_grad)
lines = extra_lines + child_lines
main_str = model._get_name() + "("
if lines:
# simple one-liner info, which most builtin Modules will use
if len(extra_lines) == 1 and not child_lines:
main_str += extra_lines[0]
else:
main_str += "\n " + "\n ".join(lines) + "\n"
if recursed:
main_str += ")"
else:
main_str += ")" + _count_module_requires_grad(model)
return main_str
def repr_model_with_requires_grad_simple(
model: torch.nn.Module, skip_after_num=1, *, recursed: bool = False
):
"""
Similar to repr(model), but also includes whether or not each module requires grad (True/False/partial).
It will print only 0 if there are lots of modules (0, 1, ...)
Args:
recursed: If True, then we are calling this function recursively on a submodule
and we should not include requires_grad information when closing the
parenthesis. Must be False when calling this function.
"""
# We treat the extra repr like the sub-module, one item per line
extra_lines = []
extra_repr = model.extra_repr()
# empty string will be split into list ['']
if extra_repr:
extra_lines = extra_repr.split("\n")
child_lines = []
prev_module_name = ""
for key, module in model._modules.items():
assert module is not None
if (
key.isdigit()
and module._get_name() == prev_module_name
and int(key) >= skip_after_num
):
child_lines.append(
f"({key}): {module._get_name()} (skipped printing to simplify)"
)
continue
mod_str = repr_model_with_requires_grad_simple(
module, skip_after_num, recursed=True
)
mod_str = _addindent(mod_str, 2)
requires_grad = _count_module_requires_grad(module)
child_lines.append("(" + key + "): " + mod_str + requires_grad)
prev_module_name = module._get_name()
lines = extra_lines + child_lines
main_str = model._get_name() + "("
if lines:
# simple one-liner info, which most builtin Modules will use
if len(extra_lines) == 1 and not child_lines:
main_str += extra_lines[0]
else:
main_str += "\n " + "\n ".join(lines) + "\n"
if recursed:
main_str += ")"
else:
main_str += ")" + _count_module_requires_grad(model)
return main_str
if __name__ == "__main__":
import torchvision
model = torchvision.models.resnet18()
print(repr_model_with_requires_grad(model))
model.fc.requires_grad_(False)
print(repr_model_with_requires_grad(model))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment