Created
March 17, 2021 19:56
-
-
Save joeydebreuk/2e1333fb8da82220bca5300ee81d225c to your computer and use it in GitHub Desktop.
strawberry-django-pagination
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Usage like: | |
class SomeResolver(PaginationMixin, ModelResolver): | |
model = SomeModel | |
and | |
class SomeOtherResolver(PaginationMixin, ModelResolver): | |
model = SomeOtherModel | |
@strawberry.field | |
def related( | |
self, | |
root: Show, | |
before: Optional[ID] = UNSET, | |
after: Optional[ID] = UNSET, | |
first: Optional[int] = UNSET, | |
last: Optional[int] = UNSET, | |
) -> SomeResolver.get_connection_type(): | |
return paginate_django_queryset( | |
queryset=root.related.all(), | |
connection_type=SomeResolver.get_connection_type(), | |
edge_type=SomeResolver.get_edge_type(), | |
before=before, | |
after=after, | |
first=first, | |
last=last, | |
) | |
""" | |
from typing import List, Optional, Type | |
import strawberry | |
from django.conf import settings | |
from django.db import models | |
from graphql import GraphQLError | |
from strawberry import ID | |
from strawberry.arguments import UNSET, is_unset | |
from strawberry.types import Info | |
from strawberry_django.resolvers import get_permission_classes | |
@strawberry.type( | |
description=( | |
"Information about pagination in a connection, " | |
"based on https://graphql.github.io/learn/pagination/" | |
) | |
) | |
class Pagination: | |
has_next_page: bool = strawberry.field( | |
description="When paginating forwards, are there more items?" | |
) | |
has_previous_page: bool = strawberry.field( | |
description="When paginating backwards, are there more items?" | |
) | |
start_cursor: Optional[str] = strawberry.field( | |
description="When paginating forwards, the cursor to continue" | |
) | |
end_cursor: Optional[str] = strawberry.field( | |
description="When paginating backwards, the cursor to continue" | |
) | |
def paginate_django_queryset( | |
queryset: models.QuerySet, | |
# connection_type is a strawberry field that implements page_info and edges | |
connection_type: Type[strawberry.field], | |
edge_type: Type[strawberry.field], | |
first: Optional[int], | |
last: Optional[int] = None, | |
before: Optional[str] = None, | |
after: Optional[str] = None, | |
): | |
limit = getattr(settings, "GRAPHQL_MAX_PAGINATION_LIMIT", 500) | |
first_is_set = first is not None and not is_unset(first) | |
last_is_set = last is not None and not is_unset(last) | |
if not first_is_set and not last_is_set: | |
# Default pagination to first x results | |
first = 50 | |
first_is_set = True | |
if first_is_set and last_is_set: | |
raise GraphQLError("Passing both `first` and `last` is not supported") | |
if first_is_set: | |
if first < 0: | |
raise GraphQLError("Negative indexing not supported") | |
if first > limit: | |
raise GraphQLError( | |
f"Requesting {first} records on the {connection_type.__name__} " | |
f"exceeds the `first` limit of {limit} records" | |
) | |
if last_is_set: | |
if last < 0: | |
raise GraphQLError("Negative indexing not supported") | |
if last > limit: | |
raise GraphQLError( | |
f"Requesting {last} records on the {connection_type.__name__} " | |
f"exceeds the `last` limit of {limit} " | |
f"records" | |
) | |
queryset = queryset.order_by("pk") | |
clean_queryset = queryset | |
total_count = queryset.count() | |
if after: | |
queryset = queryset.filter(pk__gt=after) | |
if before: | |
queryset = queryset.filter(pk__lt=before) | |
if first is not None and first_is_set: | |
queryset = queryset[:first] | |
elif last is not None and last_is_set: | |
offset = total_count - last | |
offset = max(offset, 0) # Make sure we don't accidentally negative index | |
queryset = queryset[offset:] | |
if queryset: | |
has_previous_page = clean_queryset.filter(pk__lt=queryset[0].pk).count() > 0 | |
else: | |
has_previous_page = False | |
if queryset: | |
last_obj_pk = queryset[len(queryset) - 1].pk | |
has_next_page = clean_queryset.filter(pk__gt=last_obj_pk).count() > 0 | |
else: | |
has_next_page = False | |
return connection_type( | |
total_count=total_count, | |
page_info=Pagination( | |
start_cursor=queryset[0].pk if queryset else None, | |
end_cursor=queryset[len(queryset) - 1].pk if queryset else None, | |
has_next_page=has_next_page, | |
has_previous_page=has_previous_page, | |
), | |
edges=[edge_type.from_queryset(q) for q in queryset], | |
) | |
class PaginationMixin: | |
edge_type = None | |
connection_type = None | |
@classmethod | |
def get_edge_type(cls): | |
if cls.edge_type: | |
return cls.edge_type | |
name = f"{cls.get_pacalcase_name()}Edge" | |
@strawberry.type(name=name) | |
class Edge: | |
instance: strawberry.Private[cls.model] | |
@strawberry.field | |
def cursor(self) -> int: | |
return self.instance.id | |
@strawberry.field | |
def node(self) -> cls.output_type: | |
return self.instance | |
@staticmethod | |
def from_queryset(queryset: models.QuerySet) -> "Edge": | |
return Edge(queryset) | |
cls.edge_type = type(name, (Edge,), {}) | |
cls.edge_type = Edge | |
return cls.edge_type | |
@classmethod | |
def get_pacalcase_name(cls) -> str: | |
return cls.model._meta.object_name | |
@classmethod | |
def get_connection_type(cls): | |
if cls.connection_type: | |
return cls.connection_type | |
name = f"{cls.get_pacalcase_name()}Connection" | |
@strawberry.type(name=name) | |
class Connection: | |
edges: List[Optional[cls.get_edge_type()]] | |
page_info: Pagination | |
total_count: int = strawberry.field( | |
description="Identifies the total count of items in the connection." | |
) | |
cls.connection_type = type(name, (Connection,), {}) | |
return cls.connection_type | |
@classmethod | |
def list_field(cls): | |
permission_classes = get_permission_classes(cls, "view") | |
edge_type = cls.get_edge_type() | |
connection_type = cls.get_connection_type() | |
@strawberry.field(permission_classes=permission_classes) | |
def list_field( | |
info: Info, | |
root, | |
filters: Optional[List[str]] = None, | |
before: Optional[ID] = UNSET, | |
after: Optional[ID] = UNSET, | |
first: Optional[int] = UNSET, | |
last: Optional[int] = UNSET, | |
) -> connection_type: | |
instance = cls(info, root) | |
return paginate_django_queryset( | |
queryset=instance.list(filters=filters), | |
connection_type=connection_type, | |
edge_type=edge_type, | |
before=before, | |
after=after, | |
first=first, | |
last=last, | |
) | |
return list_field |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment