Created
September 8, 2022 22:58
-
-
Save acookin/aa5f0d718997e7a645cbc160e2c94a16 to your computer and use it in GitHub Desktop.
Agent + flow to cleanup local process flow runs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Runs prefect agent | |
Keeps track of flow runs the agent kicked off, | |
and sets a DateTime block storage to NOW() for | |
every running flow, every couple of seconds. | |
Also, has a cleanup function that will set the state of | |
every running flow to AwaitingRetry for each running | |
flow. | |
This is meant to be used in a context where the agent | |
process is PID 1 (e.g. in a container), and local process | |
infrastructure is being used to run flows. This may be the | |
only reasonable way to run agents/flows in a restrictive | |
infrastructure, such as heroku, but I wouldn't recommend | |
running your agents or flows this way unless you had to. | |
case in a restricted managed infrastructure, such as heroku. | |
""" | |
import asyncio | |
import logging | |
from datetime import datetime | |
from typing import List, Set | |
from uuid import UUID | |
import anyio | |
import pendulum | |
from prefect.agent import OrionAgent | |
from prefect.blocks.system import DateTime | |
from prefect.client import OrionClient, get_client | |
from prefect.orion.schemas.states import AwaitingRetry, StateType | |
logger = logging.getLogger(__name__) | |
RUNNING_STATE_TYPES = [StateType.PENDING, StateType.RUNNING] | |
async def flow_is_running(client: OrionClient, run_id: UUID) -> bool: | |
flow_run = await client.read_flow_run(run_id) | |
return flow_run.state.type in RUNNING_STATE_TYPES | |
async def retry_flow(client: OrionClient, run_id: UUID): | |
if await flow_is_running(client, run_id): | |
failed_state = AwaitingRetry(datetime.now()) | |
await client.set_flow_run_state(run_id, failed_state) | |
class PrefectAgent: | |
def __init__(self, queues: List[str]) -> None: | |
self.work_queues = set(queues) | |
self.running_flows: Set[UUID] = set() | |
async def _ping_for_flow_run(self, flow_run_id: UUID): | |
"""Ping for a running flow, so external processes can see this flow is actually | |
running.""" | |
try: | |
storage_key = f"flowrun-{str(flow_run_id)}" | |
block = DateTime(name=storage_key, value=pendulum.now()) | |
await block.save(storage_key, overwrite=True) | |
except Exception as e: | |
logger.error("Error pinging for flow run: %s", str(e)) | |
async def _reset_running_flows(self): | |
still_running = set() | |
async with get_client() as client: | |
for run_id in self.running_flows: | |
if await flow_is_running(client, run_id): | |
still_running.add(run_id) | |
self.running_flows = still_running | |
async def start(self): | |
logger.info("Starting prefect agent...") | |
async with OrionAgent(work_queues=self.work_queues) as agent: | |
while True: | |
flow_runs = await agent.get_and_submit_flow_runs() | |
for r in flow_runs: | |
self.running_flows.add(r.id) | |
await anyio.sleep(2.0) | |
ping_tasks = [] | |
for flow_run_id in self.running_flows: | |
ping_tasks.append(self._ping_for_flow_run(flow_run_id)) | |
asyncio.gather(*ping_tasks, self._reset_running_flows()) | |
async def cleanup_flow_runs(self): | |
async with get_client() as client: | |
await asyncio.gather(*[retry_flow(client, f) for f in self.running_flows]) | |
async def main(): | |
prefect_agent = PrefectAgent(queues=["default"]) | |
try: | |
asyncio.run(prefect_agent.start()) | |
except KeyboardInterrupt: | |
logger.info("Cleaning up flows") | |
asyncio.run(prefect_agent.cleanup_flow_runs()) | |
if __name__ == "__main__": | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Flow that can cancel flow runs that have not "pinged" their | |
DateTime storage block for some amount of time. | |
Helps cleanup any "zombies" left from the agent above that | |
may have gotten sigkilled before finishing its cleanup process. | |
""" | |
import asyncio | |
from datetime import timedelta | |
from typing import List | |
import pendulum | |
from prefect import flow, get_run_logger, task | |
from prefect.client import OrionClient, get_client | |
from prefect.orion.schemas.core import FlowRun | |
from prefect.orion.schemas.filters import (FlowRunFilter, FlowRunFilterState, | |
FlowRunFilterStateType) | |
from prefect.orion.schemas.states import Cancelled, StateType | |
MAX_TIME = timedelta(minutes=5) | |
AGENT_NAMES = ["default"] | |
def get_agent_tags(): | |
tags = [] | |
for agent in AGENT_NAMES: | |
tags.append(f"agent:{agent}") | |
return tags | |
async def get_matching_flow_runs( | |
client: OrionClient, state_types: List[StateType], tags: List[str] | |
) -> FlowRun: | |
flow_runs = await client.read_flow_runs( | |
flow_run_filter=FlowRunFilter( | |
state=FlowRunFilterState(type=FlowRunFilterStateType(any_=state_types)), | |
) | |
) | |
ret = [] | |
for run in flow_runs: | |
matching_run_tags = [t for t in run.tags if t in tags] | |
if len(matching_run_tags) > 0: | |
ret.append(run) | |
return ret | |
@task | |
async def get_stale_runs(): | |
stale_run_ids = [] | |
async with get_client() as c: | |
scheduled_runs = await get_matching_flow_runs( | |
c, [StateType.SCHEDULED], get_agent_tags() | |
) | |
cutoff_time = pendulum.now() - timedelta( | |
minutes=MAX_TIME, | |
) | |
for run in scheduled_runs: | |
if ( | |
"auto-scheduled" not in run.tags | |
and run.next_scheduled_start_time < cutoff_time | |
): | |
stale_run_ids.append(run.id) | |
return stale_run_ids | |
@task(retries=3) | |
async def cancel_run(flow_run_id): | |
logger = get_run_logger() | |
logger.info(f"Cancelling flow run {flow_run_id}") | |
async with get_client() as c: | |
cancelled_state = Cancelled( | |
message="Cancelled because run was in Scheduled state beyond time limit" | |
) | |
await c.set_flow_run_state(flow_run_id, cancelled_state) | |
@flow() | |
async def cancel_stale_runs(): | |
stale_runs = await get_stale_runs() | |
cancel_tasks = [cancel_run(str(run_id)) for run_id in stale_runs] | |
asyncio.gather(*cancel_tasks) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment