Last active
October 6, 2021 23:36
-
-
Save ThirVondukr/27668d96d492b456b6c68b977efb17d8 to your computer and use it in GitHub Desktop.
Strawberry Input types generation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import dataclasses | |
import typing | |
from types import SimpleNamespace | |
from typing import Optional, List | |
import strawberry | |
class Op(SimpleNamespace): | |
eq = "eq" | |
neq = "neq" | |
lt = "lt" | |
lte = "lte" | |
gt = "gt" | |
gte = "gte" | |
contains = "contains" | |
not_contains = "not_contains" | |
in_ = "in_" | |
not_in = "not_in" | |
_OP_COMPARISONS = {Op.eq, Op.neq, Op.lt, Op.lte, Op.gt, Op.gte} | |
_SAME_TYPE_OP = _OP_COMPARISONS | |
_INCLUSION_OP = {Op.in_, Op.not_in} | |
_CONTAINS_OP = {Op.contains, Op.not_contains} | |
FILTER_MAP: dict[type, set[str]] = { | |
bool: {Op.eq, Op.neq}, | |
int: {*_OP_COMPARISONS, *_INCLUSION_OP}, | |
str: {*_OP_COMPARISONS, *_INCLUSION_OP, *_CONTAINS_OP}, | |
set: {*_CONTAINS_OP}, | |
list: {*_CONTAINS_OP}, | |
} | |
def create_filter_name(type_): | |
generics = typing.get_args(type_) | |
return "".join(g.__name__.capitalize() for g in generics) + type_.__name__.capitalize() + "Filter" | |
def create_filter(type_: type): | |
operations = FILTER_MAP[typing.get_origin(type_) or type_] | |
fields = [] | |
for op in operations: | |
if op in _SAME_TYPE_OP: | |
fields.append((op, Optional[type_], dataclasses.field(default=None))) | |
elif op in _INCLUSION_OP: | |
fields.append((op, Optional[List[type_]], dataclasses.field(default=None))) | |
elif op in _CONTAINS_OP: | |
generic_args = typing.get_args(type_) | |
container_type = typing.get_origin(type_) or type_ | |
if len(generic_args) == 1: | |
resulting_type = Optional[container_type[generic_args[0]]] | |
else: | |
resulting_type = Optional[container_type] | |
fields.append((op, resulting_type, dataclasses.field(default=None))) | |
filter_ = dataclasses.make_dataclass( | |
create_filter_name(type_), | |
fields=fields | |
) | |
return strawberry.input(filter_) | |
StrFilter = create_filter(str) | |
IntFilter = create_filter(int) | |
BoolFilter = create_filter(bool) | |
@strawberry.type | |
class Root: | |
@strawberry.field | |
def test( | |
self, | |
number: filters.IntFilter, | |
boolean: filters.BoolFilter, | |
string: filters.StrFilter, | |
int_list_filter: filters.create_filter(list[int]), | |
str_list_filter: filters.create_filter(list[str]), | |
) -> int: | |
return 42 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment