Created
August 4, 2025 05:25
-
-
Save nrbnlulu/f07c1475e0232d32b48bbd73fed56f6a to your computer and use it in GitHub Desktop.
pubsub.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from __future__ import annotations | |
| import abc | |
| import asyncio | |
| import contextlib | |
| import functools | |
| import uuid | |
| from collections.abc import AsyncIterator, Awaitable, Hashable, Iterable | |
| from dataclasses import dataclass | |
| from datetime import datetime, timedelta | |
| from typing import TYPE_CHECKING, Any, Concatenate, NamedTuple, Protocol, cast, overload, override | |
| import msgspec | |
| import redis.asyncio as redis | |
| from aiostream import stream | |
| from loguru import logger | |
| from result import Err, Ok | |
| from t5hob_sdk.bases.err import ResourceNotFoundErr | |
| from t5hob_sdk.bases.patternclass import BasePatternClass | |
| from t5hob_sdk.bases.settings import RedisSettings | |
| from t5hob_sdk.serde import Decoder, default_encode, get_decoder | |
| from t5hob_sdk.types import Broadcast, ContextedAsyncIterator, DecodedT, Encoder, SerializeAble | |
| if TYPE_CHECKING: | |
| from collections.abc import AsyncGenerator, Callable | |
| from result import Result | |
| type ChannelDecoder[T: msgspec.Struct] = Callable[[Any, bytes], T] | |
| def get_decoder_fn[T: msgspec.Struct](t: type[T]) -> ChannelDecoder[T]: | |
| decoder = get_decoder(t) | |
| return lambda _, msg: decoder.decode(msg) | |
| @dataclass(slots=True) | |
| class SubscribeOptions[T, T_pattern: BasePatternClass]: | |
| path: bytes | |
| pattern_klass: type[T_pattern] | |
| dec_fn: Callable[[T_pattern, bytes], T] | |
| class SubscriptionResult[P, V](NamedTuple): | |
| pattern: P | |
| value: V | |
| subscriber_path: bytes | |
| """The original path used to subscribe""" | |
| class SubscriptionManager: | |
| """ | |
| Public interface for subscribing to broadcast messages. | |
| Register message handlers for each path. | |
| It is safe to subscribe many times with the same path, | |
| it won't create multiple subscriptions. | |
| Whene `(p)subscribe) is called, will look if there is already a handler for the path and return it, | |
| otherwise it will create a new handler and cache that handler for further subscriptions. | |
| """ | |
| def __init__(self, broadcast: Broadcast) -> None: | |
| self.broadcast = broadcast | |
| self.pubsub = broadcast.pubsub() | |
| self._phandlers: dict[bytes, _MessageHandler] = {} | |
| self._handlers: dict[bytes, _MessageHandler] = {} | |
| async def execute(self) -> None: | |
| while True: | |
| if self._handlers or self._phandlers: | |
| msg = await self.pubsub.get_message(ignore_subscribe_messages=True) | |
| if msg is not None: | |
| channel = msg["channel"] | |
| if (pattern := msg["pattern"]) is not None and ( | |
| handler := self._phandlers.get(pattern) | |
| ): | |
| try: | |
| await handler.handle_message(channel, msg["data"], pattern) | |
| except Exception as e: # pragma: no cover # noqa: BLE001 | |
| logger.error(f"Error handling message {msg} on channel {channel}: {e}") | |
| elif handler := self._handlers.get(channel): | |
| try: | |
| await handler.handle_message(channel, msg["data"], channel) | |
| except Exception as e: # pragma: no cover # noqa: BLE001 | |
| logger.error(f"Error handling message {msg} on channel {channel}: {e}") | |
| await asyncio.sleep(0.01) | |
| # NOTE: we are using context managers to be able to initialize the subscription # noqa: ERA001 | |
| # NOTE: without having to iterate it at the first time. # noqa: ERA001 | |
| @contextlib.asynccontextmanager | |
| async def subscribe[T, T_pattern: BasePatternClass]( | |
| self, options: SubscribeOptions[T, T_pattern] | |
| ) -> ContextedAsyncIterator[SubscriptionResult[T_pattern, T]]: | |
| if not (handler := self._handlers.get(options.path)): | |
| handler = _MessageHandler(dec_fn=options.dec_fn, pattern_cls=options.pattern_klass) | |
| await self.pubsub.subscribe(options.path) | |
| self._handlers[options.path] = handler | |
| try: | |
| sub_id, agen = handler.subscribe() | |
| yield agen | |
| finally: | |
| with contextlib.suppress(NameError): | |
| handler.unsubscribe(sub_id) # pyright: ignore [reportPossiblyUnboundVariable] | |
| if not handler.has_listeners(): | |
| self._handlers.pop(options.path) | |
| await self.pubsub.unsubscribe(options.path) | |
| @contextlib.asynccontextmanager | |
| async def psubscribe[T, T_pattern: BasePatternClass]( | |
| self, options: SubscribeOptions[T, T_pattern] | |
| ) -> ContextedAsyncIterator[SubscriptionResult[T_pattern, T]]: | |
| """Subscribe to a pattern.""" | |
| if not (handler := self._phandlers.get(options.path)): | |
| handler = _MessageHandler(dec_fn=options.dec_fn, pattern_cls=options.pattern_klass) | |
| await self.pubsub.psubscribe(options.path) | |
| self._phandlers[options.path] = handler | |
| try: | |
| sub_id, gen = handler.subscribe() | |
| yield gen | |
| finally: | |
| with contextlib.suppress(NameError): | |
| handler.unsubscribe(sub_id) # pyright: ignore [reportPossiblyUnboundVariable] | |
| if not handler.has_listeners(): | |
| self._phandlers.pop(options.path) | |
| await self.pubsub.punsubscribe(options.path) | |
| @overload | |
| @contextlib.asynccontextmanager | |
| async def batch_psubscribe[T1, TP1: BasePatternClass, T2, TP2: BasePatternClass]( | |
| self, __op1: SubscribeOptions[T1, TP1], __op2: SubscribeOptions[T2, TP2], / | |
| ) -> AsyncGenerator[ | |
| AsyncGenerator[ | |
| tuple[SubscriptionResult[TP1, T1] | None, SubscriptionResult[TP2, T2] | None] | |
| ] | |
| ]: ... | |
| @overload | |
| @contextlib.asynccontextmanager | |
| async def batch_psubscribe[ | |
| T1, | |
| TP1: BasePatternClass, | |
| T2, | |
| TP2: BasePatternClass, | |
| T3, | |
| TP3: BasePatternClass, | |
| ]( | |
| self, | |
| __op1: SubscribeOptions[T1, TP1], | |
| __op2: SubscribeOptions[T2, TP2], | |
| __op3: SubscribeOptions[T3, TP3], | |
| /, | |
| ) -> AsyncGenerator[ | |
| AsyncGenerator[ | |
| tuple[ | |
| SubscriptionResult[TP1, T1] | None, | |
| SubscriptionResult[TP2, T2] | None, | |
| SubscriptionResult[TP3, T3] | None, | |
| ] | |
| ] | |
| ]: ... | |
| @overload | |
| @contextlib.asynccontextmanager | |
| async def batch_psubscribe[ | |
| T1, | |
| TP1: BasePatternClass, | |
| T2, | |
| TP2: BasePatternClass, | |
| T3, | |
| TP3: BasePatternClass, | |
| T4, | |
| TP4: BasePatternClass, | |
| ]( | |
| self, | |
| __op1: SubscribeOptions[T1, TP1], | |
| __op2: SubscribeOptions[T2, TP2], | |
| __op3: SubscribeOptions[T3, TP3], | |
| __op4: SubscribeOptions[T4, TP4], | |
| /, | |
| ) -> AsyncGenerator[ | |
| AsyncGenerator[ | |
| tuple[ | |
| SubscriptionResult[TP1, T1] | None, | |
| SubscriptionResult[TP2, T2] | None, | |
| SubscriptionResult[TP3, T3] | None, | |
| SubscriptionResult[TP4, T4] | None, | |
| ] | |
| ] | |
| ]: ... | |
| @contextlib.asynccontextmanager | |
| async def batch_psubscribe( | |
| self, *options: SubscribeOptions[Any, Any] | |
| ) -> AsyncIterator[AsyncIterator[tuple[SubscriptionResult[Any, Any] | None, ...]]]: | |
| """Subscribe to multiple paths at once.""" | |
| paths = tuple(op.path for op in options) | |
| assert len(paths) == len(set(paths)), "Duplicate paths in options will cause ambiguity" | |
| exit_stack = contextlib.AsyncExitStack() | |
| # TODO(#236): optimize for cases where the subscribe options could suffice with a normal subscribe | |
| gens = [await exit_stack.enter_async_context(self.psubscribe(opt)) for opt in options] | |
| try: | |
| async def impl() -> AsyncIterator[tuple[SubscriptionResult[Any, Any] | None, ...]]: | |
| async with stream.merge(*gens).stream() as streamer: | |
| async for res in streamer: | |
| yield tuple( | |
| res if res.subscriber_path == option_pattern else None | |
| for option_pattern in paths | |
| ) | |
| yield impl() | |
| finally: | |
| await exit_stack.aclose() | |
| async def wait_until_done(self) -> None: | |
| """Wait until all handlers are unsubscribed.""" | |
| while self._handlers or self._phandlers: # noqa: ASYNC110, this is only used in tests | |
| await asyncio.sleep(0.01) | |
| class _MessageHandler[T, T_pattern: BasePatternClass]: | |
| """Handles messages for a specific channel.""" | |
| def __init__( | |
| self, dec_fn: Callable[[T_pattern, bytes], T], pattern_cls: type[T_pattern] | |
| ) -> None: | |
| self._current_message: T | |
| self._subscribers: dict[str, asyncio.Queue[SubscriptionResult[T_pattern, T]]] = {} | |
| self._is_subscribed = False | |
| self._pattern_cls = pattern_cls | |
| self._decode_fn = dec_fn | |
| def has_listeners(self) -> bool: | |
| return bool(self._subscribers) | |
| async def handle_message(self, channel: bytes, data: bytes, subscriber_path: bytes) -> None: | |
| chan = self._pattern_cls.decode(channel) | |
| resolved = self._decode_fn(chan, data) | |
| await asyncio.gather( | |
| *[ | |
| que.put(SubscriptionResult(chan, resolved, subscriber_path)) | |
| for que in self._subscribers.values() | |
| ] | |
| ) | |
| def subscribe(self) -> tuple[str, AsyncGenerator[SubscriptionResult[T_pattern, T]]]: | |
| id_ = uuid.uuid4().hex | |
| queue = asyncio.Queue() | |
| self._subscribers[id_] = queue | |
| async def _impl(): | |
| while True: | |
| yield await queue.get() | |
| return id_, _impl() | |
| def unsubscribe(self, id_: str) -> None: | |
| self._subscribers.pop(id_) | |
| @contextlib.asynccontextmanager | |
| async def get_broadcast(settings: RedisSettings) -> AsyncIterator[Broadcast]: | |
| ret = redis.Redis( | |
| connection_pool=redis.ConnectionPool( | |
| username=settings.username or None, | |
| password=settings.password, | |
| host=settings.host, | |
| port=settings.port, | |
| db=settings.db, | |
| ), | |
| ) | |
| assert await ret.ping() is True, "Redis connection failed" | |
| yield cast("Broadcast", ret) | |
| await ret.aclose() | |
| @dataclass(slots=True) | |
| class BaseRedisSetRepository[K: BasePatternClass, V: Hashable]: | |
| """ | |
| Repository with a redis backend. | |
| Uses redis sets to store the data. | |
| Usually used to store a set of db ids. | |
| Ref: https://redis.io/docs/latest/develop/data-types/sets/ | |
| """ | |
| r: Broadcast | |
| @classmethod | |
| def deserialize(cls, data: bytes) -> V: | |
| raise NotImplementedError | |
| @classmethod | |
| def serialize(cls, data: V) -> SerializeAble: | |
| raise NotImplementedError | |
| async def members(self, key: BasePatternClass) -> Iterable[V]: | |
| return [self.__class__.deserialize(m) for m in await self.r.smembers(key.sserialize())] # type: ignore | |
| async def exist(self, k: BasePatternClass, member: V) -> bool: | |
| return await self.r.sismember(k.sserialize(), self.__class__.serialize(member)) == 1 # type: ignore | |
| async def remove(self, k: BasePatternClass, *members: V) -> None: | |
| await self.r.srem(k.sserialize(), *(self.__class__.serialize(m) for m in members)) # type: ignore | |
| async def add(self, k: BasePatternClass, *member: V) -> None: | |
| await self.r.sadd(k.sserialize(), *(self.__class__.serialize(m) for m in member)) # type: ignore | |
| class BaseSimpleRedisSetRepo[K: BasePatternClass, V: DecodedT](BaseRedisSetRepository[K, V]): | |
| @classmethod | |
| def type_(cls) -> type[DecodedT]: ... | |
| @override | |
| @classmethod | |
| def deserialize(cls, data: bytes) -> V: | |
| return cls.type_()(data) # pyright: ignore reportCallIssue | |
| @override | |
| @classmethod | |
| def serialize(cls, data: SerializeAble) -> SerializeAble: | |
| return data | |
| @dataclass(slots=True) | |
| class RedisHashRepository[K: BasePatternClass, V]: | |
| """ | |
| Repository with a redis backend. | |
| Uses redis hashes to store the data. | |
| Usually used to store a mapping of db ids to the data. | |
| Ref: https://redis.io/docs/latest/develop/data-types/hashes | |
| """ | |
| r: Broadcast | |
| def deserialize(self, data: bytes) -> V: | |
| raise NotImplementedError | |
| @classmethod | |
| @abc.abstractmethod | |
| def serialize(cls, data: V) -> str | bytes: | |
| raise NotImplementedError | |
| @classmethod | |
| @abc.abstractmethod | |
| def pattern_class(cls) -> type[K]: | |
| raise NotImplementedError | |
| async def items(self, key: K) -> dict[str, V]: | |
| res: dict[bytes, bytes] = await self.r.hgetall(key.sserialize()) # type: ignore | |
| return {k.decode("utf-8"): self.deserialize(v) for k, v in res.items()} | |
| async def keys(self, pattern: K) -> list[K]: | |
| return [self.pattern_class().decode(k) for k in await self.r.keys(pattern.as_listen())] | |
| async def values(self, key: K) -> Iterable[V]: | |
| res: list[bytes] = await self.r.hvals(key.sserialize()) # pyright: ignore[reportGeneralTypeIssues] | |
| return [self.deserialize(v) for v in res] | |
| async def get(self, k: K, field: str) -> V | None: | |
| res: bytes = await self.r.hget(k.sserialize(), field) # type: ignore | |
| return self.deserialize(res) if res else None | |
| async def update(self, k: K, field: str, v: V) -> Result[V, ResourceNotFoundErr]: | |
| """Update a field in the hashmap.""" | |
| if await self.r.hexists(k.sserialize(), field): # type: ignore | |
| await self.r.hset(k.sserialize(), field, self.serialize(v)) # type: ignore | |
| return Ok(v) | |
| return Err(ResourceNotFoundErr(f"Field {field} not found in hashmap {k}")) | |
| async def set(self, k: K, field: str, v: V) -> None: | |
| """ | |
| Set a field in the hashmap. | |
| Args: | |
| ---- | |
| k: The key. | |
| field: The field to set (usually would be the ID field of the entity). | |
| v: The value to set. | |
| """ | |
| await self.r.hset(k.sserialize(), field, self.serialize(v)) # type: ignore | |
| async def remove(self, k: K, *fields: str) -> None: | |
| """Remove fields from the hashmap.""" | |
| await self.r.hdel(k.sserialize(), *fields) # type: ignore | |
| async def delete(self, k: K) -> None: | |
| """Delete the entire hashmap.""" | |
| await self.r.delete(k.sserialize()) | |
| class Channel[T: msgspec.Struct](BasePatternClass): | |
| __abstract__ = True | |
| async def publish(self, r: Broadcast, data: T) -> None: | |
| """Publish data to the channel.""" | |
| await r.publish(self.sserialize(), default_encode(data)) | |
| class RedisCacheEntry[T](msgspec.Struct): | |
| data: T | |
| insert_time: datetime | |
| class HasRedis(Protocol): | |
| redis: Broadcast | |
| def cached_command[T_self: HasRedis, **Ps, T_ret: msgspec.Struct, T_err]( | |
| encoder: Encoder[RedisCacheEntry[T_ret]], | |
| decoder: Decoder[RedisCacheEntry[T_ret]], | |
| ttl: timedelta | None = None, | |
| name: str | None = None, | |
| ) -> Callable[ | |
| [Callable[Concatenate[T_self, Ps], Awaitable[Result[T_ret, T_err]]]], | |
| Callable[Concatenate[T_self, Ps], Awaitable[Result[T_ret, T_err]]], | |
| ]: | |
| def _decorator( | |
| fn: Callable[Concatenate[T_self, Ps], Awaitable[Result[T_ret, T_err]]], | |
| ) -> Callable[Concatenate[T_self, Ps], Awaitable[Result[T_ret, T_err]]]: | |
| cache_key = name or fn.__qualname__ | |
| @functools.wraps(fn) | |
| async def wrapper( | |
| self_: T_self, *args: Ps.args, **kwargs: Ps.kwargs | |
| ) -> Result[T_ret, T_err]: | |
| serialized_args = msgspec.json.encode(args) | |
| serialized_kwargs = msgspec.json.encode(kwargs) | |
| key = f"{cache_key}:{serialized_args}:{serialized_kwargs}" | |
| entry = await self_.redis.get(key) | |
| if entry is not None: | |
| return Ok(decoder.decode(entry.data).data) | |
| match await fn(self_, *args, **kwargs): | |
| case Ok(value): | |
| await self_.redis.set( | |
| key, encoder(RedisCacheEntry(value, datetime.now())), ex=ttl | |
| ) | |
| return Ok(value) | |
| case Err(error): | |
| return Err(error) | |
| return wrapper | |
| return _decorator |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment