Skip to content

Instantly share code, notes, and snippets.

@jszym
Created February 13, 2023 00:46
Show Gist options
  • Save jszym/479db2af32411b64249bfb1bff43a95e to your computer and use it in GitHub Desktop.
Save jszym/479db2af32411b64249bfb1bff43a95e to your computer and use it in GitHub Desktop.
Dictionary Logger for PyTorch Lightning
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.loggers import Logger
from pytorch_lightning.loggers.logger import rank_zero_experiment
from collections import defaultdict
class DictLogger(Logger):
def __init__(self):
super().__init__()
def def_value():
return []
# Defining the dict
self.metrics = defaultdict(def_value)
@property
def name(self):
return 'DictLogger'
@property
@rank_zero_experiment
def experiment(self):
# Return the experiment object associated with this logger.
pass
@property
def version(self):
# Return the experiment version, int or str.
return '0.1'
@rank_zero_only
def log_hyperparams(self, params):
# params is an argparse.Namespace
# your code to record hyperparameters goes here
pass
@rank_zero_only
def log_metrics(self, metrics, step):
# metrics is a dictionary of metric names and values
# your code to record metrics goes here
for key in metrics.keys():
self.metrics[key].append(metrics[key])
@rank_zero_only
def save(self):
# Optional. Any code necessary to save logger data goes here
# If you implement this, remember to call `super().save()`
# at the start of the method (important for aggregation of metrics)
super().save()
@rank_zero_only
def finalize(self, status):
# Optional. Any code that needs to be run after training
# finishes goes here
pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment