Last active
August 17, 2025 03:26
-
-
Save zekka-lotushealth/a70866ab563085f407ee77ff4ae66944 to your computer and use it in GitHub Desktop.
fastapi_mcp -- multiserver support
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from contextlib import asynccontextmanager | |
import json | |
import traceback | |
from urllib.parse import quote | |
import anyio | |
import httpx | |
from typing import AsyncGenerator, Dict, Optional, Any, List, Union | |
from fastapi import FastAPI, Request, APIRouter | |
from fastapi.openapi.utils import get_openapi | |
from mcp import ErrorData, JSONRPCError, JSONRPCRequest | |
from mcp.server.lowlevel.server import Server | |
import mcp.types as types | |
from sse_starlette import EventSourceResponse | |
from starlette.types import Receive, Scope, Send | |
from uuid import UUID, uuid4 | |
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream | |
from fastapi import Response, HTTPException | |
from fastapi.responses import JSONResponse | |
from pydantic import ValidationError | |
from mcp.types import JSONRPCMessage | |
from lotushealth.shared import get_logger | |
logger = get_logger(__name__) | |
def _lower_errors(message: JSONRPCMessage | ValidationError) -> JSONRPCMessage: | |
if isinstance(message, ValidationError): | |
error_data = ErrorData( | |
code=-32700, # Parse error code in JSON-RPC | |
message="Parse error", | |
data={"validation_error": str(message)}, | |
) | |
json_rpc_error = JSONRPCError( | |
jsonrpc="2.0", | |
id="unknown", # We don't know the ID from the invalid request | |
error=error_data, | |
) | |
return JSONRPCMessage(root=json_rpc_error) | |
return message | |
class FastApiProxyingSseTransport(object): | |
def __init__(self, messages_path: str): | |
self._endpoint: str = messages_path | |
self._transport: ProxyingTransport = ProxyingTransport(RedisTransportProxy()) | |
@asynccontextmanager | |
async def connect_sse( | |
self, scope: Scope, receive: Receive, send: Send | |
) -> AsyncGenerator[ | |
tuple[ | |
MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], | |
MemoryObjectSendStream[types.JSONRPCMessage], | |
], | |
None, | |
]: | |
if scope["type"] != "http": | |
logger.error("connect_sse received non-HTTP request") | |
raise ValueError("connect_sse can only handle HTTP requests") | |
logger.debug("Setting up SSE connection") | |
session_id = uuid4() | |
session_uri = f"{quote(self._endpoint)}?session_id={session_id.hex}" | |
async with self._transport.create_local_session(session_id) as local_session: | |
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[ | |
dict[str, object] | |
](0) | |
async def sse_writer(): | |
logger.debug("Starting SSE writer") | |
async with sse_stream_writer: | |
await sse_stream_writer.send( | |
{"event": "endpoint", "data": session_uri} | |
) | |
logger.debug(f"Sent endpoint event: {session_uri}") | |
async for message in local_session.sse_output_source: | |
logger.debug(f"Sending message via SSE: {message}") | |
await sse_stream_writer.send( | |
{ | |
"event": "message", | |
"data": message.model_dump_json( | |
by_alias=True, exclude_none=True | |
), | |
} | |
) | |
async with anyio.create_task_group() as tg: | |
response = EventSourceResponse( | |
content=sse_stream_reader, data_sender_callable=sse_writer | |
) | |
logger.debug("Starting SSE response task") | |
tg.start_soon(response, scope, receive, send) | |
logger.debug("Yielding read and write streams") | |
yield local_session.reader, local_session.writer | |
async def handle_fastapi_post_message(self, request: Request) -> Response: | |
""" | |
A reimplementation of the handle_post_message method of SseServerTransport | |
that integrates better with FastAPI. | |
A few good reasons for doing this: | |
1. Avoid mounting a whole Starlette app and instead use a more FastAPI-native | |
approach. Mounting has some known issues and limitations. | |
2. Avoid re-constructing the scope, receive, and send from the request, as done | |
in the original implementation. | |
3. Use FastAPI's native response handling mechanisms and exception patterns to | |
avoid unexpected rabbit holes. | |
The combination of mounting a whole Starlette app and reconstructing the scope | |
and send from the request proved to be especially error-prone for us when using | |
tracing tools like Sentry, which had destructive effects on the request object | |
when using the original implementation. | |
""" | |
logger.debug("Handling POST message with FastAPI patterns") | |
session_id_param = request.query_params.get("session_id") | |
if session_id_param is None: | |
logger.warning("Received request without session_id") | |
raise HTTPException(status_code=400, detail="session_id is required") | |
try: | |
session_id = UUID(hex=session_id_param) | |
logger.debug(f"Parsed session ID: {session_id}") | |
except ValueError: | |
logger.warning(f"Received invalid session ID: {session_id_param}") | |
raise HTTPException(status_code=400, detail="Invalid session ID") | |
remote_session = await self._transport.fetch_remote_session(session_id) | |
if not remote_session: | |
logger.warning(f"Could not find session for ID: {session_id}") | |
raise HTTPException(status_code=404, detail="Could not find session") | |
body = await request.body() | |
logger.debug(f"Received JSON: {body.decode()}") | |
try: | |
message = JSONRPCMessage.model_validate_json(body) | |
logger.debug(f"Validated client message: {message}") | |
except ValidationError as err: | |
logger.error(f"Failed to parse message: {err}") | |
# Create background task to send error | |
await remote_session.send(_lower_errors(err)) | |
return JSONResponse( | |
content={"error": "Could not parse message"}, status_code=400 | |
) | |
except Exception as e: | |
logger.error(f"Error processing request body: {e}") | |
raise HTTPException(status_code=400, detail="Invalid request body") | |
# Create background task to send message | |
logger.debug("Accepting message, will send in background") | |
await remote_session.send(message) | |
# Return response with background task | |
return JSONResponse(content={"message": "Accepted"}, status_code=202) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from abc import ABC, abstractmethod | |
import asyncio | |
from collections.abc import AsyncGenerator, Awaitable | |
from contextlib import AbstractAsyncContextManager, asynccontextmanager | |
import sys | |
from typing import AsyncContextManager, Callable, override | |
from uuid import UUID | |
import anyio | |
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream | |
from mcp import types | |
from mcp.types import JSONRPCMessage | |
from lotushealth.shared import get_logger, get_redis | |
logger = get_logger(__name__) | |
SESSION_EXPIRY = 3600 | |
class ProxyingTransport(ABC): | |
""" | |
MCP-over-SSE is defined using a "two streams" representation of JSONRPC: | |
- a reader (from client to server) | |
- a writer (from server to client) | |
A server stands between the reader and the writer -- processing reader jobs, | |
getting results, and sending them to the writer. | |
There's nothing wrong with this kind of interface, but there's an inherent | |
mismatch between this interface and the actual SSE interface. The client | |
ultimately sends the server messages by opening new requests. Something | |
needs to move those requests from the client's message to the reader stream. | |
In the original implementation, this was an ad hoc process: | |
- Figure out what the message actually says (parse it) | |
- Find the reader stream in a big table | |
- Send the message directly to the reader stream | |
ProxyingTransport formalizes this process: | |
- It provides a `create_local_session` method that is roughly | |
interface-compatible with `connect_sse`. | |
- It provides a `RemoteSession` type that is roughly interface-compatible | |
with the send() side of the reader stream. | |
However, it delegates the details of the connection to a `TransportProxy` | |
object, which manages (1) tracking advertisements for reader streams and | |
(2) transferring messages meant for reader streams over the network. | |
Of course it's possible to implement the old logic on top of it using a | |
null TransportProxy -- but more importantly, you can also transfer messages | |
over systems like Redis PubSub, allowing the server to be load-balanced. | |
""" | |
def __init__(self, proxy: "AbstractTransportProxy"): | |
""" | |
Create a new ProxyingTransport. | |
""" | |
self._proxy: AbstractTransportProxy = proxy | |
@asynccontextmanager | |
async def create_local_session( | |
self, session_id: UUID | |
) -> AsyncGenerator["LocalSession", None]: | |
""" | |
Create a new LocalSession. | |
This does the anyio-specific parts of getting an SSE connection ready, and replaces | |
the bulk of connect_sse. | |
""" | |
logger.debug("Setting up SSE connection") | |
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage] | |
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage] | |
write_stream: MemoryObjectSendStream[types.JSONRPCMessage] | |
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] | |
read_stream_writer, read_stream = anyio.create_memory_object_stream[ | |
types.JSONRPCMessage | |
](0) | |
write_stream, write_stream_reader = anyio.create_memory_object_stream[ | |
types.JSONRPCMessage | |
](0) | |
try: | |
async with self._link_local_to_remote(session_id, read_stream_writer): | |
yield LocalSession( | |
read_stream, | |
write_stream, | |
write_stream_reader, | |
) | |
finally: | |
read_stream_writer.close() | |
write_stream.close() | |
async def fetch_remote_session(self, session_id: UUID) -> "RemoteSession | None": | |
""" | |
Find the RemoteSession (as advertised by the proxy) for a given session ID. | |
If the session does not exist, return None. | |
""" | |
if not await self._proxy.remote_exists(session_id): | |
return None | |
return RemoteSession(self._proxy, session_id) | |
@asynccontextmanager | |
async def _link_local_to_remote( | |
self, | |
session_id: UUID, | |
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage], | |
) -> AsyncGenerator[None, None]: | |
""" | |
Link the local reader to the proxy. | |
The proxy's read operation will be canceled when the context manager is closed. | |
""" | |
ready = asyncio.Event() | |
async def _receiver() -> None: | |
try: | |
async for i in self._proxy.listen_to_remote(session_id, ready): | |
await read_stream_writer.send(i) | |
finally: | |
ready.set() | |
async with anyio.create_task_group() as tg: | |
tg.start_soon(_receiver) | |
try: | |
_ = await ready.wait() | |
yield | |
finally: | |
tg.cancel_scope.cancel() | |
class AbstractTransportProxy(ABC): | |
""" | |
A TransportProxy manages connections between: | |
- many senders | |
- a "remote," which is the public, advertised representation of a receiver | |
- a single receiver task running on a server somewhere | |
""" | |
@abstractmethod | |
async def remote_exists(self, session_id: UUID) -> bool: | |
""" | |
Check if a remote exists. | |
This must return True if a receiver exists. | |
It can temporarily return True for a while even after the receiver has stopped | |
listening. | |
""" | |
raise NotImplementedError() | |
@abstractmethod | |
async def send_to_remote(self, session_id: UUID, message: JSONRPCMessage) -> None: | |
""" | |
Send a message to the remote session. | |
This can block, and should fail silently if the remote does not exist. | |
(Why fail silently? Because some PubSub systems can't notify a caller of | |
the absence of subscribers.) | |
""" | |
raise NotImplementedError() | |
@abstractmethod | |
def listen_to_remote( | |
self, session_id: UUID, started: asyncio.Event | |
) -> AsyncGenerator[JSONRPCMessage, None]: | |
""" | |
Create a remote and listen on it. | |
The implementation must signal `started` as soon as it is listening. At | |
such a time, remote_exists must return True. | |
It cannot signal `started` any earlier, as doing so would create a race | |
wherein the client might send a message to the remote before it is present. | |
If the generator is collected (GeneratorExit is raised), all resources | |
must be collected. | |
""" | |
raise NotImplementedError() | |
class LocalTransportProxy(AbstractTransportProxy): | |
""" | |
The LocalTransportProxy is a transport proxy that does no additional work. | |
It connects senders on this server to a remote that is a MemoryObjectSendStream, | |
and connects that to a receiver that is just a standard poll on that stream. | |
""" | |
def __init__(self): | |
""" | |
Create a new LocalTransportProxy. | |
""" | |
self._remotes: dict[UUID, MemoryObjectSendStream[JSONRPCMessage]] = {} | |
@override | |
async def remote_exists(self, session_id: UUID) -> bool: | |
""" | |
Return True if a remote exists. (Although the remote is, of course, | |
the same server.) | |
""" | |
return session_id in self._remotes | |
@override | |
async def send_to_remote(self, session_id: UUID, message: JSONRPCMessage) -> None: | |
""" | |
Send this message to the receiver on the same server. | |
This is a blocking send, but it will return instantly if no one is listening. | |
""" | |
stream_writer = self._remotes.get(session_id) | |
if stream_writer is None: | |
return | |
try: | |
await stream_writer.send(message) | |
except anyio.ClosedResourceError: | |
pass | |
@override | |
async def listen_to_remote( | |
self, session_id: UUID, started: asyncio.Event | |
) -> AsyncGenerator[JSONRPCMessage, None]: | |
""" | |
Listen for requests sent to our remote. | |
""" | |
stream_writer, stream_reader = anyio.create_memory_object_stream[ | |
JSONRPCMessage | |
](0) | |
with stream_writer, stream_reader: | |
self._remotes[session_id] = stream_writer | |
try: | |
async for msg in stream_reader: | |
yield msg | |
finally: | |
del self._remotes[session_id] | |
class RedisTransportProxy(AbstractTransportProxy): | |
""" | |
The RedisTransportProxy is a transport proxy that uses Redis to advertise the | |
remotes and PubSub to communicate to them. | |
""" | |
def _remote_key(self, session_id: UUID) -> str: | |
return f"mcpRedisTransportProxy:remote:{session_id}" | |
@override | |
async def remote_exists(self, session_id: UUID) -> bool: | |
""" | |
Check if a remote session exists. | |
The check is done against a Redis key, which is set when the remote session is created. | |
""" | |
redis = get_redis() | |
key = self._remote_key(session_id) | |
return bool(await redis.exists(key)) | |
@override | |
async def send_to_remote(self, session_id: UUID, message: JSONRPCMessage) -> None: | |
""" | |
Send a message to the remote. | |
The message is published to a Redis channel, which is subscribed to by the receiver. | |
""" | |
redis = get_redis() | |
key = self._remote_key(session_id) | |
await redis.publish(key, message.model_dump_json().encode("utf8")) | |
@override | |
async def listen_to_remote( | |
self, session_id: UUID, started: asyncio.Event | |
) -> AsyncGenerator[JSONRPCMessage, None]: | |
""" | |
Listen for messages from the remote. | |
Creates a key in Redis that advertises the existence of the receiver to `remote_session_exists`. | |
""" | |
redis = get_redis() | |
key = self._remote_key(session_id) | |
async with redis.pubsub(ignore_subscribe_messages=True) as pubsub: | |
await pubsub.subscribe(key) | |
await redis.set(key, "1", ex=SESSION_EXPIRY) | |
started.set() | |
async for msg in pubsub.listen(): | |
if msg["type"] != "message": | |
raise AssertionError(f"unexpected message type: {msg['type']}") | |
yield JSONRPCMessage.model_validate_json(msg["data"].decode("utf8")) | |
class LocalSession(object): | |
""" | |
The LocalSession is the interface to two streams: | |
- a reader stream, which contains JSONRPCMessages sent to us by remote services | |
that use a compatible TransportProxy. | |
- a writer stream, which contains JSONRPCMessages that are expected to be sent | |
to the user soon | |
If you have a LocalSession, you presumably got it from a ProxyingTransport, | |
which has therefore wired up the reader for you. It doesn't know anything about | |
the user, so you'll have to wire up the writer yourself (using sse_output_source). | |
""" | |
def __init__( | |
self, | |
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage], | |
write_stream: MemoryObjectSendStream[types.JSONRPCMessage], | |
sse_output_source: MemoryObjectReceiveStream[types.JSONRPCMessage], | |
): | |
self._read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage] = read_stream | |
self._write_stream: MemoryObjectSendStream[types.JSONRPCMessage] = write_stream | |
self._sse_output_source: MemoryObjectReceiveStream[types.JSONRPCMessage] = ( | |
sse_output_source | |
) | |
@property | |
def reader(self) -> MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]: | |
""" | |
This is the reader -- it contains JSONRPCMessages sent to us by the remote. | |
""" | |
return self._read_stream | |
@property | |
def writer(self) -> MemoryObjectSendStream[types.JSONRPCMessage]: | |
""" | |
This is the writer -- write JSONRPCMessages here and the user will eventually | |
see them. | |
""" | |
return self._write_stream | |
@property | |
def sse_output_source(self) -> MemoryObjectReceiveStream[types.JSONRPCMessage]: | |
""" | |
This is the link between the `writer` (which is the place that MCP code | |
writes its JSONRPCMessages) and the SSE output stream (which must be linked | |
in some way to the existing SSE code.) | |
""" | |
return self._sse_output_source | |
class RemoteSession(object): | |
""" | |
This is a handle to a session that is being handled by some other task. | |
It could be handled on this server, or it could be handled on Mars. | |
Its existence is a witness that the session existed at the time this object was | |
created. Its continued existence does not guarantee that the session still exists. | |
""" | |
def __init__( | |
self, | |
proxy: AbstractTransportProxy, | |
session_id: UUID, | |
): | |
""" | |
Create a new RemoteSession. | |
""" | |
self._proxy: AbstractTransportProxy = proxy | |
self._session_id: UUID = session_id | |
async def send(self, message: JSONRPCMessage) -> None: | |
""" | |
Send a message to the remote session. | |
This delegates to the proxy's send_to_remote method. | |
If the remote session no longer exists, this will silently fail. | |
""" | |
await self._proxy.send_to_remote(self._session_id, message) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment