Created
February 21, 2025 19:56
-
-
Save dmexs/3f7b697bab806494051b17a76bc4bab7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from hamilton.lifecycle import base | |
from hamilton import node, graph | |
import logging | |
from typing import Type, Any, Optional, Dict, List | |
import mlflow | |
import pickle | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.DEBUG) | |
class MLFlowTrackerAsync( | |
base.BasePreNodeExecuteAsync, | |
base.BasePostNodeExecuteAsync, | |
base.BasePreGraphExecuteAsync, | |
base.BasePostGraphExecuteAsync | |
): | |
mlflow_cache = {} | |
mlflow_run_info = None | |
mlflow_client: Optional[mlflow.MlflowClient] = None | |
mlflow_request_id: Optional[str] = None | |
def __init__(self, *args, **kwargs): | |
if mlflow.active_run(): | |
self.mlflow_client = mlflow.MlflowClient() | |
self.mlflow_run_info = mlflow.active_run().info | |
logger.debug(f"MLFlowTrackerAsync: {self.mlflow_run_info.run_id}") | |
async def pre_graph_execute( | |
self, | |
*, | |
run_id: str, | |
graph: "graph.FunctionGraph", | |
final_vars: List[str], | |
inputs: Dict[str, Any], | |
overrides: Dict[str, Any], | |
): | |
pass | |
async def post_graph_execute( | |
self, | |
*, | |
run_id: str, | |
graph: "graph.FunctionGraph", | |
success: bool, | |
error: Optional[Exception], | |
results: Optional[Dict[str, Any]], | |
): | |
pass | |
async def pre_node_execute( | |
self, | |
run_id: str, | |
node_: node.Node, | |
kwargs: Dict[str, Any], | |
task_id: Optional[str] = None | |
) -> None: | |
if self.mlflow_run_info is None: | |
return | |
trace = self.mlflow_client.start_trace( | |
name=f'hamilton.{node_.name}', | |
inputs=kwargs, | |
experiment_id=self.mlflow_run_info.experiment_id, | |
span_type=mlflow.entities.SpanType.CHAIN | |
) | |
self.mlflow_cache[task_id] = trace.request_id | |
async def post_node_execute( | |
self, | |
run_id: str, | |
node_: node.Node, | |
success: bool, | |
error: Optional[Exception], | |
result: Any, | |
task_id: Optional[str] = None, | |
**future_kwargs: dict, | |
) -> None: | |
if self.mlflow_run_info is None: | |
return | |
self.mlflow_client.end_trace( | |
request_id=self.mlflow_cache[task_id], | |
outputs=result | |
) | |
file_path = f"{node_.name}.pickle" | |
with open(file_path, "wb") as f: | |
pickle.dump(result, f) | |
self.mlflow_client.log_dict(self.mlflow_run_info.run_id, result, file_path) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment