-
-
Save ochafik/f0a105300f1d0bc007b21769a07f9381 to your computer and use it in GitHub Desktop.
MCP Proxy
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import { afterEach, describe, it, expect } from "bun:test"; | |
| import winston from "winston"; | |
| import z from "zod"; | |
| import { Client } from "@modelcontextprotocol/sdk/client/index.js"; | |
| import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js"; | |
| import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; | |
| import { CallToolRequest, ClientCapabilities, ElicitRequest, ElicitRequestSchema, PingRequestSchema } from "@modelcontextprotocol/sdk/types.js"; | |
| import { createProxyServer } from "./proxy.js"; | |
| import { CleanupFunction } from "../shared/cleanup-utils.js"; | |
| import { getPort, registerStreamableHttpMcpServer, ServerFactory, setupExpressServer } from "../shared/mcp-utils.js"; | |
| const logger = winston.createLogger({ | |
| level: 'debug', | |
| format: winston.format.combine( | |
| winston.format.colorize(), | |
| winston.format.timestamp(), | |
| winston.format.errors({ stack: true }), | |
| winston.format.simple(), | |
| ), | |
| transports: [ | |
| new winston.transports.Console({}), | |
| ], | |
| }); | |
| export type Config = { | |
| port: number; | |
| proxiedServerEndpoint?: URL; | |
| mcpPath: string; | |
| } | |
| describe('Proxy server', () => { | |
| const cleanups: CleanupFunction[] = []; | |
| afterEach(async () => { | |
| for (const cleanup of cleanups) { | |
| try { | |
| await cleanup(); | |
| } catch (error) { | |
| console.error('Error during cleanup:', error); | |
| } | |
| } | |
| cleanups.length = 0; // Clear the cleanups array | |
| }); | |
| async function setupTargetServer(serverName: string, createServer: ServerFactory<McpServer>) { | |
| const { close, server } = await setupExpressServer({ | |
| logger, | |
| port: 0, // Use a random port | |
| callback: async (app, addCleanup) => { | |
| const { cleanup } = registerStreamableHttpMcpServer(serverName, app, { | |
| logger, | |
| path: '/mcp', | |
| createServer, | |
| }); | |
| addCleanup(cleanup); | |
| }, | |
| }); | |
| cleanups.push(close); | |
| const targetEndpoint = new URL(`http://localhost:${getPort(server)}/mcp`); | |
| console.log(`Target server running: ${targetEndpoint}`); | |
| return {targetEndpoint}; | |
| } | |
| async function createClient(endpoint: URL, capabilities: ClientCapabilities, cb?: (client: Client) => Promise<void>) { | |
| const client = new Client({ | |
| name: 'target-client', | |
| version: '1.0.0', | |
| }, { | |
| capabilities, | |
| }); | |
| client.setRequestHandler(PingRequestSchema, async (_, extra) => { | |
| logger.info(`Client received MCP ping request from server: ${extra.requestId}`); | |
| return {}; | |
| }); | |
| if (cb) { | |
| await cb(client); | |
| } | |
| await client.connect(new StreamableHTTPClientTransport(endpoint)); | |
| return client; | |
| } | |
| const EchoInputSchema = z.object({ | |
| message: z.string(), | |
| }); | |
| const EchoOutputSchema = z.object({ | |
| echo: z.string(), | |
| }); | |
| async function setupEchoMcpServer() { | |
| return await setupTargetServer('echo-server', async _initialize => { | |
| const server = new McpServer({ | |
| name: 'echo-server', | |
| version: '1.0.0', | |
| }, { | |
| instructions: "Echoes messages", | |
| // capabilities: { | |
| // elicitation: {}, | |
| // } | |
| }); | |
| server.registerTool( | |
| 'echo', | |
| { | |
| inputSchema: EchoInputSchema.shape, | |
| outputSchema: EchoOutputSchema.shape, | |
| }, | |
| async ({ message }) => ({ | |
| // content: [{ type: 'text', text: `Echo: ${message}` }], | |
| content: [], | |
| structuredContent: { | |
| echo: `Echo: ${message}`, | |
| }, | |
| }) | |
| ); | |
| return server; | |
| }); | |
| } | |
| // async function setupClassifierServer(classifyMcpTool: (args: ClassifierInput) => Promise<ClassifierOutput>) { | |
| // classifyMcpTool ??= async _inputs => { | |
| // return { | |
| // decision: 'allow', | |
| // }; | |
| // }; | |
| // return setupTargetServer('classifier-server', async _initialize => { | |
| // const server = new McpServer({ | |
| // name: 'classifier-server', | |
| // version: '1.0.0', | |
| // }, { | |
| // capabilities: { | |
| // }, | |
| // instructions: "Echoes messages", | |
| // }); | |
| // server.registerTool( | |
| // 'classify_mcp_tool', | |
| // { | |
| // inputSchema: ClassifierInputSchema.shape, | |
| // outputSchema: ClassifierOutputSchema.shape, | |
| // }, | |
| // async (args: ClassifierInput, _extra) => { | |
| // return <ClassifierOutput>{ | |
| // content: [], | |
| // structuredContent: await classifyMcpTool(args), | |
| // }; | |
| // } | |
| // ); | |
| // return server; | |
| // }); | |
| // } | |
| async function setupProxyServer(proxiedServerEndpoint: URL) { | |
| const mcpPath = '/test-mcp'; | |
| const { close, netServer, elicitInput } = await createProxyServer( | |
| { | |
| type: 'http', | |
| port: 0, | |
| host: 'localhost', | |
| path: mcpPath, | |
| }, | |
| { | |
| type: 'http', | |
| endpoint: proxiedServerEndpoint, | |
| }, | |
| logger, | |
| ); | |
| cleanups.push(close); | |
| if (!netServer) { | |
| throw new Error('Proxy server did not create a net server'); | |
| } | |
| const proxyEndpoint = new URL(`http://localhost:${getPort(netServer)}${mcpPath}`); | |
| console.log(`Proxy server running: ${proxyEndpoint}`); | |
| return { proxyEndpoint, netServer, elicitInput }; | |
| } | |
| it('should proxy tool calls target server', async () => { | |
| const { targetEndpoint } = await setupEchoMcpServer(); | |
| const { proxyEndpoint } = await setupProxyServer(targetEndpoint); | |
| const targetClient = await createClient(targetEndpoint, {}); | |
| const proxyClient = await createClient(proxyEndpoint, {}); | |
| const targetTools = await targetClient.listTools(); | |
| const proxiedTools = await proxyClient.listTools(); | |
| expect(proxiedTools).toEqual(targetTools); | |
| const params: CallToolRequest['params'] = {name: 'echo', arguments: { message: 'Hello, world!' }}; | |
| const targetResult = await targetClient.callTool(params); | |
| const proxyResult = await proxyClient.callTool(params); | |
| expect(proxyResult).toEqual(targetResult); | |
| }); | |
| it('should send elicitations to client', async () => { | |
| const { targetEndpoint } = await setupEchoMcpServer(); | |
| const { proxyEndpoint, elicitInput } = await setupProxyServer(targetEndpoint); | |
| console.log(`Proxy endpoint: ${proxyEndpoint}`); | |
| const elicitRequests: ElicitRequest[] = []; | |
| const proxyClient = await createClient(proxyEndpoint, { elicitation: {} }, async client => { | |
| client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { | |
| logger.info(`Received elicit request: ${JSON.stringify(request)}, extra: ${JSON.stringify(extra)}`); | |
| elicitRequests.push(request); | |
| return { | |
| action: 'accept', | |
| content: { | |
| testField: 'yay', | |
| } | |
| } | |
| }); | |
| }); | |
| const elicitParams: ElicitRequest['params'] = { | |
| message: 'Test elicit input', | |
| requestedSchema: { | |
| 'type': 'object', | |
| 'properties': { | |
| 'testField': { | |
| 'type': 'string', | |
| }, | |
| }, | |
| required: ['testField'], | |
| }, | |
| }; | |
| console.log(`Elicit params: ${JSON.stringify(elicitParams)}`); | |
| const result = await elicitInput(elicitParams); | |
| console.log(`Elicit result: ${JSON.stringify(result)}`); | |
| expect(result as any).toBe({ testField: 'yay' }); | |
| expect(elicitRequests).toHaveLength(1); | |
| expect(elicitRequests[0].params.message).toBe('Test elicit input'); | |
| expect(elicitRequests[0].params.requestedSchema).toEqual(elicitParams.requestedSchema); | |
| }); | |
| // it('should call the classifier w/ tool calls', async () => { | |
| // const classificationInputs: ClassifierInput[] = []; | |
| // const proxiedServerEndpoint = await setupEchoMcpServer(); | |
| // const classifierEndpoint = await setupClassifierServer(async inputs => { | |
| // classificationInputs.push(inputs); | |
| // return { | |
| // decision: 'allow', | |
| // }; | |
| // }); | |
| // const { proxyEndpoint } = await setupProxyServer({proxiedServerEndpoint, classifierEndpoint}); | |
| // const targetClient = await createClient(proxiedServerEndpoint, { elicitation: false }); | |
| // const proxyClient = await createClient(proxyEndpoint, { elicitation: false }); | |
| // const params: CallToolRequest['params'] = {name: 'echo', arguments: { message: 'Hello, world!' }}; | |
| // const targetTools = await targetClient.listTools(); | |
| // const proxyTools = await proxyClient.listTools(); | |
| // expect(proxyTools).toEqual(targetTools); | |
| // expect(proxyTools.tools).toHaveLength(1); | |
| // expect(proxyTools.tools[0].name).toBe('echo'); | |
| // // expect(proxyTools.tools[0].inputSchema).toEqual(EchoInputSchema.shape); | |
| // // expect(proxyTools.tools[0].outputSchema).toEqual(EchoOutputSchema.shape); | |
| // const targetResult = await targetClient.callTool(params); | |
| // const proxyResult = await proxyClient.callTool(params); | |
| // expect(proxyResult).toEqual(targetResult); | |
| // expect(classificationInputs).toHaveLength(1); | |
| // expect(classificationInputs[0].request.name).toEqual('echo'); | |
| // expect(classificationInputs[0].request.arguments).toEqual({ message: 'Hello, world!' }); | |
| // // expect(classificationInputs[0]).toEqual({ | |
| // // name: 'echo', | |
| // // tool: { | |
| // // name: 'echo', | |
| // // inputSchema: EchoInputSchema.shape, | |
| // // outputSchema: EchoOutputSchema.shape, | |
| // // }, | |
| // // arguments: { message: 'Hello, world!' }, | |
| // // }); | |
| // }); | |
| }); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import net from "net"; | |
| import { Client } from "@modelcontextprotocol/sdk/client/index.js"; | |
| import { Server } from "@modelcontextprotocol/sdk/server/index.js"; | |
| import { | |
| ListPromptsRequestSchema, | |
| GetPromptRequestSchema, | |
| ListToolsRequestSchema, | |
| CallToolRequestSchema, | |
| CallToolResultSchema, | |
| CompatibilityCallToolResultSchema, | |
| ListResourcesRequestSchema, | |
| ListResourceTemplatesRequestSchema, | |
| ReadResourceRequestSchema, | |
| ElicitRequestSchema, | |
| CreateMessageRequest, | |
| PingRequestSchema, | |
| LoggingLevelSchema, | |
| SetLevelRequestSchema, | |
| LoggingMessageNotificationSchema, | |
| CompleteRequestSchema, | |
| ResourceUpdatedNotificationSchema, | |
| ResourceListChangedNotificationSchema, | |
| PromptListChangedNotificationSchema, | |
| CreateMessageResultSchema, | |
| CreateMessageRequestSchema, | |
| ListRootsRequestSchema, | |
| RootsListChangedNotificationSchema, | |
| JSONRPCMessageSchema, | |
| } from "@modelcontextprotocol/sdk/types.js"; | |
| import { RequestOptions } from '@modelcontextprotocol/sdk/shared/protocol.js'; | |
| import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js'; | |
| import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js'; | |
| import { ElicitRequest, ElicitResult, InitializeRequest, isJSONRPCRequest, JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js'; | |
| import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'; | |
| import { registerStreamableHttpMcpServer, SessionData, setupExpressServer } from "../shared/mcp-utils.js"; | |
| import winston from "winston"; | |
| import { ca } from "zod/v4/locales"; | |
| export type OuterMcpHttpConfig = | |
| { | |
| type: 'http', | |
| host: string, | |
| port: number, | |
| path: string | |
| }; | |
| export type InnerMcpHttpConfig = | |
| { | |
| type: 'http', | |
| endpoint: URL, | |
| }; | |
| export type InnerMcpStdioConfig = | |
| { | |
| type: 'stdio', | |
| command: string, | |
| args: string[], | |
| }; | |
| export type OuterMcpStdioConfig = | |
| { | |
| type: 'stdio', | |
| }; | |
| export type OuterMcpConfig = OuterMcpStdioConfig | OuterMcpHttpConfig; | |
| export type InnerMcpConfig = InnerMcpStdioConfig | InnerMcpHttpConfig; | |
| function readLine() { | |
| return new Promise<string>((resolve) => { | |
| let line = '' | |
| const callback = () => { | |
| let chunk: string; | |
| while ((chunk = process.stdin.read(1)) !== null) { | |
| const chunkStr = chunk.toString(); | |
| line += chunkStr; | |
| if (chunkStr === '\n') { | |
| if (line === 'conda activate py311\n') { | |
| line = ''; | |
| continue; // Ignore this line, it's just noise from vscode. | |
| } | |
| process.stdin.removeListener('readable', callback); | |
| resolve(line); | |
| break; | |
| } | |
| } | |
| }; | |
| process.stdin.on('readable', callback); | |
| }); | |
| } | |
| export async function createProxyServer(outer: OuterMcpConfig, inner: InnerMcpConfig, logger: winston.Logger) : Promise<{ | |
| netServer?: net.Server | |
| elicitInput: (params: ElicitRequest['params'], options?: RequestOptions) => Promise<ElicitResult>, | |
| close: () => Promise<void>, | |
| }> { | |
| async function createServer(initialize: InitializeRequest['params']): Promise<Server> { | |
| const clientCapabilities = initialize.capabilities; | |
| const client = new Client({ | |
| name: initialize.clientInfo.name, | |
| title: initialize.clientInfo.title, | |
| version: initialize.clientInfo.version, | |
| }, { | |
| capabilities: { | |
| // Copy capabilities explicitly to avoid new ones creeping in. | |
| sampling: clientCapabilities.sampling ? {} : undefined, | |
| elicitation: clientCapabilities.elicitation ? {} : undefined, // Enable elicitation on proxy server if client supports it. | |
| roots: clientCapabilities.roots ? {} : undefined, | |
| }, | |
| }); | |
| try { | |
| if (inner.type === 'stdio') { | |
| await client.connect(new StdioClientTransport({command: inner.command, args: inner.args})); | |
| } else { | |
| await client.connect(new StreamableHTTPClientTransport(inner.endpoint)); | |
| } | |
| } catch (error) { | |
| logger.error('Error connecting to inner MCP server:'); | |
| logger.error(inner); | |
| logger.error(error); | |
| throw new Error(`Could not connect to inner MCP server: ${error}`); | |
| } | |
| const serverVersion = client.getServerVersion(); | |
| const serverCapabilities = client.getServerCapabilities(); | |
| const server = new Server({ | |
| name: serverVersion?.name ?? '?', | |
| title: serverVersion?.title, | |
| version: serverVersion?.version ?? '?', | |
| }, { | |
| capabilities: { | |
| // Copy capabilities explicitly to avoid new ones creeping in. | |
| completions: serverCapabilities?.completions, | |
| elicitation: serverCapabilities?.elicitation, | |
| logging: serverCapabilities?.logging ? {} : undefined, | |
| prompts: serverCapabilities?.prompts ? { | |
| listChanged: serverCapabilities.prompts.listChanged ? true : undefined, | |
| } : undefined, | |
| resources: serverCapabilities?.resources ? { | |
| subscribe: serverCapabilities.resources.subscribe ? true : undefined, | |
| listChanged: serverCapabilities.resources.listChanged ? true : undefined, | |
| } : undefined, | |
| tools: serverCapabilities?.tools ? { | |
| listChanged: serverCapabilities.tools.listChanged ? true : undefined, | |
| } : undefined, | |
| }, | |
| instructions: client.getInstructions(), | |
| }); | |
| // Wire ping handlers. | |
| server.setRequestHandler(PingRequestSchema, async (_request, extra) => { | |
| return client.ping(); // TODO: remap related ids: {relatedRequestId: extra.requestId}); | |
| }); | |
| client.setRequestHandler(PingRequestSchema, async (_request, _extra) => { | |
| return server.ping(); | |
| }); | |
| if (serverCapabilities?.prompts) { | |
| server.setRequestHandler(ListPromptsRequestSchema, (request, extra) => { | |
| return client.listPrompts(request.params, extra); | |
| }); | |
| server.setRequestHandler(GetPromptRequestSchema, (request, extra) => { | |
| return client.getPrompt(request.params, extra); | |
| }); | |
| if (serverCapabilities.prompts?.listChanged) { | |
| client.setNotificationHandler(PromptListChangedNotificationSchema, _ => { | |
| return server.sendPromptListChanged(); | |
| }); | |
| } | |
| } | |
| if (serverCapabilities?.tools) { | |
| server.setRequestHandler(ListToolsRequestSchema, (request, extra) => { | |
| return client.listTools(request.params, extra); | |
| }); | |
| server.setRequestHandler(CallToolRequestSchema, (request, extra) => { | |
| return client.callTool(request.params, CallToolResultSchema, extra); | |
| }); | |
| } | |
| if (serverCapabilities?.logging) { | |
| server.setRequestHandler(SetLevelRequestSchema, (request, extra) => { | |
| return client.setLoggingLevel(request.params.level, extra); | |
| }); | |
| client.setNotificationHandler(LoggingMessageNotificationSchema, (notification) => { | |
| return server.sendLoggingMessage(notification.params); | |
| }); | |
| } | |
| if (serverCapabilities?.completions) { | |
| server.setRequestHandler(CompleteRequestSchema, (request, extra) => { | |
| return client.complete(request.params, extra); | |
| }); | |
| } | |
| if (serverCapabilities?.resources) { | |
| server.setRequestHandler(ListResourcesRequestSchema, (request, extra) => { | |
| return client.listResources(request.params, extra); | |
| }); | |
| server.setRequestHandler(ListResourceTemplatesRequestSchema, (request, extra) => { | |
| return client.listResourceTemplates(request.params, extra); | |
| }); | |
| server.setRequestHandler(ReadResourceRequestSchema, (request, extra) => { | |
| return client.readResource(request.params, extra); | |
| }); | |
| if (serverCapabilities.resources?.update) { | |
| client.setNotificationHandler(ResourceUpdatedNotificationSchema, notification => { | |
| return server.sendResourceUpdated(notification.params); | |
| }); | |
| } | |
| if (serverCapabilities.resources?.listChanged) { | |
| client.setNotificationHandler(ResourceListChangedNotificationSchema, _ => { | |
| return server.sendResourceListChanged(); | |
| }); | |
| } | |
| } | |
| if (clientCapabilities.elicitation) { | |
| client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { | |
| return server.elicitInput(request.params, extra); | |
| }); | |
| } | |
| if (clientCapabilities.sampling) { | |
| client.setRequestHandler(CreateMessageRequestSchema, async (request, extra) => { | |
| return server.createMessage(request.params, extra); | |
| }); | |
| } | |
| if (clientCapabilities.roots) { | |
| client.setRequestHandler(ListRootsRequestSchema, async (request, extra) => { | |
| return server.listRoots(request.params, extra); | |
| }); | |
| // if (clientCapabilities.roots?.listChanged) { | |
| // client.setNotificationHandler(RootsListChangedNotificationSchema, _ => { | |
| // return server.sendRootsListChanged(); | |
| // }); | |
| // } | |
| } | |
| return server; | |
| } | |
| return new Promise(async (resolve, reject) => { | |
| try { | |
| if (outer.type === 'stdio') { | |
| const line = await readLine(); | |
| let message: JSONRPCMessage; | |
| try { | |
| message = JSONRPCMessageSchema.parse(JSON.parse(line)); | |
| } catch (error) { | |
| reject(new Error(`Could not parse initial message: ${error}: ${line}`)); | |
| return; | |
| } | |
| if (!isJSONRPCRequest(message) || message.method !== 'initialize') { | |
| reject(new Error(`Expected initialize message, got: ${JSON.stringify(message)}`)); | |
| return; | |
| } | |
| const initialize = message as any as InitializeRequest; | |
| const server = await createServer(initialize.params); | |
| const outerTransport = new StdioServerTransport(); | |
| await server.connect(outerTransport); | |
| if (!outerTransport.onmessage) { | |
| throw new Error('transport.onmessage should have been defined by the server'); | |
| } | |
| outerTransport.onmessage(message); | |
| resolve({ | |
| elicitInput: async (params: ElicitRequest['params'], options?: RequestOptions) => { | |
| return server.elicitInput(params, options); | |
| }, | |
| close: () => server.close(), | |
| }); | |
| } else { | |
| const transports = new Map<string, SessionData<Server>>(); | |
| const {server: netServer, close} = await setupExpressServer({ | |
| port: outer.port, | |
| host: outer.host, | |
| callback: async (app, addCleanup) => { | |
| const {cleanup} = await registerStreamableHttpMcpServer('proxy', app, { | |
| path: outer.path, | |
| transports, | |
| createServer, | |
| logger, | |
| }); | |
| addCleanup(cleanup); | |
| }, | |
| logger, | |
| }); | |
| resolve({ | |
| netServer, | |
| elicitInput: async (params: ElicitRequest['params'], options?: RequestOptions) => { | |
| const promises = Array.from(transports.values()) | |
| .filter(d => d.server.getClientCapabilities()?.elicitation) | |
| .map(d => { | |
| return d.server.elicitInput(params, options); | |
| }); | |
| if (promises.length === 0) { | |
| return { | |
| action: 'cancel', | |
| } | |
| } else { | |
| return await Promise.any(promises); | |
| } | |
| }, | |
| close, | |
| }); | |
| } | |
| } catch (error) { | |
| reject(error); | |
| } | |
| }); | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment