Last active
October 3, 2023 11:26
-
-
Save frankie567/7aad9491f47cd7442cd8e1e9073f6457 to your computer and use it in GitHub Desktop.
Type-hinted sorting fields dependency for FastAPI
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from enum import StrEnum | |
from fastapi import FastAPI | |
from .sorting import Sorting | |
class SortingField(StrEnum): | |
FIELD_A = "field_a" | |
FIELD_B = "field_b" | |
app = FastAPI() | |
@app.get("/sorting-literal") | |
async def test_sorting_literal(sorting: Sorting[Literal["field_a", "field_b"]]): | |
... | |
@app.get("/sorting-enum") | |
async def test_sorting_enum(sorting: Sorting[SortingFields]): | |
... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import enum | |
from functools import cached_property | |
from typing import ( | |
Annotated, | |
Generic, | |
Literal, | |
TypeGuard, | |
TypeVar, | |
get_args, | |
get_origin, | |
) | |
from fastapi import Depends, Query | |
ASF = TypeVar("ASF", bound=str) | |
SortingType = list[tuple[ASF, bool]] | |
class SortingGetterInvalidConfiguration(Exception): | |
def __init__(self) -> None: | |
message = ( | |
"The type you provided to Sorting is not supported. " | |
"Please use a `Literal` or an `Enum`." | |
) | |
super().__init__(message) | |
class SortingFieldNotAllowed(Exception): | |
def __init__(self, field: str, allowed_fields: set[str]) -> None: | |
self.field = field | |
self.allowed_fields = allowed_fields | |
message = ( | |
f'You cannot sort by the field "{field}". ' | |
f"Allowed fields are: {', '.join(allowed_fields)}" | |
) | |
super().__init__(message) | |
class SortingGetter(Generic[ASF]): | |
def __call__(self, sort: str = Query(None)) -> Sorting[ASF]: | |
sorting: Sorting[ASF] = [] | |
for field in sort.split(","): | |
is_desc = False | |
if field.startswith("-"): | |
is_desc = True | |
field = field[1:] | |
if not self._is_allowed_field(field): | |
raise SortingFieldNotAllowed(field, self.allowed_fields) | |
sorting.append((field, is_desc)) | |
return sorting | |
def _is_allowed_field(self, field: str) -> TypeGuard[ASF]: | |
return field in self.allowed_fields | |
@cached_property | |
def allowed_fields(self) -> set[str]: | |
generic_type = self.__orig_class__.__args__[0] # type: ignore | |
if get_origin(generic_type) is Literal: | |
return set(get_args(generic_type)) | |
elif isinstance(generic_type, enum.EnumType): | |
return set(item.value for item in generic_type) # type: ignore | |
raise SortingGetterInvalidConfiguration() | |
Sorting = Annotated[SortingType[ASF], Depends(SortingGetter[ASF])] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment