Skip to content

Instantly share code, notes, and snippets.

@dmontagu
Last active June 14, 2019 13:18
Show Gist options
  • Save dmontagu/9abbeb86fd53556e2c3d9bf8908f81bb to your computer and use it in GitHub Desktop.
Save dmontagu/9abbeb86fd53556e2c3d9bf8908f81bb to your computer and use it in GitHub Desktop.
FastAPI app with response shape wrapping
from functools import lru_cache
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union
from pydantic import BaseModel, create_model
from starlette.requests import Request
from starlette.responses import JSONResponse
from fastapi import FastAPI
from fastapi.encoders import jsonable_encoder
class Error(BaseModel):
kind: str
detail: str
ErrorsT = List[Error]
ContextT = Dict[str, Any]
T = TypeVar("T", bound=BaseModel)
@lru_cache()
def get_standard_response_model(cls: Type[BaseModel]) -> Type[BaseModel]:
assert issubclass(cls, BaseModel)
return create_model(
f"StandardData[{cls.__name__}]", context=(ContextT, ...), errors=(ErrorsT, ...), data=(Optional[cls], None)
)
class StandardResponse(Generic[T]):
def __class_getitem__(cls, item):
return get_standard_response_model(item)
def __new__(cls, data: Union[T, Type[T]], request: Optional[Request] = None) -> "StandardResponse[T]":
if request is not None:
context = request.state.context # type: ignore
errors = request.state.errors # type: ignore
else:
context = {}
errors = []
# noinspection PyUnusedLocal
response_data: Optional[BaseModel]
if isinstance(data, BaseModel):
response_type = get_standard_response_model(type(data))
response_data = data
else:
assert issubclass(data, BaseModel)
response_type = get_standard_response_model(data)
response_data = None
# noinspection PyTypeChecker
return response_type(context=context, errors=errors, data=response_data) # type: ignore
class MyResponse1(BaseModel):
text: str
class MyResponse2(BaseModel):
number: int
app = FastAPI()
@app.get("/1", response_model=StandardResponse[MyResponse1])
def get_response_1(request: Request) -> StandardResponse[MyResponse1]:
add_context(request, "endpoint", "1")
response = MyResponse1(text="hello world")
return StandardResponse(response, request=request)
@app.get("/2", response_model=StandardResponse[MyResponse2])
def get_response_2(request: Request) -> StandardResponse[MyResponse2]:
add_context(request, "endpoint", "2")
response = MyResponse2(number=42)
return StandardResponse(response, request=request)
@app.get("/expected-error", response_model=StandardResponse[MyResponse1])
def get_expected_error(request: Request):
add_context(request, "endpoint", "expected-error")
add_error(request, kind="expected", detail="expected error")
return StandardResponse(MyResponse1, request)
@app.get("/unexpected-error", response_model=StandardResponse[MyResponse1])
def get_unexpected_error(request: Request):
add_context(request, "endpoint", "unexpected-error")
add_error(request, kind="expected", detail="expected error")
raise RuntimeError("whoops")
def add_context(request: Request, key: str, value: Any):
request.state.context[key] = value # type: ignore
def add_error(request: Request, kind: str, detail: str):
request.state.errors.append(Error(kind=kind, detail=detail)) # type: ignore
@app.middleware("http")
async def context_middleware(request: Request, call_next):
request.state.context: Dict[str, Any] = {} # type: ignore
request.state.errors: List[Error] = [] # type: ignore
return await call_next(request)
@app.exception_handler(Exception)
async def validation_exception_handler(request, exc):
response = JSONResponse(
jsonable_encoder(
{
"errors": [Error(kind=type(exc).__name__, detail=str(exc))] + request.state.errors,
"context": request.state.context,
"data": None,
}
),
status_code=500,
)
return response
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment