-
-
Save datavudeja/92696cf129b38af696f0e0a428ff6174 to your computer and use it in GitHub Desktop.
[Timer]
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import time | |
| import functools | |
| from typing import Literal, Optional, Callable, Any | |
| class TimerContext: | |
| @staticmethod | |
| def cuda_timer(unit: Literal['ms', 's'] = 's'): | |
| """返回CUDA计时上下文管理器""" | |
| class CudaTimerContext: | |
| def __init__(self): | |
| self.start_event = torch.cuda.Event(enable_timing=True) | |
| self.end_event = torch.cuda.Event(enable_timing=True) | |
| self.unit = unit | |
| self.elapsed_time = None | |
| def __enter__(self): | |
| self.start_event.record() | |
| return self | |
| def __exit__(self, exc_type, exc_value, traceback): | |
| self.end_event.record() | |
| torch.cuda.synchronize() | |
| elapsed_ms = self.start_event.elapsed_time(self.end_event) | |
| self.elapsed_time = elapsed_ms if unit == 'ms' else elapsed_ms / 1000 | |
| return CudaTimerContext() | |
| @staticmethod | |
| def timer(unit: Literal['ms', 's'] = 's', use_cuda_sync: bool = False): | |
| """返回性能计时上下文管理器""" | |
| class PerfTimerContext: | |
| def __init__(self): | |
| self.unit = unit | |
| self.use_cuda_sync = use_cuda_sync | |
| self.elapsed_time = None | |
| def __enter__(self): | |
| if self.use_cuda_sync and torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| self.start_time = time.perf_counter() | |
| return self | |
| def __exit__(self, exc_type, exc_value, traceback): | |
| if self.use_cuda_sync and torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| end_time = time.perf_counter() | |
| elapsed_s = end_time - self.start_time | |
| self.elapsed_time = elapsed_s * 1000 if unit == 'ms' else elapsed_s | |
| return PerfTimerContext() | |
| """示意用例""" | |
| def test_function(): | |
| x = torch.randn(1000, 1000, device='cuda') | |
| return torch.mm(x, x.T) | |
| # 用with语句计时context | |
| with TimerContext.cuda_timer(unit='ms') as timer: | |
| test_function() | |
| print(f"CUDA上下文计时: {timer.elapsed_time:.3f} ms") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import time | |
| import functools | |
| from typing import Literal, Optional, Callable, Any | |
| class TimerDecorator: | |
| @staticmethod | |
| def cuda_timer(unit: Literal['ms', 's'] = 's', print_result: bool = True, store_attr: Optional[str] = None): | |
| """ | |
| CUDA事件计时装饰器 | |
| Args: | |
| unit: 返回时间的单位,'ms' 或 's',默认为 's' | |
| print_result: 是否打印计时结果,默认为 True | |
| store_attr: 如果指定,将计时结果存储到被装饰函数的指定属性中 | |
| """ | |
| def decorator(func: Callable) -> Callable: | |
| @functools.wraps(func) | |
| def wrapper(*args, **kwargs): | |
| start_event = torch.cuda.Event(enable_timing=True) | |
| end_event = torch.cuda.Event(enable_timing=True) | |
| start_event.record() | |
| result = func(*args, **kwargs) | |
| end_event.record() | |
| torch.cuda.synchronize() | |
| elapsed_ms = start_event.elapsed_time(end_event) | |
| elapsed_time = elapsed_ms if unit == 'ms' else elapsed_ms / 1000 | |
| if print_result: | |
| unit_str = 'ms' if unit == 'ms' else 's' | |
| print(f"[CUDA Timer] {func.__name__}: {elapsed_time:.6f} {unit_str}") | |
| if store_attr: | |
| setattr(func, store_attr, elapsed_time) | |
| return result | |
| return wrapper | |
| return decorator | |
| @staticmethod | |
| def timer(unit: Literal['ms', 's'] = 's', use_cuda_sync: bool = False, | |
| print_result: bool = True, store_attr: Optional[str] = None): | |
| """ | |
| 基于time.perf_counter()的计时装饰器 | |
| Args: | |
| unit: 返回时间的单位,'ms' 或 's',默认为 's' | |
| use_cuda_sync: 是否使用torch.cuda.synchronize(),默认为 False | |
| print_result: 是否打印计时结果,默认为 True | |
| store_attr: 如果指定,将计时结果存储到被装饰函数的指定属性中 | |
| """ | |
| def decorator(func: Callable) -> Callable: | |
| @functools.wraps(func) | |
| def wrapper(*args, **kwargs): | |
| if use_cuda_sync and torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| start_time = time.perf_counter() | |
| result = func(*args, **kwargs) | |
| if use_cuda_sync and torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| end_time = time.perf_counter() | |
| elapsed_s = end_time - start_time | |
| elapsed_time = elapsed_s * 1000 if unit == 'ms' else elapsed_s | |
| if print_result: | |
| unit_str = 'ms' if unit == 'ms' else 's' | |
| sync_str = ' (with CUDA sync)' if use_cuda_sync else '' | |
| print(f"[Perf Timer] {func.__name__}: {elapsed_time:.6f} {unit_str}{sync_str}") | |
| if store_attr: | |
| setattr(func, store_attr, elapsed_time) | |
| return result | |
| return wrapper | |
| return decorator | |
| """示意用例""" | |
| def test_function(): | |
| x = torch.randn(1000, 1000, device='cuda') | |
| return torch.mm(x, x.T) | |
| # 定义函数时使用装饰器 | |
| @TimerDecorator.cuda_timer(unit='ms') | |
| def test_function_with_timer(): | |
| x = torch.randn(1000, 1000, device='cuda') | |
| return torch.mm(x, x.T) | |
| # 动态添加装饰器 | |
| perf_timed_func = TimerDecorator.timer(unit='ms')(test_function) | |
| perf_timed_func_with_sync = TimerDecorator.timer(unit='ms',use_cuda_sync=True)(test_function) | |
| test_function_with_timer() | |
| perf_timed_func_with_sync() | |
| perf_timed_func() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment