Skip to content

Instantly share code, notes, and snippets.

@Ethkuil
Last active September 1, 2025 10:24
Show Gist options
  • Save Ethkuil/bcaad8759a5e1317085903fb2a7107e8 to your computer and use it in GitHub Desktop.
Save Ethkuil/bcaad8759a5e1317085903fb2a7107e8 to your computer and use it in GitHub Desktop.
Convenient torch profiler that only profile needed part
import torch
def on_trace_ready(dir_name=".", use_gzip=True):
import os
import socket
import time
def handler_fn(prof) -> None:
if not os.path.isdir(dir_name):
try:
os.makedirs(dir_name, exist_ok=True)
except Exception as e:
raise RuntimeError("Can't create directory: " + dir_name) from e
rank = torch.distributed.get_rank()
worker_name = f"rank{rank}.{socket.gethostname()}"
curr_time = time.strftime("%H-%M-%S", time.localtime())
file_name = f"{curr_time}.{worker_name}.pt.trace.json"
if use_gzip:
file_name = file_name + ".gz"
prof.export_chrome_trace(os.path.join(dir_name, file_name))
return handler_fn
from torch.profiler import profile, ProfilerActivity
my_profiler = profile(
activities=[
ProfilerActivity.CPU,
ProfilerActivity.CUDA,
],
with_stack=True,
on_trace_ready=on_trace_ready("profile", use_gzip=True),
)
from enum import Enum
class ProfilerState(Enum):
NOT_STARTED = 0
RUNNING = 1
STOPPED = 2
COUNTER = 0
PROFILER_STATE = ProfilerState.NOT_STARTED
def my_profile(begin_cnt: int, len_cnt: int, rank_list=[0]):
"""
return decorator to profile the function.
:param begin_cnt: The count to start profiling.
:param len_cnt: How many counts to profile.
:param rank_list: The list of ranks to profile.
"""
def decorator(func):
from functools import wraps
@wraps(func)
def wrapped_func(*args, **kwargs):
rank = torch.distributed.get_rank()
global COUNTER, PROFILER_STATE
if rank in rank_list and PROFILER_STATE != ProfilerState.STOPPED:
match PROFILER_STATE:
case ProfilerState.NOT_STARTED:
if COUNTER >= begin_cnt:
logger.setLevel("INFO")
logger.info(
f"Profiler started at count {COUNTER} for rank {rank}"
)
my_profiler.start()
PROFILER_STATE = ProfilerState.RUNNING
case ProfilerState.RUNNING:
if COUNTER >= begin_cnt + len_cnt:
my_profiler.stop()
PROFILER_STATE = ProfilerState.STOPPED
logger.info(
f"Profiler stopped at count {COUNTER} for rank {rank}"
)
case _:
assert False
COUNTER += 1
return func(*args, **kwargs)
return wrapped_func
return decorator
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment