Created
January 13, 2023 22:15
-
-
Save a-recknagel/9684999d2f896b2f19ff20cefae1fb85 to your computer and use it in GitHub Desktop.
Batchwise factory for fastapi routes
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Literal | |
from fastapi import FastAPI, Depends | |
from pydantic import BaseModel | |
import uvicorn | |
app = FastAPI() | |
class Request(BaseModel): | |
data: str | |
@app.post("/foo", response_model=str) | |
def foo(r: Request): | |
return f"{r.data}.foo" | |
def batchwise( | |
app: FastAPI, | |
base: str, | |
path=None, | |
docstring=None, | |
reuse_dependencies=True, # doesn't do anyhting yet, just assumed | |
errors: Literal["drop", "strict", "explicit"]="strict" | |
) -> callable: | |
# identify the base | |
if not base.startswith("/"): | |
raise ValueError(f"A route has to start with a leading slash, got {base=}") | |
for r in app.routes: | |
if r.path == base: | |
origin = r | |
break | |
else: | |
raise ValueError(f"{base=} not listed in current routing table.") | |
# define the batch function | |
dependencies = {f"{dep.call.__name__}": dep.call for dep in origin.dependant.dependencies} | |
dep_params, dep_args = "", "" | |
for name in dependencies: | |
param_name = name.lower() if name != name.lower() else f"{name.lower()}_" | |
dep_params += f", {param_name}: {name} = Depends()" | |
dep_args += f", {param_name}" | |
env = {**dependencies, "origin": origin, "Depends": Depends, "docstring": docstring} | |
f_name = f"batchified_{origin.name}" | |
f_code = ( | |
f"def {f_name}(rqs: list[origin.body_field.type_]{dep_params}) -> list[origin.response_model]:\n" | |
# add partial success handling other than returning 500 | |
# add option to not re-use dependencies? | |
f" return [origin.endpoint(rq{dep_args}) for rq in rqs]\n" | |
f"{f_name}.__doc__ = docstring" | |
) | |
exec(f_code, env) | |
batch_endpoint = env[f_name] | |
# register it in the routing table | |
for method in origin.methods: | |
if method != "POST": | |
# would be nice to support GET, but I have no clue about query params | |
raise NotImplementedError("We only support 'post' for now.") | |
path = path if path is not None else f"{origin.path}s" | |
getattr(app, method.lower())(path, dependencies=[*origin.dependencies], response_model=list[origin.response_model])(batch_endpoint) | |
return batch_endpoint | |
batchwise(app, base="/foo", docstring="Helper route.\n\nReduces network traffic.") | |
if __name__ == '__main__': | |
uvicorn.run( | |
app, | |
port=5000, | |
host="0.0.0.0" | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment