-
-
Save adriangb/b21424afee4b2464399e3592fe86b601 to your computer and use it in GitHub Desktop.
Initial implementation of a Hook system to build middlewares.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import ( | |
Awaitable, | |
Iterable, | |
Mapping, | |
Optional, | |
Protocol, | |
Tuple, | |
TypeVar, | |
Union, | |
cast, | |
) | |
from asgiref.typing import ( | |
ASGI3Application, | |
ASGIReceiveCallable, | |
ASGISendCallable, | |
ASGISendEvent, | |
ASGIReceiveEvent, | |
HTTPResponseBodyEvent, | |
HTTPResponseStartEvent, | |
Scope, | |
) | |
from starlette.datastructures import MutableHeaders | |
class ScopeHook(Protocol): | |
def __call__(self, __scope: Scope) -> Union[Awaitable[None], None]: | |
... | |
class SendHook(Protocol): | |
def __call__( | |
self, __scope: Scope, __message: ASGISendEvent | |
) -> Union[Awaitable[None], None]: | |
... | |
class ReceiveHook(Protocol): | |
def __call__( | |
self, __scope: Scope, __message: ASGIReceiveEvent | |
) -> Union[Awaitable[None], None]: | |
... | |
class HookMiddleware: | |
def __init__( | |
self, | |
app: ASGI3Application, | |
scope_hook: Optional[ScopeHook] = None, | |
send_hook: Optional[SendHook] = None, | |
receive_hook: Optional[ReceiveHook] = None, | |
) -> None: | |
self._app = app | |
self._scope_hook = scope_hook | |
self._send_hook = send_hook | |
self._receive_hook = receive_hook | |
async def __call__( | |
self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable | |
) -> None: | |
async def wrapped_send(message: ASGISendEvent) -> None: | |
if self._send_hook: | |
maybe_awaitable = self._send_hook(scope, message) | |
if maybe_awaitable is not None: | |
await maybe_awaitable | |
await self._app(scope, receive, wrapped_send) | |
def http_response_start_filter(hook: SendHook) -> SendHook: | |
async def wrapped_hook(scope: Scope, message: ASGISendEvent) -> None: | |
if message["type"] == "http.response.start": | |
maybe_aw = hook(scope, message) | |
if maybe_aw is not None: | |
await maybe_aw | |
return wrapped_hook | |
def add_headers( | |
headers: Union[Iterable[Tuple[str, str]], Mapping[str, str]] | |
) -> SendHook: | |
if isinstance(headers, Mapping): | |
headers = cast("Mapping[str, str]", headers) | |
items = [(key, value) for key, value in headers.items()] | |
else: | |
items = [(key, value) for key, value in headers] | |
def wrapped_send(scope: Scope, message: ASGISendEvent) -> None: | |
resp_headers = MutableHeaders(scope=message) # type: ignore | |
for key, value in items: | |
resp_headers.append(key, value) | |
return wrapped_send | |
async def app( | |
scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable | |
) -> None: | |
await send( | |
HTTPResponseStartEvent(type="http.response.start", status=200, headers=[]) | |
) | |
await send( | |
HTTPResponseBodyEvent( | |
type="http.response.body", body=b"Hello, world!", more_body=False | |
) | |
) | |
wrapped_app = HookMiddleware( | |
app=app, | |
send_hook=http_response_start_filter(add_headers({"x-foo": "bar"})), | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment