|
import functools |
|
import json |
|
|
|
from asgiref.sync import async_to_sync |
|
from channels.consumer import SyncConsumer |
|
from channels.exceptions import StopConsumer |
|
from rx import Observable |
|
|
|
from .schema import schema |
|
|
|
|
|
# GraphQL types might use info.context.user to access currently authenticated user. |
|
# When Query is called, info.context is request object, |
|
# however when Subscription is called, info.context is scope dict. |
|
# This is minimal wrapper around dict to mimic object behavior. |
|
class AttrDict: |
|
def __init__(self, data): |
|
self.data = data or {} |
|
|
|
def __getattr__(self, item): |
|
return self.get(item) |
|
|
|
def get(self, item): |
|
return self.data.get(item) |
|
|
|
|
|
class StreamObservable: |
|
def __call__(self, observer): |
|
self.observer = observer |
|
|
|
def send(self, value): |
|
if not self.observer: |
|
raise Exception("Can't send values to disconnected observer.") |
|
self.observer.on_next(value) |
|
|
|
|
|
class GraphqlSubcriptionConsumer(SyncConsumer): |
|
def __init__(self, scope): |
|
super().__init__(scope) |
|
self.subscriptions = {} |
|
self.groups = {} |
|
|
|
def websocket_connect(self, message): |
|
self.send({ |
|
"type": "websocket.accept", |
|
"subprotocol": "graphql-ws" |
|
}) |
|
|
|
def websocket_disconnect(self, message): |
|
for group in self.groups.keys(): |
|
group_discard = async_to_sync(self.channel_layer.group_discard) |
|
group_discard(f'django.{group}', self.channel_name) |
|
|
|
self.send({ |
|
"type": "websocket.close", "code": 1000 |
|
}) |
|
raise StopConsumer() |
|
|
|
def websocket_receive(self, message): |
|
request = json.loads(message['text']) |
|
id = request.get('id') |
|
|
|
if request['type'] == 'connection_init': |
|
return |
|
|
|
elif request['type'] == 'start': |
|
payload = request['payload'] |
|
context = AttrDict(self.scope) |
|
context.subscribe = functools.partial(self._subscribe, id) |
|
|
|
stream = StreamObservable() |
|
|
|
result = schema.execute( |
|
payload['query'], |
|
operation_name=payload['operationName'], |
|
variable_values=payload['variables'], |
|
context_value=context, |
|
root_value=Observable.create(stream).share(), |
|
allow_subscriptions=True, |
|
) |
|
if hasattr(result, 'subscribe'): |
|
result.subscribe(functools.partial(self._send_result, id)) |
|
self.subscriptions[id] = stream |
|
else: |
|
self._send_result(id, result) |
|
|
|
elif request['type'] == 'stop': |
|
self._unsubscribe(id) |
|
del self.subscriptions[id] |
|
|
|
def model_changed(self, message): |
|
model = message['model'] |
|
pk = message['pk'] |
|
|
|
for id in self.groups.get(model, []): |
|
stream = self.subscriptions.get(id) |
|
if not stream: |
|
continue |
|
stream.send((pk, model)) |
|
|
|
def _subscribe(self, id, model_name): |
|
group = self.groups.setdefault(model_name, set()) |
|
if not len(group): |
|
group_add = async_to_sync(self.channel_layer.group_add) |
|
group_add(f'django.{model_name}', self.channel_name) |
|
self.groups[model_name].add(id) |
|
|
|
def _unsubscribe(self, id): |
|
for group, ids in self.groups.items(): |
|
if id not in ids: |
|
continue |
|
|
|
ids.remove(id) |
|
if not len(ids): |
|
# no more subscriptions for this group |
|
group_discard = async_to_sync(self.channel_layer.group_discard) |
|
group_discard(f'django.{group}', self.channel_name) |
|
|
|
def _send_result(self, id, result): |
|
errors = result.errors |
|
|
|
self.send({ |
|
'type': 'websocket.send', |
|
'text': json.dumps({ |
|
'id': id, |
|
'type': 'data', |
|
'payload': { |
|
'data': result.data, |
|
'errors': list(map(str, errors)) if errors else None, |
|
} |
|
}) |
|
}) |