Skip to content

Instantly share code, notes, and snippets.

@CyanScholar
Last active August 12, 2025 14:26
Show Gist options
  • Save CyanScholar/0ea9bb4821f2b9e916e0a447d3dc3d39 to your computer and use it in GitHub Desktop.
Save CyanScholar/0ea9bb4821f2b9e916e0a447d3dc3d39 to your computer and use it in GitHub Desktop.
[Timer]
import torch
import time
import functools
from typing import Literal, Optional, Callable, Any
class TimerContext:
@staticmethod
def cuda_timer(unit: Literal['ms', 's'] = 's'):
"""返回CUDA计时上下文管理器"""
class CudaTimerContext:
def __init__(self):
self.start_event = torch.cuda.Event(enable_timing=True)
self.end_event = torch.cuda.Event(enable_timing=True)
self.unit = unit
self.elapsed_time = None
def __enter__(self):
self.start_event.record()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.end_event.record()
torch.cuda.synchronize()
elapsed_ms = self.start_event.elapsed_time(self.end_event)
self.elapsed_time = elapsed_ms if unit == 'ms' else elapsed_ms / 1000
return CudaTimerContext()
@staticmethod
def timer(unit: Literal['ms', 's'] = 's', use_cuda_sync: bool = False):
"""返回性能计时上下文管理器"""
class PerfTimerContext:
def __init__(self):
self.unit = unit
self.use_cuda_sync = use_cuda_sync
self.elapsed_time = None
def __enter__(self):
if self.use_cuda_sync and torch.cuda.is_available():
torch.cuda.synchronize()
self.start_time = time.perf_counter()
return self
def __exit__(self, exc_type, exc_value, traceback):
if self.use_cuda_sync and torch.cuda.is_available():
torch.cuda.synchronize()
end_time = time.perf_counter()
elapsed_s = end_time - self.start_time
self.elapsed_time = elapsed_s * 1000 if unit == 'ms' else elapsed_s
return PerfTimerContext()
"""示意用例"""
def test_function():
x = torch.randn(1000, 1000, device='cuda')
return torch.mm(x, x.T)
# 用with语句计时context
with TimerContext.cuda_timer(unit='ms') as timer:
test_function()
print(f"CUDA上下文计时: {timer.elapsed_time:.3f} ms")
import torch
import time
import functools
from typing import Literal, Optional, Callable, Any
class TimerDecorator:
@staticmethod
def cuda_timer(unit: Literal['ms', 's'] = 's', print_result: bool = True, store_attr: Optional[str] = None):
"""
CUDA事件计时装饰器
Args:
unit: 返回时间的单位,'ms' 或 's',默认为 's'
print_result: 是否打印计时结果,默认为 True
store_attr: 如果指定,将计时结果存储到被装饰函数的指定属性中
"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
result = func(*args, **kwargs)
end_event.record()
torch.cuda.synchronize()
elapsed_ms = start_event.elapsed_time(end_event)
elapsed_time = elapsed_ms if unit == 'ms' else elapsed_ms / 1000
if print_result:
unit_str = 'ms' if unit == 'ms' else 's'
print(f"[CUDA Timer] {func.__name__}: {elapsed_time:.6f} {unit_str}")
if store_attr:
setattr(func, store_attr, elapsed_time)
return result
return wrapper
return decorator
@staticmethod
def timer(unit: Literal['ms', 's'] = 's', use_cuda_sync: bool = False,
print_result: bool = True, store_attr: Optional[str] = None):
"""
基于time.perf_counter()的计时装饰器
Args:
unit: 返回时间的单位,'ms' 或 's',默认为 's'
use_cuda_sync: 是否使用torch.cuda.synchronize(),默认为 False
print_result: 是否打印计时结果,默认为 True
store_attr: 如果指定,将计时结果存储到被装饰函数的指定属性中
"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs):
if use_cuda_sync and torch.cuda.is_available():
torch.cuda.synchronize()
start_time = time.perf_counter()
result = func(*args, **kwargs)
if use_cuda_sync and torch.cuda.is_available():
torch.cuda.synchronize()
end_time = time.perf_counter()
elapsed_s = end_time - start_time
elapsed_time = elapsed_s * 1000 if unit == 'ms' else elapsed_s
if print_result:
unit_str = 'ms' if unit == 'ms' else 's'
sync_str = ' (with CUDA sync)' if use_cuda_sync else ''
print(f"[Perf Timer] {func.__name__}: {elapsed_time:.6f} {unit_str}{sync_str}")
if store_attr:
setattr(func, store_attr, elapsed_time)
return result
return wrapper
return decorator
"""示意用例"""
def test_function():
x = torch.randn(1000, 1000, device='cuda')
return torch.mm(x, x.T)
# 定义函数时使用装饰器
@TimerDecorator.cuda_timer(unit='ms')
def test_function_with_timer():
x = torch.randn(1000, 1000, device='cuda')
return torch.mm(x, x.T)
# 动态添加装饰器
perf_timed_func = TimerDecorator.timer(unit='ms')(test_function)
perf_timed_func_with_sync = TimerDecorator.timer(unit='ms',use_cuda_sync=True)(test_function)
test_function_with_timer()
perf_timed_func_with_sync()
perf_timed_func()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment