Created
February 27, 2024 02:57
-
-
Save madebyollin/41a948a7c69a36b1e1fded71f253e7ef to your computer and use it in GitHub Desktop.
Add human-readable profiling markers to a pytorch module
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def add_profiling_markers(model): | |
"""Monkey-patch profiling markers into an nn.Module. | |
Args: | |
model: an nn.Module | |
Effect: | |
all model.named_module() forward calls get wrapped in their | |
own profiling scope, making traces easier to understand. | |
""" | |
from torch.profiler import record_function | |
def add_profiling_to_forward(name, module): | |
def profiled_forward(*args, **kwargs): | |
with record_function(f"{name}.forward"): | |
return module._forward(*args, **kwargs) | |
return profiled_forward | |
for name, module in model.named_modules(): | |
if not hasattr(module, "_forward"): | |
module._forward = module.forward | |
module.forward = add_profiling_to_forward(name, module) | |
# Usage | |
add_profiling_markers(model) | |
with torch.profiler.profile() as prof: | |
y = model(x).cpu() | |
prof.export_chrome_trace("trace.json") | |
# then open chrome and load trace.json into the chrome://tracing tab |
Author
madebyollin
commented
Feb 27, 2024

Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment