Last active
May 15, 2021 17:55
-
-
Save Object905/324ad346be59d6cbe8fa83aac58e9429 to your computer and use it in GitHub Desktop.
Generic pagination for fastapi and sqlalchemy (easyli adapted to other orms/db drivers)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
from typing import Generic, List, Optional, TypeVar | |
from fastapi import Query | |
from pydantic.generics import GenericModel | |
from starlette.datastructures import URL | |
DEFAULT_PAGE_SIZE = 25 | |
MAX_PAGE_SIZE = 100 | |
ItemT = TypeVar("ItemT") | |
class Paginated(GenericModel, Generic[ItemT]): | |
item_count: int | |
page_count: int | |
current_page: int | |
previous: Optional[str] | |
next: Optional[str] | |
data: List[ItemT] | |
class PageInfo: | |
def __init__( | |
self, | |
page: int = Query(0, ge=0), | |
page_size: int = Query( | |
DEFAULT_PAGE_SIZE, | |
le=MAX_PAGE_SIZE, | |
alias="pageSize", | |
), | |
): | |
self.page = page | |
self.page_size = page_size | |
def paginate(self, sqla_query, url: URL) -> Optional[Paginated[ItemT]]: | |
items = self.get_current_page_items(sqla_query) | |
item_count = self.get_total_item_count(sqla_query) | |
paging_kwargs = dict( | |
item_count=item_count, | |
page_count=self.total_page_count(item_count), | |
current_page=self.page, | |
) | |
if self.has_next_page(len(items), item_count): | |
paging_kwargs["next"] = self.next_page_url(url) | |
if self.has_previous_page(): | |
paging_kwargs["previous"] = self.previous_page_url(url) | |
return Paginated(data=items, **paging_kwargs) | |
@staticmethod | |
def get_total_item_count(sqla_query): | |
return sqla_query.order_by(None).count() | |
def get_current_page_items(self, sqla_query): | |
return sqla_query.limit(self.page_size).offset(self.offset()).all() | |
def offset(self) -> int: | |
return self.page * self.page_size | |
def total_page_count(self, total_item_count: int) -> int: | |
return int(math.ceil(total_item_count / float(self.page_size))) | |
def has_next_page(self, current_page_item_count: int, total_item_count: int): | |
seen_item_count = self.offset() + current_page_item_count | |
return seen_item_count < total_item_count | |
def has_previous_page(self) -> bool: | |
return self.page > 0 | |
def next_page_url(self, base_url: URL) -> str: | |
return self.next().change_paging_url_params(base_url) | |
def previous_page_url(self, base_url: URL) -> str: | |
return self.prev().change_paging_url_params(base_url) | |
def change_paging_url_params(self, url: URL) -> str: | |
self_dict = self.dict() | |
removed_previous_paging = url.remove_query_params(self_dict.keys()) | |
return str(removed_previous_paging.include_query_params(**self_dict)) | |
def dict(self) -> dict: | |
return {"page": self.page, "page_size": self.page_size} | |
def next(self) -> "PageInfo": | |
return PageInfo(page=self.page + 1, page_size=self.page_size) | |
def prev(self) -> "PageInfo": | |
if self.page == 0: | |
raise ValueError("Can't go before page 0") | |
return PageInfo(page=self.page - 1, page_size=self.page_size) | |
# ### USAGE | |
# @router.get("/users", response_model=Paginated[User]) | |
# def get_users(request: Request, page: PageInfo = Depends(), db: Session = Depends()): | |
# return page.paginate( | |
# query_users(db), request.url | |
# ) # query_users should return sqlalchemy query |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment