Created
September 29, 2019 23:01
-
-
Save wapiflapi/9076882aff1bcc1846b2c65e986932d9 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# @wapiflapi gqldl early draft. | |
# | |
# The goal of this is to manage dataloaders for multiple types while | |
# at the same time providing an easy integration for relay-compliance. | |
# - There will be full utility support for the relay spec. | |
# - Integration with ariadne WILL be easy and documented. | |
# - Integration with graphene MIGHT be documented. | |
# | |
# Feel free to comment, but documentation, tests and a proper release | |
# are comming soon(tm). | |
import abc | |
import base64 | |
import binascii | |
import functools | |
import aiodataloader | |
from papi.logging import logger | |
import gc; gc.disable() | |
class DataLoader(aiodataloader.DataLoader): | |
def __init__(self, *args, context=None, **kwargs): | |
print("CUSTOM: %s, %s" % (args, kwargs)) | |
self.context = context | |
super().__init__(*args, **kwargs) | |
def relay_connection(resolver): | |
@functools.wraps(resolver) | |
async def wrapped_resolver(obj, info, *args, **kwargs): | |
hasprevpage, hasnextpage, items = await resolver( | |
obj, info, *args, **kwargs) | |
for key, cursor in items: | |
if not isinstance(cursor, str): | |
raise TypeError(f"Invalid cursor {cursor}: must be str.") | |
return { | |
"pageInfo": { | |
"hasPreviousPage": hasprevpage, | |
"hasNextPage": hasnextpage, | |
"startCursor": items[0][1] if items else None, | |
"endCursor": items[-1][1] if items else None, | |
}, | |
"edges": [{ | |
"cursor": cursor, | |
"key": key, | |
} for key, cursor in items] | |
} | |
return wrapped_resolver | |
class Relay(abc.ABC): | |
@abc.abstractmethod | |
async def get_gdl(self, obj, info): | |
pass | |
@abc.abstractmethod | |
async def setup_edge_node(self, edgename, typename, resolver): | |
pass | |
def __init__(self): | |
self.typename_resolver_map = {} | |
def resolve(self, typename): | |
def decorator(resolver): | |
self.typename_resolver_map[typename] = resolver | |
return resolver | |
return decorator | |
async def resolve_object(self, root, info, obj): | |
if isinstance(obj, dict): | |
typename = obj.get("__typename", None) | |
else: | |
# TODO: I'm not sure this works. Test it. | |
# (because of the __ being special.) | |
typename = getattr(obj, "__typename", None) | |
if typename is None: | |
raise TypeError( | |
f"No __typename attribute or key found in {obj}.") | |
try: | |
resolver = self.typename_resolver_map[typename] | |
except KeyError: | |
raise NotImplementedError( | |
f"No resolver registered for {typename} using relay.resolve()") | |
# TODO: Should we still handle this being an object and not a dict ? | |
resolved = await resolver(root, info, obj) | |
if not "id" in resolved: | |
try: | |
key = resolved["key"] | |
except KeyError: | |
raise TypeError( | |
f"No id attribute or key found in {obj} and no key provided.") | |
resolved["id"] = self.get_gdl(root, info).to_global_id( | |
typename, key, | |
) | |
return resolved | |
async def resolve_global_id(self, root, info, id): | |
return await self.resolve_object( | |
root, info, await self.get_gdl(root, info).load_global_id(id)) | |
async def resolve_type_key(self, root, info, typename, key, index="id"): | |
return await self.resolve_object( | |
root, info, await self.get_gdl(root, info).load_type_key( | |
typename, key, index=index)) | |
def connection(self, typename, index="id"): | |
async def resolver(root, info): | |
return await self.resolve_type_key( | |
root, info, typename=typename, key=root["key"], index=index) | |
self.setup_edge_node(f"{typename}Edge", typename, resolver) | |
return relay_connection | |
class GlobalDataLoader(abc.ABC): | |
@staticmethod | |
def to_global_id(typename, key): | |
# Let's not serialize/un-serialize keys: not our job. | |
if not typename or not isinstance(typename, str): | |
raise TypeError( | |
f"Invalid typename {typename}: must be non-empty str.") | |
if not key or not isinstance(key, str): | |
raise TypeError( | |
f"Invalid key {key}: must be non-empty str.") | |
gid = f"{typename}:{key}" | |
return base64.b64encode(gid.encode("utf8")).decode("utf8") | |
@staticmethod | |
def from_global_id(gid): | |
try: | |
gid = base64.b64decode(gid.encode("utf8")).decode("utf8") | |
except (binascii.Error, UnicodeDecodeError): | |
raise TypeError(f"Received invalid gid {gid}") | |
typename, _, key = gid.partition(':') | |
if not typename or not key: | |
raise TypeError(f"Invalid global ID {gid}") | |
return typename, key | |
@classmethod | |
def register_loadertype(cls, typename, loadertype, index="id"): | |
if not (isinstance(loadertype, type) | |
and issubclass(loadertype, aiodataloader.DataLoader)): | |
loadertype = functools.partial(DataLoader, loadertype) | |
cls.typename_loadertype_map[(typename, index)] = loadertype | |
@classmethod | |
def loadertype(cls, typename, index="id"): | |
return functools.partial( | |
cls.register_loadertype, typename, index=index) | |
@classmethod | |
def enforce_typed_objects(cls, typename, objects, index="id", keys=None): | |
if keys is not None and len(keys) != len(objects): | |
raise TypeError( | |
"keys should be None or the same length as objects.") | |
for i, obj in enumerate(objects): | |
if keys is not None and index == "id": | |
gid = cls.to_global_id(typename, keys[i]) | |
elif isinstance(obj, dict): | |
gid = obj.get("id", None) | |
else: | |
gid = getattr(obj, "id", None) | |
if gid is None: | |
raise TypeError( | |
f"Loader for {typename} with index={index} returned " | |
f"something without 'id' attribute or key." | |
) | |
if isinstance(obj, dict): | |
assert obj.get("id", gid) == gid | |
obj["__typename"] = typename | |
obj["id"] = gid | |
else: | |
assert getattr(obj, "id", gid) == gid | |
setattr(obj, "__typename", typename) | |
setattr(obj, "id", gid) | |
return objects | |
@property | |
@abc.abstractmethod | |
def typename_loadertype_map(self): | |
pass | |
def __init__(self, context=None): | |
self.typename_loader_map = { | |
(typename, index): loadertype(context=context) | |
for (typename, index), loadertype | |
in self.typename_loadertype_map.items() | |
} | |
def get_type_loader(self, typename, index="id"): | |
try: | |
return self.typename_loader_map[(typename, index)] | |
except KeyError: | |
raise TypeError( | |
f"No loader registered for {typename} with index={index}.") | |
async def load_global_id(self, gid): | |
typename, key = self.from_global_id(gid) | |
return await self.load_type_key(typename, key) | |
async def load_global_ids(self, gids): | |
return [await self.load_global_id(gid) for gid in gids] | |
async def load_type_key(self, typename, key, index="id"): | |
return (await self.load_type_keys(typename, [key], index=index))[0] | |
async def load_type_keys(self, typename, keys, index="id"): | |
loader = self.get_type_loader(typename, index=index) | |
return self.enforce_typed_objects( | |
typename, await loader.load_many(keys), index=index, keys=keys) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment