Last active
June 14, 2019 13:18
-
-
Save dmontagu/9abbeb86fd53556e2c3d9bf8908f81bb to your computer and use it in GitHub Desktop.
FastAPI app with response shape wrapping
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import lru_cache | |
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union | |
from pydantic import BaseModel, create_model | |
from starlette.requests import Request | |
from starlette.responses import JSONResponse | |
from fastapi import FastAPI | |
from fastapi.encoders import jsonable_encoder | |
class Error(BaseModel): | |
kind: str | |
detail: str | |
ErrorsT = List[Error] | |
ContextT = Dict[str, Any] | |
T = TypeVar("T", bound=BaseModel) | |
@lru_cache() | |
def get_standard_response_model(cls: Type[BaseModel]) -> Type[BaseModel]: | |
assert issubclass(cls, BaseModel) | |
return create_model( | |
f"StandardData[{cls.__name__}]", context=(ContextT, ...), errors=(ErrorsT, ...), data=(Optional[cls], None) | |
) | |
class StandardResponse(Generic[T]): | |
def __class_getitem__(cls, item): | |
return get_standard_response_model(item) | |
def __new__(cls, data: Union[T, Type[T]], request: Optional[Request] = None) -> "StandardResponse[T]": | |
if request is not None: | |
context = request.state.context # type: ignore | |
errors = request.state.errors # type: ignore | |
else: | |
context = {} | |
errors = [] | |
# noinspection PyUnusedLocal | |
response_data: Optional[BaseModel] | |
if isinstance(data, BaseModel): | |
response_type = get_standard_response_model(type(data)) | |
response_data = data | |
else: | |
assert issubclass(data, BaseModel) | |
response_type = get_standard_response_model(data) | |
response_data = None | |
# noinspection PyTypeChecker | |
return response_type(context=context, errors=errors, data=response_data) # type: ignore | |
class MyResponse1(BaseModel): | |
text: str | |
class MyResponse2(BaseModel): | |
number: int | |
app = FastAPI() | |
@app.get("/1", response_model=StandardResponse[MyResponse1]) | |
def get_response_1(request: Request) -> StandardResponse[MyResponse1]: | |
add_context(request, "endpoint", "1") | |
response = MyResponse1(text="hello world") | |
return StandardResponse(response, request=request) | |
@app.get("/2", response_model=StandardResponse[MyResponse2]) | |
def get_response_2(request: Request) -> StandardResponse[MyResponse2]: | |
add_context(request, "endpoint", "2") | |
response = MyResponse2(number=42) | |
return StandardResponse(response, request=request) | |
@app.get("/expected-error", response_model=StandardResponse[MyResponse1]) | |
def get_expected_error(request: Request): | |
add_context(request, "endpoint", "expected-error") | |
add_error(request, kind="expected", detail="expected error") | |
return StandardResponse(MyResponse1, request) | |
@app.get("/unexpected-error", response_model=StandardResponse[MyResponse1]) | |
def get_unexpected_error(request: Request): | |
add_context(request, "endpoint", "unexpected-error") | |
add_error(request, kind="expected", detail="expected error") | |
raise RuntimeError("whoops") | |
def add_context(request: Request, key: str, value: Any): | |
request.state.context[key] = value # type: ignore | |
def add_error(request: Request, kind: str, detail: str): | |
request.state.errors.append(Error(kind=kind, detail=detail)) # type: ignore | |
@app.middleware("http") | |
async def context_middleware(request: Request, call_next): | |
request.state.context: Dict[str, Any] = {} # type: ignore | |
request.state.errors: List[Error] = [] # type: ignore | |
return await call_next(request) | |
@app.exception_handler(Exception) | |
async def validation_exception_handler(request, exc): | |
response = JSONResponse( | |
jsonable_encoder( | |
{ | |
"errors": [Error(kind=type(exc).__name__, detail=str(exc))] + request.state.errors, | |
"context": request.state.context, | |
"data": None, | |
} | |
), | |
status_code=500, | |
) | |
return response |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment