import asyncio
from contextlib import asynccontextmanager
import typing
import asyncio
from httpx._models import Request, Response
from httpx._transports.asgi import ASGITransport
from httpx._types import AsyncByteStream


class ASGIResponseByteStream(AsyncByteStream):
    def __init__(
        self, stream: typing.AsyncGenerator[bytes, None]
    ) -> None:
        self._stream = stream

    def __aiter__(self) -> typing.AsyncIterator[bytes]:
        return self._stream.__aiter__()

    async def aclose(self) -> None:
        await self._stream.aclose()


async def patch_handle_async_request(
    self: ASGITransport,
    request: Request,
) -> Response:
    assert isinstance(request.stream, AsyncByteStream)

    # ASGI scope.
    scope = {
        "type": "http",
        "asgi": {"version": "3.0"},
        "http_version": "1.1",
        "method": request.method,
        "headers": [(k.lower(), v) for (k, v) in request.headers.raw],
        "scheme": request.url.scheme,
        "path": request.url.path,
        "raw_path": request.url.raw_path,
        "query_string": request.url.query,
        "server": (request.url.host, request.url.port),
        "client": self.client,
        "root_path": self.root_path,
    }

    # Request.
    request_body_chunks = request.stream.__aiter__()
    request_complete = False

    # Response.
    status_code = None
    response_headers = None
    sentinel = object()
    body_queue = asyncio.Queue()
    response_started = asyncio.Event()
    response_complete = asyncio.Event()

    # ASGI callables.

    async def receive() -> typing.Dict[str, typing.Any]:
        nonlocal request_complete

        if request_complete:
            await response_complete.wait()
            return {"type": "http.disconnect"}

        try:
            body = await request_body_chunks.__anext__()
        except StopAsyncIteration:
            request_complete = True
            return {"type": "http.request", "body": b"", "more_body": False}
        return {"type": "http.request", "body": body, "more_body": True}

    async def send(message: typing.Dict[str, typing.Any]) -> None:
        nonlocal status_code, response_headers, response_started

        if message["type"] == "http.response.start":
            assert not response_started.is_set()
            status_code = message["status"]
            response_headers = message.get("headers", [])
            response_started.set()

        elif message["type"] == "http.response.body":
            assert response_started.is_set()
            assert not response_complete.is_set()
            body = message.get("body", b"")
            more_body = message.get("more_body", False)

            if body and request.method != "HEAD":
                await body_queue.put(body)

            if not more_body:
                await body_queue.put(sentinel)
                response_complete.set()

    async def run_app() -> None:
        try:
            await self.app(scope, receive, send)
        except Exception:  # noqa: PIE-786
            if self.raise_app_exceptions or not response_complete.is_set():
                raise

    async def body_stream() -> typing.AsyncGenerator[bytes, None]:
        while True:
            body = await body_queue.get()
            if body != sentinel:
                yield body
            else:
                return

    asyncio.create_task(run_app())

    await response_started.wait()
    assert status_code is not None
    assert response_headers is not None

    stream = ASGIResponseByteStream(body_stream())
    return Response(status_code, headers=response_headers, stream=stream)


@asynccontextmanager
async def patch_asgi_transport():
    restore = ASGITransport.handle_async_request
    ASGITransport.handle_async_request = patch_handle_async_request
    yield
    ASGITransport.handle_async_request = restore