Created
November 14, 2021 17:59
-
-
Save jonnor/c107f3ca24a36c91d8ff94029a0cd357 to your computer and use it in GitHub Desktop.
MLFlow integration for Keras-Tuner
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""mlflow integration for KerasTuner | |
Copyright: Soundsensing AS, 2021 | |
License: MIT | |
""" | |
import uuid | |
import structlog | |
log = structlog.get_logger() | |
import mlflow | |
import keras_tuner | |
def get_run_id(run): | |
if run is None: | |
return None | |
return run.info.run_id | |
class MlflowLogger(object): | |
"""KerasTuner Logger for integrating with mlflow | |
Each KerasTuner trial is a parent mlflow run, | |
and then each execution is a child | |
XXX: assumes that executions are done sequentially and non-concurrently | |
""" | |
def __init__(self): | |
self.search_run = None | |
self.search_id = None | |
self.trial_run = None | |
self.trial_id = None | |
self.trial_state = None | |
self.execution_run = None | |
self.execution_id = 0 | |
def register_tuner(self, tuner_state): | |
"""Called at start of search""" | |
log.debug('mlflow-logger-search-start') | |
self.search_id = str(uuid.uuid4()) | |
# Register a top-level run | |
self.search_run = mlflow.start_run(nested=False, run_name=f'search-{self.search_id[0:8]}') | |
def exit(self): | |
"""Called at end of a search""" | |
log.debug('mlflow-logger-search-end') | |
self.seach_run = None | |
self.search_id = None | |
def register_trial(self, trial_id, trial_state): | |
"""Called at beginning of trial""" | |
log.debug('mlflow-logger-trial-start', | |
trial_id=trial_id, | |
active_run_id=get_run_id(mlflow.active_run()), | |
) | |
assert self.search_run is not None | |
assert self.trial_run is None | |
assert self.execution_run is None | |
assert self.execution_id == 0 | |
self.trial_id = trial_id | |
self.trial_state = trial_state | |
# Start a new run, under the search run | |
self.trial_run = mlflow.start_run(nested=True, | |
run_name=f'trial-{self.trial_id[0:8]}-{self.search_id[0:8]}' | |
) | |
# For now, only register these on each execution | |
#hyperparams = self.trial_state['hyperparameters']['values'] | |
#mlflow.log_params(hyperparams) | |
def report_trial_state(self, trial_id, trial_state): | |
"""Called at end of trial""" | |
log.debug('mlflow-logger-trial-end', | |
trial_id=trial_id, | |
active_run_id=get_run_id(mlflow.active_run()), | |
) | |
assert self.search_run is not None | |
assert self.trial_run is not None | |
assert self.execution_run is None | |
# Start a new run, under the search run | |
mlflow.end_run() ## XXX: no way to specify the id? | |
self.trial_run = None | |
self.trial_id = None | |
self.trial_state = None | |
self.execution_id = 0 | |
def register_execution(self): | |
log.debug('mlflow-logger-execution-start', | |
active_run_id=get_run_id(mlflow.active_run()), | |
) | |
assert self.search_run is not None | |
assert self.trial_run is not None | |
assert self.execution_run is None | |
self.execution_run = mlflow.start_run(nested=True, | |
run_name=f'exec-{self.execution_id}-{self.trial_id[0:8]}-{self.search_id[0:8]}', | |
) | |
self.execution_id += 1 | |
# register hyperparameters from the trial | |
hyperparams = self.trial_state['hyperparameters']['values'] | |
mlflow.log_params(hyperparams) | |
def report_execution_state(self, histories): | |
log.debug('mlflow-logger-execution-end', | |
active_run_id=get_run_id(mlflow.active_run()), | |
) | |
assert self.search_run is not None | |
assert self.trial_run is not None | |
assert self.execution_run is not None | |
mlflow.end_run() ## XXX: no way to specify the id? | |
self.execution_run = None | |
class FakeHistories(): | |
def __init__(self, metrics={}): | |
self.history = metrics | |
class LoggerTunerMixin(): | |
def __init__(self, *args, **kwargs): | |
if kwargs.get('logger') is None: | |
kwargs['logger'] = MlflowLogger() | |
self.on_exception = kwargs.get('on_exception', 'pass') | |
return super(LoggerTunerMixin, self).__init__(*args, **kwargs) | |
# Hack in registration for each model training "execution" | |
def _build_and_fit_model(self, trial, *args, **kwargs): | |
# log start of execution | |
if self.logger: | |
self.logger.register_execution() | |
histories = None | |
try: | |
# call the original function | |
histories = super(LoggerTunerMixin, self)._build_and_fit_model(trial, *args, **kwargs) | |
except Exception as e: | |
if self.on_exception == 'pass': | |
o = self.oracle.objective | |
value = float('inf') if o.direction == 'min' else float('-inf') | |
histories = FakeHistories({o.name: value}) | |
else: | |
raise e | |
# log end of execution | |
if self.logger: | |
self.logger.report_execution_state(histories) | |
return histories | |
# Integrate with keras tuners | |
class RandomSearch(LoggerTunerMixin, keras_tuner.RandomSearch): | |
pass | |
class BayesianOptimization(LoggerTunerMixin, keras_tuner.BayesianOptimization): | |
pass | |
class SklearnTuner(LoggerTunerMixin, keras_tuner.SklearnTuner): | |
pass | |
class Hyperband(LoggerTunerMixin, keras_tuner.Hyperband): | |
pass |
Yes you need to install the "structlog" module (with pip)
…On Sat, 4 Dec 2021 at 10:40, abdulbasitds ***@***.***> wrote:
***@***.**** commented on this gist.
------------------------------
@jonnor <https://github.com/jonnor> I am getting this error ModuleNotFoundError:
No module named 'structlog', am I missing something?
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<https://gist.github.com/c107f3ca24a36c91d8ff94029a0cd357#gistcomment-3983720>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAALBAJRN5ALNKSYTFG2ISTUPHOYFANCNFSM5ITIKBYA>
.
Triage notifications on the go with GitHub Mobile for iOS
<https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675>
or Android
<https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub>.
--
Jon Nordby - www.jonnor.com
@metalglove ah yes, good point. And anything one want logged that is not part of the autolog, one should use the mlflow tracking API for (mlflow.log_*). https://mlflow.org/docs/latest/tracking.html#logging-functions
I'm trying to log a plot of a confusion matrix as an mlflow artifact. Where should I add that bit of code? And any suggestions for how I can retrieve the model from the active trial to run model.predict?
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@jonnor I am getting this error
ModuleNotFoundError: No module named 'structlog'
, am I missing something?