Created
May 31, 2021 16:42
-
-
Save jhillacre/ae87a5f038d619728bb0326b879e3f51 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from channels.db import database_sync_to_async | |
from djangochannelsrestframework.decorators import action as dcrf_action | |
from djangochannelsrestframework.generics import GenericAsyncAPIConsumer | |
from djangochannelsrestframework.mixins import CreateModelMixin | |
from djangochannelsrestframework.mixins import DeleteModelMixin | |
from djangochannelsrestframework.mixins import ListModelMixin | |
from djangochannelsrestframework.mixins import PaginatedModelListMixin | |
from djangochannelsrestframework.mixins import PatchModelMixin | |
from djangochannelsrestframework.mixins import UpdateModelMixin | |
from djangochannelsrestframework.observer import ModelObserver | |
from djangochannelsrestframework.observer.generics import ObserverModelInstanceMixin | |
from djangochannelsrestframework.observer.generics import _GenericModelObserver | |
from rest_framework import status | |
from rest_framework.exceptions import NotFound | |
# our client specific stuff. | |
from tagos.utils.consumers import DCRFDjangoFilterBackend | |
from tagos.utils.consumers import ModelPermissionMixin | |
from tagos.utils.consumers import NiceConsumerMixin | |
from tagos.utils.consumers import OurPaginator | |
class ModelConsumer( | |
PaginatedModelListMixin, | |
ListModelMixin, | |
PatchModelMixin, | |
UpdateModelMixin, | |
CreateModelMixin, | |
DeleteModelMixin, | |
ObserverModelInstanceMixin, | |
NiceConsumerMixin, | |
ModelPermissionMixin, | |
GenericAsyncAPIConsumer, | |
): | |
perm_names = { | |
"list": "view", | |
"retrieve": "view", | |
"patch": "change", | |
"update": "change", | |
"create": "add", | |
"delete": "delete", | |
"subscribe_instance": "view", | |
"unsubscribe_instance": "view", | |
"subscribe_activity": "view", | |
"unsubscribe_activity": "view", | |
} | |
filter_backends = (DCRFDjangoFilterBackend,) | |
pagination_class = OurPaginator | |
def get_model(self): | |
return self.queryset.model | |
@dcrf_action() | |
async def unsubscribe_instance(self, request_id=None, **kwargs): | |
if request_id is None: | |
raise ValueError("request_id must have a value set") | |
instance = await database_sync_to_async(self.get_object)(**kwargs) | |
await self.handle_instance_change.unsubscribe(instance=instance) | |
try: | |
self._unsubscribe(request_id) | |
except KeyError: | |
raise NotFound(detail="Subscription not found.") | |
return None, status.HTTP_204_NO_CONTENT | |
@_GenericModelObserver | |
async def model_activity(self, message, observer=None, action=None, **kwargs): | |
await self.handle_observed_action( | |
action=action, | |
**message, | |
) | |
@dcrf_action() | |
async def subscribe_activity(self, request_id=None, **kwargs): | |
if request_id is None: | |
raise ValueError("request_id must have a value set") | |
self.model_activity: ModelObserver | |
groups = set(await self.model_activity.subscribe()) | |
self._subscribe(request_id, groups) | |
return None, status.HTTP_201_CREATED | |
@dcrf_action() | |
async def unsubscribe_activity(self, request_id=None, **kwargs): | |
if request_id is None: | |
raise ValueError("request_id must have a value set") | |
self.model_activity: ModelObserver | |
await self.model_activity.unsubscribe() | |
try: | |
self._unsubscribe(request_id) | |
except KeyError: | |
raise NotFound(detail="Subscription not found.") | |
return None, status.HTTP_204_NO_CONTENT | |
def _unsubscribe(self, request_id: str): | |
# Patch DCRF's `_unsubscribe` to fix unnecessary key errors (which prevent proper cleanup of subscriptions). | |
request_id_found = False | |
to_remove = [] | |
for group, request_ids in self.subscribed_requests.items(): | |
if request_id in request_ids: | |
request_id_found = True | |
request_ids.remove(request_id) | |
if not request_ids: | |
to_remove.append(group) | |
if not request_id_found: | |
raise KeyError(request_id) | |
for group in to_remove: | |
self.subscribed_requests.pop(group) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment