Created
February 13, 2023 00:46
-
-
Save jszym/479db2af32411b64249bfb1bff43a95e to your computer and use it in GitHub Desktop.
Dictionary Logger for PyTorch Lightning
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pytorch_lightning.utilities import rank_zero_only | |
from pytorch_lightning.loggers import Logger | |
from pytorch_lightning.loggers.logger import rank_zero_experiment | |
from collections import defaultdict | |
class DictLogger(Logger): | |
def __init__(self): | |
super().__init__() | |
def def_value(): | |
return [] | |
# Defining the dict | |
self.metrics = defaultdict(def_value) | |
@property | |
def name(self): | |
return 'DictLogger' | |
@property | |
@rank_zero_experiment | |
def experiment(self): | |
# Return the experiment object associated with this logger. | |
pass | |
@property | |
def version(self): | |
# Return the experiment version, int or str. | |
return '0.1' | |
@rank_zero_only | |
def log_hyperparams(self, params): | |
# params is an argparse.Namespace | |
# your code to record hyperparameters goes here | |
pass | |
@rank_zero_only | |
def log_metrics(self, metrics, step): | |
# metrics is a dictionary of metric names and values | |
# your code to record metrics goes here | |
for key in metrics.keys(): | |
self.metrics[key].append(metrics[key]) | |
@rank_zero_only | |
def save(self): | |
# Optional. Any code necessary to save logger data goes here | |
# If you implement this, remember to call `super().save()` | |
# at the start of the method (important for aggregation of metrics) | |
super().save() | |
@rank_zero_only | |
def finalize(self, status): | |
# Optional. Any code that needs to be run after training | |
# finishes goes here | |
pass |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment