|
"""Routing support for taskiq.""" |
|
|
|
from abc import ABC |
|
from concurrent.futures import Executor |
|
from typing import Any, AsyncGenerator, Callable, ParamSpec, TypeVar, overload |
|
|
|
from aio_pika import DeliveryMode, ExchangeType, Message, connect_robust |
|
from aio_pika.abc import AbstractChannel, AbstractRobustConnection |
|
from taskiq import ( |
|
AckableMessage, |
|
AsyncBroker, |
|
AsyncResultBackend, |
|
BrokerMessage, |
|
) |
|
from taskiq.acks import AckableMessage |
|
from taskiq.brokers.shared_broker import AsyncSharedBroker, SharedDecoratedTask |
|
from taskiq.kicker import AsyncKicker |
|
from taskiq.message import BrokerMessage |
|
from taskiq.receiver import Receiver |
|
|
|
_FuncParams = ParamSpec("_FuncParams") |
|
_ReturnType = TypeVar("_ReturnType") |
|
_T = TypeVar("_T") |
|
|
|
|
|
def parse_val( |
|
parse_func: Callable[[str], _T], |
|
target: str | None = None, |
|
) -> _T | None: |
|
"""Parse string to some value.""" |
|
if target is None: |
|
return None |
|
|
|
try: |
|
return parse_func(target) |
|
except ValueError: |
|
return None |
|
|
|
|
|
class AsyncTaskiqDecoratedTask(SharedDecoratedTask[_FuncParams, _ReturnType]): |
|
"""Taskiq task.""" |
|
|
|
def with_broker(self, broker: AsyncBroker) -> AsyncKicker[_FuncParams, _ReturnType]: |
|
"""Alias for task with broker. |
|
|
|
Alias for task.kicker().with_broker(broker). |
|
""" |
|
return self.kicker().with_broker(broker) |
|
|
|
def send_to(self, node: str, broker: AsyncBroker | None) -> AsyncKicker[_FuncParams, _ReturnType]: |
|
"""Alias for task with labels and broker. |
|
|
|
Alias for task.kicker().with_labels(node=node).with_broker(broker). |
|
""" |
|
kicker = self.kicker().with_labels(node=node) |
|
if broker is not None: |
|
kicker = kicker.with_broker(broker) |
|
return kicker |
|
|
|
|
|
class AsyncRoutedBroker(AsyncBroker, ABC): |
|
"""Async broker that routes tasks to specified nodes.""" |
|
|
|
async def startup(self) -> None: |
|
"""Startup broker.""" |
|
await super().startup() |
|
|
|
if self.is_worker_process: |
|
raise ValueError("Routed broker cannot be worker.") |
|
|
|
def listen(self) -> Any: |
|
"""Listen to broker. |
|
|
|
This method is not implemented for routed broker. |
|
""" |
|
raise ValueError("This is routed broker. It cannot listen to messages.") |
|
|
|
|
|
class AsyncRoutedSharedBroker(AsyncRoutedBroker, AsyncSharedBroker): |
|
"""Async broker that can be shared between tasks.""" |
|
|
|
def __init__(self) -> None: |
|
"""Init.""" |
|
super().__init__() |
|
self.decorator_class = AsyncTaskiqDecoratedTask |
|
self._default_broker: AsyncListenerBroker | None = None |
|
|
|
@overload |
|
def task( |
|
self, |
|
task_name: Callable[_FuncParams, _ReturnType], |
|
**labels: Any, |
|
) -> AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType]: |
|
... |
|
|
|
@overload |
|
def task( |
|
self, |
|
task_name: str | None = None, |
|
**labels: Any, |
|
) -> Callable[[Callable[_FuncParams, _ReturnType]], AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType],]: |
|
... |
|
|
|
def task(self, *args: Any, **kwargs: Any) -> Any: |
|
"""Create task.""" |
|
return super().task(*args, **kwargs) |
|
|
|
|
|
class AioPikaRoutedBroker(AsyncRoutedBroker): |
|
"""Async AMPQ broker what routes tasks to specified nodes.""" |
|
|
|
def __init__( |
|
self, |
|
url: str | None = None, |
|
result_backend: AsyncResultBackend[Any] | None = None, |
|
task_id_generator: Callable[[], str] | None = None, |
|
exchange_name: str = "taskiq", |
|
exchange_type: ExchangeType = ExchangeType.DIRECT, |
|
declare_exchange: bool = True, |
|
) -> None: |
|
"""Init.""" |
|
super().__init__(result_backend, task_id_generator) |
|
|
|
self._url = url |
|
self._exchange_name = exchange_name |
|
self._declare_exchange = declare_exchange |
|
self._exchange_type = exchange_type |
|
self._connection: AbstractRobustConnection |
|
self._channel: AbstractChannel |
|
|
|
async def startup(self) -> None: |
|
"""Startup broker.""" |
|
await super().startup() |
|
|
|
self._connection = await connect_robust(self._url) |
|
self._channel = await self._connection.channel() |
|
|
|
if self._declare_exchange: |
|
await self._channel.declare_exchange( |
|
self._exchange_name, |
|
self._exchange_type, |
|
) |
|
|
|
async def shutdown(self) -> None: |
|
"""Shutdown broker.""" |
|
await super().shutdown() |
|
|
|
if getattr(self, "_connection", None) is not None: |
|
await self._connection.close() |
|
if getattr(self, "_connection", None) is not None: |
|
await self._channel.close() |
|
|
|
async def kick(self, message: BrokerMessage) -> None: |
|
"""Kick message to broker.""" |
|
node = message.labels.get("node") |
|
if node is None: |
|
raise ValueError("Node cannot be empty") |
|
|
|
message_base_params: dict[str, Any] = { |
|
"body": message.message, |
|
"headers": { |
|
"task_id": message.task_id, |
|
"task_name": message.task_name, |
|
**message.labels, |
|
}, |
|
"delivery_mode": DeliveryMode.PERSISTENT, |
|
"priority": parse_val( |
|
int, |
|
message.labels.get("priority"), |
|
), |
|
} |
|
rmq_message: Message = Message(**message_base_params) |
|
|
|
exchange = await self._channel.get_exchange(self._exchange_name) |
|
await exchange.publish(rmq_message, routing_key=node) |
|
|
|
|
|
class AsyncListenerBroker(AsyncBroker, ABC): |
|
"""Async broker that listens to specified node.""" |
|
|
|
node_name: str |
|
|
|
async def startup(self) -> None: |
|
"""Startup broker.""" |
|
await super().startup() |
|
|
|
if self.is_worker_process: |
|
raise ValueError("Listener broker cannot be worker.") |
|
|
|
async def kick(self, message: BrokerMessage) -> None: |
|
"""Kick message to broker. |
|
|
|
This method is not implemented for listener broker. |
|
""" |
|
raise ValueError("This is listener broker. It cannot kick messages.") |
|
|
|
|
|
class AioPikaListenerBroker(AsyncListenerBroker): |
|
"""Async AMPQ broker what listens to specified node.""" |
|
|
|
def __init__( |
|
self, |
|
node_name: str, |
|
url: str | None = None, |
|
result_backend: AsyncResultBackend[Any] | None = None, |
|
task_id_generator: Callable[[], str] | None = None, |
|
exchange_name: str = "taskiq", |
|
exchange_type: ExchangeType = ExchangeType.DIRECT, |
|
declare_exchange: bool = True, |
|
qos: int = 1, |
|
) -> None: |
|
"""Init.""" |
|
super().__init__(result_backend, task_id_generator) |
|
|
|
self._url = url |
|
self.node_name = node_name |
|
self._exchange_name = exchange_name |
|
self._declare_exchange = declare_exchange |
|
self._exchange_type = exchange_type |
|
self._connection: AbstractRobustConnection |
|
self._channel: AbstractChannel |
|
self._qos = qos |
|
|
|
async def startup(self) -> None: |
|
"""Startup broker.""" |
|
await super().startup() |
|
|
|
self._connection = await connect_robust(self._url) |
|
self._channel = await self._connection.channel() |
|
|
|
if self._declare_exchange: |
|
await self._channel.declare_exchange( |
|
self._exchange_name, |
|
self._exchange_type, |
|
) |
|
|
|
self._queue = await self._channel.declare_queue(self.node_name) |
|
|
|
await self._queue.bind(self._exchange_name, self.node_name) |
|
|
|
async def kick(self, message: BrokerMessage) -> None: |
|
"""Kick message to broker.""" |
|
raise ValueError("This is receiver. To kick message, use router.") |
|
|
|
async def listen(self) -> AsyncGenerator[bytes | AckableMessage, None]: |
|
"""Listen to broker.""" |
|
if self._channel is None: |
|
raise ValueError("Call startup before starting listening.") |
|
await self._channel.set_qos(prefetch_count=self._qos) |
|
async with self._queue.iterator() as iterator: |
|
async for message in iterator: |
|
yield AckableMessage( |
|
data=message.body, |
|
ack=message.ack, |
|
) |
|
|
|
|
|
class RoutedReceiver(Receiver): |
|
"""Receiver for routed broker.""" |
|
|
|
def __init__( |
|
self, |
|
broker: AsyncRoutedSharedBroker, |
|
executor: Executor | None = None, |
|
validate_params: bool = True, |
|
max_async_tasks: int | None = None, |
|
max_prefetch: int = 0, |
|
propagate_exceptions: bool = True, |
|
node_name: str | None = None, |
|
) -> None: |
|
"""Init.""" |
|
if node_name is None: |
|
raise ValueError("Please specify node_name") |
|
|
|
default_broker = broker._default_broker |
|
if default_broker is None: |
|
raise ValueError("Please specify default broker") |
|
default_broker.node_name = node_name |
|
|
|
super().__init__( |
|
default_broker, |
|
executor, |
|
validate_params, |
|
max_async_tasks, |
|
max_prefetch, |
|
propagate_exceptions, |
|
) |