-
-
Save jonnor/c107f3ca24a36c91d8ff94029a0cd357 to your computer and use it in GitHub Desktop.
"""mlflow integration for KerasTuner | |
Copyright: Soundsensing AS, 2021 | |
License: MIT | |
""" | |
import uuid | |
import structlog | |
log = structlog.get_logger() | |
import mlflow | |
import keras_tuner | |
def get_run_id(run): | |
if run is None: | |
return None | |
return run.info.run_id | |
class MlflowLogger(object): | |
"""KerasTuner Logger for integrating with mlflow | |
Each KerasTuner trial is a parent mlflow run, | |
and then each execution is a child | |
XXX: assumes that executions are done sequentially and non-concurrently | |
""" | |
def __init__(self): | |
self.search_run = None | |
self.search_id = None | |
self.trial_run = None | |
self.trial_id = None | |
self.trial_state = None | |
self.execution_run = None | |
self.execution_id = 0 | |
def register_tuner(self, tuner_state): | |
"""Called at start of search""" | |
log.debug('mlflow-logger-search-start') | |
self.search_id = str(uuid.uuid4()) | |
# Register a top-level run | |
self.search_run = mlflow.start_run(nested=False, run_name=f'search-{self.search_id[0:8]}') | |
def exit(self): | |
"""Called at end of a search""" | |
log.debug('mlflow-logger-search-end') | |
self.seach_run = None | |
self.search_id = None | |
def register_trial(self, trial_id, trial_state): | |
"""Called at beginning of trial""" | |
log.debug('mlflow-logger-trial-start', | |
trial_id=trial_id, | |
active_run_id=get_run_id(mlflow.active_run()), | |
) | |
assert self.search_run is not None | |
assert self.trial_run is None | |
assert self.execution_run is None | |
assert self.execution_id == 0 | |
self.trial_id = trial_id | |
self.trial_state = trial_state | |
# Start a new run, under the search run | |
self.trial_run = mlflow.start_run(nested=True, | |
run_name=f'trial-{self.trial_id[0:8]}-{self.search_id[0:8]}' | |
) | |
# For now, only register these on each execution | |
#hyperparams = self.trial_state['hyperparameters']['values'] | |
#mlflow.log_params(hyperparams) | |
def report_trial_state(self, trial_id, trial_state): | |
"""Called at end of trial""" | |
log.debug('mlflow-logger-trial-end', | |
trial_id=trial_id, | |
active_run_id=get_run_id(mlflow.active_run()), | |
) | |
assert self.search_run is not None | |
assert self.trial_run is not None | |
assert self.execution_run is None | |
# Start a new run, under the search run | |
mlflow.end_run() ## XXX: no way to specify the id? | |
self.trial_run = None | |
self.trial_id = None | |
self.trial_state = None | |
self.execution_id = 0 | |
def register_execution(self): | |
log.debug('mlflow-logger-execution-start', | |
active_run_id=get_run_id(mlflow.active_run()), | |
) | |
assert self.search_run is not None | |
assert self.trial_run is not None | |
assert self.execution_run is None | |
self.execution_run = mlflow.start_run(nested=True, | |
run_name=f'exec-{self.execution_id}-{self.trial_id[0:8]}-{self.search_id[0:8]}', | |
) | |
self.execution_id += 1 | |
# register hyperparameters from the trial | |
hyperparams = self.trial_state['hyperparameters']['values'] | |
mlflow.log_params(hyperparams) | |
def report_execution_state(self, histories): | |
log.debug('mlflow-logger-execution-end', | |
active_run_id=get_run_id(mlflow.active_run()), | |
) | |
assert self.search_run is not None | |
assert self.trial_run is not None | |
assert self.execution_run is not None | |
mlflow.end_run() ## XXX: no way to specify the id? | |
self.execution_run = None | |
class FakeHistories(): | |
def __init__(self, metrics={}): | |
self.history = metrics | |
class LoggerTunerMixin(): | |
def __init__(self, *args, **kwargs): | |
if kwargs.get('logger') is None: | |
kwargs['logger'] = MlflowLogger() | |
self.on_exception = kwargs.get('on_exception', 'pass') | |
return super(LoggerTunerMixin, self).__init__(*args, **kwargs) | |
# Hack in registration for each model training "execution" | |
def _build_and_fit_model(self, trial, *args, **kwargs): | |
# log start of execution | |
if self.logger: | |
self.logger.register_execution() | |
histories = None | |
try: | |
# call the original function | |
histories = super(LoggerTunerMixin, self)._build_and_fit_model(trial, *args, **kwargs) | |
except Exception as e: | |
if self.on_exception == 'pass': | |
o = self.oracle.objective | |
value = float('inf') if o.direction == 'min' else float('-inf') | |
histories = FakeHistories({o.name: value}) | |
else: | |
raise e | |
# log end of execution | |
if self.logger: | |
self.logger.report_execution_state(histories) | |
return histories | |
# Integrate with keras tuners | |
class RandomSearch(LoggerTunerMixin, keras_tuner.RandomSearch): | |
pass | |
class BayesianOptimization(LoggerTunerMixin, keras_tuner.BayesianOptimization): | |
pass | |
class SklearnTuner(LoggerTunerMixin, keras_tuner.SklearnTuner): | |
pass | |
class Hyperband(LoggerTunerMixin, keras_tuner.Hyperband): | |
pass |
I have to add mlflow.tensorflow.autolog()
for it to log my loss as well.
@metalglove ah yes, good point.
And anything one want logged that is not part of the autolog, one should use the mlflow tracking API for (mlflow.log_*).
https://mlflow.org/docs/latest/tracking.html#logging-functions
@jonnor I am getting this error ModuleNotFoundError: No module named 'structlog'
, am I missing something?
@metalglove ah yes, good point. And anything one want logged that is not part of the autolog, one should use the mlflow tracking API for (mlflow.log_*). https://mlflow.org/docs/latest/tracking.html#logging-functions
I'm trying to log a plot of a confusion matrix as an mlflow artifact. Where should I add that bit of code? And any suggestions for how I can retrieve the model from the active trial to run model.predict?
One imports these overridden classes, and uses that to instantiate as the search.