Created
November 1, 2020 18:33
-
-
Save zbyte64/4730f49639335078857c3559ce105e01 to your computer and use it in GitHub Desktop.
Example: Django-Graphene subscriptions with graphene_subscriptions and django_lifecycle
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from enum import Enum, auto | |
class NotificationEvents(Enum): | |
NEW_MESSAGE = auto() | |
UPDATE_MESSAGE = auto() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from django.db import models | |
from org.models import User | |
from django_lifecycle import LifecycleModelMixin, hook, AFTER_CREATE, AFTER_UPDATE | |
from graphene_subscriptions.events import SubscriptionEvent | |
from .events import NotificationEvents | |
class Message(models.Model): | |
owner = models.ForeignKey( | |
User, related_name="sent_messages", on_delete=models.CASCADE | |
) | |
title = models.CharField(max_length=100) | |
text = models.TextField() | |
to = models.ForeignKey( | |
User, | |
null=True, | |
blank=True, | |
on_delete=models.CASCADE, | |
related_name="to_received_messages", | |
) | |
participants = models.ManyToManyField( | |
User, related_name="received_messages" | |
) | |
@hook(AFTER_CREATE) | |
def notify_new_message(self): | |
event = SubscriptionEvent( | |
operation=NotificationEvents.NEW_MESSAGE, instance=self | |
) | |
event.send() | |
@hook(AFTER_UPDATE) | |
def notify_update_message(self): | |
event = SubscriptionEvent( | |
operation=NotificationEvents.UPDATE_MESSAGE, instance=self | |
) | |
event.send() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from django.test import TransactionTestCase | |
from snapshottest.unittest import TestCase | |
from graphene_django.utils.testing import graphql_query | |
from graphql_relay import to_global_id | |
import json | |
from asgiref.sync import sync_to_async, async_to_sync | |
from channels.testing import WebsocketCommunicator | |
from graphene_subscriptions.consumers import GraphqlSubscriptionConsumer | |
from org.models import User | |
from .models import UserMessage | |
from .forms import SendMessageForm | |
class UnreadMessageCountTestCase(TestCase, TransactionTestCase): | |
fixtures = ["test_data"] | |
maxDiff = 2000 | |
def setUp(self): | |
assert self.client.login(email="[email protected]", password="password") | |
@classmethod | |
def teadDownClass(cls): | |
# fix for async tests not cleaning up connections | |
# probably fixed in Django 3.1 | |
# https://stackoverflow.com/questions/8242837/django-multiprocessing-and-database-connections | |
import django | |
for ( | |
name, | |
info, | |
) in django.db.connections.databases.items(): # Close the DB connections | |
django.db.connection.close() | |
@async_to_sync | |
async def test_unread_message_count(self): | |
async def query(query, communicator, variables=None): | |
await communicator.send_json_to( | |
{ | |
"id": 1, | |
"type": "start", | |
"payload": {"query": query, "variables": variables}, | |
} | |
) | |
communicator = WebsocketCommunicator(GraphqlSubscriptionConsumer, "/graphql") | |
receiver = await sync_to_async(User.objects.get)(email="[email protected]") | |
communicator.scope["user"] = receiver | |
connected, subprotocol = await communicator.connect() | |
assert connected | |
subscription = """ | |
subscription { | |
unreadMessageCount | |
} | |
""" | |
await query(subscription, communicator) | |
response = await communicator.receive_json_from() | |
assert not response["payload"]["errors"], str(response["payload"]["errors"]) | |
self.assertMatchSnapshot(response) | |
@async_to_sync | |
async def test_update_unread_message_count_on_new_message(self): | |
async def query(query, communicator, variables=None): | |
await communicator.send_json_to( | |
{ | |
"id": 1, | |
"type": "start", | |
"payload": {"query": query, "variables": variables}, | |
} | |
) | |
communicator = WebsocketCommunicator(GraphqlSubscriptionConsumer, "/graphql") | |
receiver = await sync_to_async(User.objects.get)(email="[email protected]") | |
communicator.scope["user"] = receiver | |
connected, subprotocol = await communicator.connect() | |
assert connected | |
subscription = """ | |
subscription { | |
unreadMessageCount | |
} | |
""" | |
await query(subscription, communicator) | |
response = await communicator.receive_json_from() | |
assert not response["payload"]["errors"], str(response["payload"]["errors"]) | |
# send message | |
send_form = SendMessageForm( | |
data={ | |
"to": receiver.id, | |
"title": "hello world", | |
"text": "text", | |
"description": "description", | |
} | |
) | |
owner = await sync_to_async(User.objects.get)(email="[email protected]") | |
message = await sync_to_async(send_form.save)(owner=owner) | |
response_two = await communicator.receive_json_from() | |
assert not response["payload"]["errors"], str(response["payload"]["errors"]) | |
assert ( | |
response_two["payload"]["data"]["unreadMessageCount"] | |
== response["payload"]["data"]["unreadMessageCount"] + 1 | |
), str((response, response_two)) | |
self.assertMatchSnapshot(response, "initial") | |
self.assertMatchSnapshot(response_two, "result") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import graphene | |
import rx | |
from .events import NotificationEvents | |
class ActiveMessageCounter: | |
def __init__(self, user): | |
self.user = user | |
self.count = self.get_count() | |
def get_count(self): | |
return self.user.received_messages.count() | |
def __call__(self, *args): | |
self.count = self.get_count() | |
return self.count | |
class UnreadMessageCountSubscription(graphene.ObjectType): | |
unread_message_count = graphene.Int(test=graphene.Boolean()) | |
def resolve_unread_message_count(root, info, test=False): | |
user = info.context.user | |
active_counter = ActiveMessageCounter(user) | |
active_increments = root.filter( | |
lambda event: event.operation | |
in ( | |
NotificationEvents.NEW_MESSAGE, | |
NotificationEvents.UPDATE_MESSAGE, | |
) | |
and event.instance.participants.filter(id=user.id).exists( | |
).map(active_counter) | |
if test: | |
return ( | |
rx.Observable.merge( | |
rx.Observable.of(active_counter.count), | |
rx.Observable.interval(3000).map(active_counter), | |
active_increments, | |
) | |
.debounce(0.1) | |
.distinct_until_changed() | |
) | |
else: | |
return ( | |
rx.Observable.merge( | |
rx.Observable.of(active_counter.count), active_increments | |
) | |
.debounce(0.1) | |
.distinct_until_changed() | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment