Skip to content

Instantly share code, notes, and snippets.

@tricoder42
Last active September 10, 2024 20:00
Show Gist options
  • Save tricoder42/af3d0337c1b33d82c1b32d12bd0265ec to your computer and use it in GitHub Desktop.
Save tricoder42/af3d0337c1b33d82c1b32d12bd0265ec to your computer and use it in GitHub Desktop.
GraphQL Subscriptions with django-channels

GraphQL Subscription with django-channels

Django channels are official way for implementing async messaging in Django.

The primary caveat when working with GraphQL subscription is that we can't serialize message before broadcasting it to Group of subscribers. Each subscriber might use different GraphQL query so we don't know how to serialize instance in advance.

See related issue

from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer
from django.db.models.signals import post_save
def notify_on_model_changes(model):
from django.contrib.contenttypes.models import ContentType
ct = ContentType.objects.get_for_model(model)
model_label = '.'.join([ct.app_label, ct.model])
channel_layer = get_channel_layer()
group_send = async_to_sync(channel_layer.group_send)
def receiver(sender, instance, **kwargs):
payload = {
'type': 'model.changed',
'pk': instance.pk,
'model': model_label
}
group_send(f'django.{model_label}', payload)
post_save.connect(receiver, sender=model, weak=False,
dispatch_uid=f'django.{model_label}')
import functools
import json
from asgiref.sync import async_to_sync
from channels.consumer import SyncConsumer
from channels.exceptions import StopConsumer
from rx import Observable
from .schema import schema
# GraphQL types might use info.context.user to access currently authenticated user.
# When Query is called, info.context is request object,
# however when Subscription is called, info.context is scope dict.
# This is minimal wrapper around dict to mimic object behavior.
class AttrDict:
def __init__(self, data):
self.data = data or {}
def __getattr__(self, item):
return self.get(item)
def get(self, item):
return self.data.get(item)
class StreamObservable:
def __call__(self, observer):
self.observer = observer
def send(self, value):
if not self.observer:
raise Exception("Can't send values to disconnected observer.")
self.observer.on_next(value)
class GraphqlSubcriptionConsumer(SyncConsumer):
def __init__(self, scope):
super().__init__(scope)
self.subscriptions = {}
self.groups = {}
def websocket_connect(self, message):
self.send({
"type": "websocket.accept",
"subprotocol": "graphql-ws"
})
def websocket_disconnect(self, message):
for group in self.groups.keys():
group_discard = async_to_sync(self.channel_layer.group_discard)
group_discard(f'django.{group}', self.channel_name)
self.send({
"type": "websocket.close", "code": 1000
})
raise StopConsumer()
def websocket_receive(self, message):
request = json.loads(message['text'])
id = request.get('id')
if request['type'] == 'connection_init':
return
elif request['type'] == 'start':
payload = request['payload']
context = AttrDict(self.scope)
context.subscribe = functools.partial(self._subscribe, id)
stream = StreamObservable()
result = schema.execute(
payload['query'],
operation_name=payload['operationName'],
variable_values=payload['variables'],
context_value=context,
root_value=Observable.create(stream).share(),
allow_subscriptions=True,
)
if hasattr(result, 'subscribe'):
result.subscribe(functools.partial(self._send_result, id))
self.subscriptions[id] = stream
else:
self._send_result(id, result)
elif request['type'] == 'stop':
self._unsubscribe(id)
del self.subscriptions[id]
def model_changed(self, message):
model = message['model']
pk = message['pk']
for id in self.groups.get(model, []):
stream = self.subscriptions.get(id)
if not stream:
continue
stream.send((pk, model))
def _subscribe(self, id, model_name):
group = self.groups.setdefault(model_name, set())
if not len(group):
group_add = async_to_sync(self.channel_layer.group_add)
group_add(f'django.{model_name}', self.channel_name)
self.groups[model_name].add(id)
def _unsubscribe(self, id):
for group, ids in self.groups.items():
if id not in ids:
continue
ids.remove(id)
if not len(ids):
# no more subscriptions for this group
group_discard = async_to_sync(self.channel_layer.group_discard)
group_discard(f'django.{group}', self.channel_name)
def _send_result(self, id, result):
errors = result.errors
self.send({
'type': 'websocket.send',
'text': json.dumps({
'id': id,
'type': 'data',
'payload': {
'data': result.data,
'errors': list(map(str, errors)) if errors else None,
}
})
})
from collections import OrderedDict
import graphene
from graphene import Field
from graphene.types.objecttype import ObjectTypeOptions
from graphene.types.utils import yank_fields_from_attrs
from graphene.utils.props import props
from rx import Observable
from six import get_unbound_function
from django.contrib.contenttypes.models import ContentType
class SubscriptionOptions(ObjectTypeOptions):
arguments = None
output = None
resolver = None
class Subscription(graphene.ObjectType):
@classmethod
def __init_subclass_with_meta__(cls, resolver=None, output=None, arguments=None,
_meta=None, **options):
if not _meta:
_meta = SubscriptionOptions(cls)
output = output or getattr(cls, 'Output', None)
fields = {}
if not output:
# If output is defined, we don't need to get the fields
fields = OrderedDict()
for base in reversed(cls.__mro__):
fields.update(
yank_fields_from_attrs(base.__dict__, _as=Field)
)
output = cls
if not arguments:
input_class = getattr(cls, 'Arguments', None)
if input_class:
arguments = props(input_class)
else:
arguments = {}
if not resolver:
assert hasattr(cls, 'next'), 'All subscriptions must define a next method in it'
resolver = get_unbound_function(getattr(cls, 'resolver'))
if _meta.fields:
_meta.fields.update(fields)
else:
_meta.fields = fields
_meta.output = output
_meta.resolver = resolver
_meta.arguments = arguments
super(Subscription, cls).__init_subclass_with_meta__(_meta=_meta, **options)
@classmethod
def subscribe(cls, info):
return cls._meta.output._meta.model
@classmethod
def resolver(cls, obj, info, **kwargs):
subscribe = info.context.subscribe
if subscribe:
models = cls.subscribe(info)
if not isinstance(models, list):
models = [models]
for model in models:
ct = ContentType.objects.get_for_model(model)
model_label = '.'.join([ct.app_label, ct.model])
subscribe(model_label)
observable = info.root_value
return observable.map(lambda obj: cls.next(obj, info, **kwargs))
@classmethod
def Field(cls, *args, **kwargs):
return Field(
cls._meta.output,
args=cls._meta.arguments,
resolver=cls._meta.resolver
)
import graphene
from .types import ConversationType
class ConversationsSubscription(Subscription):
class Meta:
output = ConversationType
@classmethod
def subscribe(cls, info):
return [Message, Conversation]
@classmethod
def next(cls, pk_model, info):
user = info.context.user
pk, model_label = pk_model
qs = Conversation.objects.for_user(user)
if model_label == 'conversations.conversation':
qs = qs.filter(id=pk)
else:
qs = qs.filter(messages__id=pk)
return qs.annotate_unread(user).first()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment