Skip to content

Instantly share code, notes, and snippets.

@zekka-lotushealth
Last active August 17, 2025 03:26
Show Gist options
  • Save zekka-lotushealth/a70866ab563085f407ee77ff4ae66944 to your computer and use it in GitHub Desktop.
Save zekka-lotushealth/a70866ab563085f407ee77ff4ae66944 to your computer and use it in GitHub Desktop.
fastapi_mcp -- multiserver support
from contextlib import asynccontextmanager
import json
import traceback
from urllib.parse import quote
import anyio
import httpx
from typing import AsyncGenerator, Dict, Optional, Any, List, Union
from fastapi import FastAPI, Request, APIRouter
from fastapi.openapi.utils import get_openapi
from mcp import ErrorData, JSONRPCError, JSONRPCRequest
from mcp.server.lowlevel.server import Server
import mcp.types as types
from sse_starlette import EventSourceResponse
from starlette.types import Receive, Scope, Send
from uuid import UUID, uuid4
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from fastapi import Response, HTTPException
from fastapi.responses import JSONResponse
from pydantic import ValidationError
from mcp.types import JSONRPCMessage
from lotushealth.shared import get_logger
logger = get_logger(__name__)
def _lower_errors(message: JSONRPCMessage | ValidationError) -> JSONRPCMessage:
if isinstance(message, ValidationError):
error_data = ErrorData(
code=-32700, # Parse error code in JSON-RPC
message="Parse error",
data={"validation_error": str(message)},
)
json_rpc_error = JSONRPCError(
jsonrpc="2.0",
id="unknown", # We don't know the ID from the invalid request
error=error_data,
)
return JSONRPCMessage(root=json_rpc_error)
return message
class FastApiProxyingSseTransport(object):
def __init__(self, messages_path: str):
self._endpoint: str = messages_path
self._transport: ProxyingTransport = ProxyingTransport(RedisTransportProxy())
@asynccontextmanager
async def connect_sse(
self, scope: Scope, receive: Receive, send: Send
) -> AsyncGenerator[
tuple[
MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
MemoryObjectSendStream[types.JSONRPCMessage],
],
None,
]:
if scope["type"] != "http":
logger.error("connect_sse received non-HTTP request")
raise ValueError("connect_sse can only handle HTTP requests")
logger.debug("Setting up SSE connection")
session_id = uuid4()
session_uri = f"{quote(self._endpoint)}?session_id={session_id.hex}"
async with self._transport.create_local_session(session_id) as local_session:
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[
dict[str, object]
](0)
async def sse_writer():
logger.debug("Starting SSE writer")
async with sse_stream_writer:
await sse_stream_writer.send(
{"event": "endpoint", "data": session_uri}
)
logger.debug(f"Sent endpoint event: {session_uri}")
async for message in local_session.sse_output_source:
logger.debug(f"Sending message via SSE: {message}")
await sse_stream_writer.send(
{
"event": "message",
"data": message.model_dump_json(
by_alias=True, exclude_none=True
),
}
)
async with anyio.create_task_group() as tg:
response = EventSourceResponse(
content=sse_stream_reader, data_sender_callable=sse_writer
)
logger.debug("Starting SSE response task")
tg.start_soon(response, scope, receive, send)
logger.debug("Yielding read and write streams")
yield local_session.reader, local_session.writer
async def handle_fastapi_post_message(self, request: Request) -> Response:
"""
A reimplementation of the handle_post_message method of SseServerTransport
that integrates better with FastAPI.
A few good reasons for doing this:
1. Avoid mounting a whole Starlette app and instead use a more FastAPI-native
approach. Mounting has some known issues and limitations.
2. Avoid re-constructing the scope, receive, and send from the request, as done
in the original implementation.
3. Use FastAPI's native response handling mechanisms and exception patterns to
avoid unexpected rabbit holes.
The combination of mounting a whole Starlette app and reconstructing the scope
and send from the request proved to be especially error-prone for us when using
tracing tools like Sentry, which had destructive effects on the request object
when using the original implementation.
"""
logger.debug("Handling POST message with FastAPI patterns")
session_id_param = request.query_params.get("session_id")
if session_id_param is None:
logger.warning("Received request without session_id")
raise HTTPException(status_code=400, detail="session_id is required")
try:
session_id = UUID(hex=session_id_param)
logger.debug(f"Parsed session ID: {session_id}")
except ValueError:
logger.warning(f"Received invalid session ID: {session_id_param}")
raise HTTPException(status_code=400, detail="Invalid session ID")
remote_session = await self._transport.fetch_remote_session(session_id)
if not remote_session:
logger.warning(f"Could not find session for ID: {session_id}")
raise HTTPException(status_code=404, detail="Could not find session")
body = await request.body()
logger.debug(f"Received JSON: {body.decode()}")
try:
message = JSONRPCMessage.model_validate_json(body)
logger.debug(f"Validated client message: {message}")
except ValidationError as err:
logger.error(f"Failed to parse message: {err}")
# Create background task to send error
await remote_session.send(_lower_errors(err))
return JSONResponse(
content={"error": "Could not parse message"}, status_code=400
)
except Exception as e:
logger.error(f"Error processing request body: {e}")
raise HTTPException(status_code=400, detail="Invalid request body")
# Create background task to send message
logger.debug("Accepting message, will send in background")
await remote_session.send(message)
# Return response with background task
return JSONResponse(content={"message": "Accepted"}, status_code=202)
from abc import ABC, abstractmethod
import asyncio
from collections.abc import AsyncGenerator, Awaitable
from contextlib import AbstractAsyncContextManager, asynccontextmanager
import sys
from typing import AsyncContextManager, Callable, override
from uuid import UUID
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from mcp import types
from mcp.types import JSONRPCMessage
from lotushealth.shared import get_logger, get_redis
logger = get_logger(__name__)
SESSION_EXPIRY = 3600
class ProxyingTransport(ABC):
"""
MCP-over-SSE is defined using a "two streams" representation of JSONRPC:
- a reader (from client to server)
- a writer (from server to client)
A server stands between the reader and the writer -- processing reader jobs,
getting results, and sending them to the writer.
There's nothing wrong with this kind of interface, but there's an inherent
mismatch between this interface and the actual SSE interface. The client
ultimately sends the server messages by opening new requests. Something
needs to move those requests from the client's message to the reader stream.
In the original implementation, this was an ad hoc process:
- Figure out what the message actually says (parse it)
- Find the reader stream in a big table
- Send the message directly to the reader stream
ProxyingTransport formalizes this process:
- It provides a `create_local_session` method that is roughly
interface-compatible with `connect_sse`.
- It provides a `RemoteSession` type that is roughly interface-compatible
with the send() side of the reader stream.
However, it delegates the details of the connection to a `TransportProxy`
object, which manages (1) tracking advertisements for reader streams and
(2) transferring messages meant for reader streams over the network.
Of course it's possible to implement the old logic on top of it using a
null TransportProxy -- but more importantly, you can also transfer messages
over systems like Redis PubSub, allowing the server to be load-balanced.
"""
def __init__(self, proxy: "AbstractTransportProxy"):
"""
Create a new ProxyingTransport.
"""
self._proxy: AbstractTransportProxy = proxy
@asynccontextmanager
async def create_local_session(
self, session_id: UUID
) -> AsyncGenerator["LocalSession", None]:
"""
Create a new LocalSession.
This does the anyio-specific parts of getting an SSE connection ready, and replaces
the bulk of connect_sse.
"""
logger.debug("Setting up SSE connection")
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream[
types.JSONRPCMessage
](0)
write_stream, write_stream_reader = anyio.create_memory_object_stream[
types.JSONRPCMessage
](0)
try:
async with self._link_local_to_remote(session_id, read_stream_writer):
yield LocalSession(
read_stream,
write_stream,
write_stream_reader,
)
finally:
read_stream_writer.close()
write_stream.close()
async def fetch_remote_session(self, session_id: UUID) -> "RemoteSession | None":
"""
Find the RemoteSession (as advertised by the proxy) for a given session ID.
If the session does not exist, return None.
"""
if not await self._proxy.remote_exists(session_id):
return None
return RemoteSession(self._proxy, session_id)
@asynccontextmanager
async def _link_local_to_remote(
self,
session_id: UUID,
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage],
) -> AsyncGenerator[None, None]:
"""
Link the local reader to the proxy.
The proxy's read operation will be canceled when the context manager is closed.
"""
ready = asyncio.Event()
async def _receiver() -> None:
try:
async for i in self._proxy.listen_to_remote(session_id, ready):
await read_stream_writer.send(i)
finally:
ready.set()
async with anyio.create_task_group() as tg:
tg.start_soon(_receiver)
try:
_ = await ready.wait()
yield
finally:
tg.cancel_scope.cancel()
class AbstractTransportProxy(ABC):
"""
A TransportProxy manages connections between:
- many senders
- a "remote," which is the public, advertised representation of a receiver
- a single receiver task running on a server somewhere
"""
@abstractmethod
async def remote_exists(self, session_id: UUID) -> bool:
"""
Check if a remote exists.
This must return True if a receiver exists.
It can temporarily return True for a while even after the receiver has stopped
listening.
"""
raise NotImplementedError()
@abstractmethod
async def send_to_remote(self, session_id: UUID, message: JSONRPCMessage) -> None:
"""
Send a message to the remote session.
This can block, and should fail silently if the remote does not exist.
(Why fail silently? Because some PubSub systems can't notify a caller of
the absence of subscribers.)
"""
raise NotImplementedError()
@abstractmethod
def listen_to_remote(
self, session_id: UUID, started: asyncio.Event
) -> AsyncGenerator[JSONRPCMessage, None]:
"""
Create a remote and listen on it.
The implementation must signal `started` as soon as it is listening. At
such a time, remote_exists must return True.
It cannot signal `started` any earlier, as doing so would create a race
wherein the client might send a message to the remote before it is present.
If the generator is collected (GeneratorExit is raised), all resources
must be collected.
"""
raise NotImplementedError()
class LocalTransportProxy(AbstractTransportProxy):
"""
The LocalTransportProxy is a transport proxy that does no additional work.
It connects senders on this server to a remote that is a MemoryObjectSendStream,
and connects that to a receiver that is just a standard poll on that stream.
"""
def __init__(self):
"""
Create a new LocalTransportProxy.
"""
self._remotes: dict[UUID, MemoryObjectSendStream[JSONRPCMessage]] = {}
@override
async def remote_exists(self, session_id: UUID) -> bool:
"""
Return True if a remote exists. (Although the remote is, of course,
the same server.)
"""
return session_id in self._remotes
@override
async def send_to_remote(self, session_id: UUID, message: JSONRPCMessage) -> None:
"""
Send this message to the receiver on the same server.
This is a blocking send, but it will return instantly if no one is listening.
"""
stream_writer = self._remotes.get(session_id)
if stream_writer is None:
return
try:
await stream_writer.send(message)
except anyio.ClosedResourceError:
pass
@override
async def listen_to_remote(
self, session_id: UUID, started: asyncio.Event
) -> AsyncGenerator[JSONRPCMessage, None]:
"""
Listen for requests sent to our remote.
"""
stream_writer, stream_reader = anyio.create_memory_object_stream[
JSONRPCMessage
](0)
with stream_writer, stream_reader:
self._remotes[session_id] = stream_writer
try:
async for msg in stream_reader:
yield msg
finally:
del self._remotes[session_id]
class RedisTransportProxy(AbstractTransportProxy):
"""
The RedisTransportProxy is a transport proxy that uses Redis to advertise the
remotes and PubSub to communicate to them.
"""
def _remote_key(self, session_id: UUID) -> str:
return f"mcpRedisTransportProxy:remote:{session_id}"
@override
async def remote_exists(self, session_id: UUID) -> bool:
"""
Check if a remote session exists.
The check is done against a Redis key, which is set when the remote session is created.
"""
redis = get_redis()
key = self._remote_key(session_id)
return bool(await redis.exists(key))
@override
async def send_to_remote(self, session_id: UUID, message: JSONRPCMessage) -> None:
"""
Send a message to the remote.
The message is published to a Redis channel, which is subscribed to by the receiver.
"""
redis = get_redis()
key = self._remote_key(session_id)
await redis.publish(key, message.model_dump_json().encode("utf8"))
@override
async def listen_to_remote(
self, session_id: UUID, started: asyncio.Event
) -> AsyncGenerator[JSONRPCMessage, None]:
"""
Listen for messages from the remote.
Creates a key in Redis that advertises the existence of the receiver to `remote_session_exists`.
"""
redis = get_redis()
key = self._remote_key(session_id)
async with redis.pubsub(ignore_subscribe_messages=True) as pubsub:
await pubsub.subscribe(key)
await redis.set(key, "1", ex=SESSION_EXPIRY)
started.set()
async for msg in pubsub.listen():
if msg["type"] != "message":
raise AssertionError(f"unexpected message type: {msg['type']}")
yield JSONRPCMessage.model_validate_json(msg["data"].decode("utf8"))
class LocalSession(object):
"""
The LocalSession is the interface to two streams:
- a reader stream, which contains JSONRPCMessages sent to us by remote services
that use a compatible TransportProxy.
- a writer stream, which contains JSONRPCMessages that are expected to be sent
to the user soon
If you have a LocalSession, you presumably got it from a ProxyingTransport,
which has therefore wired up the reader for you. It doesn't know anything about
the user, so you'll have to wire up the writer yourself (using sse_output_source).
"""
def __init__(
self,
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage],
write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
sse_output_source: MemoryObjectReceiveStream[types.JSONRPCMessage],
):
self._read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage] = read_stream
self._write_stream: MemoryObjectSendStream[types.JSONRPCMessage] = write_stream
self._sse_output_source: MemoryObjectReceiveStream[types.JSONRPCMessage] = (
sse_output_source
)
@property
def reader(self) -> MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]:
"""
This is the reader -- it contains JSONRPCMessages sent to us by the remote.
"""
return self._read_stream
@property
def writer(self) -> MemoryObjectSendStream[types.JSONRPCMessage]:
"""
This is the writer -- write JSONRPCMessages here and the user will eventually
see them.
"""
return self._write_stream
@property
def sse_output_source(self) -> MemoryObjectReceiveStream[types.JSONRPCMessage]:
"""
This is the link between the `writer` (which is the place that MCP code
writes its JSONRPCMessages) and the SSE output stream (which must be linked
in some way to the existing SSE code.)
"""
return self._sse_output_source
class RemoteSession(object):
"""
This is a handle to a session that is being handled by some other task.
It could be handled on this server, or it could be handled on Mars.
Its existence is a witness that the session existed at the time this object was
created. Its continued existence does not guarantee that the session still exists.
"""
def __init__(
self,
proxy: AbstractTransportProxy,
session_id: UUID,
):
"""
Create a new RemoteSession.
"""
self._proxy: AbstractTransportProxy = proxy
self._session_id: UUID = session_id
async def send(self, message: JSONRPCMessage) -> None:
"""
Send a message to the remote session.
This delegates to the proxy's send_to_remote method.
If the remote session no longer exists, this will silently fail.
"""
await self._proxy.send_to_remote(self._session_id, message)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment