Skip to content

Instantly share code, notes, and snippets.

@mjtamlyn
Last active December 2, 2021 08:42
Show Gist options
  • Save mjtamlyn/e8e03d78764552289ea4e2155c554deb to your computer and use it in GitHub Desktop.
Save mjtamlyn/e8e03d78764552289ea4e2155c554deb to your computer and use it in GitHub Desktop.
Graphene pagination
import functools
from django.db.models import Prefetch, QuerySet
import attr
import graphene
from cursor_pagination import CursorPaginator
from graphene.utils.str_converters import to_snake_case
from graphql_relay import connection_from_list
from .optimization import get_fields, get_node_type
@attr.s
class PageQuery(object):
before = attr.ib()
after = attr.ib()
first = attr.ib()
last = attr.ib()
def connection_from_cursor_paginated(queryset, connection_type, edge_type, pageinfo_type, page_query=None):
"""Create a Connection object from a queryset, using CursorPaginator"""
paginator = CursorPaginator(queryset, queryset.query.order_by)
if page_query is None:
page_query = PageQuery()
page = paginator.page(
**attr.asdict(page_query)
)
edges = []
for item in list(page):
edge = edge_type(node=item, cursor=paginator.cursor(item))
edges.append(edge)
if page:
page_info = pageinfo_type(
start_cursor=paginator.cursor(edges[0].node),
end_cursor=paginator.cursor(edges[-1].node),
has_previous_page=page.has_previous,
has_next_page=page.has_next,
)
else:
page_info = pageinfo_type(
start_cursor=None,
end_cursor=None,
has_previous_page=False,
has_next_page=False,
)
return connection_type(
edges=edges,
page_info=page_info,
)
def apply_optimizers(queryset, node_type, subfields, info, post_processors):
"""Apply optimisations the a given queryset.
This function calls itself recursively to follow the GraphQL tree down,
looking at node_type to find functions it can call to improve the queryset
before execution.
"""
# Firstly, look for any `dependant_fields_FOO` functions
# These functions allow additional fields to be optimised for if FOO
# exists.
for field_name, details in subfields.items():
field_name = to_snake_case(field_name)
dependant_fields_getter = getattr(node_type, 'dependant_fields_%s' % field_name, None)
if dependant_fields_getter is not None:
dependant_fields = dependant_fields_getter()
for dependant_field, dependant_subfields in dependant_fields.items():
if dependant_field not in subfields:
subfields[dependant_field] = {'fields': {}}
for subfield in dependant_subfields:
if subfield not in subfields[dependant_field]['fields']:
subfields[dependant_field]['fields'][subfield] = {}
# The meat of the problem - apply optimizers
for field_name, details in subfields.items():
field_name = to_snake_case(field_name)
# If the field has subfields, and is a node, then recurse and apply as a prefetch_related
if details.get('fields'):
related_node_type = get_node_type(node_type._meta.fields[field_name])
if related_node_type:
model = related_node_type._meta.model
related_queryset = apply_optimizers(model.objects.all(), related_node_type, details['fields'], info, post_processors)
# The node can also have a custom `prefetch_FOO` function to customise the way a prefetch_related is applied
prefetcher = getattr(node_type, 'prefetch_%s' % field_name, None)
if prefetcher is not None:
queryset = prefetcher(queryset, related_queryset)
elif isinstance(queryset, QuerySet):
queryset = queryset.prefetch_related(Prefetch(field_name, queryset=related_queryset))
# Now look for any field-specific optimizers, and apply them.
# These are called `optimize_FOO`, and take the queryset, info,
# subfields if they are present, and any arguments
optimizer = getattr(node_type, 'optimize_%s' % field_name, None)
if optimizer is not None:
kwargs = {'queryset': queryset, 'info': info}
if 'fields' in details:
kwargs['subfields'] = details['fields']
if 'arguments' in details:
kwargs.update(details['arguments'])
queryset = optimizer(**kwargs)
# Now look for any subfield-specific optimizers, and apply them
# These are called `optimize_FOO__BAR`, and take the queryset, info,
# subfields if they are present, and any arguments
for subfield_name in details.get('fields', []):
optimizer = getattr(node_type, 'optimize_%s__%s' % (field_name, subfield_name), None)
if optimizer is not None:
kwargs = {'queryset': queryset, 'info': info}
if details.get('arguments'):
kwargs.update(details['arguments'])
queryset = optimizer(**kwargs)
# Finally, look for any `post_process_FOO` functions and collect them
post_processor = getattr(node_type, 'post_process_%s' % field_name, None)
if post_processor:
post_processors.append(post_processor)
return queryset
def node_fields_are_queried(fields, key):
"""Utility function to check if any subfields are needed."""
node_fields = fields[key].get('fields', {}).get('edges', {}).get('fields', {}).get('node', {}).get('fields', {})
if node_fields:
return True
return False
def optimize_qs(connection_type, queryset, info=None, fields=None, post_processors=None):
"""Apply optimisations to the queryset based on the fields queried."""
if fields is None:
fields = get_fields(info.field_asts, info)
key = next(iter(fields))
if not node_fields_are_queried(fields, key):
return queryset
if post_processors is None:
post_processors = []
subfields = fields[key]['fields']['edges']['fields']['node']['fields']
node_type = connection_type._meta.node
queryset = apply_optimizers(queryset, node_type, subfields, info, post_processors)
return queryset, post_processors
class CursorPaginatedConnectionField(graphene.Field):
def __init__(self, *args, **kwargs):
kwargs['resolver'] = self.resolver
kwargs.setdefault('before', graphene.String())
kwargs.setdefault('after', graphene.String())
kwargs.setdefault('first', graphene.Int())
kwargs.setdefault('last', graphene.Int())
super().__init__(*args, **kwargs)
def resolver(self, instance, info, parent_resolver=None, before=None, after=None, first=None, last=None):
if parent_resolver:
qs = parent_resolver(instance, info)
else:
qs = instance.items
if isinstance(qs, self.type):
return qs
page_query = PageQuery(before=before, after=after, first=first, last=last)
queryset, post_processors = optimize_qs(self.type, qs, info)
if isinstance(qs, list):
connection = connection_from_list(
queryset,
connection_type=self.type,
edge_type=self.type.Edge,
pageinfo_type=graphene.relay.PageInfo,
args=attr.asdict(page_query),
)
else:
connection = connection_from_cursor_paginated(
queryset,
connection_type=self.type,
edge_type=self.type.Edge,
pageinfo_type=graphene.relay.PageInfo,
page_query=page_query,
)
# Post processors currently only work when the processor model is the
# topmost model, behaviour below that is undefined and may break
for processor in post_processors:
processor([edge.node for edge in connection.edges], info)
return connection
def get_resolver(self, parent_resolver):
return functools.partial(self.resolver, parent_resolver=parent_resolver)
from graphql.language.ast import (
FragmentSpread, InlineFragment, ListValue, Variable,
)
def simplify_argument(arg):
if isinstance(arg, ListValue):
return [simplify_argument(a) for a in arg.values]
if isinstance(arg, Variable):
return arg.name.value
return arg.value
def merge_dicts(d1, d2):
for key in d2:
if key in d1 and isinstance(d1[key], dict) and isinstance(d2[key], dict):
merge_dicts(d1[key], d2[key])
else:
d1[key] = d2[key]
def get_fields(asts, resolve_info):
fields = {}
for field in asts:
if isinstance(field, FragmentSpread):
local = get_fields(resolve_info.fragments[field.name.value].selection_set.selections, resolve_info)
for key in local:
if key in fields and local[key].get('fields'):
try:
fields[key]['fields'].update(local[key]['fields'])
except KeyError:
fields[key].update(local[key])
else:
fields[key] = local[key]
elif isinstance(field, InlineFragment):
local = get_fields(field.selection_set.selections, resolve_info)
for key in local:
if key in fields and local[key].get('fields'):
try:
fields[key]['fields'].update(local[key]['fields'])
except KeyError:
fields[key].update(local[key])
else:
fields[key] = local[key]
elif hasattr(field, 'selection_set'):
local = {}
if field.selection_set:
local['fields'] = get_fields(field.selection_set.selections, resolve_info)
if field.arguments:
local['arguments'] = {arg.name.value: simplify_argument(arg.value) for arg in field.arguments}
if fields.get(field.name.value):
if local.get('fields'):
try:
merge_dicts(fields[field.name.value]['fields'], local['fields'])
except KeyError:
fields[field.name.value]['fields'] = local['fields']
if local.get('arguments'):
try:
fields[field.name.value]['arguments'].update(local['arguments'])
except KeyError:
fields[field.name.value]['arguments'] = local['arguments']
else:
fields[field.name.value] = local
return fields
def get_node_type(field):
try:
return field.get_type().type.of_type
except AttributeError:
try:
if hasattr(field.type._meta, 'django_fields'):
return field.type
except AttributeError:
pass
return
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment