$ python -m venv venv
$ source venv/bin/activate
$ python -m pip install anyio di pydantic quart
$ python poc.py
Created
April 8, 2022 17:41
-
-
Save joeblackwaslike/cf46a4702c95fd6555b07da54983cf62 to your computer and use it in GitHub Desktop.
Quart Advanced Dependency Injection POC
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Quart Advanced Dependency Injection POC""" | |
import functools | |
import inspect | |
from pprint import pprint | |
from typing import ( | |
Union, | |
Any, | |
Optional, | |
List, | |
Callable, | |
Dict, | |
Mapping, | |
TypeVar, | |
Tuple, | |
AnyStr, | |
) | |
import anyio | |
from di.container import Container, bind_by_type | |
from di.dependant import Dependant, Injectable, Marker | |
from di.executors import AsyncExecutor | |
from di.typing import Annotated | |
from pydantic import BaseModel, Field | |
from quart import Quart, Blueprint, Request, request | |
from quart.wrappers import Request as QuartRequest | |
from werkzeug.datastructures import Authorization | |
T = TypeVar("T") | |
base = Blueprint("base", __name__) | |
# Helpers | |
def get_origin(param): | |
origin = param.annotation.__origin__ | |
while hasattr(origin, "__origin__"): | |
origin = origin.__origin__ | |
return origin | |
# Markers | |
class HeaderParam(Marker): | |
def __init__(self, alias: Optional[str]) -> None: | |
self.alias = alias | |
super().__init__(call=None, scope="request", use_cache=False) | |
def register_parameter(self, param: inspect.Parameter) -> Dependant[Any]: | |
if self.alias is not None: | |
name = self.alias | |
else: | |
name = param.name.replace("_", "-") | |
def get_header(request: Annotated[Request, Marker()]) -> Any: | |
headers = request.headers | |
type_ = param.annotation.__origin__ | |
is_model = inspect.isclass(type_) and issubclass(type_, BaseModel) | |
is_untyped = param.annotation.__origin__ is T | |
if is_model: | |
return type_(**headers) | |
elif name in headers and not is_untyped: | |
return param.annotation(headers[name]) | |
else: | |
return headers | |
return Dependant(get_header, scope="request") | |
class RequestBody(Marker): | |
def __init__(self) -> None: | |
super().__init__(call=None, scope="request", use_cache=False) | |
def register_parameter(self, param: inspect.Parameter) -> Dependant[Any]: | |
async def get_body(request: Annotated[Request, Marker()]) -> Any: | |
body = await request.get_data() | |
origin = get_origin(param) | |
if origin and origin is not T: | |
if isinstance(body, bytes) and origin is str: | |
return body.decode("utf-8") | |
else: | |
return origin(body) | |
else: | |
return body | |
return Dependant(get_body, scope="request") | |
class JsonBody(Marker): | |
def __init__(self) -> None: | |
super().__init__(call=None, scope="request", use_cache=False) | |
def register_parameter(self, param: inspect.Parameter) -> Dependant[Any]: | |
async def get_body(request: Annotated[Request, Marker()]) -> Any: | |
name = param.name.replace("_", "-") | |
props = await request.get_json() | |
type_ = param.annotation.__origin__ | |
is_model = inspect.isclass(type_) and issubclass(type_, BaseModel) | |
is_untyped = param.annotation.__origin__ is T | |
if is_model: | |
return type_(**props) | |
elif name in props and not is_untyped: | |
origin = get_origin(param) | |
return origin(props[name]) | |
else: | |
return props | |
return Dependant(get_body, scope="request") | |
class QueryParam(Marker): | |
def __init__(self) -> None: | |
super().__init__(call=None, scope="request", use_cache=False) | |
def register_parameter(self, param: inspect.Parameter) -> Dependant[Any]: | |
def get_args(request: Annotated[Request, Marker()]) -> Any: | |
name = param.name.replace("_", "-") | |
args = request.args | |
type_ = param.annotation.__origin__ | |
is_model = inspect.isclass(type_) and issubclass(type_, BaseModel) | |
is_untyped = param.annotation.__origin__ is T | |
if is_model: | |
return type_(**args) | |
elif name in args and not is_untyped: | |
return param.annotation(args[name]) | |
else: | |
return args | |
return Dependant(get_args, scope="request") | |
class CookieParam(Marker): | |
def __init__(self) -> None: | |
super().__init__(call=None, scope="request", use_cache=False) | |
def register_parameter(self, param: inspect.Parameter) -> Dependant[Any]: | |
def get_cookies(request: Annotated[Request, Marker()]) -> Any: | |
name = param.name.replace("_", "-") | |
cookies = request.cookies | |
type_ = param.annotation.__origin__ | |
is_model = inspect.isclass(type_) and issubclass(type_, BaseModel) | |
is_untyped = param.annotation.__origin__ is T | |
if is_model: | |
return type_(**cookies) | |
elif name in cookies and not is_untyped: | |
return param.annotation(cookies[name]) | |
else: | |
return cookies | |
return Dependant(get_cookies, scope="request") | |
# FastAPI Style dependencies | |
async def common_parameters(q: Optional[str] = None, skip: int = 0, limit: int = 100): | |
return {"q": q, "skip": skip, "limit": limit} | |
# Annotations | |
FromHeader = Annotated[T, HeaderParam(alias=None)] | |
FromBody = Annotated[T, RequestBody()] | |
FromJson = Annotated[T, JsonBody()] | |
FromQuery = Annotated[T, QueryParam()] | |
FromCookie = Annotated[T, CookieParam()] | |
CommonParams = Annotated[Dict, Marker(common_parameters, scope="request")] | |
# Dependencies | |
class Config(BaseModel): | |
db_uri: str = "sqlite:///:memory:" | |
class Database(Injectable, scope="request"): | |
def __init__(self, config: Config): | |
self.db_uri = config.db_uri | |
self.closed = False | |
print(f"Initializing db {self}") | |
async def execute(self, sql: str) -> None: | |
print(sql) | |
def close(self): | |
if self.closed is True: | |
raise RuntimeError("Database already closed") | |
self.closed = True | |
print(f"Closing db {self}") | |
def __repr__(self): | |
return f"{type(self).__name__}(uri={self.db_uri!r}, closed={self.closed!r})" | |
class Postgres(Database): | |
pass | |
class Application(Quart, Injectable, scope="app"): | |
pass | |
class Request(QuartRequest, Injectable, scope="request"): | |
pass | |
# Schemas | |
class HeadersModel(BaseModel): | |
x_header_one: str = Field(..., alias="x-header-one") | |
x_header_two: int = Field(..., alias="x-header-two") | |
class RequestItem(BaseModel): | |
id: int | |
name: str | |
tags: List[str] | |
class StyleParams(BaseModel): | |
color: str = "default" | |
theme: str = "default" | |
logo: str = "default.png" | |
class AuthCookies(BaseModel): | |
session: str = None | |
csrf: str = None | |
# Endpoints | |
@base.get("/items/") | |
async def read_items( | |
db: Database, | |
x_header_one: FromHeader[str], | |
header_two_val: Annotated[int, HeaderParam(alias="x-header-two")], | |
headers_model: FromHeader[HeadersModel], | |
headers: FromHeader, | |
commons: CommonParams, | |
body: FromJson, | |
item: FromJson[RequestItem], | |
name: FromJson[str], | |
tags: FromJson[List[str]], | |
style_params: FromQuery[StyleParams], | |
color: FromQuery[str], | |
query: FromQuery, | |
session: FromCookie[str], | |
auth: FromCookie[AuthCookies], | |
cookies: FromCookie, | |
raw_body: FromBody, | |
bytes_body: FromBody[bytes], | |
str_body: FromBody[str], | |
): | |
return dict( | |
commons=commons, | |
db=db, | |
x_header_one=x_header_one, | |
header_two_val=header_two_val, | |
headers_model=headers_model, | |
headers=headers, | |
body=body, | |
item=item, | |
name=name, | |
tags=tags, | |
style_params=style_params, | |
color=color, | |
query=query, | |
session=session, | |
auth=auth, | |
cookies=cookies, | |
raw_body=raw_body, | |
bytes_body=bytes_body, | |
str_body=str_body, | |
) | |
# Framework takes in a view and request, resolves and injects dependencies, and returns the response | |
async def framework( | |
view: Callable, | |
request: Request, | |
app: Optional[Application] = None, | |
container: Optional[Container] = None, | |
executor: Optional[AsyncExecutor] = None, | |
values: Optional[Mapping[str, Any]] = None, | |
): | |
app = app or Application(__name__) | |
container = container or Container() | |
executor = executor or AsyncExecutor() | |
values = values or {} | |
values = {**values, Request: request, Application: app} | |
binds = [ | |
bind_by_type(Dependant(Postgres, scope="request"), Database), | |
] | |
for bind in binds: | |
container.bind(bind) | |
with container.bind(bind_by_type(Dependant(lambda: request, scope="request"), Request)): | |
solved = container.solve( | |
Dependant(view, scope="request"), | |
scopes=["request", "app"], | |
) | |
async with container.enter_scope("app") as app_state: | |
async with container.enter_scope("request", state=app_state) as req_state: | |
result = await container.execute_async( | |
solved, | |
executor=executor, | |
values=values, | |
state=req_state, | |
) | |
return result, solved, solved.get_flat_subdependants() | |
# Test request generator helpers | |
async def generate_request( | |
app: Optional[Application] = None, | |
path: str = "/", | |
method: str = "GET", | |
headers: Optional[Dict[str, Any]] = None, | |
query_string: Optional[Dict[str, Any]] = None, | |
scheme: str = "http", | |
data: Optional[AnyStr] = None, | |
json: Optional[Dict[str, Any]] = None, | |
root_path: str = "", | |
http_version: str = "1.1", | |
scope_base: Optional[Dict[str, Any]] = None, | |
auth: Optional[Union[Authorization, Tuple[str, str]]] = None, | |
): | |
app = app or Application(__name__) | |
async with app.app_context(): | |
async with app.test_request_context( | |
path=path, | |
method=method, | |
headers=headers, | |
query_string=query_string, | |
json=json, | |
scheme=scheme, | |
data=data, | |
root_path=root_path, | |
http_version=http_version, | |
scope_base=scope_base, | |
auth=auth, | |
): | |
req = Request( | |
method=request.method, | |
scheme=request.scheme, | |
path=request.path, | |
query_string=request.query_string, | |
headers=request.headers, | |
root_path=request.root_path, | |
http_version=request.http_version, | |
scope=request.scope, | |
max_content_length=request.max_content_length, | |
body_timeout=request.body_timeout, | |
send_push_promise=request.send_push_promise, | |
) | |
req.body = request.body | |
return req | |
# Main | |
app = Application(__name__) | |
req = anyio.run( | |
functools.partial( | |
generate_request, | |
app=app, | |
path="/v1/endpoint", | |
method="POST", | |
headers={ | |
"x-header-one": "one", | |
"x-header-two": "2", | |
"cookie": "session=123; csrf=adxfl3ils", | |
}, | |
query_string=dict( | |
q="hello", | |
limit=10, | |
page=1, | |
offset=10, | |
home=False, | |
color="ffffff", | |
theme="dark", | |
logo="logo.png", | |
), | |
json=dict(id=1, name="Joe", tags=["a", "b", "c"]), | |
) | |
) | |
values = {"q": "some query?"} | |
result, solved, dependencies = anyio.run( | |
functools.partial( | |
framework, | |
view=read_items, | |
values=values, | |
request=req, | |
app=app, | |
) | |
) | |
pprint(result) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{'auth': AuthCookies(session='123', csrf='adxfl3ils'), | |
'body': {'id': 1, 'name': 'Joe', 'tags': ['a', 'b', 'c']}, | |
'bytes_body': b'{"id": 1, "name": "Joe", "tags": ["a", "b", "c"]}', | |
'color': 'ffffff', | |
'commons': {'limit': 100, 'q': None, 'skip': 0}, | |
'cookies': ImmutableMultiDict([('session', '123'), ('csrf', 'adxfl3ils')]), | |
'db': Postgres(uri='sqlite:///:memory:', closed=False), | |
'header_two_val': 2, | |
'headers': Headers([('x-header-one', 'one'), ('x-header-two', '2'), ('cookie', 'session=123; csrf=adxfl3ils'), ('User-Agent', 'Quart'), ('host', 'localhost'), ('Content-Type', 'application/json')]), | |
'headers_model': HeadersModel(x_header_one='one', x_header_two=2), | |
'item': RequestItem(id=1, name='Joe', tags=['a', 'b', 'c']), | |
'name': 'Joe', | |
'query': {'color': 'ffffff', | |
'home': 'False', | |
'limit': '10', | |
'logo': 'logo.png', | |
'offset': '10', | |
'page': '1', | |
'q': 'hello', | |
'theme': 'dark'}, | |
'raw_body': b'{"id": 1, "name": "Joe", "tags": ["a", "b", "c"]}', | |
'session': '123', | |
'str_body': '{"id": 1, "name": "Joe", "tags": ["a", "b", "c"]}', | |
'style_params': StyleParams(color='ffffff', theme='dark', logo='logo.png'), | |
'tags': ['a', 'b', 'c'], | |
'x_header_one': 'one'} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment