Last active
May 7, 2020 21:02
-
-
Save smurching/181ba02995e15a2b2a00bf1c3cf64f44 to your computer and use it in GitHub Desktop.
OSS MLflow post-run-creation hook
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from mlflow.tracking.context.abstract_context import RunContextProvider | |
from mlflow.utils import databricks_utils | |
from mlflow.entities import SourceType | |
from mlflow.utils.mlflow_tags import ( | |
MLFLOW_SOURCE_TYPE, | |
MLFLOW_SOURCE_NAME, | |
MLFLOW_DATABRICKS_WEBAPP_URL, | |
MLFLOW_DATABRICKS_NOTEBOOK_PATH, | |
MLFLOW_DATABRICKS_NOTEBOOK_ID | |
) | |
class MlflowCreateRunHook(object): | |
""" | |
IPython event hook that maintains a counter of created runs & submits this count | |
to the Databricks frontend after each command execution. Note that this hook does | |
not directly detect run-creation & increment the counter itself - detecting run | |
creation is the responsibility of ``DatabricksNotebookRunContext`` below. | |
For more info, see the IPython event API: | |
https://ipython.readthedocs.io/en/stable/config/callbacks.html#ipython-events | |
""" | |
def __init__(self): | |
self._mlflow_runs_created = 0 | |
self._user_ns = None | |
def pre_execute(self): | |
# Reset count of Mlflow runs created | |
self._mlflow_runs_created = 0 | |
def post_execute(self): | |
if self._mlflow_runs_created > 0: | |
self._user_ns.add_frontend_message({"mlflowRunsCreated": self._mlflow_runs_created}) | |
def increment_runs_created(self): | |
self._mlflow_runs_created += 1 | |
def register(self): | |
import IPython | |
ipython = IPython.get_ipython() | |
self._user_ns = ipython.user_ns | |
ipython.events.register('pre_execute', self.pre_execute) | |
ipython.events.register('post_execute', self.pre_execute) | |
class DatabricksNotebookRunContext(RunContextProvider): | |
""" | |
Context provider defining a callback to be executed on run creation via the MlflowClient.create_run | |
API. Increments the count of created runs for the currently-running cell. | |
""" | |
def __init__(self): | |
self.hook = None | |
if self.in_context(): | |
# Register IPython hook for submitting a count of # of created runs to the frontend | |
# if running in Databricks | |
self.hook = MlflowCreateRunHook() | |
self.hook.register() | |
def in_context(self): | |
return databricks_utils.is_in_databricks_notebook() | |
def post_create_run_hook(self, run): | |
""" | |
Hook that executes after a run is created via the MlflowClient.create_run API | |
(note that fluent APIs like mlflow.start_run() ultimately call MlflowClient.create_run) | |
:return: | |
""" | |
experiment_id = run.info.experiment_id | |
if experiment_id == databricks_utils.get_notebook_id(): | |
self.hook.increment_runs_created() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment