Skip to content

Instantly share code, notes, and snippets.

@ahopkins
Last active September 23, 2024 17:20
Show Gist options
  • Save ahopkins/5b6d380560d8e9d49e25281ff964ed81 to your computer and use it in GitHub Desktop.
Save ahopkins/5b6d380560d8e9d49e25281ff964ed81 to your computer and use it in GitHub Desktop.
Sanic websocket feeds - v3

Sanic Websockets Feeds v3

This is an example of how to build a distributed websocket feed. It allows for horizontal scaling using Redis as a pubsub broker to broadcast messages between application instances.

This is the third version of websocket feeds. It is built with Sanic v21.9+ in mind. Older versions:

from sanic import Blueprint
from sanic.log import logger
from .channel import Channel
bp = Blueprint("Feed", url_prefix="/feed")
@bp.websocket("/<channel_name>")
async def feed(request, ws, channel_name):
logger.info("Incoming WS request")
channel, is_existing = await Channel.get(
request.app.ctx.pubsub, request.app.ctx.redis, channel_name
)
if not is_existing:
request.app.add_task(channel.receiver())
client = await channel.register(ws)
try:
await client.receiver()
finally:
await channel.unregister(client)
from __future__ import annotations
from asyncio import Lock
from typing import Set, Tuple
from aioredis import Redis
from aioredis.client import PubSub
from aioredis.exceptions import PubSubError
from sanic.log import logger
from sanic.server.websockets.impl import WebsocketImplProtocol
from .client import Client
class ChannelCache(dict):
...
class Channel:
cache = ChannelCache()
def __init__(self, pubsub: PubSub, redis: Redis, name: str) -> None:
self.pubsub = pubsub
self.redis = redis
self.name = name
self.clients: Set[Client] = set()
self.lock = Lock()
@classmethod
async def get(
cls, pubsub: PubSub, redis: Redis, name: str
) -> Tuple[Channel, bool]:
is_existing = False
if name in cls.cache:
channel = cls.cache[name]
await channel.acquire_lock()
is_existing = True
else:
channel = cls(pubsub=pubsub, redis=redis, name=name)
await channel.acquire_lock()
cls.cache[name] = channel
await pubsub.subscribe(name)
return channel, is_existing
async def acquire_lock(self) -> None:
if not self.lock.locked():
logger.debug("Lock acquired")
await self.lock.acquire()
else:
logger.debug("Lock already acquired")
async def receiver(self) -> None:
logger.debug(f"Starting PubSub receiver for {self.name}")
while True:
try:
raw = await self.pubsub.get_message(
ignore_subscribe_messages=True, timeout=1.0
)
except PubSubError:
logger.error(f"PUBSUB closed <{self.name}>", exc_info=True)
break
else:
if raw:
logger.debug(
f"PUBSUB rcvd <{self.name}>: length=={len(raw)}"
)
for client in self.clients:
logger.debug(f"Sending to: {client.uid}")
await client.protocol.send(raw["data"])
async def register(self, protocol: WebsocketImplProtocol) -> Client:
client = Client(
protocol=protocol, redis=self.redis, channel_name=self.name
)
self.clients.add(client)
await self.publish(f"Client {client.uid} has joined")
return client
async def unregister(self, client: Client) -> None:
if client in self.clients:
await client.shutdown()
self.clients.remove(client)
await self.publish(f"Client {client.uid} has left")
if not self.clients:
self.lock.release()
await self.destroy()
async def destroy(self) -> None:
if not self.lock.locked():
logger.debug(f"Destroying Channel {self.name}")
del self.__class__.cache[self.name]
await self.pubsub.reset()
else:
logger.debug(f"Abort destroying Channel {self.name}. It is locked")
async def publish(self, message: str) -> None:
logger.debug(f"Sending message: {message}")
await self.redis.publish(self.name, message)
from dataclasses import dataclass, field
from uuid import UUID, uuid4
from aioredis import Redis
from sanic.server.websockets.impl import WebsocketImplProtocol
@dataclass
class Client:
protocol: WebsocketImplProtocol
redis: Redis
channel_name: str
uid: UUID = field(default_factory=uuid4)
def __hash__(self) -> int:
return self.uid.int
async def receiver(self):
while True:
message = await self.protocol.recv()
if not message:
break
await self.redis.publish(self.channel_name, message)
async def shutdown(self):
await self.protocol.close()
@cnicodeme
Copy link

Thank you for this code @ahopkins. I have a question :
Why the lock here? What are you trying to avoid running concurrently?

Thanks!

@ahopkins
Copy link
Author

Thanks 😁

The lock is to help in case there are multiple clients on the same feed. Since I try to keep the memory footprint small and cleanup the feeds when there are no connected clients, we want to guard against race conditions. In some cases a client has committed to cleaning up the feed, and another connects to it. We don't want the first one to remove the feed instance.

@cnicodeme
Copy link

I just finished figuring out some issues I had with an implementation from this code, and one key element I discovered that caused some troubles is in the Channel.destroy function, specifically these two:

    await self.pubsub.reset()
    del self.__class__.cache[self.name]

This, for me, caused a race condition where while the reset was being called, a new client arrived. Since the del was not called at that time, the function in Channel.get found the cached instance and returned it, but then the reset method was finished, making the new user in a Channel that had no real subscriptions. So any data sent to the pubsub was never retrieved (no one to listen to).

I simply switched these above two lines, first deleting the cache, then reseting the pubusb, and this fixed the issue!

If you notice any mistakes on my logic, feel free to point them out, but on my end, this seems to work

@ahopkins
Copy link
Author

simply switched these above two lines, first deleting the cache, then reseting the pubusb, and this fixed the issue!

If you notice any mistakes on my logic, feel free to point them out, but on my end, this seems to work

That makes sense. I have made the change. Thanks for your feedback!

@cnicodeme
Copy link

I recently find a problem and was wondering why you made the decision to create only one channel at the get method :

    @classmethod
    async def get(
        cls, pubsub: PubSub, redis: Redis, name: str
    ) -> Tuple[Channel, bool]:
        is_existing = False

        if name in cls.cache:
            channel = cls.cache[name]
            await channel.acquire_lock()
            is_existing = True
        else:
            channel = cls(pubsub=pubsub, redis=redis, name=name)
            await channel.acquire_lock()

            cls.cache[name] = channel

            await pubsub.subscribe(name)

Why using a cache? I'm asking because for instance, if the user opens too tabs, won't he only receive the events in one?

@cnicodeme
Copy link

I'm revisiting this (as per my previous comment) because I noticed that when I open two of the same tabs on both Chrome and Firefox, the data is correctly sent to one client, but not the other.

By removing the cache, all work fine.
If that shouldn't behave like this, I'm happy to share my code some place to see if I haven't made any mistake.

Thank you in advance for your help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment