Created
April 14, 2024 03:33
-
-
Save CyrusNuevoDia/76e3daaa23a39e4b2bed02a9e4eca732 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# License: MIT | |
import random | |
import typing as t | |
import logging | |
from abc import ABC | |
from dataclasses import dataclass | |
from anyio import create_memory_object_stream, create_task_group, fail_after | |
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream | |
logger = logging.getLogger(__name__) | |
@dataclass | |
class Call[T]: | |
request: T | |
_callback_stream: MemoryObjectSendStream[T] | |
@dataclass | |
class Cast[T]: | |
request: T | |
class Actor[Req](ABC): | |
""" | |
An actor that can handle call and cast messages. | |
Call messages wait for a response, while cast messages are fire-and-forget. | |
""" | |
inbox: MemoryObjectSendStream[Req] | |
mailbox: MemoryObjectReceiveStream[Req] | |
def __init__(self, mailbox_size: int = 128): | |
self.inbox, self.mailbox = create_memory_object_stream(mailbox_size) | |
async def init(self): | |
"""Any async startup code should go here.""" | |
async def handle_call(self, sender: "Actor", message: Req): | |
"""Handle a call message. Return a response.""" | |
async def handle_cast(self, sender: "Actor", message: Req): | |
"""Handle a cast message. No response.""" | |
async def call[ | |
Res | |
](self, receiver: "Actor", req: Req, timeout: float | None = 5) -> Res: | |
send_stream, receive_stream = create_memory_object_stream(1) | |
with fail_after(timeout): | |
receiver.inbox.send_nowait(Call(req, _callback_stream=send_stream)) | |
response = await receive_stream.receive() | |
return response | |
def cast(self, receiver: "Actor", message: Req) -> None: | |
receiver.inbox.send_nowait(Cast(message)) | |
async def run(self): | |
await self.init() | |
async with self.mailbox: | |
while True: | |
async for message in self.mailbox: | |
await process_message(self, message) | |
async def process_message(actor: Actor, message: Call | Cast): | |
match message: | |
case Call(): | |
response = await actor.handle_call(message.request) | |
message._callback_stream.send_nowait(response) | |
case Cast(): | |
await actor.handle_cast(message.request) | |
async def run(actor: Actor): | |
await actor.run() | |
async def run_supervised(actor: Actor, max_restarts: int = 3): | |
try: | |
async with create_task_group() as tg: | |
tg.start_soon(run, actor) | |
except Exception as exc: | |
logger.exception("Actor crashed", exc) | |
max_restarts -= 1 | |
if max_restarts == 0: | |
logger.error("Max restarts reached. Actor will not be restarted.") | |
raise exc | |
await run_supervised(actor, max_restarts) | |
async def run_one_for_one(actors: list[Actor], max_restarts: int = 3): | |
async with create_task_group() as tg: | |
for actor in actors: | |
tg.start_soon(run_supervised, actor, max_restarts) | |
async def run_one_for_all(actors: list[Actor], max_restarts: int = 3): | |
try: | |
async with create_task_group() as tg: | |
for actor in actors: | |
tg.start_soon(run, actor) | |
except Exception as exc: | |
logger.exception("Actor crashed", exc) | |
max_restarts -= 1 | |
if max_restarts == 0: | |
logger.error("Max restarts reached. Actor will not be restarted.") | |
raise exc | |
await run_one_for_all(actors, max_restarts) | |
type RestartStrategy = t.Literal["one_for_one", "one_for_all"] | |
@dataclass | |
class Supervisor(Actor): | |
children: list[Actor] | |
strategy: RestartStrategy = "one_for_one" | |
max_restarts: int = 3 | |
async def run(self, actors: list[Actor]): | |
async with create_task_group() as tg: | |
tg.start_soon(super().run) | |
runner = { | |
"one_for_one": run_one_for_one, | |
"one_for_all": run_one_for_all, | |
}[self.strategy] | |
tg.start_soon(runner, actors, self.max_restarts) | |
@dataclass | |
class WorkerPool(Supervisor, ABC): | |
""" | |
A worker pool that randomly routes messages to workers. | |
""" | |
async def router(self, message: Call | Cast) -> Actor: | |
"""Select a worker to handle the message.""" | |
return random.choice(self.children) | |
async def run(self): | |
await self.init() | |
async with create_task_group() as tg: | |
tg.start_soon(super().run) | |
while True: | |
async for message in self.mailbox: | |
worker = await self.router(message) | |
tg.start_soon(process_message, worker, message) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment