Last active
March 14, 2022 11:50
-
-
Save ThirVondukr/4f37a8b5f67d6677621f3b3d7f455da6 to your computer and use it in GitHub Desktop.
Strawberry GraphQL Cursor Pagination
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import base64 | |
import enum | |
from typing import Any, Optional, Annotated | |
from sqlalchemy import select | |
from sqlalchemy.ext.asyncio import AsyncSession | |
from sqlalchemy.orm import InstrumentedAttribute, DeclarativeMeta | |
from sqlalchemy.sql import Select | |
from gql.modules.users._fields import UserOrder | |
def model_fields_enum( | |
*columns: InstrumentedAttribute, name: str | |
) -> type[enum.Enum]: | |
return enum.Enum( # type: ignore | |
name, | |
names={column.name.upper(): column.prop for column in columns}, | |
module=__name__, | |
) | |
def encode_cursor( | |
model: DeclarativeMeta, | |
order_by: list[InstrumentedAttribute], | |
) -> str: | |
values = [ | |
str(getattr(model, attr.prop.class_attribute.name)) | |
for attr in order_by | |
] | |
return ":".join( | |
base64.b64encode(value.encode()).decode() | |
for value in values | |
) | |
def decode_cursor( | |
cursor: str, | |
order_by: list[InstrumentedAttribute], | |
# order_by isn't used but you could use it to validate cursor structure | |
) -> list[Any]: | |
return [ | |
base64.b64decode(value.encode()).decode() | |
for value in cursor.split(":") | |
] | |
class Paginator: | |
def __init__( | |
self, | |
query: Select, | |
order_by: list[enum.Enum] | list[InstrumentedAttribute], | |
): | |
self.query = query | |
self.order_by: list[InstrumentedAttribute] = [ | |
col | |
if isinstance(col, InstrumentedAttribute) | |
else col.value.class_attribute | |
for col in order_by | |
] | |
async def paginate( | |
self, | |
after: str, | |
before: str, | |
first: Optional[int], | |
last: Optional[int], | |
session: AsyncSession, | |
) -> tuple[list[Edge[Any]], PageInfo]: | |
if first and first < 0: | |
raise ValueError | |
if last and last < 0: | |
raise ValueError | |
if first and last: | |
raise ValueError | |
order_clause = tuple_(*self.order_by) | |
if last: | |
order_clause = order_clause.desc() | |
query = self.query.order_by(order_clause) | |
query = query.limit(first or last + 1) | |
if after: | |
query = query.filter( | |
decode_cursor(after, self.order_by) < self.order_by | |
) | |
if before: | |
query = query.filter( | |
self.order_by < decode_cursor(before, self.order_by) | |
) | |
nodes = list(await session.scalars(query)) | |
page_info = PageInfo( | |
has_previous_page=False, | |
has_next_page=False, | |
start_cursor="", | |
end_cursor="", | |
) | |
if first and len(nodes) > first: | |
page_info.has_next_page = True | |
nodes = nodes[:first] | |
if last: | |
# We need to reverse nodes since we used order_by.desc() | |
nodes.reverse() | |
if len(nodes) > last: | |
page_info.has_previous_page = True | |
nodes = nodes[-last:] | |
if nodes: | |
page_info.start_cursor = encode_cursor(nodes[0], self.order_by) | |
page_info.end_cursor = encode_cursor(nodes[-1], self.order_by) | |
edges = [ | |
Edge(node=node, cursor=encode_cursor(node, self.order_by)) | |
for node in nodes | |
] | |
return edges, page_info | |
@inject | |
async def all_users( | |
session: Annotated[AsyncSession, Inject], | |
order_by: list[UserOrder], | |
after: str = "", | |
before: str = "", | |
first: Optional[int] = None, | |
last: Optional[int] = None, | |
) -> Connection[Edge[UserType]]: | |
query = select(User) | |
paginator = Paginator(query, order_by) | |
edges, page_info = await paginator.paginate( | |
after=after, | |
before=before, | |
first=first, | |
last=last, | |
session=session, | |
) | |
return Connection( | |
edges=edges, | |
page_info=page_info, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment