-
-
Save jonnor/c107f3ca24a36c91d8ff94029a0cd357 to your computer and use it in GitHub Desktop.
| """mlflow integration for KerasTuner | |
| Copyright: Soundsensing AS, 2021 | |
| License: MIT | |
| """ | |
| import uuid | |
| import structlog | |
| log = structlog.get_logger() | |
| import mlflow | |
| import keras_tuner | |
| def get_run_id(run): | |
| if run is None: | |
| return None | |
| return run.info.run_id | |
| class MlflowLogger(object): | |
| """KerasTuner Logger for integrating with mlflow | |
| Each KerasTuner trial is a parent mlflow run, | |
| and then each execution is a child | |
| XXX: assumes that executions are done sequentially and non-concurrently | |
| """ | |
| def __init__(self): | |
| self.search_run = None | |
| self.search_id = None | |
| self.trial_run = None | |
| self.trial_id = None | |
| self.trial_state = None | |
| self.execution_run = None | |
| self.execution_id = 0 | |
| def register_tuner(self, tuner_state): | |
| """Called at start of search""" | |
| log.debug('mlflow-logger-search-start') | |
| self.search_id = str(uuid.uuid4()) | |
| # Register a top-level run | |
| self.search_run = mlflow.start_run(nested=False, run_name=f'search-{self.search_id[0:8]}') | |
| def exit(self): | |
| """Called at end of a search""" | |
| log.debug('mlflow-logger-search-end') | |
| self.seach_run = None | |
| self.search_id = None | |
| def register_trial(self, trial_id, trial_state): | |
| """Called at beginning of trial""" | |
| log.debug('mlflow-logger-trial-start', | |
| trial_id=trial_id, | |
| active_run_id=get_run_id(mlflow.active_run()), | |
| ) | |
| assert self.search_run is not None | |
| assert self.trial_run is None | |
| assert self.execution_run is None | |
| assert self.execution_id == 0 | |
| self.trial_id = trial_id | |
| self.trial_state = trial_state | |
| # Start a new run, under the search run | |
| self.trial_run = mlflow.start_run(nested=True, | |
| run_name=f'trial-{self.trial_id[0:8]}-{self.search_id[0:8]}' | |
| ) | |
| # For now, only register these on each execution | |
| #hyperparams = self.trial_state['hyperparameters']['values'] | |
| #mlflow.log_params(hyperparams) | |
| def report_trial_state(self, trial_id, trial_state): | |
| """Called at end of trial""" | |
| log.debug('mlflow-logger-trial-end', | |
| trial_id=trial_id, | |
| active_run_id=get_run_id(mlflow.active_run()), | |
| ) | |
| assert self.search_run is not None | |
| assert self.trial_run is not None | |
| assert self.execution_run is None | |
| # Start a new run, under the search run | |
| mlflow.end_run() ## XXX: no way to specify the id? | |
| self.trial_run = None | |
| self.trial_id = None | |
| self.trial_state = None | |
| self.execution_id = 0 | |
| def register_execution(self): | |
| log.debug('mlflow-logger-execution-start', | |
| active_run_id=get_run_id(mlflow.active_run()), | |
| ) | |
| assert self.search_run is not None | |
| assert self.trial_run is not None | |
| assert self.execution_run is None | |
| self.execution_run = mlflow.start_run(nested=True, | |
| run_name=f'exec-{self.execution_id}-{self.trial_id[0:8]}-{self.search_id[0:8]}', | |
| ) | |
| self.execution_id += 1 | |
| # register hyperparameters from the trial | |
| hyperparams = self.trial_state['hyperparameters']['values'] | |
| mlflow.log_params(hyperparams) | |
| def report_execution_state(self, histories): | |
| log.debug('mlflow-logger-execution-end', | |
| active_run_id=get_run_id(mlflow.active_run()), | |
| ) | |
| assert self.search_run is not None | |
| assert self.trial_run is not None | |
| assert self.execution_run is not None | |
| mlflow.end_run() ## XXX: no way to specify the id? | |
| self.execution_run = None | |
| class FakeHistories(): | |
| def __init__(self, metrics={}): | |
| self.history = metrics | |
| class LoggerTunerMixin(): | |
| def __init__(self, *args, **kwargs): | |
| if kwargs.get('logger') is None: | |
| kwargs['logger'] = MlflowLogger() | |
| self.on_exception = kwargs.get('on_exception', 'pass') | |
| return super(LoggerTunerMixin, self).__init__(*args, **kwargs) | |
| # Hack in registration for each model training "execution" | |
| def _build_and_fit_model(self, trial, *args, **kwargs): | |
| # log start of execution | |
| if self.logger: | |
| self.logger.register_execution() | |
| histories = None | |
| try: | |
| # call the original function | |
| histories = super(LoggerTunerMixin, self)._build_and_fit_model(trial, *args, **kwargs) | |
| except Exception as e: | |
| if self.on_exception == 'pass': | |
| o = self.oracle.objective | |
| value = float('inf') if o.direction == 'min' else float('-inf') | |
| histories = FakeHistories({o.name: value}) | |
| else: | |
| raise e | |
| # log end of execution | |
| if self.logger: | |
| self.logger.report_execution_state(histories) | |
| return histories | |
| # Integrate with keras tuners | |
| class RandomSearch(LoggerTunerMixin, keras_tuner.RandomSearch): | |
| pass | |
| class BayesianOptimization(LoggerTunerMixin, keras_tuner.BayesianOptimization): | |
| pass | |
| class SklearnTuner(LoggerTunerMixin, keras_tuner.SklearnTuner): | |
| pass | |
| class Hyperband(LoggerTunerMixin, keras_tuner.Hyperband): | |
| pass |
One imports these overridden classes, and uses that to instantiate as the search.
from mlflow_keras_tuner import RandomSearch
tuner = RandomSearch(
.... ordinary parameters ...
)
I have to add mlflow.tensorflow.autolog() for it to log my loss as well.
@metalglove ah yes, good point.
And anything one want logged that is not part of the autolog, one should use the mlflow tracking API for (mlflow.log_*).
https://mlflow.org/docs/latest/tracking.html#logging-functions
@jonnor I am getting this error ModuleNotFoundError: No module named 'structlog', am I missing something?
@metalglove ah yes, good point. And anything one want logged that is not part of the autolog, one should use the mlflow tracking API for (mlflow.log_*). https://mlflow.org/docs/latest/tracking.html#logging-functions
I'm trying to log a plot of a confusion matrix as an mlflow artifact. Where should I add that bit of code? And any suggestions for how I can retrieve the model from the active trial to run model.predict?
Could you provide an example on how you would use it with mlflow?