Created
October 13, 2020 12:39
-
-
Save adrianschneider94/90f662ffab9dce06e2f291579ad480b7 to your computer and use it in GitHub Desktop.
Load data smarter with graphene-sqlalchemy
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from graphene.utils.str_converters import to_snake_case | |
from graphql.utils.ast_to_dict import ast_to_dict | |
from sqlalchemy.orm import ColumnProperty, Query, RelationshipProperty, joinedload, lazyload, load_only, selectinload, \ | |
subqueryload | |
def collect_fields(node, fragments): | |
"""Recursively collects fields from the AST | |
Args: | |
node (dict): A node in the AST | |
fragments (dict): Fragment definitions | |
Returns: | |
A dict mapping each field found, along with their sub fields. | |
{'name': {}, | |
'sentimentsPerLanguage': {'id': {}, | |
'name': {}, | |
'totalSentiments': {}}, | |
'slug': {}} | |
""" | |
field = {} | |
if node.get('selection_set'): | |
for leaf in node['selection_set']['selections']: | |
if leaf['kind'] == 'Field': | |
field.update({ | |
leaf['name']['value']: collect_fields(leaf, fragments) | |
}) | |
elif leaf['kind'] == 'FragmentSpread': | |
field.update(collect_fields(fragments[leaf['name']['value']], | |
fragments)) | |
return field | |
def get_fields(info): | |
"""A convenience function to call collect_fields with info | |
Args: | |
info (ResolveInfo) | |
Returns: | |
dict: Returned from collect_fields | |
""" | |
fragments = {} | |
node = ast_to_dict(info.field_asts[0]) | |
for name, value in info.fragments.items(): | |
fragments[name] = ast_to_dict(value) | |
return collect_fields(node, fragments) | |
def smart_load(query: Query, info, strategy="select-in"): | |
if len(query._entities) != 1: | |
return query | |
entity = query._entities[0].type | |
fields = get_fields(info) | |
options = get_options(entity, fields, strategy=strategy) | |
if options: | |
query = query.options(*options) | |
return query | |
def get_options( | |
entity, | |
fields, | |
strategy: str = "lazy"): | |
options = [] | |
column_properties = [] | |
for field, subfields in fields.items(): | |
try: | |
prop = getattr(entity, field) | |
except AttributeError: | |
try: | |
prop = getattr(entity, to_snake_case(field)) | |
except AttributeError: | |
continue | |
try: | |
if isinstance(prop.property, ColumnProperty): | |
column_properties.append(prop) | |
elif isinstance(prop.property, RelationshipProperty): | |
sub_options = get_options(prop.property.entity.entity, subfields, strategy=strategy) | |
if strategy == "select-in": | |
options.append(selectinload(prop).options(*sub_options)) | |
elif strategy == "joined": | |
options.append(joinedload(prop).options(*sub_options)) | |
elif strategy == "subquery": | |
options.append(subqueryload(prop).options(*sub_options)) | |
else: | |
options.append(lazyload(prop).options(*sub_options)) | |
except AttributeError: | |
continue | |
if len(column_properties) > 0: | |
options.append(load_only(*column_properties)) | |
return options |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment