Last active
July 15, 2025 03:37
-
-
Save Aquaakuma/1210db7714cdce3cabdaf53f0eef4972 to your computer and use it in GitHub Desktop.
自定义accelerate追踪器的tensorboard类
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.utils.tensorboard import SummaryWriter | |
from accelerate.tracking import GeneralTracker, on_main_process | |
import os | |
from typing import Union, Optional | |
# 0. 自定义追踪器 | |
class MyCustomTracker(GeneralTracker): | |
""" | |
my custom `Tracker` class that supports `tensorboard`. Should be initialized at the start of your script. | |
Args: | |
run_name (`str`): | |
The name of the experiment run | |
logging_dir (`str`, `os.PathLike`): | |
Location for TensorBoard logs to be stored. | |
kwargs: | |
Additional key word arguments passed along to the `tensorboard.SummaryWriter.__init__` method. | |
""" | |
name = "custom" | |
requires_logging_directory = True | |
@on_main_process | |
def __init__(self, run_name: str, logging_dir: Union[str, os.PathLike], | |
**kwargs): | |
super().__init__() | |
self.run_name = run_name | |
self.logging_dir = os.path.join(logging_dir, run_name) | |
os.makedirs(self.logging_dir, exist_ok=True) | |
self.writer = SummaryWriter(self.logging_dir, **kwargs) | |
@property | |
def tracker(self): | |
return self.writer | |
@on_main_process | |
def log(self, values: dict, step: Optional[int], **kwargs): | |
# 对每个键值对调用 add_scalar | |
if step is None: | |
return | |
for key, value in values.items(): | |
self.add_scalar(key, value, global_step=step) | |
@on_main_process | |
def store_init_configuration(self, values: dict): | |
"""Log experiment configuration parameters""" | |
text = "\n".join([f"{k}: {v}" for k, v in values.items()]) | |
self.add_text(tag="config", text_string=text) | |
@on_main_process | |
def add_scalar(self, tag, scalar_value, **kwargs): | |
self.writer.add_scalar(tag=tag, scalar_value=scalar_value, **kwargs) | |
@on_main_process | |
def add_scalars(self, main_tag, tag_scalar_dict, **kwargs): | |
self.writer.add_scalars(main_tag=main_tag, tag_scalar_dict=tag_scalar_dict, **kwargs) | |
@on_main_process | |
def add_text(self, tag, text_string, **kwargs): | |
self.writer.add_text(tag=tag, text_string=text_string, **kwargs) | |
@on_main_process | |
def add_figure(self, tag, figure, **kwargs): | |
self.writer.add_figure(tag=tag, figure=figure, **kwargs) | |
@on_main_process | |
def finish(self): | |
"""Close the writer when finished""" | |
self.writer.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment