Created
April 11, 2019 11:46
-
-
Save ferrine/0df6ae873ad0308e2e63d24f7ae37e6b to your computer and use it in GitHub Desktop.
Tensorboard sacred observer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sacred.observers | |
import tensorboardX | |
import os | |
class TensorboardObserver(sacred.observers.FileStorageObserver, tensorboardX.SummaryWriter): | |
VERSION = "TensorboardObserver-0.0.1" | |
def __init__(self, basedir, resource_dir=None, source_dir=None, | |
template=None, priority=sacred.observers.file_storage.DEFAULT_FILE_STORAGE_PRIORITY, | |
config2name=lambda c: "", **kwargs): | |
sacred.observers.FileStorageObserver.__init__( | |
self, | |
basedir=basedir, | |
resource_dir=resource_dir, | |
source_dir=source_dir, | |
template=template, | |
priority=priority | |
) | |
self._tb_kwargs = kwargs | |
self.config2name = config2name | |
def started_event(self, ex_info, command, host_info, start_time, config, | |
meta_info, _id): | |
_id = super().started_event( | |
ex_info=ex_info, command=command, host_info=host_info, start_time=start_time, config=config, | |
meta_info=meta_info, _id=_id | |
) | |
log_dir = os.path.join(self.dir, "_", self.config2name(config)) | |
tensorboardX.SummaryWriter.__init__(self, log_dir=log_dir, **self._tb_kwargs) | |
del self._tb_kwargs | |
return _id | |
def __eq__(self, other): | |
if isinstance(other, self.__class__): | |
return self.basedir == other.basedir | |
return False | |
def completed_event(self, stop_time, result): | |
super().completed_event(stop_time=stop_time, result=result) | |
self.close() | |
def interrupted_event(self, interrupt_time, status): | |
super().interrupted_event(interrupt_time=interrupt_time, status=status) | |
self.close() | |
def failed_event(self, fail_time, fail_trace): | |
super().failed_event(fail_time=fail_time, fail_trace=fail_trace) | |
self.close() | |
@classmethod | |
def create(cls, basedir, resource_dir=None, source_dir=None, | |
template=None, priority=sacred.observers.file_storage.DEFAULT_FILE_STORAGE_PRIORITY, **kwargs): | |
if not os.path.exists(basedir): | |
os.makedirs(basedir) | |
resource_dir = resource_dir or os.path.join(basedir, '_resources') | |
source_dir = source_dir or os.path.join(basedir, '_sources') | |
if template is not None: | |
if not os.path.exists(template): | |
raise FileNotFoundError("Couldn't find template file '{}'" | |
.format(template)) | |
else: | |
template = os.path.join(basedir, 'template.html') | |
if not os.path.exists(template): | |
template = None | |
return cls(basedir, resource_dir, source_dir, template, priority, **kwargs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment