Created
November 8, 2018 16:46
-
-
Save ahokinson/89ad498bfe78177406eb1618ca68d966 to your computer and use it in GitHub Desktop.
Adding filters with requirements to Graphene and SQLAlchemy.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from graphene import Connection | |
from graphene_sqlalchemy import SQLAlchemyConnectionField | |
from database.graphql.filters import filter_argument_for_model, Filter | |
class SQLAlchemyFilteredConnectionField(SQLAlchemyConnectionField): | |
def __init__(self, type, *args, **kwargs): | |
if "filter" not in kwargs and issubclass(type, Connection): | |
try: | |
model = type.Edge.node._type._meta.model | |
kwargs.setdefault("filter", filter_argument_for_model(model)) | |
except Exception: | |
raise Exception('Cannot create filter argument for {}. A model is required. Set the "filter" argument ' | |
"to None to disabling the creation of the filter query argument".format(type.__name__)) | |
elif "filter" in kwargs and kwargs["filter"] is None: | |
del kwargs["filter"] | |
self.required = kwargs.pop("required") | |
super(SQLAlchemyFilteredConnectionField, self).__init__(type, *args, **kwargs) | |
@classmethod | |
def get_query(cls, model, info, filter=None, **kwargs): | |
query = super(SQLAlchemyFilteredConnectionField, cls).get_query(model, info, **kwargs) | |
if filter: | |
for k, v in filter.items(): | |
query = Filter.add_filter(query, model, k, v) | |
return query | |
@classmethod | |
def resolve_connection(cls, connection_type, model, info, args, resolved): | |
filters = args.get("filter", {}) | |
required_filters = getattr(info.schema._query, info.field_name).required | |
missing_filters = set(required_filters) - set(filters.keys()) | |
if missing_filters: | |
raise RequiredFilterException(missing_filters) | |
return super(SQLAlchemyFilteredConnectionField, cls).resolve_connection( | |
connection_type, model, info, args, resolved) | |
class RequiredFilterException(Exception): | |
def __init__(self, missing): | |
super().__init__("In order to resolve this query, the following filters are still required: {}".format( | |
", ".join(missing))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment