-
-
Save ochafik/42c9f4a3a472ec31e630a82dafe13371 to your computer and use it in GitHub Desktop.
MCP proxy options for antechamber
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import { afterEach, describe, it, expect } from "bun:test"; | |
| import winston from "winston"; | |
| import z from "zod"; | |
| import { Client } from "@modelcontextprotocol/sdk/client/index.js"; | |
| import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js"; | |
| import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; | |
| import { CallToolRequest, ClientCapabilities, ElicitRequest, ElicitRequestSchema, ElicitResult, PingRequestSchema, CompleteRequestSchema, ListRootsRequestSchema, ResourceListChangedNotificationSchema, PromptListChangedNotificationSchema, ToolListChangedNotificationSchema } from "@modelcontextprotocol/sdk/types.js"; | |
| import { createProxy, createStdioProxy } from "./proxy_client_server.js"; | |
| import { CleanupFunction } from "../shared/cleanup-utils.js"; | |
| import { getPort, registerStreamableHttpMcpServer, ServerFactory, setupExpressServer } from "../shared/mcp-utils.js"; | |
| import { Stream } from "winston/lib/winston/transports/index.js"; | |
| import { RequestOptions } from "@modelcontextprotocol/sdk/shared/protocol.js"; | |
| const logger = winston.createLogger({ | |
| level: 'debug', | |
| format: winston.format.combine( | |
| winston.format.colorize(), | |
| winston.format.timestamp(), | |
| winston.format.errors({ stack: true }), | |
| winston.format.simple(), | |
| ), | |
| transports: [ | |
| new winston.transports.Console({}), | |
| ], | |
| }); | |
| export type Config = { | |
| port: number; | |
| proxiedServerEndpoint?: URL; | |
| mcpPath: string; | |
| } | |
| describe('Proxy server', () => { | |
| const cleanups: CleanupFunction[] = []; | |
| afterEach(async () => { | |
| for (const cleanup of cleanups) { | |
| try { | |
| await cleanup(); | |
| } catch (error) { | |
| console.error('Error during cleanup:', error); | |
| } | |
| } | |
| cleanups.length = 0; // Clear the cleanups array | |
| }); | |
| async function setupTargetServer(serverName: string, createServer: ServerFactory<McpServer>) { | |
| const { close, server } = await setupExpressServer({ | |
| logger, | |
| port: 0, // Use a random port | |
| callback: async (app, addCleanup) => { | |
| const { cleanup } = registerStreamableHttpMcpServer(serverName, app, { | |
| logger, | |
| path: '/mcp', | |
| createServer, | |
| }); | |
| addCleanup(cleanup); | |
| }, | |
| }); | |
| cleanups.push(close); | |
| const targetEndpoint = new URL(`http://localhost:${getPort(server)}/mcp`); | |
| console.log(`Target server running: ${targetEndpoint}`); | |
| return {targetEndpoint}; | |
| } | |
| async function createClient(endpoint: URL, capabilities: ClientCapabilities, cb?: (client: Client) => Promise<void>) { | |
| const client = new Client({ | |
| name: 'target-client', | |
| version: '1.0.0', | |
| }, { | |
| capabilities, | |
| }); | |
| client.setRequestHandler(PingRequestSchema, async (_, extra) => { | |
| logger.info(`Client received MCP ping request from server: ${extra.requestId}`); | |
| return {}; | |
| }); | |
| if (cb) { | |
| await cb(client); | |
| } | |
| await client.connect(new StreamableHTTPClientTransport(endpoint)); | |
| return client; | |
| } | |
| const EchoInputSchema = z.object({ | |
| message: z.string(), | |
| }); | |
| const EchoOutputSchema = z.object({ | |
| echo: z.string(), | |
| }); | |
| interface MockServerConfig { | |
| enableTools?: boolean; | |
| enablePrompts?: boolean; | |
| enableResources?: boolean; | |
| enableElicitation?: boolean; | |
| enableCompletion?: boolean; | |
| enableLogging?: boolean; | |
| enableSampling?: boolean; | |
| enableRoots?: boolean; | |
| customTools?: Array<{ | |
| name: string; | |
| description?: string; | |
| inputSchema?: any; | |
| outputSchema?: any; | |
| handler: (args: any) => Promise<any>; | |
| }>; | |
| } | |
| async function setupMockMcpServer(config?: MockServerConfig) { | |
| const { | |
| enableTools = true, | |
| enablePrompts = true, | |
| enableResources = true, | |
| enableElicitation = true, | |
| enableCompletion = true, | |
| enableLogging = true, | |
| enableSampling = false, | |
| enableRoots = false, | |
| customTools = [], | |
| } = config || {}; | |
| return await setupTargetServer('mock-server', async _initialize => { | |
| const server = new McpServer({ | |
| name: 'mock-server', | |
| version: '1.0.0', | |
| }, { | |
| instructions: "Mock MCP server for testing", | |
| capabilities: { | |
| elicitation: enableElicitation ? {} : undefined, | |
| completions: enableCompletion ? {} : undefined, | |
| logging: enableLogging ? {} : undefined, | |
| prompts: enablePrompts ? { listChanged: true } : undefined, | |
| resources: enableResources ? { | |
| subscribe: true, | |
| listChanged: true | |
| } : undefined, | |
| tools: enableTools ? { listChanged: true } : undefined, | |
| sampling: enableSampling ? {} : undefined, | |
| }, | |
| }); | |
| // Register tools | |
| if (enableTools) { | |
| // Echo tool (existing) | |
| server.registerTool( | |
| 'echo', | |
| { | |
| inputSchema: EchoInputSchema.shape, | |
| outputSchema: EchoOutputSchema.shape, | |
| }, | |
| async ({ message }) => ({ | |
| content: [], | |
| structuredContent: { | |
| echo: `Echo: ${message}`, | |
| }, | |
| }) | |
| ); | |
| // Mock tool that returns configurable responses | |
| server.registerTool( | |
| 'mockTool', | |
| { | |
| description: 'A mock tool that returns configurable responses', | |
| inputSchema: { | |
| action: z.string(), | |
| data: z.any().optional(), | |
| }, | |
| outputSchema: { | |
| result: z.string(), | |
| metadata: z.record(z.any()).optional(), | |
| }, | |
| }, | |
| async ({ action, data }) => ({ | |
| content: [], | |
| structuredContent: { | |
| result: `Mock action: ${action}`, | |
| metadata: data ? { receivedData: data } : undefined, | |
| }, | |
| }) | |
| ); | |
| // Special elicitation trigger tool | |
| if (enableElicitation) { | |
| server.registerTool( | |
| 'triggerElicitation', | |
| { | |
| description: 'Triggers elicitation with any payload', | |
| inputSchema: { | |
| message: z.string(), | |
| requestedSchema: z.any(), | |
| }, | |
| outputSchema: { | |
| elicitationResult: z.any(), | |
| }, | |
| }, | |
| async ({ message, requestedSchema }, extra) => { | |
| try { | |
| const result = await server.server.elicitInput({ | |
| message, | |
| requestedSchema: requestedSchema as any, | |
| }, extra); | |
| return { | |
| content: [], | |
| structuredContent: { | |
| elicitationResult: result, | |
| }, | |
| }; | |
| } catch (error) { | |
| return { | |
| content: [{ type: 'text', text: `Elicitation failed: ${error}` }], | |
| structuredContent: { | |
| elicitationResult: { error: String(error) }, | |
| }, | |
| }; | |
| } | |
| } | |
| ); | |
| } | |
| // Register custom tools | |
| for (const tool of customTools) { | |
| server.registerTool( | |
| tool.name, | |
| { | |
| description: tool.description, | |
| inputSchema: tool.inputSchema, | |
| outputSchema: tool.outputSchema, | |
| }, | |
| tool.handler | |
| ); | |
| } | |
| } | |
| // Register prompts | |
| if (enablePrompts) { | |
| // Simple prompt without arguments | |
| server.registerPrompt( | |
| 'mockPrompt', | |
| { | |
| description: 'A mock prompt that returns a template message', | |
| }, | |
| async () => ({ | |
| messages: [ | |
| { | |
| role: 'user' as const, | |
| content: { | |
| type: 'text' as const, | |
| text: 'This is a mock prompt response', | |
| }, | |
| }, | |
| ], | |
| }) | |
| ); | |
| // Prompt with arguments | |
| server.registerPrompt( | |
| 'mockPromptWithArgs', | |
| { | |
| description: 'A mock prompt that accepts arguments', | |
| argsSchema: { | |
| name: z.string(), | |
| value: z.string().optional(), | |
| }, | |
| }, | |
| async ({ name, value }) => ({ | |
| messages: [ | |
| { | |
| role: 'user' as const, | |
| content: { | |
| type: 'text' as const, | |
| text: `Mock prompt for ${name}${value !== undefined ? ` with value ${value}` : ''}`, | |
| }, | |
| }, | |
| ], | |
| }) | |
| ); | |
| } | |
| // Register resources | |
| if (enableResources) { | |
| // Static resource | |
| server.registerResource( | |
| 'mockStaticResource', | |
| 'mock://static', | |
| { | |
| description: 'A static mock resource', | |
| mimeType: 'application/json', | |
| }, | |
| async () => ({ | |
| contents: [ | |
| { | |
| uri: 'mock://static', | |
| mimeType: 'application/json', | |
| text: JSON.stringify({ type: 'static', data: 'mock data' }), | |
| }, | |
| ], | |
| }) | |
| ); | |
| // Resource template | |
| const { ResourceTemplate } = await import('@modelcontextprotocol/sdk/server/mcp.js'); | |
| server.registerResource( | |
| 'mockResourceTemplate', | |
| new ResourceTemplate('mock://{type}/{id}', { | |
| list: async () => ({ | |
| resources: [ | |
| { uri: 'mock://user/1', name: 'User 1', mimeType: 'application/json' }, | |
| { uri: 'mock://user/2', name: 'User 2', mimeType: 'application/json' }, | |
| { uri: 'mock://item/a', name: 'Item A', mimeType: 'application/json' }, | |
| { uri: 'mock://item/b', name: 'Item B', mimeType: 'application/json' }, | |
| ] | |
| }), | |
| complete: { | |
| type: async () => ['user', 'item'], | |
| id: async (_value, context) => { | |
| const type = context?.arguments?.type; | |
| if (type === 'user') return ['1', '2']; | |
| if (type === 'item') return ['a', 'b']; | |
| return []; | |
| }, | |
| }, | |
| }), | |
| { | |
| description: 'A mock resource template with variables', | |
| mimeType: 'application/json', | |
| }, | |
| async (uri, arguments_) => { | |
| const { type, id } = arguments_ as { type: string; id: string }; | |
| return { | |
| contents: [ | |
| { | |
| uri: `mock://${type}/${id}`, | |
| mimeType: 'application/json', | |
| text: JSON.stringify({ type, id, data: `Mock ${type} ${id}` }), | |
| }, | |
| ], | |
| }; | |
| } | |
| ); | |
| } | |
| // Register completion handler | |
| if (enableCompletion) { | |
| server.server.setRequestHandler(CompleteRequestSchema, async ({ ref, prompt }) => ({ | |
| completion: { | |
| values: [`Completed: ${prompt}`], | |
| total: 1, | |
| hasMore: false, | |
| } | |
| })); | |
| } | |
| // Register logging handler | |
| if (enableLogging) { | |
| const { SetLevelRequestSchema } = await import('@modelcontextprotocol/sdk/types.js'); | |
| server.server.setRequestHandler(SetLevelRequestSchema, async () => ({ | |
| // Logging level set successfully | |
| })); | |
| } | |
| // Register roots handler | |
| if (enableRoots) { | |
| server.server.setRequestHandler(ListRootsRequestSchema, async () => ({ | |
| roots: [ | |
| { uri: 'file:///mock/root1', name: 'Root 1' }, | |
| { uri: 'file:///mock/root2', name: 'Root 2' }, | |
| ] | |
| })); | |
| // Add triggerListRoots tool for testing | |
| // Note: Unlike elicitInput, listRoots is a client->server request, | |
| // so we can't call it from the server. Instead, we return the same | |
| // data that the ListRootsRequestSchema handler returns. | |
| server.registerTool( | |
| 'triggerListRoots', | |
| { | |
| description: 'Returns the roots data', | |
| outputSchema: { | |
| roots: z.array(z.object({ | |
| uri: z.string(), | |
| name: z.string().optional(), | |
| })) | |
| }, | |
| }, | |
| async () => { | |
| // Return the same data that the ListRootsRequestSchema handler returns | |
| return { | |
| content: [], | |
| structuredContent: { | |
| roots: [ | |
| { uri: 'file:///mock/root1', name: 'Root 1' }, | |
| { uri: 'file:///mock/root2', name: 'Root 2' }, | |
| ] | |
| }, | |
| }; | |
| } | |
| ); | |
| } | |
| return server; | |
| }); | |
| } | |
| // Keep backward compatibility | |
| async function setupEchoMcpServer() { | |
| return setupMockMcpServer({ | |
| enablePrompts: false, | |
| enableResources: false, | |
| enableElicitation: false, | |
| enableCompletion: false, | |
| enableLogging: false, | |
| enableTools: true, | |
| }); | |
| } | |
| // async function setupClassifierServer(classifyMcpTool: (args: ClassifierInput) => Promise<ClassifierOutput>) { | |
| // classifyMcpTool ??= async _inputs => { | |
| // return { | |
| // decision: 'allow', | |
| // }; | |
| // }; | |
| // return setupTargetServer('classifier-server', async _initialize => { | |
| // const server = new McpServer({ | |
| // name: 'classifier-server', | |
| // version: '1.0.0', | |
| // }, { | |
| // capabilities: { | |
| // }, | |
| // instructions: "Echoes messages", | |
| // }); | |
| // server.registerTool( | |
| // 'classify_mcp_tool', | |
| // { | |
| // inputSchema: ClassifierInputSchema.shape, | |
| // outputSchema: ClassifierOutputSchema.shape, | |
| // }, | |
| // async (args: ClassifierInput, _extra) => { | |
| // return <ClassifierOutput>{ | |
| // content: [], | |
| // structuredContent: await classifyMcpTool(args), | |
| // }; | |
| // } | |
| // ); | |
| // return server; | |
| // }); | |
| // } | |
| async function setupProxyServer(proxiedServerEndpoint: URL) { | |
| const mcpPath = '/test-mcp'; | |
| // let elicitInput: ElicitInputFunction | undefined; | |
| // const elicitInput = new Promise<ElicitInputFunction>(resolve => { | |
| const {server: netServer, close} = await setupExpressServer({ | |
| port: 0, | |
| callback: async (app, addCleanup) => { | |
| const {cleanup} = registerStreamableHttpMcpServer('proxy', app, { | |
| path: mcpPath, | |
| createServer: async initialize => { | |
| const mcpServer = await createProxy(new StreamableHTTPClientTransport(proxiedServerEndpoint), initialize, logger); | |
| // elicitInput = (params, options) => mcpServer.elicitInput(params, options); | |
| return mcpServer; | |
| }, | |
| logger, | |
| }); | |
| addCleanup(cleanup); | |
| }, | |
| logger, | |
| }); | |
| cleanups.push(close); | |
| if (!netServer) { | |
| throw new Error('Proxy server did not create a net server'); | |
| } | |
| const proxyEndpoint = new URL(`http://localhost:${getPort(netServer)}${mcpPath}`); | |
| console.log(`Proxy server running: ${proxyEndpoint}`); | |
| return { proxyEndpoint, netServer }; | |
| // return { proxyEndpoint, netServer, elicitInput }; | |
| } | |
| it('should proxy tool calls target server', async () => { | |
| const { targetEndpoint } = await setupEchoMcpServer(); | |
| const { proxyEndpoint } = await setupProxyServer(targetEndpoint); | |
| const targetClient = await createClient(targetEndpoint, {}); | |
| const proxyClient = await createClient(proxyEndpoint, {}); | |
| const targetTools = await targetClient.listTools(); | |
| const proxiedTools = await proxyClient.listTools(); | |
| expect(proxiedTools).toEqual(targetTools); | |
| const params: CallToolRequest['params'] = {name: 'echo', arguments: { message: 'Hello, world!' }}; | |
| const targetResult = await targetClient.callTool(params); | |
| const proxyResult = await proxyClient.callTool(params); | |
| expect(proxyResult).toEqual(targetResult); | |
| }); | |
| it('should expose all mock server capabilities', async () => { | |
| const { targetEndpoint } = await setupMockMcpServer(); | |
| const client = await createClient(targetEndpoint, {}); | |
| // Test tools | |
| const tools = await client.listTools(); | |
| expect(tools.tools.map(t => t.name)).toContain('echo'); | |
| expect(tools.tools.map(t => t.name)).toContain('mockTool'); | |
| expect(tools.tools.map(t => t.name)).toContain('triggerElicitation'); | |
| // Test echo tool | |
| const echoResult = await client.callTool({ | |
| name: 'echo', | |
| arguments: { message: 'test' } | |
| }); | |
| expect(echoResult.structuredContent).toEqual({ echo: 'Echo: test' }); | |
| // Test mock tool | |
| const mockResult = await client.callTool({ | |
| name: 'mockTool', | |
| arguments: { action: 'test-action', data: { key: 'value' } } | |
| }); | |
| expect((mockResult.structuredContent as any)?.result).toBe('Mock action: test-action'); | |
| expect((mockResult.structuredContent as any)?.metadata).toEqual({ receivedData: { key: 'value' } }); | |
| // Test prompts | |
| const prompts = await client.listPrompts(); | |
| expect(prompts.prompts.map(p => p.name)).toContain('mockPrompt'); | |
| expect(prompts.prompts.map(p => p.name)).toContain('mockPromptWithArgs'); | |
| // Test simple prompt | |
| const promptResult = await client.getPrompt({ name: 'mockPrompt' }); | |
| expect(promptResult.messages).toHaveLength(1); | |
| expect(promptResult.messages[0].content).toEqual({ | |
| type: 'text', | |
| text: 'This is a mock prompt response' | |
| }); | |
| // Test prompt with args | |
| const promptWithArgsResult = await client.getPrompt({ | |
| name: 'mockPromptWithArgs', | |
| arguments: { name: 'John', value: '42' } | |
| }); | |
| expect(promptWithArgsResult.messages[0].content).toEqual({ | |
| type: 'text', | |
| text: 'Mock prompt for John with value 42' | |
| }); | |
| // Test resources | |
| const resources = await client.listResources(); | |
| expect(resources.resources.map(r => r.name)).toContain('mockStaticResource'); | |
| expect(resources.resources.some(r => r.name.includes('User'))).toBe(true); | |
| // Test static resource | |
| const staticResource = await client.readResource({ uri: 'mock://static' }); | |
| expect(staticResource.contents).toHaveLength(1); | |
| const content = JSON.parse(staticResource.contents[0].text as string || '{}'); | |
| expect(content).toEqual({ type: 'static', data: 'mock data' }); | |
| // Test resource template | |
| const templateResource = await client.readResource({ uri: 'mock://user/1' }); | |
| expect(templateResource.contents).toHaveLength(1); | |
| const templateContent = JSON.parse(templateResource.contents[0].text as string || '{}'); | |
| expect(templateContent).toEqual({ type: 'user', id: '1', data: 'Mock user 1' }); | |
| }); | |
| it('should proxy listResourceTemplates', async () => { | |
| const { targetEndpoint } = await setupMockMcpServer(); | |
| const { proxyEndpoint } = await setupProxyServer(targetEndpoint); | |
| const targetClient = await createClient(targetEndpoint, {}); | |
| const proxyClient = await createClient(proxyEndpoint, {}); | |
| const targetTemplates = await targetClient.listResourceTemplates(); | |
| const proxiedTemplates = await proxyClient.listResourceTemplates(); | |
| expect(proxiedTemplates).toEqual(targetTemplates); | |
| expect(proxiedTemplates.resourceTemplates).toHaveLength(1); | |
| expect(proxiedTemplates.resourceTemplates[0].uriTemplate).toBe('mock://{type}/{id}'); | |
| }); | |
| it('should proxy ping requests bidirectionally', async () => { | |
| const { targetEndpoint } = await setupMockMcpServer(); | |
| const { proxyEndpoint } = await setupProxyServer(targetEndpoint); | |
| const proxyClient = await createClient(proxyEndpoint, {}); | |
| // Test ping from client to server through proxy | |
| const pingResult = await proxyClient.ping(); | |
| expect(pingResult).toEqual({}); | |
| }); | |
| it('should proxy complete requests', async () => { | |
| const { targetEndpoint } = await setupMockMcpServer({ enableCompletion: true }); | |
| const { proxyEndpoint } = await setupProxyServer(targetEndpoint); | |
| const proxyClient = await createClient(proxyEndpoint, {}); | |
| const completeResult = await proxyClient.complete({ | |
| ref: { type: 'ref/prompt' as const, name: 'test' }, | |
| argument: { name: 'arg1', value: 'val1' }, | |
| }); | |
| expect(completeResult.completion?.values).toHaveLength(1); | |
| expect(completeResult.completion?.values?.[0]).toContain('Completed:'); | |
| }); | |
| it('should handle roots listing through proxy', async () => { | |
| const { targetEndpoint } = await setupMockMcpServer({ enableRoots: true }); | |
| const { proxyEndpoint } = await setupProxyServer(targetEndpoint); | |
| const proxyClient = await createClient(proxyEndpoint, { roots: {} }); | |
| // Call the triggerListRoots tool to test roots listing | |
| const result = await proxyClient.callTool({ | |
| name: 'triggerListRoots', | |
| arguments: {} | |
| }); | |
| expect(result.structuredContent).toEqual({ | |
| roots: [ | |
| { uri: 'file:///mock/root1', name: 'Root 1' }, | |
| { uri: 'file:///mock/root2', name: 'Root 2' }, | |
| ] | |
| }); | |
| }); | |
| it('should proxy logging levels', async () => { | |
| const { targetEndpoint } = await setupMockMcpServer({ enableLogging: true }); | |
| const { proxyEndpoint } = await setupProxyServer(targetEndpoint); | |
| const proxyClient = await createClient(proxyEndpoint, {}); | |
| // Test setting logging level - this should work through the proxy | |
| const setLevelResult = await proxyClient.setLoggingLevel('debug'); | |
| expect(setLevelResult).toEqual({}); | |
| }); | |
| it('should handle elicitation through proxy', async () => { | |
| const { targetEndpoint } = await setupMockMcpServer({ enableElicitation: true }); | |
| const { proxyEndpoint } = await setupProxyServer(targetEndpoint); | |
| const elicitRequests: ElicitRequest[] = []; | |
| const proxyClient = await createClient(proxyEndpoint, { elicitation: {} }, async client => { | |
| client.setRequestHandler(ElicitRequestSchema, async (request, _extra) => { | |
| elicitRequests.push(request); | |
| return { | |
| action: 'accept' as const, | |
| content: { approved: true } | |
| }; | |
| }); | |
| }); | |
| // Call the triggerElicitation tool | |
| const result = await proxyClient.callTool({ | |
| name: 'triggerElicitation', | |
| arguments: { | |
| message: 'Test elicitation', | |
| requestedSchema: { | |
| type: 'object', | |
| properties: { | |
| approved: { type: 'boolean' } | |
| }, | |
| required: ['approved'] | |
| } | |
| } | |
| }); | |
| expect(elicitRequests).toHaveLength(1); | |
| expect(elicitRequests[0].params.message).toBe('Test elicitation'); | |
| expect((result.structuredContent as any)?.elicitationResult).toEqual({ | |
| action: 'accept', | |
| content: { approved: true } | |
| }); | |
| }); | |
| it('should handle sampling through proxy', async () => { | |
| const { targetEndpoint } = await setupMockMcpServer({ enableSampling: true }); | |
| const { proxyEndpoint } = await setupProxyServer(targetEndpoint); | |
| // Mock server needs to handle createMessage request | |
| const targetClient = await createClient(targetEndpoint, { sampling: {} }); | |
| // For now, just verify the capability is exposed through the proxy | |
| const proxyClient = await createClient(proxyEndpoint, { sampling: {} }); | |
| // The proxy should accept sampling capability from the client | |
| // but actual sampling implementation would need more setup | |
| }); | |
| it('should forward resource and prompt list change notifications', async () => { | |
| const { targetEndpoint } = await setupMockMcpServer(); | |
| const { proxyEndpoint } = await setupProxyServer(targetEndpoint); | |
| let resourceListChangedCount = 0; | |
| let promptListChangedCount = 0; | |
| let toolListChangedCount = 0; | |
| const proxyClient = await createClient(proxyEndpoint, {}, async client => { | |
| // Use the actual notification schemas | |
| client.setNotificationHandler(ResourceListChangedNotificationSchema, () => { | |
| resourceListChangedCount++; | |
| }); | |
| client.setNotificationHandler(PromptListChangedNotificationSchema, () => { | |
| promptListChangedCount++; | |
| }); | |
| client.setNotificationHandler(ToolListChangedNotificationSchema, () => { | |
| toolListChangedCount++; | |
| }); | |
| }); | |
| // Verify we can list resources through the proxy | |
| const resources = await proxyClient.listResources(); | |
| expect(resources.resources.length).toBeGreaterThan(0); | |
| // To actually test notification forwarding, we'd need to trigger changes | |
| // on the target server and verify they're forwarded to the client | |
| // For now, just verify the handlers are set up correctly | |
| }); | |
| // it('should send elicitations to client', async () => { | |
| // const { targetEndpoint } = await setupEchoMcpServer(); | |
| // const { proxyEndpoint, elicitInput } = await setupProxyServer(targetEndpoint); | |
| // console.log(`Proxy endpoint: ${proxyEndpoint}`); | |
| // const elicitRequests: ElicitRequest[] = []; | |
| // const proxyClient = await createClient(proxyEndpoint, { elicitation: {} }, async client => { | |
| // client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { | |
| // logger.info(`Received elicit request: ${JSON.stringify(request)}, extra: ${JSON.stringify(extra)}`); | |
| // elicitRequests.push(request); | |
| // return { | |
| // action: 'accept', | |
| // content: { | |
| // testField: 'yay', | |
| // } | |
| // } | |
| // }); | |
| // }); | |
| // const elicitParams: ElicitRequest['params'] = { | |
| // message: 'Test elicit input', | |
| // requestedSchema: { | |
| // 'type': 'object', | |
| // 'properties': { | |
| // 'testField': { | |
| // 'type': 'string', | |
| // }, | |
| // }, | |
| // required: ['testField'], | |
| // }, | |
| // }; | |
| // console.log(`Elicit params: ${JSON.stringify(elicitParams)}`); | |
| // const result = await elicitInput(elicitParams); | |
| // console.log(`Elicit result: ${JSON.stringify(result)}`); | |
| // expect(result as any).toBe({ testField: 'yay' }); | |
| // expect(elicitRequests).toHaveLength(1); | |
| // expect(elicitRequests[0].params.message).toBe('Test elicit input'); | |
| // expect(elicitRequests[0].params.requestedSchema).toEqual(elicitParams.requestedSchema); | |
| // }); | |
| // it('should call the classifier w/ tool calls', async () => { | |
| // const classificationInputs: ClassifierInput[] = []; | |
| // const proxiedServerEndpoint = await setupEchoMcpServer(); | |
| // const classifierEndpoint = await setupClassifierServer(async inputs => { | |
| // classificationInputs.push(inputs); | |
| // return { | |
| // decision: 'allow', | |
| // }; | |
| // }); | |
| // const { proxyEndpoint } = await setupProxyServer({proxiedServerEndpoint, classifierEndpoint}); | |
| // const targetClient = await createClient(proxiedServerEndpoint, { elicitation: false }); | |
| // const proxyClient = await createClient(proxyEndpoint, { elicitation: false }); | |
| // const params: CallToolRequest['params'] = {name: 'echo', arguments: { message: 'Hello, world!' }}; | |
| // const targetTools = await targetClient.listTools(); | |
| // const proxyTools = await proxyClient.listTools(); | |
| // expect(proxyTools).toEqual(targetTools); | |
| // expect(proxyTools.tools).toHaveLength(1); | |
| // expect(proxyTools.tools[0].name).toBe('echo'); | |
| // // expect(proxyTools.tools[0].inputSchema).toEqual(EchoInputSchema.shape); | |
| // // expect(proxyTools.tools[0].outputSchema).toEqual(EchoOutputSchema.shape); | |
| // const targetResult = await targetClient.callTool(params); | |
| // const proxyResult = await proxyClient.callTool(params); | |
| // expect(proxyResult).toEqual(targetResult); | |
| // expect(classificationInputs).toHaveLength(1); | |
| // expect(classificationInputs[0].request.name).toEqual('echo'); | |
| // expect(classificationInputs[0].request.arguments).toEqual({ message: 'Hello, world!' }); | |
| // // expect(classificationInputs[0]).toEqual({ | |
| // // name: 'echo', | |
| // // tool: { | |
| // // name: 'echo', | |
| // // inputSchema: EchoInputSchema.shape, | |
| // // outputSchema: EchoOutputSchema.shape, | |
| // // }, | |
| // // arguments: { message: 'Hello, world!' }, | |
| // // }); | |
| // }); | |
| }); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import net from "net"; | |
| import { Client } from "@modelcontextprotocol/sdk/client/index.js"; | |
| import { Server } from "@modelcontextprotocol/sdk/server/index.js"; | |
| import { | |
| ListPromptsRequestSchema, | |
| GetPromptRequestSchema, | |
| ListToolsRequestSchema, | |
| CallToolRequestSchema, | |
| CallToolResultSchema, | |
| ListResourcesRequestSchema, | |
| ListResourceTemplatesRequestSchema, | |
| ReadResourceRequestSchema, | |
| ElicitRequestSchema, | |
| PingRequestSchema, | |
| SetLevelRequestSchema, | |
| LoggingMessageNotificationSchema, | |
| CompleteRequestSchema, | |
| ResourceUpdatedNotificationSchema, | |
| ResourceListChangedNotificationSchema, | |
| PromptListChangedNotificationSchema, | |
| CreateMessageRequestSchema, | |
| ListRootsRequestSchema, | |
| JSONRPCMessageSchema, | |
| ToolListChangedNotificationSchema, | |
| } from "@modelcontextprotocol/sdk/types.js"; | |
| import { RequestOptions } from '@modelcontextprotocol/sdk/shared/protocol.js'; | |
| import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js'; | |
| import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js'; | |
| import { ElicitRequest, ElicitResult, InitializeRequest, isJSONRPCRequest, JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js'; | |
| import winston from "winston"; | |
| import { readLine } from "../shared/terminal-utils.js"; | |
| import { Transport } from "@modelcontextprotocol/sdk/shared/transport.js"; | |
| export async function createProxy(clientTransport: Transport, initialize: InitializeRequest['params'], logger: winston.Logger): Promise<Server> { | |
| const clientCapabilities = initialize.capabilities; | |
| const client = new Client({ | |
| name: initialize.clientInfo.name, | |
| title: initialize.clientInfo.title, | |
| version: initialize.clientInfo.version, | |
| }, { | |
| capabilities: { | |
| // Copy capabilities explicitly to avoid new ones creeping in. | |
| sampling: clientCapabilities.sampling ? {} : undefined, | |
| elicitation: clientCapabilities.elicitation ? {} : undefined, // Enable elicitation on proxy server if client supports it. | |
| roots: clientCapabilities.roots ? {} : undefined, | |
| }, | |
| }); | |
| try { | |
| await client.connect(clientTransport); | |
| } catch (error) { | |
| logger.error('Error connecting to client transport'); | |
| logger.error(error); | |
| throw new Error(`Could not connect to client transport: ${error}`); | |
| } | |
| const serverVersion = client.getServerVersion(); | |
| const serverCapabilities = client.getServerCapabilities(); | |
| const server = new Server({ | |
| name: serverVersion?.name ?? '?', | |
| title: serverVersion?.title, | |
| version: serverVersion?.version ?? '?', | |
| }, { | |
| capabilities: { | |
| // Copy capabilities explicitly to avoid new ones creeping in. | |
| completions: serverCapabilities?.completions, | |
| elicitation: serverCapabilities?.elicitation, | |
| logging: serverCapabilities?.logging ? {} : undefined, | |
| prompts: serverCapabilities?.prompts ? { | |
| listChanged: serverCapabilities.prompts.listChanged ? true : undefined, | |
| } : undefined, | |
| resources: serverCapabilities?.resources ? { | |
| subscribe: serverCapabilities.resources.subscribe ? true : undefined, | |
| listChanged: serverCapabilities.resources.listChanged ? true : undefined, | |
| } : undefined, | |
| tools: serverCapabilities?.tools ? { | |
| listChanged: serverCapabilities.tools.listChanged ? true : undefined, | |
| } : undefined, | |
| }, | |
| instructions: client.getInstructions(), | |
| }); | |
| // Wire handlers. | |
| // TODO: remap related ids: {relatedRequestId: extra.requestId}); | |
| server.setRequestHandler(PingRequestSchema, async (_request, extra) => { | |
| return client.ping(extra); | |
| }); | |
| client.setRequestHandler(PingRequestSchema, async (_request, _extra) => { | |
| return server.ping(); | |
| }); | |
| if (serverCapabilities?.prompts) { | |
| server.setRequestHandler(ListPromptsRequestSchema, (request, extra) => { | |
| return client.listPrompts(request.params, extra); | |
| }); | |
| server.setRequestHandler(GetPromptRequestSchema, (request, extra) => { | |
| return client.getPrompt(request.params, extra); | |
| }); | |
| if (serverCapabilities.prompts?.listChanged) { | |
| client.setNotificationHandler(PromptListChangedNotificationSchema, _ => { | |
| return server.sendPromptListChanged(); | |
| }); | |
| } | |
| } | |
| if (serverCapabilities?.tools) { | |
| server.setRequestHandler(ListToolsRequestSchema, (request, extra) => { | |
| return client.listTools(request.params, extra); | |
| }); | |
| if (serverCapabilities.tools?.listChanged) { | |
| server.setNotificationHandler(ToolListChangedNotificationSchema, _ => { | |
| return server.sendToolListChanged(); | |
| }); | |
| } | |
| server.setRequestHandler(CallToolRequestSchema, (request, extra) => { | |
| return client.callTool(request.params, CallToolResultSchema, extra); | |
| }); | |
| } | |
| if (serverCapabilities?.logging) { | |
| server.setRequestHandler(SetLevelRequestSchema, (request, extra) => { | |
| return client.setLoggingLevel(request.params.level, extra); | |
| }); | |
| client.setNotificationHandler(LoggingMessageNotificationSchema, (notification) => { | |
| return server.sendLoggingMessage(notification.params); | |
| }); | |
| } | |
| if (serverCapabilities?.completions) { | |
| server.setRequestHandler(CompleteRequestSchema, (request, extra) => { | |
| return client.complete(request.params, extra); | |
| }); | |
| } | |
| if (serverCapabilities?.resources) { | |
| server.setRequestHandler(ListResourcesRequestSchema, (request, extra) => { | |
| return client.listResources(request.params, extra); | |
| }); | |
| server.setRequestHandler(ListResourceTemplatesRequestSchema, (request, extra) => { | |
| return client.listResourceTemplates(request.params, extra); | |
| }); | |
| server.setRequestHandler(ReadResourceRequestSchema, (request, extra) => { | |
| return client.readResource(request.params, extra); | |
| }); | |
| if (serverCapabilities.resources?.update) { | |
| client.setNotificationHandler(ResourceUpdatedNotificationSchema, notification => { | |
| return server.sendResourceUpdated(notification.params); | |
| }); | |
| } | |
| if (serverCapabilities.resources?.listChanged) { | |
| client.setNotificationHandler(ResourceListChangedNotificationSchema, _ => { | |
| return server.sendResourceListChanged(); | |
| }); | |
| } | |
| } | |
| if (clientCapabilities.elicitation) { | |
| client.setRequestHandler(ElicitRequestSchema, async (request, extra) => { | |
| return server.elicitInput(request.params, extra); | |
| }); | |
| } | |
| if (clientCapabilities.sampling) { | |
| client.setRequestHandler(CreateMessageRequestSchema, async (request, extra) => { | |
| return server.createMessage(request.params, extra); | |
| }); | |
| } | |
| if (clientCapabilities.roots) { | |
| client.setRequestHandler(ListRootsRequestSchema, async (request, extra) => { | |
| return server.listRoots(request.params, extra); | |
| }); | |
| // if (clientCapabilities.roots?.listChanged) { | |
| // client.setNotificationHandler(RootsListChangedNotificationSchema, _ => { | |
| // return server.sendRootsListChanged(); | |
| // }); | |
| // } | |
| } | |
| return server; | |
| } | |
| export async function createStdioProxy(inner: {command: string, args: string[]}, logger: winston.Logger) : Promise<{ | |
| elicitInput: (params: ElicitRequest['params'], options?: RequestOptions) => Promise<ElicitResult>, | |
| close: () => Promise<void>, | |
| }> { | |
| return new Promise(async (resolve, reject) => { | |
| try { | |
| const line = await readLine(); | |
| // const line = '{"method":"initialize","params":{"protocolVersion":"2025-06-18","capabilities":{"elicitation":{}},"clientInfo":{"name":"test-client","version":"1.0.0"}},"jsonrpc":"2.0","id":0}\n'; | |
| let message: JSONRPCMessage; | |
| try { | |
| message = JSONRPCMessageSchema.parse(JSON.parse(line)); | |
| } catch (error) { | |
| reject(new Error(`Could not parse initial message: ${error}: ${line}`)); | |
| return; | |
| } | |
| if (!isJSONRPCRequest(message) || message.method !== 'initialize') { | |
| reject(new Error(`Expected initialize message, got: ${JSON.stringify(message)}`)); | |
| return; | |
| } | |
| const initialize = message as any as InitializeRequest; | |
| const transport = new StdioClientTransport({command: inner.command, args: inner.args}); | |
| const server = await createProxy(transport, initialize.params, logger); | |
| const outerTransport = new StdioServerTransport(); | |
| await server.connect(outerTransport); | |
| if (!outerTransport.onmessage) { | |
| throw new Error('transport.onmessage should have been defined by the server'); | |
| } | |
| outerTransport.onmessage(message); | |
| resolve({ | |
| elicitInput: async (params: ElicitRequest['params'], options?: RequestOptions) => { | |
| return server.elicitInput(params, options); | |
| }, | |
| close: () => server.close(), | |
| }); | |
| } catch (error) { | |
| reject(error); | |
| } | |
| }); | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import { CancelledNotification, InitializedNotification, InitializeRequest, isInitializeRequest, JSONRPCError, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse } from "@modelcontextprotocol/sdk/types.js"; | |
| import { ElicitRequest, ElicitResult, JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js'; | |
| import winston from "winston"; | |
| import { Transport, TransportSendOptions } from "@modelcontextprotocol/sdk/shared/transport.js"; | |
| import { beforeEach, describe, expect, it } from "bun:test"; | |
| import { createProxy, TypedRequest } from "./proxy_transports.js"; | |
| class MockTransport implements Transport { | |
| startCallCount = 0; | |
| closeCallCount = 0; | |
| sendCalls: Parameters<Transport['send']>[] = []; | |
| async start(): Promise<void> { | |
| this.startCallCount++; | |
| } | |
| async send(message: JSONRPCMessage, options?: TransportSendOptions): Promise<void> { | |
| this.sendCalls.push([message, options]); | |
| } | |
| async close(): Promise<void> { | |
| this.closeCallCount++; | |
| } | |
| onclose?: (() => void) | undefined; | |
| onerror?: ((error: Error) => void) | undefined; | |
| onmessage?: ((message: JSONRPCMessage) => void) | undefined; | |
| } | |
| const elicitRequestParams: ElicitRequest['params'] = { | |
| message: "Hello, world!", | |
| requestedSchema: { | |
| type: "object", | |
| properties: {name: {type: "string"}}, | |
| required: ["name"], | |
| additionalProperties: false, | |
| }, | |
| }; | |
| const initializeWithElicitation = <TypedRequest<InitializeRequest>>{ | |
| method: "initialize", | |
| params: { | |
| "protocolVersion":"2025-06-18", | |
| "capabilities":{"elicitation":{}, "sampling":{}}, | |
| "clientInfo":{"name":"test-client","version":"1.0.0"}, | |
| }, | |
| jsonrpc:"2.0", id: 0 | |
| }; | |
| describe('createProxy', () => { | |
| let logger: winston.Logger | undefined; | |
| let errors: Parameters<Parameters<typeof createProxy>['0']['onerror']>[] = []; | |
| let proxy: Awaited<ReturnType<typeof createProxy>> | undefined; | |
| let clientTransport: MockTransport | undefined; | |
| let serverTransport: MockTransport | undefined; | |
| beforeEach(async () => { | |
| logger = winston.createLogger({ | |
| level: 'debug', | |
| format: winston.format.simple(), | |
| transports: [ | |
| new winston.transports.Console(), | |
| ], | |
| }); | |
| clientTransport = new MockTransport(); | |
| serverTransport = new MockTransport(); | |
| proxy = await createProxy({ | |
| clientTransport, | |
| serverTransport, | |
| logger: logger!, | |
| verbose: true, | |
| onerror: (source, error) => { | |
| errors.push([source, error]); | |
| } | |
| }); | |
| }); | |
| it('should forward the initialize message from server transport (upstream client) to client transport (downstream server)', async () => { | |
| serverTransport?.onmessage?.(initializeWithElicitation); | |
| expect(serverTransport?.sendCalls.length).toBeEmpty(); | |
| expect(clientTransport?.sendCalls.length).toBe(1); | |
| expect(clientTransport?.sendCalls[0][0]).toEqual(initializeWithElicitation); | |
| }); | |
| it('should forward the initialized notification from client transport (downstream server) to server transport (upstream client)', async () => { | |
| const initializedNotification: JSONRPCMessage & InitializedNotification = | |
| {"method":"notifications/initialized","jsonrpc":"2.0"}; | |
| clientTransport?.onmessage?.(initializedNotification); | |
| expect(clientTransport?.sendCalls.length).toBeEmpty(); | |
| expect(serverTransport?.sendCalls.length).toBe(1); | |
| expect(serverTransport?.sendCalls[0][0]).toEqual(initializedNotification); | |
| }); | |
| it('should forward non-special-cased elicitation requests from client transport (downstream server) to server transport (upstream client)', async () => { | |
| const params: ElicitRequest['params'] = { | |
| message: "Hello, world!", | |
| requestedSchema: { | |
| type: "object", | |
| properties: {name: {type: "string"}}, | |
| required: ["name"], | |
| additionalProperties: false, | |
| }, | |
| }; | |
| const message: JSONRPCMessage = | |
| {method: "elicitation/create", params, jsonrpc: "2.0", id:-123}; | |
| clientTransport?.onmessage?.(message); | |
| expect(clientTransport?.sendCalls.length).toBeEmpty(); | |
| expect(serverTransport?.sendCalls.length).toBe(1); | |
| expect(serverTransport?.sendCalls[0][0]).toEqual(message); | |
| }); | |
| it('should send synthetic elicitInput to the server transport (upstream client), wait for initialize and not forward its response to anyone', async () => { | |
| const response = proxy?.elicitInput(elicitRequestParams); | |
| await new Promise(resolve => setTimeout(resolve, 0)); | |
| expect(clientTransport!.sendCalls.length).toBeEmpty(); | |
| expect(serverTransport!.sendCalls.length).toBeEmpty(); | |
| serverTransport?.onmessage?.(initializeWithElicitation); | |
| await new Promise(resolve => setTimeout(resolve, 0)); | |
| expect(serverTransport!.sendCalls.length).toBe(1); | |
| expect(serverTransport!.sendCalls[0][0]).toEqual( | |
| {id: -1, jsonrpc: "2.0", method: "elicitation/create", params: elicitRequestParams} | |
| ); | |
| serverTransport!.sendCalls.length = 0; | |
| const elicitResult: ElicitResult = <ElicitResult>{ | |
| action: 'accept', | |
| content: {name: "John Doe"}, | |
| }; | |
| const upstreamResponse: JSONRPCResponse = { | |
| id: -1, | |
| jsonrpc: "2.0", | |
| result: elicitResult, | |
| }; | |
| serverTransport?.onmessage?.(upstreamResponse); | |
| expect(await response).toEqual(elicitResult); | |
| // Check that elicitation request wasn't forwarded to any transport (special-cased) | |
| expect(serverTransport!.sendCalls.length).toBeEmpty(); | |
| expect(clientTransport!.sendCalls.length).toBeEmpty(); | |
| // Now simulate an other elicitation response from the server (upstream client), confusingly w/ the same id. | |
| // It should be forwarded as usual (we've cleaned up the synthetic id from the map). | |
| serverTransport?.onmessage?.(upstreamResponse); | |
| expect(serverTransport!.sendCalls.length).toBeEmpty(); | |
| expect(clientTransport!.sendCalls.length).toBe(2); | |
| expect(clientTransport!.sendCalls[0][0]).toEqual(initializeWithElicitation); | |
| expect(clientTransport!.sendCalls[1][0]).toEqual(upstreamResponse); | |
| }); | |
| it('should send synthetic elicitInput to the server transport (upstream client) and reflect errors', async () => { | |
| const response = proxy?.elicitInput(elicitRequestParams); | |
| serverTransport?.onmessage?.(initializeWithElicitation); | |
| await new Promise(resolve => setTimeout(resolve, 0)); | |
| serverTransport?.onmessage?.(<JSONRPCError>{ | |
| id: -1, | |
| jsonrpc: "2.0", | |
| error: { | |
| code: -32603, | |
| message: "Internal error", | |
| }, | |
| }); | |
| try { | |
| await response; | |
| throw new Error('Expected elicitInput to throw'); | |
| } catch (error: any) { | |
| expect(error.message).toBe('Internal error'); | |
| } | |
| // expect(response).resolves.toThrowError('aaa'); | |
| // Check that elicitation request wasn't forwarded to any transport (special-cased) | |
| expect(serverTransport!.sendCalls.length).toBeEmpty(); | |
| expect(clientTransport!.sendCalls.length).toBeEmpty(); | |
| }); | |
| it('should send synthetic elicitInput to the server transport (upstream client) and reflect cancellation', async () => { | |
| const response = proxy?.elicitInput(elicitRequestParams); | |
| serverTransport?.onmessage?.(initializeWithElicitation); | |
| await new Promise(resolve => setTimeout(resolve, 0)); | |
| const interceptedCancelledNotification = <TypedRequest<CancelledNotification>>{ | |
| jsonrpc: "2.0", | |
| method: "notifications/cancelled", | |
| params: { | |
| requestId: -1, | |
| }, | |
| }; | |
| const passThroughCancelledNotification = <TypedRequest<CancelledNotification>>{ | |
| jsonrpc: "2.0", | |
| method: "notifications/cancelled", | |
| params: { | |
| requestId: 12345, | |
| }, | |
| }; | |
| serverTransport?.onmessage?.(passThroughCancelledNotification); | |
| serverTransport?.onmessage?.(interceptedCancelledNotification); | |
| try { | |
| await response; | |
| throw new Error('Expected elicitInput to throw'); | |
| } catch (error: any) { | |
| expect(error.message).toBe('Request cancelled'); | |
| } | |
| // Check that elicitation request wasn't forwarded to any transport (special-cased) | |
| expect(serverTransport!.sendCalls.length).toBeEmpty(); | |
| expect(clientTransport!.sendCalls.length).toBe(2); | |
| expect(clientTransport!.sendCalls[0][0]).toEqual(initializeWithElicitation); | |
| expect(clientTransport!.sendCalls[1][0]).toEqual(passThroughCancelledNotification); | |
| }); | |
| }); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import { | |
| isJSONRPCResponse, | |
| isJSONRPCError, | |
| ElicitResultSchema, | |
| CancelledNotification, | |
| InitializeRequest, | |
| isInitializeRequest, | |
| isJSONRPCNotification, | |
| isJSONRPCRequest, | |
| JSONRPCResponseSchema, | |
| } from "@modelcontextprotocol/sdk/types.js"; | |
| import { RequestOptions } from '@modelcontextprotocol/sdk/shared/protocol.js'; | |
| import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js'; | |
| import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js'; | |
| import { ElicitRequest, ElicitResult, JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js'; | |
| import winston from "winston"; | |
| import { Transport, TransportSendOptions } from "@modelcontextprotocol/sdk/shared/transport.js"; | |
| import { CancelledNotificationSchema } from '@modelcontextprotocol/sdk/types.js'; | |
| const isCancelledNotification: (value: unknown) => value is CancelledNotification = | |
| ((value: any) => CancelledNotificationSchema.safeParse(value).success) as any; | |
| type NamedTransport<T extends Transport = Transport> = { | |
| name: 'client' | 'server', | |
| transport: T, | |
| } | |
| type PromiseHandlers<T> = { | |
| resolve: (response: T) => void, | |
| reject: (reason?: any) => void | |
| }; | |
| export async function createProxy( | |
| { | |
| clientTransport, | |
| serverTransport, | |
| logger, | |
| onerror, | |
| verbose, | |
| } : { | |
| clientTransport: Transport, | |
| serverTransport: Transport, | |
| logger: winston.Logger, | |
| verbose: boolean, | |
| onerror: (source: NamedTransport['name'], error: any) => void, | |
| }) : Promise<{ | |
| elicitInput: (params: ElicitRequest['params'], options?: RequestOptions) => Promise<ElicitResult>, | |
| close: () => Promise<void>, | |
| }> { | |
| const client: NamedTransport = {name: 'client', transport: clientTransport}; | |
| const server: NamedTransport = {name: 'server', transport: serverTransport}; | |
| // TODO: different space for client vs. server synthetic ids? | |
| let nextSyntheticMessageId = -1; | |
| const syntheticResponseHandlers = new Map<number | string, PromiseHandlers<JSONRPCMessage>>(); | |
| const sendSynthetic = async (target: NamedTransport, factory: (id: number) => JSONRPCMessage, options?: TransportSendOptions) => { | |
| const id = nextSyntheticMessageId--; | |
| const message = factory(id); | |
| if (isJSONRPCNotification(message) || isJSONRPCError(message)) { | |
| target.transport.send(message, options); | |
| return; | |
| } | |
| if (isJSONRPCRequest(message)) { | |
| return new Promise<JSONRPCMessage>((resolve, reject) => { | |
| const handlers = <any>{resolve, reject}; | |
| target.transport.send(message, options); | |
| syntheticResponseHandlers.set(id, handlers); | |
| }); | |
| } | |
| throw new Error(`Unexpected message type: ${JSON.stringify(message)}`); | |
| } | |
| const initializeRequest = new Promise<InitializeRequest>(resolve => { | |
| const propagateMessage = (source: NamedTransport, target: NamedTransport) => { | |
| source.transport.onmessage = async (message, extra) => { | |
| if (verbose) { | |
| logger.info(`[proxy]: Message from ${source.name} transport: ${JSON.stringify(message)}; extra: ${JSON.stringify(extra)}`); | |
| } | |
| // Handle responses to synthetic (proxy-injected) messages | |
| if ((isJSONRPCResponse(message) || isJSONRPCError(message)) && syntheticResponseHandlers.has(message.id)) { | |
| const handler = syntheticResponseHandlers.get(message.id as number)!; | |
| syntheticResponseHandlers.delete(message.id as number); | |
| if (isJSONRPCError(message)) { | |
| handler.reject(new Error(message.error.message)); | |
| } else { | |
| handler.resolve(message); | |
| } | |
| return; | |
| } | |
| const isCancellation = isCancelledNotification(message); | |
| if (isCancellation && syntheticResponseHandlers.has(message.params.requestId)) { | |
| const handler = syntheticResponseHandlers.get(message.params.requestId)!; | |
| syntheticResponseHandlers.delete(message.params.requestId); | |
| handler.reject(new Error('Request cancelled')); | |
| return; | |
| } | |
| try { | |
| const relatedRequestId = isCancellation ? message.params.requestId : undefined; | |
| await target.transport.send(message, {relatedRequestId}); | |
| } catch (error) { | |
| logger.error(`[proxy]: Error sending message to ${target.name}: ${error}`); | |
| logger.error(error); | |
| } | |
| if (source.name === 'server' && isInitializeRequest(message)) { | |
| resolve(message); | |
| } | |
| }; | |
| }; | |
| propagateMessage(server, client); | |
| propagateMessage(client, server); | |
| }); | |
| const addErrorHandler = (transport: NamedTransport) => { | |
| transport.transport.onerror = async (error: Error) => { | |
| if (verbose) { | |
| logger.error(`[proxy]: Error from ${transport.name} transport: ${error.message}`); | |
| } | |
| onerror(transport.name, error); | |
| }; | |
| }; | |
| addErrorHandler(client); | |
| addErrorHandler(server); | |
| server.transport.start(); | |
| client.transport.start(); | |
| return { | |
| elicitInput: async (params: ElicitRequest['params'], options?: RequestOptions) => { | |
| try { | |
| if (verbose) { | |
| logger.info(`[proxy]: elicitInput called with params: ${JSON.stringify(params)}, options: ${JSON.stringify(options)}`); | |
| } | |
| const init = await initializeRequest; | |
| if (!init.params.capabilities.elicitation) { | |
| if (verbose) { | |
| logger.info(`[proxy]: Upstream client does not support elicitation, declining`); | |
| } | |
| return <ElicitResult>{action: 'decline'}; | |
| } | |
| const response = await sendSynthetic(server, id => ({ | |
| jsonrpc: "2.0", | |
| id, | |
| ...<ElicitRequest>{ | |
| method: "elicitation/create", | |
| params, | |
| }, | |
| }), options); | |
| if (verbose) { | |
| logger.info(`[proxy]: elicitInput response: ${JSON.stringify(response)}`); | |
| } | |
| return ElicitResultSchema.parse(JSONRPCResponseSchema.parse(response).result); | |
| } catch (error: any) { | |
| logger.error(`Error during elicitInput: ${error.message}`); | |
| throw error; | |
| } | |
| }, | |
| close: async () => { | |
| await Promise.all([ | |
| client.transport.close(), | |
| server.transport.close(), | |
| ]); | |
| }, | |
| } | |
| } | |
| export async function createStdioProxy( | |
| inner: {command: string, args: string[]}, | |
| opts: { | |
| logger: winston.Logger, | |
| verbose: boolean, | |
| onerror: (source: NamedTransport['name'], error: any) => void, | |
| }) : Promise<{ | |
| elicitInput: (params: ElicitRequest['params'], options?: RequestOptions) => Promise<ElicitResult>, | |
| close: () => Promise<void>, | |
| }> { | |
| return createProxy({ | |
| clientTransport: new StdioClientTransport({command: inner.command, args: inner.args}), | |
| serverTransport: new StdioServerTransport(), | |
| ...opts, | |
| }); | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment