FastAPI middleware to override Host header value by X-Forwarded-Host header value if it is exists.
FastAPI(Starlette) can make response used by X-Forwarded-Host header for slash-tailed redirection.
| from typing import List, Tuple | |
| from starlette.types import ASGIApp, Receive, Scope, Send | |
| Headers = List[Tuple[bytes, bytes]] | |
| class ForwardedHostMiddleware: | |
| def __init__(self, app: ASGIApp): | |
| self.app = app | |
| async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: | |
| if scope["type"] not in ("http", "websocket"): | |
| await self.app(scope, receive, send) | |
| return | |
| scope["headers"] = self.remap_headers( | |
| scope["headers"], b"host", b"x-forwarded-host" | |
| ) | |
| await self.app(scope, receive, send) | |
| return | |
| def remap_headers(self, src: Headers, before: bytes, after: bytes) -> Headers: | |
| remapped = [] | |
| before_value = None | |
| after_value = None | |
| for header in src: | |
| k, v = header | |
| if k == before: | |
| before_value = v | |
| continue | |
| elif k == after: | |
| after_value = v | |
| continue | |
| remapped.append(header) | |
| if after_value: | |
| remapped.append((before, after_value)) | |
| elif before_value: | |
| remapped.append((before, before_value)) | |
| return remapped |
| """Tests for ``attakei_net.routes.uploads``. | |
| """ | |
| from typing import Tuple | |
| from fastapi import APIRouter, FastAPI | |
| from fastapi.testclient import TestClient | |
| from attakei_net.middleware import ForwardedHostMiddleware | |
| def configure_client() -> Tuple[FastAPI, TestClient]: | |
| app = FastAPI() | |
| app.add_middleware(ForwardedHostMiddleware) | |
| client = TestClient(app) | |
| return app, client | |
| def test_forwarded(): | |
| app, client = configure_client() | |
| router = APIRouter() | |
| router.get("/{path}/")(lambda path: "OK") | |
| app.include_router(router) | |
| resp = client.get( | |
| "/sample", allow_redirects=False, headers={"X-Forwarded-Host": "test2"} | |
| ) | |
| assert resp.status_code == 307 | |
| assert resp.headers["location"].startswith("http://test2/") | |
| def test_no_forwarded(): | |
| app, client = configure_client() | |
| router = APIRouter() | |
| router.get("/{path}/")(lambda path: "OK") | |
| app.include_router(router) | |
| resp = client.get( | |
| "/sample", allow_redirects=False | |
| ) | |
| assert resp.status_code == 307 | |
| assert resp.headers["location"].startswith("http://testserver/") |