Last active
December 2, 2021 08:42
-
-
Save mjtamlyn/e8e03d78764552289ea4e2155c554deb to your computer and use it in GitHub Desktop.
Graphene pagination
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import functools | |
from django.db.models import Prefetch, QuerySet | |
import attr | |
import graphene | |
from cursor_pagination import CursorPaginator | |
from graphene.utils.str_converters import to_snake_case | |
from graphql_relay import connection_from_list | |
from .optimization import get_fields, get_node_type | |
@attr.s | |
class PageQuery(object): | |
before = attr.ib() | |
after = attr.ib() | |
first = attr.ib() | |
last = attr.ib() | |
def connection_from_cursor_paginated(queryset, connection_type, edge_type, pageinfo_type, page_query=None): | |
"""Create a Connection object from a queryset, using CursorPaginator""" | |
paginator = CursorPaginator(queryset, queryset.query.order_by) | |
if page_query is None: | |
page_query = PageQuery() | |
page = paginator.page( | |
**attr.asdict(page_query) | |
) | |
edges = [] | |
for item in list(page): | |
edge = edge_type(node=item, cursor=paginator.cursor(item)) | |
edges.append(edge) | |
if page: | |
page_info = pageinfo_type( | |
start_cursor=paginator.cursor(edges[0].node), | |
end_cursor=paginator.cursor(edges[-1].node), | |
has_previous_page=page.has_previous, | |
has_next_page=page.has_next, | |
) | |
else: | |
page_info = pageinfo_type( | |
start_cursor=None, | |
end_cursor=None, | |
has_previous_page=False, | |
has_next_page=False, | |
) | |
return connection_type( | |
edges=edges, | |
page_info=page_info, | |
) | |
def apply_optimizers(queryset, node_type, subfields, info, post_processors): | |
"""Apply optimisations the a given queryset. | |
This function calls itself recursively to follow the GraphQL tree down, | |
looking at node_type to find functions it can call to improve the queryset | |
before execution. | |
""" | |
# Firstly, look for any `dependant_fields_FOO` functions | |
# These functions allow additional fields to be optimised for if FOO | |
# exists. | |
for field_name, details in subfields.items(): | |
field_name = to_snake_case(field_name) | |
dependant_fields_getter = getattr(node_type, 'dependant_fields_%s' % field_name, None) | |
if dependant_fields_getter is not None: | |
dependant_fields = dependant_fields_getter() | |
for dependant_field, dependant_subfields in dependant_fields.items(): | |
if dependant_field not in subfields: | |
subfields[dependant_field] = {'fields': {}} | |
for subfield in dependant_subfields: | |
if subfield not in subfields[dependant_field]['fields']: | |
subfields[dependant_field]['fields'][subfield] = {} | |
# The meat of the problem - apply optimizers | |
for field_name, details in subfields.items(): | |
field_name = to_snake_case(field_name) | |
# If the field has subfields, and is a node, then recurse and apply as a prefetch_related | |
if details.get('fields'): | |
related_node_type = get_node_type(node_type._meta.fields[field_name]) | |
if related_node_type: | |
model = related_node_type._meta.model | |
related_queryset = apply_optimizers(model.objects.all(), related_node_type, details['fields'], info, post_processors) | |
# The node can also have a custom `prefetch_FOO` function to customise the way a prefetch_related is applied | |
prefetcher = getattr(node_type, 'prefetch_%s' % field_name, None) | |
if prefetcher is not None: | |
queryset = prefetcher(queryset, related_queryset) | |
elif isinstance(queryset, QuerySet): | |
queryset = queryset.prefetch_related(Prefetch(field_name, queryset=related_queryset)) | |
# Now look for any field-specific optimizers, and apply them. | |
# These are called `optimize_FOO`, and take the queryset, info, | |
# subfields if they are present, and any arguments | |
optimizer = getattr(node_type, 'optimize_%s' % field_name, None) | |
if optimizer is not None: | |
kwargs = {'queryset': queryset, 'info': info} | |
if 'fields' in details: | |
kwargs['subfields'] = details['fields'] | |
if 'arguments' in details: | |
kwargs.update(details['arguments']) | |
queryset = optimizer(**kwargs) | |
# Now look for any subfield-specific optimizers, and apply them | |
# These are called `optimize_FOO__BAR`, and take the queryset, info, | |
# subfields if they are present, and any arguments | |
for subfield_name in details.get('fields', []): | |
optimizer = getattr(node_type, 'optimize_%s__%s' % (field_name, subfield_name), None) | |
if optimizer is not None: | |
kwargs = {'queryset': queryset, 'info': info} | |
if details.get('arguments'): | |
kwargs.update(details['arguments']) | |
queryset = optimizer(**kwargs) | |
# Finally, look for any `post_process_FOO` functions and collect them | |
post_processor = getattr(node_type, 'post_process_%s' % field_name, None) | |
if post_processor: | |
post_processors.append(post_processor) | |
return queryset | |
def node_fields_are_queried(fields, key): | |
"""Utility function to check if any subfields are needed.""" | |
node_fields = fields[key].get('fields', {}).get('edges', {}).get('fields', {}).get('node', {}).get('fields', {}) | |
if node_fields: | |
return True | |
return False | |
def optimize_qs(connection_type, queryset, info=None, fields=None, post_processors=None): | |
"""Apply optimisations to the queryset based on the fields queried.""" | |
if fields is None: | |
fields = get_fields(info.field_asts, info) | |
key = next(iter(fields)) | |
if not node_fields_are_queried(fields, key): | |
return queryset | |
if post_processors is None: | |
post_processors = [] | |
subfields = fields[key]['fields']['edges']['fields']['node']['fields'] | |
node_type = connection_type._meta.node | |
queryset = apply_optimizers(queryset, node_type, subfields, info, post_processors) | |
return queryset, post_processors | |
class CursorPaginatedConnectionField(graphene.Field): | |
def __init__(self, *args, **kwargs): | |
kwargs['resolver'] = self.resolver | |
kwargs.setdefault('before', graphene.String()) | |
kwargs.setdefault('after', graphene.String()) | |
kwargs.setdefault('first', graphene.Int()) | |
kwargs.setdefault('last', graphene.Int()) | |
super().__init__(*args, **kwargs) | |
def resolver(self, instance, info, parent_resolver=None, before=None, after=None, first=None, last=None): | |
if parent_resolver: | |
qs = parent_resolver(instance, info) | |
else: | |
qs = instance.items | |
if isinstance(qs, self.type): | |
return qs | |
page_query = PageQuery(before=before, after=after, first=first, last=last) | |
queryset, post_processors = optimize_qs(self.type, qs, info) | |
if isinstance(qs, list): | |
connection = connection_from_list( | |
queryset, | |
connection_type=self.type, | |
edge_type=self.type.Edge, | |
pageinfo_type=graphene.relay.PageInfo, | |
args=attr.asdict(page_query), | |
) | |
else: | |
connection = connection_from_cursor_paginated( | |
queryset, | |
connection_type=self.type, | |
edge_type=self.type.Edge, | |
pageinfo_type=graphene.relay.PageInfo, | |
page_query=page_query, | |
) | |
# Post processors currently only work when the processor model is the | |
# topmost model, behaviour below that is undefined and may break | |
for processor in post_processors: | |
processor([edge.node for edge in connection.edges], info) | |
return connection | |
def get_resolver(self, parent_resolver): | |
return functools.partial(self.resolver, parent_resolver=parent_resolver) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from graphql.language.ast import ( | |
FragmentSpread, InlineFragment, ListValue, Variable, | |
) | |
def simplify_argument(arg): | |
if isinstance(arg, ListValue): | |
return [simplify_argument(a) for a in arg.values] | |
if isinstance(arg, Variable): | |
return arg.name.value | |
return arg.value | |
def merge_dicts(d1, d2): | |
for key in d2: | |
if key in d1 and isinstance(d1[key], dict) and isinstance(d2[key], dict): | |
merge_dicts(d1[key], d2[key]) | |
else: | |
d1[key] = d2[key] | |
def get_fields(asts, resolve_info): | |
fields = {} | |
for field in asts: | |
if isinstance(field, FragmentSpread): | |
local = get_fields(resolve_info.fragments[field.name.value].selection_set.selections, resolve_info) | |
for key in local: | |
if key in fields and local[key].get('fields'): | |
try: | |
fields[key]['fields'].update(local[key]['fields']) | |
except KeyError: | |
fields[key].update(local[key]) | |
else: | |
fields[key] = local[key] | |
elif isinstance(field, InlineFragment): | |
local = get_fields(field.selection_set.selections, resolve_info) | |
for key in local: | |
if key in fields and local[key].get('fields'): | |
try: | |
fields[key]['fields'].update(local[key]['fields']) | |
except KeyError: | |
fields[key].update(local[key]) | |
else: | |
fields[key] = local[key] | |
elif hasattr(field, 'selection_set'): | |
local = {} | |
if field.selection_set: | |
local['fields'] = get_fields(field.selection_set.selections, resolve_info) | |
if field.arguments: | |
local['arguments'] = {arg.name.value: simplify_argument(arg.value) for arg in field.arguments} | |
if fields.get(field.name.value): | |
if local.get('fields'): | |
try: | |
merge_dicts(fields[field.name.value]['fields'], local['fields']) | |
except KeyError: | |
fields[field.name.value]['fields'] = local['fields'] | |
if local.get('arguments'): | |
try: | |
fields[field.name.value]['arguments'].update(local['arguments']) | |
except KeyError: | |
fields[field.name.value]['arguments'] = local['arguments'] | |
else: | |
fields[field.name.value] = local | |
return fields | |
def get_node_type(field): | |
try: | |
return field.get_type().type.of_type | |
except AttributeError: | |
try: | |
if hasattr(field.type._meta, 'django_fields'): | |
return field.type | |
except AttributeError: | |
pass | |
return |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment