Last active
October 30, 2024 01:55
-
-
Save kiyoon/d7dffd0c20f5e4c0123e5b18defe1b77 to your computer and use it in GitHub Desktop.
Print pytorch model with #params and requires_grad information.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# NOTE: __future__ annotations are needed because we want to lazily evaluate torch type hints during runtime. | |
# Or, you need to wrap the types in quotes, e.g. "torch.nn.Module" instead of torch.nn.Module. | |
from __future__ import annotations | |
from typing import TYPE_CHECKING | |
if TYPE_CHECKING: | |
import torch | |
def _addindent(s_: str, num_spaces: int): | |
s = s_.split("\n") | |
# don't do anything for single-line stuff | |
if len(s) == 1: | |
return s_ | |
first = s.pop(0) | |
s = [(num_spaces * " ") + line for line in s] | |
s = "\n".join(s) | |
s = first + "\n" + s | |
return s | |
def _count_module_requires_grad(module: torch.nn.Module) -> str: | |
"""Check if all of the parameters in the module requires grad.""" | |
if len(list(module.parameters())) == 0: | |
return "" | |
num_params = 0 | |
num_requires_grad = 0 | |
msg = "" | |
for param in module.parameters(): | |
nn = 1 | |
for s in param.size(): | |
nn = nn * s | |
num_params += nn | |
if param.requires_grad: | |
num_requires_grad += nn | |
if num_requires_grad == num_params: | |
msg = f" - num_params={num_params:,} - requires_grad=True" | |
elif num_requires_grad == 0: | |
msg = f" - num_params={num_params:,} - requires_grad=False" | |
else: | |
msg = f" - requires_grad={num_requires_grad:,}/{num_params:,} ({num_requires_grad/num_params*100:.2f}%)" | |
return msg | |
def repr_model_with_requires_grad(model: torch.nn.Module, *, recursed: bool = False): | |
""" | |
Similar to repr(model), but also includes whether or not each module requires grad (True/False/partial). | |
Args: | |
model: nn.Module | |
recursed: If True, then we are calling this function recursively on a submodule | |
and we should not include requires_grad information when closing the | |
parenthesis. Must be False when calling this function. | |
""" | |
# We treat the extra repr like the sub-module, one item per line | |
extra_lines = [] | |
extra_repr = model.extra_repr() | |
# empty string will be split into list [''] | |
if extra_repr: | |
extra_lines = extra_repr.split("\n") | |
child_lines = [] | |
for key, module in model._modules.items(): | |
assert module is not None | |
mod_str = repr_model_with_requires_grad(module, recursed=True) | |
mod_str = _addindent(mod_str, 2) | |
requires_grad = _count_module_requires_grad(module) | |
child_lines.append("(" + key + "): " + mod_str + requires_grad) | |
lines = extra_lines + child_lines | |
main_str = model._get_name() + "(" | |
if lines: | |
# simple one-liner info, which most builtin Modules will use | |
if len(extra_lines) == 1 and not child_lines: | |
main_str += extra_lines[0] | |
else: | |
main_str += "\n " + "\n ".join(lines) + "\n" | |
if recursed: | |
main_str += ")" | |
else: | |
main_str += ")" + _count_module_requires_grad(model) | |
return main_str | |
def repr_model_with_requires_grad_simple( | |
model: torch.nn.Module, skip_after_num=1, *, recursed: bool = False | |
): | |
""" | |
Similar to repr(model), but also includes whether or not each module requires grad (True/False/partial). | |
It will print only 0 if there are lots of modules (0, 1, ...) | |
Args: | |
recursed: If True, then we are calling this function recursively on a submodule | |
and we should not include requires_grad information when closing the | |
parenthesis. Must be False when calling this function. | |
""" | |
# We treat the extra repr like the sub-module, one item per line | |
extra_lines = [] | |
extra_repr = model.extra_repr() | |
# empty string will be split into list [''] | |
if extra_repr: | |
extra_lines = extra_repr.split("\n") | |
child_lines = [] | |
prev_module_name = "" | |
for key, module in model._modules.items(): | |
assert module is not None | |
if ( | |
key.isdigit() | |
and module._get_name() == prev_module_name | |
and int(key) >= skip_after_num | |
): | |
child_lines.append( | |
f"({key}): {module._get_name()} (skipped printing to simplify)" | |
) | |
continue | |
mod_str = repr_model_with_requires_grad_simple( | |
module, skip_after_num, recursed=True | |
) | |
mod_str = _addindent(mod_str, 2) | |
requires_grad = _count_module_requires_grad(module) | |
child_lines.append("(" + key + "): " + mod_str + requires_grad) | |
prev_module_name = module._get_name() | |
lines = extra_lines + child_lines | |
main_str = model._get_name() + "(" | |
if lines: | |
# simple one-liner info, which most builtin Modules will use | |
if len(extra_lines) == 1 and not child_lines: | |
main_str += extra_lines[0] | |
else: | |
main_str += "\n " + "\n ".join(lines) + "\n" | |
if recursed: | |
main_str += ")" | |
else: | |
main_str += ")" + _count_module_requires_grad(model) | |
return main_str | |
if __name__ == "__main__": | |
import torchvision | |
model = torchvision.models.resnet18() | |
print(repr_model_with_requires_grad(model)) | |
model.fc.requires_grad_(False) | |
print(repr_model_with_requires_grad(model)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment