Created
May 29, 2019 17:12
-
-
Save georgepar/7a7370bc3ccb399444c9d21fe07a6d80 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import mlflow | |
import mlflow.pytorch | |
class MlFlowLogger(object): | |
def __init__(self, | |
uri=None, | |
experiment_name=None, | |
model_path='models', | |
**params): | |
self.params = params | |
self.experiment_name = experiment_name | |
self.run = None | |
self.uri = uri | |
self.model_path = model_path | |
self.start() | |
def get_or_set_experiment(self): | |
print(mlflow.get_tracking_uri()) | |
if self.experiment_name is None: | |
return | |
try: | |
mlflow.create_experiment(self.experiment_name) | |
except Exception: | |
print('Experiment {} already exists' | |
.format(self.experiment_name)) | |
mlflow.set_experiment(self.experiment_name) | |
@staticmethod | |
def log_param(k, v): | |
mlflow.log_param(k, v) | |
def log_params(self, params=None): | |
if params is None: | |
params = self.params | |
for k, v in params.items(): | |
self.log_param(k, v) | |
@staticmethod | |
def log_metric(k, v): | |
mlflow.log_metric(k, v) | |
def log_metrics(self, metrics): | |
for k, v in metrics.items(): | |
self.log_metric(k, v) | |
def log_model(self, model): | |
""" for local saving of models """ | |
mlflow.pytorch.save_model(model, self.model_path) | |
def start(self): | |
mlflow.set_tracking_uri(self.uri) | |
self.get_or_set_experiment() | |
self.run = mlflow.start_run() | |
self.log_params() | |
def end(self): | |
mlflow.end_run() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment