Last active
September 1, 2025 10:24
-
-
Save Ethkuil/bcaad8759a5e1317085903fb2a7107e8 to your computer and use it in GitHub Desktop.
Convenient torch profiler that only profile needed part
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
def on_trace_ready(dir_name=".", use_gzip=True): | |
import os | |
import socket | |
import time | |
def handler_fn(prof) -> None: | |
if not os.path.isdir(dir_name): | |
try: | |
os.makedirs(dir_name, exist_ok=True) | |
except Exception as e: | |
raise RuntimeError("Can't create directory: " + dir_name) from e | |
rank = torch.distributed.get_rank() | |
worker_name = f"rank{rank}.{socket.gethostname()}" | |
curr_time = time.strftime("%H-%M-%S", time.localtime()) | |
file_name = f"{curr_time}.{worker_name}.pt.trace.json" | |
if use_gzip: | |
file_name = file_name + ".gz" | |
prof.export_chrome_trace(os.path.join(dir_name, file_name)) | |
return handler_fn | |
from torch.profiler import profile, ProfilerActivity | |
my_profiler = profile( | |
activities=[ | |
ProfilerActivity.CPU, | |
ProfilerActivity.CUDA, | |
], | |
with_stack=True, | |
on_trace_ready=on_trace_ready("profile", use_gzip=True), | |
) | |
from enum import Enum | |
class ProfilerState(Enum): | |
NOT_STARTED = 0 | |
RUNNING = 1 | |
STOPPED = 2 | |
COUNTER = 0 | |
PROFILER_STATE = ProfilerState.NOT_STARTED | |
def my_profile(begin_cnt: int, len_cnt: int, rank_list=[0]): | |
""" | |
return decorator to profile the function. | |
:param begin_cnt: The count to start profiling. | |
:param len_cnt: How many counts to profile. | |
:param rank_list: The list of ranks to profile. | |
""" | |
def decorator(func): | |
from functools import wraps | |
@wraps(func) | |
def wrapped_func(*args, **kwargs): | |
rank = torch.distributed.get_rank() | |
global COUNTER, PROFILER_STATE | |
if rank in rank_list and PROFILER_STATE != ProfilerState.STOPPED: | |
match PROFILER_STATE: | |
case ProfilerState.NOT_STARTED: | |
if COUNTER >= begin_cnt: | |
logger.setLevel("INFO") | |
logger.info( | |
f"Profiler started at count {COUNTER} for rank {rank}" | |
) | |
my_profiler.start() | |
PROFILER_STATE = ProfilerState.RUNNING | |
case ProfilerState.RUNNING: | |
if COUNTER >= begin_cnt + len_cnt: | |
my_profiler.stop() | |
PROFILER_STATE = ProfilerState.STOPPED | |
logger.info( | |
f"Profiler stopped at count {COUNTER} for rank {rank}" | |
) | |
case _: | |
assert False | |
COUNTER += 1 | |
return func(*args, **kwargs) | |
return wrapped_func | |
return decorator | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment