Skip to content

Instantly share code, notes, and snippets.

@Aquaakuma
Last active July 15, 2025 03:37
Show Gist options
  • Save Aquaakuma/1210db7714cdce3cabdaf53f0eef4972 to your computer and use it in GitHub Desktop.
Save Aquaakuma/1210db7714cdce3cabdaf53f0eef4972 to your computer and use it in GitHub Desktop.
自定义accelerate追踪器的tensorboard类
from torch.utils.tensorboard import SummaryWriter
from accelerate.tracking import GeneralTracker, on_main_process
import os
from typing import Union, Optional
# 0. 自定义追踪器
class MyCustomTracker(GeneralTracker):
"""
my custom `Tracker` class that supports `tensorboard`. Should be initialized at the start of your script.
Args:
run_name (`str`):
The name of the experiment run
logging_dir (`str`, `os.PathLike`):
Location for TensorBoard logs to be stored.
kwargs:
Additional key word arguments passed along to the `tensorboard.SummaryWriter.__init__` method.
"""
name = "custom"
requires_logging_directory = True
@on_main_process
def __init__(self, run_name: str, logging_dir: Union[str, os.PathLike],
**kwargs):
super().__init__()
self.run_name = run_name
self.logging_dir = os.path.join(logging_dir, run_name)
os.makedirs(self.logging_dir, exist_ok=True)
self.writer = SummaryWriter(self.logging_dir, **kwargs)
@property
def tracker(self):
return self.writer
@on_main_process
def log(self, values: dict, step: Optional[int], **kwargs):
# 对每个键值对调用 add_scalar
if step is None:
return
for key, value in values.items():
self.add_scalar(key, value, global_step=step)
@on_main_process
def store_init_configuration(self, values: dict):
"""Log experiment configuration parameters"""
text = "\n".join([f"{k}: {v}" for k, v in values.items()])
self.add_text(tag="config", text_string=text)
@on_main_process
def add_scalar(self, tag, scalar_value, **kwargs):
self.writer.add_scalar(tag=tag, scalar_value=scalar_value, **kwargs)
@on_main_process
def add_scalars(self, main_tag, tag_scalar_dict, **kwargs):
self.writer.add_scalars(main_tag=main_tag, tag_scalar_dict=tag_scalar_dict, **kwargs)
@on_main_process
def add_text(self, tag, text_string, **kwargs):
self.writer.add_text(tag=tag, text_string=text_string, **kwargs)
@on_main_process
def add_figure(self, tag, figure, **kwargs):
self.writer.add_figure(tag=tag, figure=figure, **kwargs)
@on_main_process
def finish(self):
"""Close the writer when finished"""
self.writer.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment