Skip to content

Instantly share code, notes, and snippets.

@colonelpanic8
Created July 13, 2024 22:35
Show Gist options
  • Save colonelpanic8/14e79edf21c754df0ee0ecb94be9ca79 to your computer and use it in GitHub Desktop.
Save colonelpanic8/14e79edf21c754df0ee0ecb94be9ca79 to your computer and use it in GitHub Desktop.
class HatchetClient:
def __init__(
self,
workflows: Dict[str, Any],
sync_hatchet: Hatchet,
async_hatchet: AsyncClientProxy[Hatchet],
async_sessionmaker: sqlalchemy_async_sessionmaker,
path_resolver: SegmentedVideoPathResolver,
):
self._workflows = workflows
self._hatchet = sync_hatchet
self._async_hatchet = async_hatchet
self._async_sessionmaker = async_sessionmaker
self._path_resolver = path_resolver
def validate_workflow_name(self, workflow_name: str) -> bool:
return any(
workflow_name == workflow.__class__.__name__
for workflow in self._workflows.values()
)
async def async_workflow_id_by_name(self, workflow_name: str) -> str:
workflows = await self._async_hatchet.client.rest.workflow_list()
matching_workflows = [
workflow for workflow in workflows.rows if workflow_name in workflow.name
]
if not matching_workflows:
raise ValueError(f"No workflow found with name '{workflow_name}'")
return matching_workflows[0].metadata.id
async def async_get_workflow_runs(self, workflow_id: str) -> List[Any]:
workflow_runs = await self._async_hatchet.client.rest.workflow_run_list(
workflow_id
)
return workflow_runs.rows
@cached(ttl=600, cache=Cache.MEMORY)
async def async_cached_get_hatchet_workflow_run(self, workflow_run_id: str):
workflow_run = await self._async_hatchet.client.rest.workflow_run_get(
workflow_run_id
)
if not workflow_run.job_runs:
return False
for job_run in workflow_run.job_runs:
if not job_run.step_runs:
continue
for step_run in job_run.step_runs:
if not step_run.input:
continue
return json.loads(step_run.input)
async def async_check_workflow_run_input(
self, workflow_run_id: str, parameters_to_match: Dict[str, Any]
) -> bool:
input_data = await self.async_cached_get_hatchet_workflow_run(workflow_run_id)
if all(
input_data.get("input", {}).get(key) == value
for key, value in parameters_to_match.items()
):
return True
return False
async def async_is_task_already_enqueued(
self, workflow_name: str, parameters_to_match: Dict[str, Any]
) -> bool:
try:
workflow_id = await self.async_workflow_id_by_name(workflow_name)
workflow_runs = await self.async_get_workflow_runs(workflow_id)
active_runs = [
run for run in workflow_runs if run.status in ("RUNNING", "PENDING")
]
results = await gather(
*[
self.async_check_workflow_run_input(
run.metadata.id, parameters_to_match
)
for run in active_runs
]
)
return any(results)
except Exception as e:
logger.error(f"Error checking if task is enqueued: {str(e)}")
return False
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment