Skip to content

Instantly share code, notes, and snippets.

@dmexs
Created February 21, 2025 19:56
Show Gist options
  • Save dmexs/3f7b697bab806494051b17a76bc4bab7 to your computer and use it in GitHub Desktop.
Save dmexs/3f7b697bab806494051b17a76bc4bab7 to your computer and use it in GitHub Desktop.
from hamilton.lifecycle import base
from hamilton import node, graph
import logging
from typing import Type, Any, Optional, Dict, List
import mlflow
import pickle
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
class MLFlowTrackerAsync(
base.BasePreNodeExecuteAsync,
base.BasePostNodeExecuteAsync,
base.BasePreGraphExecuteAsync,
base.BasePostGraphExecuteAsync
):
mlflow_cache = {}
mlflow_run_info = None
mlflow_client: Optional[mlflow.MlflowClient] = None
mlflow_request_id: Optional[str] = None
def __init__(self, *args, **kwargs):
if mlflow.active_run():
self.mlflow_client = mlflow.MlflowClient()
self.mlflow_run_info = mlflow.active_run().info
logger.debug(f"MLFlowTrackerAsync: {self.mlflow_run_info.run_id}")
async def pre_graph_execute(
self,
*,
run_id: str,
graph: "graph.FunctionGraph",
final_vars: List[str],
inputs: Dict[str, Any],
overrides: Dict[str, Any],
):
pass
async def post_graph_execute(
self,
*,
run_id: str,
graph: "graph.FunctionGraph",
success: bool,
error: Optional[Exception],
results: Optional[Dict[str, Any]],
):
pass
async def pre_node_execute(
self,
run_id: str,
node_: node.Node,
kwargs: Dict[str, Any],
task_id: Optional[str] = None
) -> None:
if self.mlflow_run_info is None:
return
trace = self.mlflow_client.start_trace(
name=f'hamilton.{node_.name}',
inputs=kwargs,
experiment_id=self.mlflow_run_info.experiment_id,
span_type=mlflow.entities.SpanType.CHAIN
)
self.mlflow_cache[task_id] = trace.request_id
async def post_node_execute(
self,
run_id: str,
node_: node.Node,
success: bool,
error: Optional[Exception],
result: Any,
task_id: Optional[str] = None,
**future_kwargs: dict,
) -> None:
if self.mlflow_run_info is None:
return
self.mlflow_client.end_trace(
request_id=self.mlflow_cache[task_id],
outputs=result
)
file_path = f"{node_.name}.pickle"
with open(file_path, "wb") as f:
pickle.dump(result, f)
self.mlflow_client.log_dict(self.mlflow_run_info.run_id, result, file_path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment