Skip to content

Instantly share code, notes, and snippets.

@cofob
Last active July 24, 2023 09:06
Show Gist options
  • Select an option

  • Save cofob/719924897bddaa7504e27739e8122572 to your computer and use it in GitHub Desktop.

Select an option

Save cofob/719924897bddaa7504e27739e8122572 to your computer and use it in GitHub Desktop.
Taskiq routing support

This code does roughly the same thing as this example, but in a more optimized and flexible way, because:

  1. it does not open channels for each node on both client and server. In the example there were 3 queues and each vorker would open all 3 channels that it never used. This is achieved by separating the brokers into Listener and Sender (Routed).
  2. it supports dynamic sending to the specific node, the code from the example only supported predefined queues.
  3. Added more convenient aliases to remove the duplication of with_broker(broker).with_labels(node=node_name) on each task call.

It also breaks compatibility with taskiq_aio_pika a bit, since it uses a different Exchange type and uses the node name as routing_key.

But despite the pros it doesn't support some taskiq features and is more complex to use. For example worker will not be able to call context.requeue because listener broker has no write channel.

The startup is almost the same as in the example:

taskiq worker worker:async_shared_broker --receiver router:RoutedReceiver --receiver_arg node_name=node1

and

python3 test.py node1
"""Routing support for taskiq."""
from abc import ABC
from concurrent.futures import Executor
from typing import Any, AsyncGenerator, Callable, ParamSpec, TypeVar, overload
from aio_pika import DeliveryMode, ExchangeType, Message, connect_robust
from aio_pika.abc import AbstractChannel, AbstractRobustConnection
from taskiq import (
AckableMessage,
AsyncBroker,
AsyncResultBackend,
BrokerMessage,
)
from taskiq.acks import AckableMessage
from taskiq.brokers.shared_broker import AsyncSharedBroker, SharedDecoratedTask
from taskiq.kicker import AsyncKicker
from taskiq.message import BrokerMessage
from taskiq.receiver import Receiver
_FuncParams = ParamSpec("_FuncParams")
_ReturnType = TypeVar("_ReturnType")
_T = TypeVar("_T")
def parse_val(
parse_func: Callable[[str], _T],
target: str | None = None,
) -> _T | None:
"""Parse string to some value."""
if target is None:
return None
try:
return parse_func(target)
except ValueError:
return None
class AsyncTaskiqDecoratedTask(SharedDecoratedTask[_FuncParams, _ReturnType]):
"""Taskiq task."""
def with_broker(self, broker: AsyncBroker) -> AsyncKicker[_FuncParams, _ReturnType]:
"""Alias for task with broker.
Alias for task.kicker().with_broker(broker).
"""
return self.kicker().with_broker(broker)
def send_to(self, node: str, broker: AsyncBroker | None) -> AsyncKicker[_FuncParams, _ReturnType]:
"""Alias for task with labels and broker.
Alias for task.kicker().with_labels(node=node).with_broker(broker).
"""
kicker = self.kicker().with_labels(node=node)
if broker is not None:
kicker = kicker.with_broker(broker)
return kicker
class AsyncRoutedBroker(AsyncBroker, ABC):
"""Async broker that routes tasks to specified nodes."""
async def startup(self) -> None:
"""Startup broker."""
await super().startup()
if self.is_worker_process:
raise ValueError("Routed broker cannot be worker.")
def listen(self) -> Any:
"""Listen to broker.
This method is not implemented for routed broker.
"""
raise ValueError("This is routed broker. It cannot listen to messages.")
class AsyncRoutedSharedBroker(AsyncRoutedBroker, AsyncSharedBroker):
"""Async broker that can be shared between tasks."""
def __init__(self) -> None:
"""Init."""
super().__init__()
self.decorator_class = AsyncTaskiqDecoratedTask
self._default_broker: AsyncListenerBroker | None = None
@overload
def task(
self,
task_name: Callable[_FuncParams, _ReturnType],
**labels: Any,
) -> AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType]:
...
@overload
def task(
self,
task_name: str | None = None,
**labels: Any,
) -> Callable[[Callable[_FuncParams, _ReturnType]], AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType],]:
...
def task(self, *args: Any, **kwargs: Any) -> Any:
"""Create task."""
return super().task(*args, **kwargs)
class AioPikaRoutedBroker(AsyncRoutedBroker):
"""Async AMPQ broker what routes tasks to specified nodes."""
def __init__(
self,
url: str | None = None,
result_backend: AsyncResultBackend[Any] | None = None,
task_id_generator: Callable[[], str] | None = None,
exchange_name: str = "taskiq",
exchange_type: ExchangeType = ExchangeType.DIRECT,
declare_exchange: bool = True,
) -> None:
"""Init."""
super().__init__(result_backend, task_id_generator)
self._url = url
self._exchange_name = exchange_name
self._declare_exchange = declare_exchange
self._exchange_type = exchange_type
self._connection: AbstractRobustConnection
self._channel: AbstractChannel
async def startup(self) -> None:
"""Startup broker."""
await super().startup()
self._connection = await connect_robust(self._url)
self._channel = await self._connection.channel()
if self._declare_exchange:
await self._channel.declare_exchange(
self._exchange_name,
self._exchange_type,
)
async def shutdown(self) -> None:
"""Shutdown broker."""
await super().shutdown()
if getattr(self, "_connection", None) is not None:
await self._connection.close()
if getattr(self, "_connection", None) is not None:
await self._channel.close()
async def kick(self, message: BrokerMessage) -> None:
"""Kick message to broker."""
node = message.labels.get("node")
if node is None:
raise ValueError("Node cannot be empty")
message_base_params: dict[str, Any] = {
"body": message.message,
"headers": {
"task_id": message.task_id,
"task_name": message.task_name,
**message.labels,
},
"delivery_mode": DeliveryMode.PERSISTENT,
"priority": parse_val(
int,
message.labels.get("priority"),
),
}
rmq_message: Message = Message(**message_base_params)
exchange = await self._channel.get_exchange(self._exchange_name)
await exchange.publish(rmq_message, routing_key=node)
class AsyncListenerBroker(AsyncBroker, ABC):
"""Async broker that listens to specified node."""
node_name: str
async def startup(self) -> None:
"""Startup broker."""
await super().startup()
if self.is_worker_process:
raise ValueError("Listener broker cannot be worker.")
async def kick(self, message: BrokerMessage) -> None:
"""Kick message to broker.
This method is not implemented for listener broker.
"""
raise ValueError("This is listener broker. It cannot kick messages.")
class AioPikaListenerBroker(AsyncListenerBroker):
"""Async AMPQ broker what listens to specified node."""
def __init__(
self,
node_name: str,
url: str | None = None,
result_backend: AsyncResultBackend[Any] | None = None,
task_id_generator: Callable[[], str] | None = None,
exchange_name: str = "taskiq",
exchange_type: ExchangeType = ExchangeType.DIRECT,
declare_exchange: bool = True,
qos: int = 1,
) -> None:
"""Init."""
super().__init__(result_backend, task_id_generator)
self._url = url
self.node_name = node_name
self._exchange_name = exchange_name
self._declare_exchange = declare_exchange
self._exchange_type = exchange_type
self._connection: AbstractRobustConnection
self._channel: AbstractChannel
self._qos = qos
async def startup(self) -> None:
"""Startup broker."""
await super().startup()
self._connection = await connect_robust(self._url)
self._channel = await self._connection.channel()
if self._declare_exchange:
await self._channel.declare_exchange(
self._exchange_name,
self._exchange_type,
)
self._queue = await self._channel.declare_queue(self.node_name)
await self._queue.bind(self._exchange_name, self.node_name)
async def kick(self, message: BrokerMessage) -> None:
"""Kick message to broker."""
raise ValueError("This is receiver. To kick message, use router.")
async def listen(self) -> AsyncGenerator[bytes | AckableMessage, None]:
"""Listen to broker."""
if self._channel is None:
raise ValueError("Call startup before starting listening.")
await self._channel.set_qos(prefetch_count=self._qos)
async with self._queue.iterator() as iterator:
async for message in iterator:
yield AckableMessage(
data=message.body,
ack=message.ack,
)
class RoutedReceiver(Receiver):
"""Receiver for routed broker."""
def __init__(
self,
broker: AsyncRoutedSharedBroker,
executor: Executor | None = None,
validate_params: bool = True,
max_async_tasks: int | None = None,
max_prefetch: int = 0,
propagate_exceptions: bool = True,
node_name: str | None = None,
) -> None:
"""Init."""
if node_name is None:
raise ValueError("Please specify node_name")
default_broker = broker._default_broker
if default_broker is None:
raise ValueError("Please specify default broker")
default_broker.node_name = node_name
super().__init__(
default_broker,
executor,
validate_params,
max_async_tasks,
max_prefetch,
propagate_exceptions,
)
from asyncio import run
from sys import argv
from taskiq import Context, TaskiqDepends
from router import AioPikaRoutedBroker, AsyncRoutedSharedBroker
AMQP_URL = "amqp://guest:guest@localhost:5672/"
async_shared_broker = AsyncRoutedSharedBroker()
@async_shared_broker.task
async def test_task(context: Context = TaskiqDepends()) -> None:
print("Node", context.broker.node_name) # type: ignore[attr-defined]
async def client_main() -> None:
node_name = argv[1]
amqp_broker = AioPikaRoutedBroker(AMQP_URL)
await amqp_broker.startup()
await test_task.send_to(node_name, amqp_broker).kiq()
print(f"Sent task to {node_name}")
if __name__ == "__main__":
run(client_main())
from client import AMQP_URL, async_shared_broker
from router import AioPikaListenerBroker
broker = AioPikaListenerBroker(AMQP_URL)
async_shared_broker.default_broker(broker)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment