Created
July 13, 2024 22:35
-
-
Save colonelpanic8/14e79edf21c754df0ee0ecb94be9ca79 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class HatchetClient: | |
def __init__( | |
self, | |
workflows: Dict[str, Any], | |
sync_hatchet: Hatchet, | |
async_hatchet: AsyncClientProxy[Hatchet], | |
async_sessionmaker: sqlalchemy_async_sessionmaker, | |
path_resolver: SegmentedVideoPathResolver, | |
): | |
self._workflows = workflows | |
self._hatchet = sync_hatchet | |
self._async_hatchet = async_hatchet | |
self._async_sessionmaker = async_sessionmaker | |
self._path_resolver = path_resolver | |
def validate_workflow_name(self, workflow_name: str) -> bool: | |
return any( | |
workflow_name == workflow.__class__.__name__ | |
for workflow in self._workflows.values() | |
) | |
async def async_workflow_id_by_name(self, workflow_name: str) -> str: | |
workflows = await self._async_hatchet.client.rest.workflow_list() | |
matching_workflows = [ | |
workflow for workflow in workflows.rows if workflow_name in workflow.name | |
] | |
if not matching_workflows: | |
raise ValueError(f"No workflow found with name '{workflow_name}'") | |
return matching_workflows[0].metadata.id | |
async def async_get_workflow_runs(self, workflow_id: str) -> List[Any]: | |
workflow_runs = await self._async_hatchet.client.rest.workflow_run_list( | |
workflow_id | |
) | |
return workflow_runs.rows | |
@cached(ttl=600, cache=Cache.MEMORY) | |
async def async_cached_get_hatchet_workflow_run(self, workflow_run_id: str): | |
workflow_run = await self._async_hatchet.client.rest.workflow_run_get( | |
workflow_run_id | |
) | |
if not workflow_run.job_runs: | |
return False | |
for job_run in workflow_run.job_runs: | |
if not job_run.step_runs: | |
continue | |
for step_run in job_run.step_runs: | |
if not step_run.input: | |
continue | |
return json.loads(step_run.input) | |
async def async_check_workflow_run_input( | |
self, workflow_run_id: str, parameters_to_match: Dict[str, Any] | |
) -> bool: | |
input_data = await self.async_cached_get_hatchet_workflow_run(workflow_run_id) | |
if all( | |
input_data.get("input", {}).get(key) == value | |
for key, value in parameters_to_match.items() | |
): | |
return True | |
return False | |
async def async_is_task_already_enqueued( | |
self, workflow_name: str, parameters_to_match: Dict[str, Any] | |
) -> bool: | |
try: | |
workflow_id = await self.async_workflow_id_by_name(workflow_name) | |
workflow_runs = await self.async_get_workflow_runs(workflow_id) | |
active_runs = [ | |
run for run in workflow_runs if run.status in ("RUNNING", "PENDING") | |
] | |
results = await gather( | |
*[ | |
self.async_check_workflow_run_input( | |
run.metadata.id, parameters_to_match | |
) | |
for run in active_runs | |
] | |
) | |
return any(results) | |
except Exception as e: | |
logger.error(f"Error checking if task is enqueued: {str(e)}") | |
return False |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment