Skip to content

Instantly share code, notes, and snippets.

@nrbnlulu
Created August 4, 2025 05:25
Show Gist options
  • Save nrbnlulu/f07c1475e0232d32b48bbd73fed56f6a to your computer and use it in GitHub Desktop.
Save nrbnlulu/f07c1475e0232d32b48bbd73fed56f6a to your computer and use it in GitHub Desktop.
pubsub.py
from __future__ import annotations
import abc
import asyncio
import contextlib
import functools
import uuid
from collections.abc import AsyncIterator, Awaitable, Hashable, Iterable
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, Concatenate, NamedTuple, Protocol, cast, overload, override
import msgspec
import redis.asyncio as redis
from aiostream import stream
from loguru import logger
from result import Err, Ok
from t5hob_sdk.bases.err import ResourceNotFoundErr
from t5hob_sdk.bases.patternclass import BasePatternClass
from t5hob_sdk.bases.settings import RedisSettings
from t5hob_sdk.serde import Decoder, default_encode, get_decoder
from t5hob_sdk.types import Broadcast, ContextedAsyncIterator, DecodedT, Encoder, SerializeAble
if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Callable
from result import Result
type ChannelDecoder[T: msgspec.Struct] = Callable[[Any, bytes], T]
def get_decoder_fn[T: msgspec.Struct](t: type[T]) -> ChannelDecoder[T]:
decoder = get_decoder(t)
return lambda _, msg: decoder.decode(msg)
@dataclass(slots=True)
class SubscribeOptions[T, T_pattern: BasePatternClass]:
path: bytes
pattern_klass: type[T_pattern]
dec_fn: Callable[[T_pattern, bytes], T]
class SubscriptionResult[P, V](NamedTuple):
pattern: P
value: V
subscriber_path: bytes
"""The original path used to subscribe"""
class SubscriptionManager:
"""
Public interface for subscribing to broadcast messages.
Register message handlers for each path.
It is safe to subscribe many times with the same path,
it won't create multiple subscriptions.
Whene `(p)subscribe) is called, will look if there is already a handler for the path and return it,
otherwise it will create a new handler and cache that handler for further subscriptions.
"""
def __init__(self, broadcast: Broadcast) -> None:
self.broadcast = broadcast
self.pubsub = broadcast.pubsub()
self._phandlers: dict[bytes, _MessageHandler] = {}
self._handlers: dict[bytes, _MessageHandler] = {}
async def execute(self) -> None:
while True:
if self._handlers or self._phandlers:
msg = await self.pubsub.get_message(ignore_subscribe_messages=True)
if msg is not None:
channel = msg["channel"]
if (pattern := msg["pattern"]) is not None and (
handler := self._phandlers.get(pattern)
):
try:
await handler.handle_message(channel, msg["data"], pattern)
except Exception as e: # pragma: no cover # noqa: BLE001
logger.error(f"Error handling message {msg} on channel {channel}: {e}")
elif handler := self._handlers.get(channel):
try:
await handler.handle_message(channel, msg["data"], channel)
except Exception as e: # pragma: no cover # noqa: BLE001
logger.error(f"Error handling message {msg} on channel {channel}: {e}")
await asyncio.sleep(0.01)
# NOTE: we are using context managers to be able to initialize the subscription # noqa: ERA001
# NOTE: without having to iterate it at the first time. # noqa: ERA001
@contextlib.asynccontextmanager
async def subscribe[T, T_pattern: BasePatternClass](
self, options: SubscribeOptions[T, T_pattern]
) -> ContextedAsyncIterator[SubscriptionResult[T_pattern, T]]:
if not (handler := self._handlers.get(options.path)):
handler = _MessageHandler(dec_fn=options.dec_fn, pattern_cls=options.pattern_klass)
await self.pubsub.subscribe(options.path)
self._handlers[options.path] = handler
try:
sub_id, agen = handler.subscribe()
yield agen
finally:
with contextlib.suppress(NameError):
handler.unsubscribe(sub_id) # pyright: ignore [reportPossiblyUnboundVariable]
if not handler.has_listeners():
self._handlers.pop(options.path)
await self.pubsub.unsubscribe(options.path)
@contextlib.asynccontextmanager
async def psubscribe[T, T_pattern: BasePatternClass](
self, options: SubscribeOptions[T, T_pattern]
) -> ContextedAsyncIterator[SubscriptionResult[T_pattern, T]]:
"""Subscribe to a pattern."""
if not (handler := self._phandlers.get(options.path)):
handler = _MessageHandler(dec_fn=options.dec_fn, pattern_cls=options.pattern_klass)
await self.pubsub.psubscribe(options.path)
self._phandlers[options.path] = handler
try:
sub_id, gen = handler.subscribe()
yield gen
finally:
with contextlib.suppress(NameError):
handler.unsubscribe(sub_id) # pyright: ignore [reportPossiblyUnboundVariable]
if not handler.has_listeners():
self._phandlers.pop(options.path)
await self.pubsub.punsubscribe(options.path)
@overload
@contextlib.asynccontextmanager
async def batch_psubscribe[T1, TP1: BasePatternClass, T2, TP2: BasePatternClass](
self, __op1: SubscribeOptions[T1, TP1], __op2: SubscribeOptions[T2, TP2], /
) -> AsyncGenerator[
AsyncGenerator[
tuple[SubscriptionResult[TP1, T1] | None, SubscriptionResult[TP2, T2] | None]
]
]: ...
@overload
@contextlib.asynccontextmanager
async def batch_psubscribe[
T1,
TP1: BasePatternClass,
T2,
TP2: BasePatternClass,
T3,
TP3: BasePatternClass,
](
self,
__op1: SubscribeOptions[T1, TP1],
__op2: SubscribeOptions[T2, TP2],
__op3: SubscribeOptions[T3, TP3],
/,
) -> AsyncGenerator[
AsyncGenerator[
tuple[
SubscriptionResult[TP1, T1] | None,
SubscriptionResult[TP2, T2] | None,
SubscriptionResult[TP3, T3] | None,
]
]
]: ...
@overload
@contextlib.asynccontextmanager
async def batch_psubscribe[
T1,
TP1: BasePatternClass,
T2,
TP2: BasePatternClass,
T3,
TP3: BasePatternClass,
T4,
TP4: BasePatternClass,
](
self,
__op1: SubscribeOptions[T1, TP1],
__op2: SubscribeOptions[T2, TP2],
__op3: SubscribeOptions[T3, TP3],
__op4: SubscribeOptions[T4, TP4],
/,
) -> AsyncGenerator[
AsyncGenerator[
tuple[
SubscriptionResult[TP1, T1] | None,
SubscriptionResult[TP2, T2] | None,
SubscriptionResult[TP3, T3] | None,
SubscriptionResult[TP4, T4] | None,
]
]
]: ...
@contextlib.asynccontextmanager
async def batch_psubscribe(
self, *options: SubscribeOptions[Any, Any]
) -> AsyncIterator[AsyncIterator[tuple[SubscriptionResult[Any, Any] | None, ...]]]:
"""Subscribe to multiple paths at once."""
paths = tuple(op.path for op in options)
assert len(paths) == len(set(paths)), "Duplicate paths in options will cause ambiguity"
exit_stack = contextlib.AsyncExitStack()
# TODO(#236): optimize for cases where the subscribe options could suffice with a normal subscribe
gens = [await exit_stack.enter_async_context(self.psubscribe(opt)) for opt in options]
try:
async def impl() -> AsyncIterator[tuple[SubscriptionResult[Any, Any] | None, ...]]:
async with stream.merge(*gens).stream() as streamer:
async for res in streamer:
yield tuple(
res if res.subscriber_path == option_pattern else None
for option_pattern in paths
)
yield impl()
finally:
await exit_stack.aclose()
async def wait_until_done(self) -> None:
"""Wait until all handlers are unsubscribed."""
while self._handlers or self._phandlers: # noqa: ASYNC110, this is only used in tests
await asyncio.sleep(0.01)
class _MessageHandler[T, T_pattern: BasePatternClass]:
"""Handles messages for a specific channel."""
def __init__(
self, dec_fn: Callable[[T_pattern, bytes], T], pattern_cls: type[T_pattern]
) -> None:
self._current_message: T
self._subscribers: dict[str, asyncio.Queue[SubscriptionResult[T_pattern, T]]] = {}
self._is_subscribed = False
self._pattern_cls = pattern_cls
self._decode_fn = dec_fn
def has_listeners(self) -> bool:
return bool(self._subscribers)
async def handle_message(self, channel: bytes, data: bytes, subscriber_path: bytes) -> None:
chan = self._pattern_cls.decode(channel)
resolved = self._decode_fn(chan, data)
await asyncio.gather(
*[
que.put(SubscriptionResult(chan, resolved, subscriber_path))
for que in self._subscribers.values()
]
)
def subscribe(self) -> tuple[str, AsyncGenerator[SubscriptionResult[T_pattern, T]]]:
id_ = uuid.uuid4().hex
queue = asyncio.Queue()
self._subscribers[id_] = queue
async def _impl():
while True:
yield await queue.get()
return id_, _impl()
def unsubscribe(self, id_: str) -> None:
self._subscribers.pop(id_)
@contextlib.asynccontextmanager
async def get_broadcast(settings: RedisSettings) -> AsyncIterator[Broadcast]:
ret = redis.Redis(
connection_pool=redis.ConnectionPool(
username=settings.username or None,
password=settings.password,
host=settings.host,
port=settings.port,
db=settings.db,
),
)
assert await ret.ping() is True, "Redis connection failed"
yield cast("Broadcast", ret)
await ret.aclose()
@dataclass(slots=True)
class BaseRedisSetRepository[K: BasePatternClass, V: Hashable]:
"""
Repository with a redis backend.
Uses redis sets to store the data.
Usually used to store a set of db ids.
Ref: https://redis.io/docs/latest/develop/data-types/sets/
"""
r: Broadcast
@classmethod
def deserialize(cls, data: bytes) -> V:
raise NotImplementedError
@classmethod
def serialize(cls, data: V) -> SerializeAble:
raise NotImplementedError
async def members(self, key: BasePatternClass) -> Iterable[V]:
return [self.__class__.deserialize(m) for m in await self.r.smembers(key.sserialize())] # type: ignore
async def exist(self, k: BasePatternClass, member: V) -> bool:
return await self.r.sismember(k.sserialize(), self.__class__.serialize(member)) == 1 # type: ignore
async def remove(self, k: BasePatternClass, *members: V) -> None:
await self.r.srem(k.sserialize(), *(self.__class__.serialize(m) for m in members)) # type: ignore
async def add(self, k: BasePatternClass, *member: V) -> None:
await self.r.sadd(k.sserialize(), *(self.__class__.serialize(m) for m in member)) # type: ignore
class BaseSimpleRedisSetRepo[K: BasePatternClass, V: DecodedT](BaseRedisSetRepository[K, V]):
@classmethod
def type_(cls) -> type[DecodedT]: ...
@override
@classmethod
def deserialize(cls, data: bytes) -> V:
return cls.type_()(data) # pyright: ignore reportCallIssue
@override
@classmethod
def serialize(cls, data: SerializeAble) -> SerializeAble:
return data
@dataclass(slots=True)
class RedisHashRepository[K: BasePatternClass, V]:
"""
Repository with a redis backend.
Uses redis hashes to store the data.
Usually used to store a mapping of db ids to the data.
Ref: https://redis.io/docs/latest/develop/data-types/hashes
"""
r: Broadcast
def deserialize(self, data: bytes) -> V:
raise NotImplementedError
@classmethod
@abc.abstractmethod
def serialize(cls, data: V) -> str | bytes:
raise NotImplementedError
@classmethod
@abc.abstractmethod
def pattern_class(cls) -> type[K]:
raise NotImplementedError
async def items(self, key: K) -> dict[str, V]:
res: dict[bytes, bytes] = await self.r.hgetall(key.sserialize()) # type: ignore
return {k.decode("utf-8"): self.deserialize(v) for k, v in res.items()}
async def keys(self, pattern: K) -> list[K]:
return [self.pattern_class().decode(k) for k in await self.r.keys(pattern.as_listen())]
async def values(self, key: K) -> Iterable[V]:
res: list[bytes] = await self.r.hvals(key.sserialize()) # pyright: ignore[reportGeneralTypeIssues]
return [self.deserialize(v) for v in res]
async def get(self, k: K, field: str) -> V | None:
res: bytes = await self.r.hget(k.sserialize(), field) # type: ignore
return self.deserialize(res) if res else None
async def update(self, k: K, field: str, v: V) -> Result[V, ResourceNotFoundErr]:
"""Update a field in the hashmap."""
if await self.r.hexists(k.sserialize(), field): # type: ignore
await self.r.hset(k.sserialize(), field, self.serialize(v)) # type: ignore
return Ok(v)
return Err(ResourceNotFoundErr(f"Field {field} not found in hashmap {k}"))
async def set(self, k: K, field: str, v: V) -> None:
"""
Set a field in the hashmap.
Args:
----
k: The key.
field: The field to set (usually would be the ID field of the entity).
v: The value to set.
"""
await self.r.hset(k.sserialize(), field, self.serialize(v)) # type: ignore
async def remove(self, k: K, *fields: str) -> None:
"""Remove fields from the hashmap."""
await self.r.hdel(k.sserialize(), *fields) # type: ignore
async def delete(self, k: K) -> None:
"""Delete the entire hashmap."""
await self.r.delete(k.sserialize())
class Channel[T: msgspec.Struct](BasePatternClass):
__abstract__ = True
async def publish(self, r: Broadcast, data: T) -> None:
"""Publish data to the channel."""
await r.publish(self.sserialize(), default_encode(data))
class RedisCacheEntry[T](msgspec.Struct):
data: T
insert_time: datetime
class HasRedis(Protocol):
redis: Broadcast
def cached_command[T_self: HasRedis, **Ps, T_ret: msgspec.Struct, T_err](
encoder: Encoder[RedisCacheEntry[T_ret]],
decoder: Decoder[RedisCacheEntry[T_ret]],
ttl: timedelta | None = None,
name: str | None = None,
) -> Callable[
[Callable[Concatenate[T_self, Ps], Awaitable[Result[T_ret, T_err]]]],
Callable[Concatenate[T_self, Ps], Awaitable[Result[T_ret, T_err]]],
]:
def _decorator(
fn: Callable[Concatenate[T_self, Ps], Awaitable[Result[T_ret, T_err]]],
) -> Callable[Concatenate[T_self, Ps], Awaitable[Result[T_ret, T_err]]]:
cache_key = name or fn.__qualname__
@functools.wraps(fn)
async def wrapper(
self_: T_self, *args: Ps.args, **kwargs: Ps.kwargs
) -> Result[T_ret, T_err]:
serialized_args = msgspec.json.encode(args)
serialized_kwargs = msgspec.json.encode(kwargs)
key = f"{cache_key}:{serialized_args}:{serialized_kwargs}"
entry = await self_.redis.get(key)
if entry is not None:
return Ok(decoder.decode(entry.data).data)
match await fn(self_, *args, **kwargs):
case Ok(value):
await self_.redis.set(
key, encoder(RedisCacheEntry(value, datetime.now())), ex=ttl
)
return Ok(value)
case Err(error):
return Err(error)
return wrapper
return _decorator
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment