Skip to content

Instantly share code, notes, and snippets.

@adrianschneider94
Created October 13, 2020 12:39
Show Gist options
  • Save adrianschneider94/90f662ffab9dce06e2f291579ad480b7 to your computer and use it in GitHub Desktop.
Save adrianschneider94/90f662ffab9dce06e2f291579ad480b7 to your computer and use it in GitHub Desktop.
Load data smarter with graphene-sqlalchemy
from graphene.utils.str_converters import to_snake_case
from graphql.utils.ast_to_dict import ast_to_dict
from sqlalchemy.orm import ColumnProperty, Query, RelationshipProperty, joinedload, lazyload, load_only, selectinload, \
subqueryload
def collect_fields(node, fragments):
"""Recursively collects fields from the AST
Args:
node (dict): A node in the AST
fragments (dict): Fragment definitions
Returns:
A dict mapping each field found, along with their sub fields.
{'name': {},
'sentimentsPerLanguage': {'id': {},
'name': {},
'totalSentiments': {}},
'slug': {}}
"""
field = {}
if node.get('selection_set'):
for leaf in node['selection_set']['selections']:
if leaf['kind'] == 'Field':
field.update({
leaf['name']['value']: collect_fields(leaf, fragments)
})
elif leaf['kind'] == 'FragmentSpread':
field.update(collect_fields(fragments[leaf['name']['value']],
fragments))
return field
def get_fields(info):
"""A convenience function to call collect_fields with info
Args:
info (ResolveInfo)
Returns:
dict: Returned from collect_fields
"""
fragments = {}
node = ast_to_dict(info.field_asts[0])
for name, value in info.fragments.items():
fragments[name] = ast_to_dict(value)
return collect_fields(node, fragments)
def smart_load(query: Query, info, strategy="select-in"):
if len(query._entities) != 1:
return query
entity = query._entities[0].type
fields = get_fields(info)
options = get_options(entity, fields, strategy=strategy)
if options:
query = query.options(*options)
return query
def get_options(
entity,
fields,
strategy: str = "lazy"):
options = []
column_properties = []
for field, subfields in fields.items():
try:
prop = getattr(entity, field)
except AttributeError:
try:
prop = getattr(entity, to_snake_case(field))
except AttributeError:
continue
try:
if isinstance(prop.property, ColumnProperty):
column_properties.append(prop)
elif isinstance(prop.property, RelationshipProperty):
sub_options = get_options(prop.property.entity.entity, subfields, strategy=strategy)
if strategy == "select-in":
options.append(selectinload(prop).options(*sub_options))
elif strategy == "joined":
options.append(joinedload(prop).options(*sub_options))
elif strategy == "subquery":
options.append(subqueryload(prop).options(*sub_options))
else:
options.append(lazyload(prop).options(*sub_options))
except AttributeError:
continue
if len(column_properties) > 0:
options.append(load_only(*column_properties))
return options
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment