Last active
April 14, 2024 10:25
-
-
Save exhuma/ee51a71e186d07041eded0e90b7c8fbd to your computer and use it in GitHub Desktop.
Automatic Pagination with FastAPI and SQLAlchemy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastapi import Depends, FastAPI | |
from pydantic import BaseModel | |
from sqlalchemy import Column, Integer, String | |
from sqlalchemy.ext.declarative import declarative_base | |
from sqlalchemy.orm import Query, Session | |
from pagination import PaginatedList, paginated_get | |
Base = declarative_base() | |
# --- The SQLAlchemy Model ---------------------------------------------------- | |
class CustomerDbModel(Base): | |
__tablename__ = "customers" | |
id = Column(Integer, primary_key=True, index=True) | |
name = Column(String) | |
# --- Core Business Logic ----------------------------------------------------- | |
# This is a simple example of a function that returns a SQLAlchemy query. A | |
# real-world example may be more complex, use joins, filters, etc. | |
# The main point here is that it returns an SQLAlchemy query, which does not | |
# need to know about pagination. This ensures that the core business logic is | |
# not coupled to the pagination logic and other API-layer concerns. | |
# While it is almost guaranteed that a real-world implementation will require | |
# values from the user-request (f.ex. for filtering, sorting, etc.), it is | |
# important to separate *how* the user-request is transformed to the query from | |
# the *how* it is handled on the back-end. Using this pattern, this | |
# "mapping/decoupling" can be handled solely on the FastAPI route definition. | |
def get_customers(db: Session) -> Query[CustomerDbModel]: | |
result = db.query(CustomerDbModel) | |
return result | |
# --- FastAPI ----------------------------------------------------------------- | |
def get_db(): | |
db = Session() | |
try: | |
yield db | |
finally: | |
db.close() | |
class CustomerApiModel(BaseModel): | |
name: str | |
APP = FastAPI() | |
@paginated_get( | |
APP, | |
"/customers", | |
# Using "model_validate" here is a simple example. It could be made | |
# arbitrarily complex. | |
api_mapper=CustomerApiModel.model_validate, | |
response_model=PaginatedList[CustomerApiModel], | |
) | |
def customers(db: Session = Depends(get_db)): | |
# NOTE: This uses the `get_customers` function from the core business logic, | |
# returning a non-paginated SQLAlchemy query. | |
# The "paginated_get" decorator will handle the pagination and | |
# automatically inject the necessary HTTP query parameters. | |
return get_customers(db) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This module provides a decorator that paginates a SQLAlchemy query and returns | |
the results as a Pydantic model. | |
""" | |
import inspect | |
from typing import Any, Callable | |
from fastapi import APIRouter, FastAPI | |
from pydantic import BaseModel | |
from sqlalchemy.orm import Query | |
_PARAMETER_ORDER = { | |
inspect.Parameter.POSITIONAL_ONLY: 1, | |
inspect.Parameter.POSITIONAL_OR_KEYWORD: 2, | |
inspect.Parameter.VAR_POSITIONAL: 3, | |
inspect.Parameter.KEYWORD_ONLY: 4, | |
inspect.Parameter.VAR_KEYWORD: 5, | |
} | |
class PaginatedList[T](BaseModel): | |
items: list[T] | |
page: int | |
total_items: int | |
def paginated_get[ | |
T, U: BaseModel | |
]( | |
app: FastAPI | APIRouter, | |
path: str, | |
api_mapper: Callable[[T], U], | |
*args, | |
default_page_size: int = 25, | |
**kwargs, | |
): | |
""" | |
Register a route with FastAPI that returns automatically paginates an | |
SQLAlchemy . | |
:param app: The FastAPI application or router to register the route with. | |
:param path: The path to register the route at. | |
:param api_mapper: A function that maps the SQLAlchemy model (of one item of | |
the query) to a Pydantic model. | |
:param default_page_size: The default number of items per page. | |
:param args: Additional positional arguments to pass to the FastAPI route. | |
:param kwargs: Additional keyword arguments to pass to the FastAPI route. | |
Example: | |
>>> from fastapi import FastAPI | |
>>> from pydantic import BaseModel | |
>>> from sqlalchemy.orm import Query | |
>>> from flightlogs.pagination import PaginatedList, paginated_get | |
>>> | |
>>> app = FastAPI() | |
>>> | |
>>> class CustomerApiModel(BaseModel): | |
... name: str | |
>>> | |
>>> def mapper(data: CustomerDbModel) -> CustomerApiModel: | |
... return CustomerApiModel(name=data.label) | |
>>> | |
>>> def get_customers(db: Session) -> Query[CustomerDbModel]: | |
... result = db.query(CustomerDbModel) | |
... return result | |
>>> | |
>>> @paginated_get( | |
... app, | |
... "/customers", | |
... api_mapper=mapper, | |
... response_model=PaginatedList[CustomerApiModel], | |
... ) | |
... def customers(db: Session = Depends(get_db)): | |
... return get_customers(db) | |
""" | |
def decorator( | |
func: Callable[..., Query[Any]] | |
) -> Callable[..., PaginatedList[U]]: | |
""" | |
Decorator that registers a route with FastAPI that returns a paginated | |
list of items. | |
:param func: The function that returns the SQLAlchemy query to paginate. | |
""" | |
def wrapper( | |
*inner_args, | |
page: int = 1, | |
per_page: int = default_page_size, | |
**inner_kwargs, | |
) -> PaginatedList[U]: | |
data = func(*inner_args, **inner_kwargs) | |
total_count = data.count() | |
data = data.limit(per_page) | |
data = data.offset((page - 1) * per_page) | |
items = [api_mapper(row) for row in data] | |
return PaginatedList( | |
items=items, | |
page=page, | |
total_items=total_count, | |
) | |
# We need to merge the signatures of the original function and the | |
# wrapper function. This exposes the parameters to FastAPI. This enable | |
# all the functionality of FastAPI, such as automatic validation and | |
# documentation. | |
# | |
# When doing this, we need to make sure that the parameters are in the | |
# correct order for Python itself. | |
func_sig = inspect.signature(func) | |
wrapper_sig = inspect.signature(wrapper) | |
wrapper_params = list(wrapper_sig.parameters.values()) | |
wrapper_params.extend(func_sig.parameters.values()) | |
wrapper_params = [ | |
p | |
for p in sorted( | |
wrapper_params, | |
key=lambda p: _PARAMETER_ORDER.get(p.kind, 0), | |
) | |
if p.kind | |
not in ( | |
inspect.Parameter.VAR_POSITIONAL, | |
inspect.Parameter.VAR_KEYWORD, | |
) | |
] | |
wrapper_sig = wrapper_sig.replace(parameters=wrapper_params) | |
# We can't use "functools.wraps" here, because it would copy the | |
# signature of the wrapper function, undoing the work we did above. | |
wrapper.__signature__ = wrapper_sig | |
wrapper.__doc__ = func.__doc__ | |
wrapper.__name__ = func.__name__ | |
route = app.get(path, *args, **kwargs)(wrapper) | |
return route | |
return decorator |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment