Skip to content

Instantly share code, notes, and snippets.

@ochafik
Created September 1, 2025 16:19
Show Gist options
  • Select an option

  • Save ochafik/42c9f4a3a472ec31e630a82dafe13371 to your computer and use it in GitHub Desktop.

Select an option

Save ochafik/42c9f4a3a472ec31e630a82dafe13371 to your computer and use it in GitHub Desktop.
MCP proxy options for antechamber
import { afterEach, describe, it, expect } from "bun:test";
import winston from "winston";
import z from "zod";
import { Client } from "@modelcontextprotocol/sdk/client/index.js";
import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js";
import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js";
import { CallToolRequest, ClientCapabilities, ElicitRequest, ElicitRequestSchema, ElicitResult, PingRequestSchema, CompleteRequestSchema, ListRootsRequestSchema, ResourceListChangedNotificationSchema, PromptListChangedNotificationSchema, ToolListChangedNotificationSchema } from "@modelcontextprotocol/sdk/types.js";
import { createProxy, createStdioProxy } from "./proxy_client_server.js";
import { CleanupFunction } from "../shared/cleanup-utils.js";
import { getPort, registerStreamableHttpMcpServer, ServerFactory, setupExpressServer } from "../shared/mcp-utils.js";
import { Stream } from "winston/lib/winston/transports/index.js";
import { RequestOptions } from "@modelcontextprotocol/sdk/shared/protocol.js";
const logger = winston.createLogger({
level: 'debug',
format: winston.format.combine(
winston.format.colorize(),
winston.format.timestamp(),
winston.format.errors({ stack: true }),
winston.format.simple(),
),
transports: [
new winston.transports.Console({}),
],
});
export type Config = {
port: number;
proxiedServerEndpoint?: URL;
mcpPath: string;
}
describe('Proxy server', () => {
const cleanups: CleanupFunction[] = [];
afterEach(async () => {
for (const cleanup of cleanups) {
try {
await cleanup();
} catch (error) {
console.error('Error during cleanup:', error);
}
}
cleanups.length = 0; // Clear the cleanups array
});
async function setupTargetServer(serverName: string, createServer: ServerFactory<McpServer>) {
const { close, server } = await setupExpressServer({
logger,
port: 0, // Use a random port
callback: async (app, addCleanup) => {
const { cleanup } = registerStreamableHttpMcpServer(serverName, app, {
logger,
path: '/mcp',
createServer,
});
addCleanup(cleanup);
},
});
cleanups.push(close);
const targetEndpoint = new URL(`http://localhost:${getPort(server)}/mcp`);
console.log(`Target server running: ${targetEndpoint}`);
return {targetEndpoint};
}
async function createClient(endpoint: URL, capabilities: ClientCapabilities, cb?: (client: Client) => Promise<void>) {
const client = new Client({
name: 'target-client',
version: '1.0.0',
}, {
capabilities,
});
client.setRequestHandler(PingRequestSchema, async (_, extra) => {
logger.info(`Client received MCP ping request from server: ${extra.requestId}`);
return {};
});
if (cb) {
await cb(client);
}
await client.connect(new StreamableHTTPClientTransport(endpoint));
return client;
}
const EchoInputSchema = z.object({
message: z.string(),
});
const EchoOutputSchema = z.object({
echo: z.string(),
});
interface MockServerConfig {
enableTools?: boolean;
enablePrompts?: boolean;
enableResources?: boolean;
enableElicitation?: boolean;
enableCompletion?: boolean;
enableLogging?: boolean;
enableSampling?: boolean;
enableRoots?: boolean;
customTools?: Array<{
name: string;
description?: string;
inputSchema?: any;
outputSchema?: any;
handler: (args: any) => Promise<any>;
}>;
}
async function setupMockMcpServer(config?: MockServerConfig) {
const {
enableTools = true,
enablePrompts = true,
enableResources = true,
enableElicitation = true,
enableCompletion = true,
enableLogging = true,
enableSampling = false,
enableRoots = false,
customTools = [],
} = config || {};
return await setupTargetServer('mock-server', async _initialize => {
const server = new McpServer({
name: 'mock-server',
version: '1.0.0',
}, {
instructions: "Mock MCP server for testing",
capabilities: {
elicitation: enableElicitation ? {} : undefined,
completions: enableCompletion ? {} : undefined,
logging: enableLogging ? {} : undefined,
prompts: enablePrompts ? { listChanged: true } : undefined,
resources: enableResources ? {
subscribe: true,
listChanged: true
} : undefined,
tools: enableTools ? { listChanged: true } : undefined,
sampling: enableSampling ? {} : undefined,
},
});
// Register tools
if (enableTools) {
// Echo tool (existing)
server.registerTool(
'echo',
{
inputSchema: EchoInputSchema.shape,
outputSchema: EchoOutputSchema.shape,
},
async ({ message }) => ({
content: [],
structuredContent: {
echo: `Echo: ${message}`,
},
})
);
// Mock tool that returns configurable responses
server.registerTool(
'mockTool',
{
description: 'A mock tool that returns configurable responses',
inputSchema: {
action: z.string(),
data: z.any().optional(),
},
outputSchema: {
result: z.string(),
metadata: z.record(z.any()).optional(),
},
},
async ({ action, data }) => ({
content: [],
structuredContent: {
result: `Mock action: ${action}`,
metadata: data ? { receivedData: data } : undefined,
},
})
);
// Special elicitation trigger tool
if (enableElicitation) {
server.registerTool(
'triggerElicitation',
{
description: 'Triggers elicitation with any payload',
inputSchema: {
message: z.string(),
requestedSchema: z.any(),
},
outputSchema: {
elicitationResult: z.any(),
},
},
async ({ message, requestedSchema }, extra) => {
try {
const result = await server.server.elicitInput({
message,
requestedSchema: requestedSchema as any,
}, extra);
return {
content: [],
structuredContent: {
elicitationResult: result,
},
};
} catch (error) {
return {
content: [{ type: 'text', text: `Elicitation failed: ${error}` }],
structuredContent: {
elicitationResult: { error: String(error) },
},
};
}
}
);
}
// Register custom tools
for (const tool of customTools) {
server.registerTool(
tool.name,
{
description: tool.description,
inputSchema: tool.inputSchema,
outputSchema: tool.outputSchema,
},
tool.handler
);
}
}
// Register prompts
if (enablePrompts) {
// Simple prompt without arguments
server.registerPrompt(
'mockPrompt',
{
description: 'A mock prompt that returns a template message',
},
async () => ({
messages: [
{
role: 'user' as const,
content: {
type: 'text' as const,
text: 'This is a mock prompt response',
},
},
],
})
);
// Prompt with arguments
server.registerPrompt(
'mockPromptWithArgs',
{
description: 'A mock prompt that accepts arguments',
argsSchema: {
name: z.string(),
value: z.string().optional(),
},
},
async ({ name, value }) => ({
messages: [
{
role: 'user' as const,
content: {
type: 'text' as const,
text: `Mock prompt for ${name}${value !== undefined ? ` with value ${value}` : ''}`,
},
},
],
})
);
}
// Register resources
if (enableResources) {
// Static resource
server.registerResource(
'mockStaticResource',
'mock://static',
{
description: 'A static mock resource',
mimeType: 'application/json',
},
async () => ({
contents: [
{
uri: 'mock://static',
mimeType: 'application/json',
text: JSON.stringify({ type: 'static', data: 'mock data' }),
},
],
})
);
// Resource template
const { ResourceTemplate } = await import('@modelcontextprotocol/sdk/server/mcp.js');
server.registerResource(
'mockResourceTemplate',
new ResourceTemplate('mock://{type}/{id}', {
list: async () => ({
resources: [
{ uri: 'mock://user/1', name: 'User 1', mimeType: 'application/json' },
{ uri: 'mock://user/2', name: 'User 2', mimeType: 'application/json' },
{ uri: 'mock://item/a', name: 'Item A', mimeType: 'application/json' },
{ uri: 'mock://item/b', name: 'Item B', mimeType: 'application/json' },
]
}),
complete: {
type: async () => ['user', 'item'],
id: async (_value, context) => {
const type = context?.arguments?.type;
if (type === 'user') return ['1', '2'];
if (type === 'item') return ['a', 'b'];
return [];
},
},
}),
{
description: 'A mock resource template with variables',
mimeType: 'application/json',
},
async (uri, arguments_) => {
const { type, id } = arguments_ as { type: string; id: string };
return {
contents: [
{
uri: `mock://${type}/${id}`,
mimeType: 'application/json',
text: JSON.stringify({ type, id, data: `Mock ${type} ${id}` }),
},
],
};
}
);
}
// Register completion handler
if (enableCompletion) {
server.server.setRequestHandler(CompleteRequestSchema, async ({ ref, prompt }) => ({
completion: {
values: [`Completed: ${prompt}`],
total: 1,
hasMore: false,
}
}));
}
// Register logging handler
if (enableLogging) {
const { SetLevelRequestSchema } = await import('@modelcontextprotocol/sdk/types.js');
server.server.setRequestHandler(SetLevelRequestSchema, async () => ({
// Logging level set successfully
}));
}
// Register roots handler
if (enableRoots) {
server.server.setRequestHandler(ListRootsRequestSchema, async () => ({
roots: [
{ uri: 'file:///mock/root1', name: 'Root 1' },
{ uri: 'file:///mock/root2', name: 'Root 2' },
]
}));
// Add triggerListRoots tool for testing
// Note: Unlike elicitInput, listRoots is a client->server request,
// so we can't call it from the server. Instead, we return the same
// data that the ListRootsRequestSchema handler returns.
server.registerTool(
'triggerListRoots',
{
description: 'Returns the roots data',
outputSchema: {
roots: z.array(z.object({
uri: z.string(),
name: z.string().optional(),
}))
},
},
async () => {
// Return the same data that the ListRootsRequestSchema handler returns
return {
content: [],
structuredContent: {
roots: [
{ uri: 'file:///mock/root1', name: 'Root 1' },
{ uri: 'file:///mock/root2', name: 'Root 2' },
]
},
};
}
);
}
return server;
});
}
// Keep backward compatibility
async function setupEchoMcpServer() {
return setupMockMcpServer({
enablePrompts: false,
enableResources: false,
enableElicitation: false,
enableCompletion: false,
enableLogging: false,
enableTools: true,
});
}
// async function setupClassifierServer(classifyMcpTool: (args: ClassifierInput) => Promise<ClassifierOutput>) {
// classifyMcpTool ??= async _inputs => {
// return {
// decision: 'allow',
// };
// };
// return setupTargetServer('classifier-server', async _initialize => {
// const server = new McpServer({
// name: 'classifier-server',
// version: '1.0.0',
// }, {
// capabilities: {
// },
// instructions: "Echoes messages",
// });
// server.registerTool(
// 'classify_mcp_tool',
// {
// inputSchema: ClassifierInputSchema.shape,
// outputSchema: ClassifierOutputSchema.shape,
// },
// async (args: ClassifierInput, _extra) => {
// return <ClassifierOutput>{
// content: [],
// structuredContent: await classifyMcpTool(args),
// };
// }
// );
// return server;
// });
// }
async function setupProxyServer(proxiedServerEndpoint: URL) {
const mcpPath = '/test-mcp';
// let elicitInput: ElicitInputFunction | undefined;
// const elicitInput = new Promise<ElicitInputFunction>(resolve => {
const {server: netServer, close} = await setupExpressServer({
port: 0,
callback: async (app, addCleanup) => {
const {cleanup} = registerStreamableHttpMcpServer('proxy', app, {
path: mcpPath,
createServer: async initialize => {
const mcpServer = await createProxy(new StreamableHTTPClientTransport(proxiedServerEndpoint), initialize, logger);
// elicitInput = (params, options) => mcpServer.elicitInput(params, options);
return mcpServer;
},
logger,
});
addCleanup(cleanup);
},
logger,
});
cleanups.push(close);
if (!netServer) {
throw new Error('Proxy server did not create a net server');
}
const proxyEndpoint = new URL(`http://localhost:${getPort(netServer)}${mcpPath}`);
console.log(`Proxy server running: ${proxyEndpoint}`);
return { proxyEndpoint, netServer };
// return { proxyEndpoint, netServer, elicitInput };
}
it('should proxy tool calls target server', async () => {
const { targetEndpoint } = await setupEchoMcpServer();
const { proxyEndpoint } = await setupProxyServer(targetEndpoint);
const targetClient = await createClient(targetEndpoint, {});
const proxyClient = await createClient(proxyEndpoint, {});
const targetTools = await targetClient.listTools();
const proxiedTools = await proxyClient.listTools();
expect(proxiedTools).toEqual(targetTools);
const params: CallToolRequest['params'] = {name: 'echo', arguments: { message: 'Hello, world!' }};
const targetResult = await targetClient.callTool(params);
const proxyResult = await proxyClient.callTool(params);
expect(proxyResult).toEqual(targetResult);
});
it('should expose all mock server capabilities', async () => {
const { targetEndpoint } = await setupMockMcpServer();
const client = await createClient(targetEndpoint, {});
// Test tools
const tools = await client.listTools();
expect(tools.tools.map(t => t.name)).toContain('echo');
expect(tools.tools.map(t => t.name)).toContain('mockTool');
expect(tools.tools.map(t => t.name)).toContain('triggerElicitation');
// Test echo tool
const echoResult = await client.callTool({
name: 'echo',
arguments: { message: 'test' }
});
expect(echoResult.structuredContent).toEqual({ echo: 'Echo: test' });
// Test mock tool
const mockResult = await client.callTool({
name: 'mockTool',
arguments: { action: 'test-action', data: { key: 'value' } }
});
expect((mockResult.structuredContent as any)?.result).toBe('Mock action: test-action');
expect((mockResult.structuredContent as any)?.metadata).toEqual({ receivedData: { key: 'value' } });
// Test prompts
const prompts = await client.listPrompts();
expect(prompts.prompts.map(p => p.name)).toContain('mockPrompt');
expect(prompts.prompts.map(p => p.name)).toContain('mockPromptWithArgs');
// Test simple prompt
const promptResult = await client.getPrompt({ name: 'mockPrompt' });
expect(promptResult.messages).toHaveLength(1);
expect(promptResult.messages[0].content).toEqual({
type: 'text',
text: 'This is a mock prompt response'
});
// Test prompt with args
const promptWithArgsResult = await client.getPrompt({
name: 'mockPromptWithArgs',
arguments: { name: 'John', value: '42' }
});
expect(promptWithArgsResult.messages[0].content).toEqual({
type: 'text',
text: 'Mock prompt for John with value 42'
});
// Test resources
const resources = await client.listResources();
expect(resources.resources.map(r => r.name)).toContain('mockStaticResource');
expect(resources.resources.some(r => r.name.includes('User'))).toBe(true);
// Test static resource
const staticResource = await client.readResource({ uri: 'mock://static' });
expect(staticResource.contents).toHaveLength(1);
const content = JSON.parse(staticResource.contents[0].text as string || '{}');
expect(content).toEqual({ type: 'static', data: 'mock data' });
// Test resource template
const templateResource = await client.readResource({ uri: 'mock://user/1' });
expect(templateResource.contents).toHaveLength(1);
const templateContent = JSON.parse(templateResource.contents[0].text as string || '{}');
expect(templateContent).toEqual({ type: 'user', id: '1', data: 'Mock user 1' });
});
it('should proxy listResourceTemplates', async () => {
const { targetEndpoint } = await setupMockMcpServer();
const { proxyEndpoint } = await setupProxyServer(targetEndpoint);
const targetClient = await createClient(targetEndpoint, {});
const proxyClient = await createClient(proxyEndpoint, {});
const targetTemplates = await targetClient.listResourceTemplates();
const proxiedTemplates = await proxyClient.listResourceTemplates();
expect(proxiedTemplates).toEqual(targetTemplates);
expect(proxiedTemplates.resourceTemplates).toHaveLength(1);
expect(proxiedTemplates.resourceTemplates[0].uriTemplate).toBe('mock://{type}/{id}');
});
it('should proxy ping requests bidirectionally', async () => {
const { targetEndpoint } = await setupMockMcpServer();
const { proxyEndpoint } = await setupProxyServer(targetEndpoint);
const proxyClient = await createClient(proxyEndpoint, {});
// Test ping from client to server through proxy
const pingResult = await proxyClient.ping();
expect(pingResult).toEqual({});
});
it('should proxy complete requests', async () => {
const { targetEndpoint } = await setupMockMcpServer({ enableCompletion: true });
const { proxyEndpoint } = await setupProxyServer(targetEndpoint);
const proxyClient = await createClient(proxyEndpoint, {});
const completeResult = await proxyClient.complete({
ref: { type: 'ref/prompt' as const, name: 'test' },
argument: { name: 'arg1', value: 'val1' },
});
expect(completeResult.completion?.values).toHaveLength(1);
expect(completeResult.completion?.values?.[0]).toContain('Completed:');
});
it('should handle roots listing through proxy', async () => {
const { targetEndpoint } = await setupMockMcpServer({ enableRoots: true });
const { proxyEndpoint } = await setupProxyServer(targetEndpoint);
const proxyClient = await createClient(proxyEndpoint, { roots: {} });
// Call the triggerListRoots tool to test roots listing
const result = await proxyClient.callTool({
name: 'triggerListRoots',
arguments: {}
});
expect(result.structuredContent).toEqual({
roots: [
{ uri: 'file:///mock/root1', name: 'Root 1' },
{ uri: 'file:///mock/root2', name: 'Root 2' },
]
});
});
it('should proxy logging levels', async () => {
const { targetEndpoint } = await setupMockMcpServer({ enableLogging: true });
const { proxyEndpoint } = await setupProxyServer(targetEndpoint);
const proxyClient = await createClient(proxyEndpoint, {});
// Test setting logging level - this should work through the proxy
const setLevelResult = await proxyClient.setLoggingLevel('debug');
expect(setLevelResult).toEqual({});
});
it('should handle elicitation through proxy', async () => {
const { targetEndpoint } = await setupMockMcpServer({ enableElicitation: true });
const { proxyEndpoint } = await setupProxyServer(targetEndpoint);
const elicitRequests: ElicitRequest[] = [];
const proxyClient = await createClient(proxyEndpoint, { elicitation: {} }, async client => {
client.setRequestHandler(ElicitRequestSchema, async (request, _extra) => {
elicitRequests.push(request);
return {
action: 'accept' as const,
content: { approved: true }
};
});
});
// Call the triggerElicitation tool
const result = await proxyClient.callTool({
name: 'triggerElicitation',
arguments: {
message: 'Test elicitation',
requestedSchema: {
type: 'object',
properties: {
approved: { type: 'boolean' }
},
required: ['approved']
}
}
});
expect(elicitRequests).toHaveLength(1);
expect(elicitRequests[0].params.message).toBe('Test elicitation');
expect((result.structuredContent as any)?.elicitationResult).toEqual({
action: 'accept',
content: { approved: true }
});
});
it('should handle sampling through proxy', async () => {
const { targetEndpoint } = await setupMockMcpServer({ enableSampling: true });
const { proxyEndpoint } = await setupProxyServer(targetEndpoint);
// Mock server needs to handle createMessage request
const targetClient = await createClient(targetEndpoint, { sampling: {} });
// For now, just verify the capability is exposed through the proxy
const proxyClient = await createClient(proxyEndpoint, { sampling: {} });
// The proxy should accept sampling capability from the client
// but actual sampling implementation would need more setup
});
it('should forward resource and prompt list change notifications', async () => {
const { targetEndpoint } = await setupMockMcpServer();
const { proxyEndpoint } = await setupProxyServer(targetEndpoint);
let resourceListChangedCount = 0;
let promptListChangedCount = 0;
let toolListChangedCount = 0;
const proxyClient = await createClient(proxyEndpoint, {}, async client => {
// Use the actual notification schemas
client.setNotificationHandler(ResourceListChangedNotificationSchema, () => {
resourceListChangedCount++;
});
client.setNotificationHandler(PromptListChangedNotificationSchema, () => {
promptListChangedCount++;
});
client.setNotificationHandler(ToolListChangedNotificationSchema, () => {
toolListChangedCount++;
});
});
// Verify we can list resources through the proxy
const resources = await proxyClient.listResources();
expect(resources.resources.length).toBeGreaterThan(0);
// To actually test notification forwarding, we'd need to trigger changes
// on the target server and verify they're forwarded to the client
// For now, just verify the handlers are set up correctly
});
// it('should send elicitations to client', async () => {
// const { targetEndpoint } = await setupEchoMcpServer();
// const { proxyEndpoint, elicitInput } = await setupProxyServer(targetEndpoint);
// console.log(`Proxy endpoint: ${proxyEndpoint}`);
// const elicitRequests: ElicitRequest[] = [];
// const proxyClient = await createClient(proxyEndpoint, { elicitation: {} }, async client => {
// client.setRequestHandler(ElicitRequestSchema, async (request, extra) => {
// logger.info(`Received elicit request: ${JSON.stringify(request)}, extra: ${JSON.stringify(extra)}`);
// elicitRequests.push(request);
// return {
// action: 'accept',
// content: {
// testField: 'yay',
// }
// }
// });
// });
// const elicitParams: ElicitRequest['params'] = {
// message: 'Test elicit input',
// requestedSchema: {
// 'type': 'object',
// 'properties': {
// 'testField': {
// 'type': 'string',
// },
// },
// required: ['testField'],
// },
// };
// console.log(`Elicit params: ${JSON.stringify(elicitParams)}`);
// const result = await elicitInput(elicitParams);
// console.log(`Elicit result: ${JSON.stringify(result)}`);
// expect(result as any).toBe({ testField: 'yay' });
// expect(elicitRequests).toHaveLength(1);
// expect(elicitRequests[0].params.message).toBe('Test elicit input');
// expect(elicitRequests[0].params.requestedSchema).toEqual(elicitParams.requestedSchema);
// });
// it('should call the classifier w/ tool calls', async () => {
// const classificationInputs: ClassifierInput[] = [];
// const proxiedServerEndpoint = await setupEchoMcpServer();
// const classifierEndpoint = await setupClassifierServer(async inputs => {
// classificationInputs.push(inputs);
// return {
// decision: 'allow',
// };
// });
// const { proxyEndpoint } = await setupProxyServer({proxiedServerEndpoint, classifierEndpoint});
// const targetClient = await createClient(proxiedServerEndpoint, { elicitation: false });
// const proxyClient = await createClient(proxyEndpoint, { elicitation: false });
// const params: CallToolRequest['params'] = {name: 'echo', arguments: { message: 'Hello, world!' }};
// const targetTools = await targetClient.listTools();
// const proxyTools = await proxyClient.listTools();
// expect(proxyTools).toEqual(targetTools);
// expect(proxyTools.tools).toHaveLength(1);
// expect(proxyTools.tools[0].name).toBe('echo');
// // expect(proxyTools.tools[0].inputSchema).toEqual(EchoInputSchema.shape);
// // expect(proxyTools.tools[0].outputSchema).toEqual(EchoOutputSchema.shape);
// const targetResult = await targetClient.callTool(params);
// const proxyResult = await proxyClient.callTool(params);
// expect(proxyResult).toEqual(targetResult);
// expect(classificationInputs).toHaveLength(1);
// expect(classificationInputs[0].request.name).toEqual('echo');
// expect(classificationInputs[0].request.arguments).toEqual({ message: 'Hello, world!' });
// // expect(classificationInputs[0]).toEqual({
// // name: 'echo',
// // tool: {
// // name: 'echo',
// // inputSchema: EchoInputSchema.shape,
// // outputSchema: EchoOutputSchema.shape,
// // },
// // arguments: { message: 'Hello, world!' },
// // });
// });
});
import net from "net";
import { Client } from "@modelcontextprotocol/sdk/client/index.js";
import { Server } from "@modelcontextprotocol/sdk/server/index.js";
import {
ListPromptsRequestSchema,
GetPromptRequestSchema,
ListToolsRequestSchema,
CallToolRequestSchema,
CallToolResultSchema,
ListResourcesRequestSchema,
ListResourceTemplatesRequestSchema,
ReadResourceRequestSchema,
ElicitRequestSchema,
PingRequestSchema,
SetLevelRequestSchema,
LoggingMessageNotificationSchema,
CompleteRequestSchema,
ResourceUpdatedNotificationSchema,
ResourceListChangedNotificationSchema,
PromptListChangedNotificationSchema,
CreateMessageRequestSchema,
ListRootsRequestSchema,
JSONRPCMessageSchema,
ToolListChangedNotificationSchema,
} from "@modelcontextprotocol/sdk/types.js";
import { RequestOptions } from '@modelcontextprotocol/sdk/shared/protocol.js';
import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js';
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
import { ElicitRequest, ElicitResult, InitializeRequest, isJSONRPCRequest, JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js';
import winston from "winston";
import { readLine } from "../shared/terminal-utils.js";
import { Transport } from "@modelcontextprotocol/sdk/shared/transport.js";
export async function createProxy(clientTransport: Transport, initialize: InitializeRequest['params'], logger: winston.Logger): Promise<Server> {
const clientCapabilities = initialize.capabilities;
const client = new Client({
name: initialize.clientInfo.name,
title: initialize.clientInfo.title,
version: initialize.clientInfo.version,
}, {
capabilities: {
// Copy capabilities explicitly to avoid new ones creeping in.
sampling: clientCapabilities.sampling ? {} : undefined,
elicitation: clientCapabilities.elicitation ? {} : undefined, // Enable elicitation on proxy server if client supports it.
roots: clientCapabilities.roots ? {} : undefined,
},
});
try {
await client.connect(clientTransport);
} catch (error) {
logger.error('Error connecting to client transport');
logger.error(error);
throw new Error(`Could not connect to client transport: ${error}`);
}
const serverVersion = client.getServerVersion();
const serverCapabilities = client.getServerCapabilities();
const server = new Server({
name: serverVersion?.name ?? '?',
title: serverVersion?.title,
version: serverVersion?.version ?? '?',
}, {
capabilities: {
// Copy capabilities explicitly to avoid new ones creeping in.
completions: serverCapabilities?.completions,
elicitation: serverCapabilities?.elicitation,
logging: serverCapabilities?.logging ? {} : undefined,
prompts: serverCapabilities?.prompts ? {
listChanged: serverCapabilities.prompts.listChanged ? true : undefined,
} : undefined,
resources: serverCapabilities?.resources ? {
subscribe: serverCapabilities.resources.subscribe ? true : undefined,
listChanged: serverCapabilities.resources.listChanged ? true : undefined,
} : undefined,
tools: serverCapabilities?.tools ? {
listChanged: serverCapabilities.tools.listChanged ? true : undefined,
} : undefined,
},
instructions: client.getInstructions(),
});
// Wire handlers.
// TODO: remap related ids: {relatedRequestId: extra.requestId});
server.setRequestHandler(PingRequestSchema, async (_request, extra) => {
return client.ping(extra);
});
client.setRequestHandler(PingRequestSchema, async (_request, _extra) => {
return server.ping();
});
if (serverCapabilities?.prompts) {
server.setRequestHandler(ListPromptsRequestSchema, (request, extra) => {
return client.listPrompts(request.params, extra);
});
server.setRequestHandler(GetPromptRequestSchema, (request, extra) => {
return client.getPrompt(request.params, extra);
});
if (serverCapabilities.prompts?.listChanged) {
client.setNotificationHandler(PromptListChangedNotificationSchema, _ => {
return server.sendPromptListChanged();
});
}
}
if (serverCapabilities?.tools) {
server.setRequestHandler(ListToolsRequestSchema, (request, extra) => {
return client.listTools(request.params, extra);
});
if (serverCapabilities.tools?.listChanged) {
server.setNotificationHandler(ToolListChangedNotificationSchema, _ => {
return server.sendToolListChanged();
});
}
server.setRequestHandler(CallToolRequestSchema, (request, extra) => {
return client.callTool(request.params, CallToolResultSchema, extra);
});
}
if (serverCapabilities?.logging) {
server.setRequestHandler(SetLevelRequestSchema, (request, extra) => {
return client.setLoggingLevel(request.params.level, extra);
});
client.setNotificationHandler(LoggingMessageNotificationSchema, (notification) => {
return server.sendLoggingMessage(notification.params);
});
}
if (serverCapabilities?.completions) {
server.setRequestHandler(CompleteRequestSchema, (request, extra) => {
return client.complete(request.params, extra);
});
}
if (serverCapabilities?.resources) {
server.setRequestHandler(ListResourcesRequestSchema, (request, extra) => {
return client.listResources(request.params, extra);
});
server.setRequestHandler(ListResourceTemplatesRequestSchema, (request, extra) => {
return client.listResourceTemplates(request.params, extra);
});
server.setRequestHandler(ReadResourceRequestSchema, (request, extra) => {
return client.readResource(request.params, extra);
});
if (serverCapabilities.resources?.update) {
client.setNotificationHandler(ResourceUpdatedNotificationSchema, notification => {
return server.sendResourceUpdated(notification.params);
});
}
if (serverCapabilities.resources?.listChanged) {
client.setNotificationHandler(ResourceListChangedNotificationSchema, _ => {
return server.sendResourceListChanged();
});
}
}
if (clientCapabilities.elicitation) {
client.setRequestHandler(ElicitRequestSchema, async (request, extra) => {
return server.elicitInput(request.params, extra);
});
}
if (clientCapabilities.sampling) {
client.setRequestHandler(CreateMessageRequestSchema, async (request, extra) => {
return server.createMessage(request.params, extra);
});
}
if (clientCapabilities.roots) {
client.setRequestHandler(ListRootsRequestSchema, async (request, extra) => {
return server.listRoots(request.params, extra);
});
// if (clientCapabilities.roots?.listChanged) {
// client.setNotificationHandler(RootsListChangedNotificationSchema, _ => {
// return server.sendRootsListChanged();
// });
// }
}
return server;
}
export async function createStdioProxy(inner: {command: string, args: string[]}, logger: winston.Logger) : Promise<{
elicitInput: (params: ElicitRequest['params'], options?: RequestOptions) => Promise<ElicitResult>,
close: () => Promise<void>,
}> {
return new Promise(async (resolve, reject) => {
try {
const line = await readLine();
// const line = '{"method":"initialize","params":{"protocolVersion":"2025-06-18","capabilities":{"elicitation":{}},"clientInfo":{"name":"test-client","version":"1.0.0"}},"jsonrpc":"2.0","id":0}\n';
let message: JSONRPCMessage;
try {
message = JSONRPCMessageSchema.parse(JSON.parse(line));
} catch (error) {
reject(new Error(`Could not parse initial message: ${error}: ${line}`));
return;
}
if (!isJSONRPCRequest(message) || message.method !== 'initialize') {
reject(new Error(`Expected initialize message, got: ${JSON.stringify(message)}`));
return;
}
const initialize = message as any as InitializeRequest;
const transport = new StdioClientTransport({command: inner.command, args: inner.args});
const server = await createProxy(transport, initialize.params, logger);
const outerTransport = new StdioServerTransport();
await server.connect(outerTransport);
if (!outerTransport.onmessage) {
throw new Error('transport.onmessage should have been defined by the server');
}
outerTransport.onmessage(message);
resolve({
elicitInput: async (params: ElicitRequest['params'], options?: RequestOptions) => {
return server.elicitInput(params, options);
},
close: () => server.close(),
});
} catch (error) {
reject(error);
}
});
}
import { CancelledNotification, InitializedNotification, InitializeRequest, isInitializeRequest, JSONRPCError, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse } from "@modelcontextprotocol/sdk/types.js";
import { ElicitRequest, ElicitResult, JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js';
import winston from "winston";
import { Transport, TransportSendOptions } from "@modelcontextprotocol/sdk/shared/transport.js";
import { beforeEach, describe, expect, it } from "bun:test";
import { createProxy, TypedRequest } from "./proxy_transports.js";
class MockTransport implements Transport {
startCallCount = 0;
closeCallCount = 0;
sendCalls: Parameters<Transport['send']>[] = [];
async start(): Promise<void> {
this.startCallCount++;
}
async send(message: JSONRPCMessage, options?: TransportSendOptions): Promise<void> {
this.sendCalls.push([message, options]);
}
async close(): Promise<void> {
this.closeCallCount++;
}
onclose?: (() => void) | undefined;
onerror?: ((error: Error) => void) | undefined;
onmessage?: ((message: JSONRPCMessage) => void) | undefined;
}
const elicitRequestParams: ElicitRequest['params'] = {
message: "Hello, world!",
requestedSchema: {
type: "object",
properties: {name: {type: "string"}},
required: ["name"],
additionalProperties: false,
},
};
const initializeWithElicitation = <TypedRequest<InitializeRequest>>{
method: "initialize",
params: {
"protocolVersion":"2025-06-18",
"capabilities":{"elicitation":{}, "sampling":{}},
"clientInfo":{"name":"test-client","version":"1.0.0"},
},
jsonrpc:"2.0", id: 0
};
describe('createProxy', () => {
let logger: winston.Logger | undefined;
let errors: Parameters<Parameters<typeof createProxy>['0']['onerror']>[] = [];
let proxy: Awaited<ReturnType<typeof createProxy>> | undefined;
let clientTransport: MockTransport | undefined;
let serverTransport: MockTransport | undefined;
beforeEach(async () => {
logger = winston.createLogger({
level: 'debug',
format: winston.format.simple(),
transports: [
new winston.transports.Console(),
],
});
clientTransport = new MockTransport();
serverTransport = new MockTransport();
proxy = await createProxy({
clientTransport,
serverTransport,
logger: logger!,
verbose: true,
onerror: (source, error) => {
errors.push([source, error]);
}
});
});
it('should forward the initialize message from server transport (upstream client) to client transport (downstream server)', async () => {
serverTransport?.onmessage?.(initializeWithElicitation);
expect(serverTransport?.sendCalls.length).toBeEmpty();
expect(clientTransport?.sendCalls.length).toBe(1);
expect(clientTransport?.sendCalls[0][0]).toEqual(initializeWithElicitation);
});
it('should forward the initialized notification from client transport (downstream server) to server transport (upstream client)', async () => {
const initializedNotification: JSONRPCMessage & InitializedNotification =
{"method":"notifications/initialized","jsonrpc":"2.0"};
clientTransport?.onmessage?.(initializedNotification);
expect(clientTransport?.sendCalls.length).toBeEmpty();
expect(serverTransport?.sendCalls.length).toBe(1);
expect(serverTransport?.sendCalls[0][0]).toEqual(initializedNotification);
});
it('should forward non-special-cased elicitation requests from client transport (downstream server) to server transport (upstream client)', async () => {
const params: ElicitRequest['params'] = {
message: "Hello, world!",
requestedSchema: {
type: "object",
properties: {name: {type: "string"}},
required: ["name"],
additionalProperties: false,
},
};
const message: JSONRPCMessage =
{method: "elicitation/create", params, jsonrpc: "2.0", id:-123};
clientTransport?.onmessage?.(message);
expect(clientTransport?.sendCalls.length).toBeEmpty();
expect(serverTransport?.sendCalls.length).toBe(1);
expect(serverTransport?.sendCalls[0][0]).toEqual(message);
});
it('should send synthetic elicitInput to the server transport (upstream client), wait for initialize and not forward its response to anyone', async () => {
const response = proxy?.elicitInput(elicitRequestParams);
await new Promise(resolve => setTimeout(resolve, 0));
expect(clientTransport!.sendCalls.length).toBeEmpty();
expect(serverTransport!.sendCalls.length).toBeEmpty();
serverTransport?.onmessage?.(initializeWithElicitation);
await new Promise(resolve => setTimeout(resolve, 0));
expect(serverTransport!.sendCalls.length).toBe(1);
expect(serverTransport!.sendCalls[0][0]).toEqual(
{id: -1, jsonrpc: "2.0", method: "elicitation/create", params: elicitRequestParams}
);
serverTransport!.sendCalls.length = 0;
const elicitResult: ElicitResult = <ElicitResult>{
action: 'accept',
content: {name: "John Doe"},
};
const upstreamResponse: JSONRPCResponse = {
id: -1,
jsonrpc: "2.0",
result: elicitResult,
};
serverTransport?.onmessage?.(upstreamResponse);
expect(await response).toEqual(elicitResult);
// Check that elicitation request wasn't forwarded to any transport (special-cased)
expect(serverTransport!.sendCalls.length).toBeEmpty();
expect(clientTransport!.sendCalls.length).toBeEmpty();
// Now simulate an other elicitation response from the server (upstream client), confusingly w/ the same id.
// It should be forwarded as usual (we've cleaned up the synthetic id from the map).
serverTransport?.onmessage?.(upstreamResponse);
expect(serverTransport!.sendCalls.length).toBeEmpty();
expect(clientTransport!.sendCalls.length).toBe(2);
expect(clientTransport!.sendCalls[0][0]).toEqual(initializeWithElicitation);
expect(clientTransport!.sendCalls[1][0]).toEqual(upstreamResponse);
});
it('should send synthetic elicitInput to the server transport (upstream client) and reflect errors', async () => {
const response = proxy?.elicitInput(elicitRequestParams);
serverTransport?.onmessage?.(initializeWithElicitation);
await new Promise(resolve => setTimeout(resolve, 0));
serverTransport?.onmessage?.(<JSONRPCError>{
id: -1,
jsonrpc: "2.0",
error: {
code: -32603,
message: "Internal error",
},
});
try {
await response;
throw new Error('Expected elicitInput to throw');
} catch (error: any) {
expect(error.message).toBe('Internal error');
}
// expect(response).resolves.toThrowError('aaa');
// Check that elicitation request wasn't forwarded to any transport (special-cased)
expect(serverTransport!.sendCalls.length).toBeEmpty();
expect(clientTransport!.sendCalls.length).toBeEmpty();
});
it('should send synthetic elicitInput to the server transport (upstream client) and reflect cancellation', async () => {
const response = proxy?.elicitInput(elicitRequestParams);
serverTransport?.onmessage?.(initializeWithElicitation);
await new Promise(resolve => setTimeout(resolve, 0));
const interceptedCancelledNotification = <TypedRequest<CancelledNotification>>{
jsonrpc: "2.0",
method: "notifications/cancelled",
params: {
requestId: -1,
},
};
const passThroughCancelledNotification = <TypedRequest<CancelledNotification>>{
jsonrpc: "2.0",
method: "notifications/cancelled",
params: {
requestId: 12345,
},
};
serverTransport?.onmessage?.(passThroughCancelledNotification);
serverTransport?.onmessage?.(interceptedCancelledNotification);
try {
await response;
throw new Error('Expected elicitInput to throw');
} catch (error: any) {
expect(error.message).toBe('Request cancelled');
}
// Check that elicitation request wasn't forwarded to any transport (special-cased)
expect(serverTransport!.sendCalls.length).toBeEmpty();
expect(clientTransport!.sendCalls.length).toBe(2);
expect(clientTransport!.sendCalls[0][0]).toEqual(initializeWithElicitation);
expect(clientTransport!.sendCalls[1][0]).toEqual(passThroughCancelledNotification);
});
});
import {
isJSONRPCResponse,
isJSONRPCError,
ElicitResultSchema,
CancelledNotification,
InitializeRequest,
isInitializeRequest,
isJSONRPCNotification,
isJSONRPCRequest,
JSONRPCResponseSchema,
} from "@modelcontextprotocol/sdk/types.js";
import { RequestOptions } from '@modelcontextprotocol/sdk/shared/protocol.js';
import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js';
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
import { ElicitRequest, ElicitResult, JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js';
import winston from "winston";
import { Transport, TransportSendOptions } from "@modelcontextprotocol/sdk/shared/transport.js";
import { CancelledNotificationSchema } from '@modelcontextprotocol/sdk/types.js';
const isCancelledNotification: (value: unknown) => value is CancelledNotification =
((value: any) => CancelledNotificationSchema.safeParse(value).success) as any;
type NamedTransport<T extends Transport = Transport> = {
name: 'client' | 'server',
transport: T,
}
type PromiseHandlers<T> = {
resolve: (response: T) => void,
reject: (reason?: any) => void
};
export async function createProxy(
{
clientTransport,
serverTransport,
logger,
onerror,
verbose,
} : {
clientTransport: Transport,
serverTransport: Transport,
logger: winston.Logger,
verbose: boolean,
onerror: (source: NamedTransport['name'], error: any) => void,
}) : Promise<{
elicitInput: (params: ElicitRequest['params'], options?: RequestOptions) => Promise<ElicitResult>,
close: () => Promise<void>,
}> {
const client: NamedTransport = {name: 'client', transport: clientTransport};
const server: NamedTransport = {name: 'server', transport: serverTransport};
// TODO: different space for client vs. server synthetic ids?
let nextSyntheticMessageId = -1;
const syntheticResponseHandlers = new Map<number | string, PromiseHandlers<JSONRPCMessage>>();
const sendSynthetic = async (target: NamedTransport, factory: (id: number) => JSONRPCMessage, options?: TransportSendOptions) => {
const id = nextSyntheticMessageId--;
const message = factory(id);
if (isJSONRPCNotification(message) || isJSONRPCError(message)) {
target.transport.send(message, options);
return;
}
if (isJSONRPCRequest(message)) {
return new Promise<JSONRPCMessage>((resolve, reject) => {
const handlers = <any>{resolve, reject};
target.transport.send(message, options);
syntheticResponseHandlers.set(id, handlers);
});
}
throw new Error(`Unexpected message type: ${JSON.stringify(message)}`);
}
const initializeRequest = new Promise<InitializeRequest>(resolve => {
const propagateMessage = (source: NamedTransport, target: NamedTransport) => {
source.transport.onmessage = async (message, extra) => {
if (verbose) {
logger.info(`[proxy]: Message from ${source.name} transport: ${JSON.stringify(message)}; extra: ${JSON.stringify(extra)}`);
}
// Handle responses to synthetic (proxy-injected) messages
if ((isJSONRPCResponse(message) || isJSONRPCError(message)) && syntheticResponseHandlers.has(message.id)) {
const handler = syntheticResponseHandlers.get(message.id as number)!;
syntheticResponseHandlers.delete(message.id as number);
if (isJSONRPCError(message)) {
handler.reject(new Error(message.error.message));
} else {
handler.resolve(message);
}
return;
}
const isCancellation = isCancelledNotification(message);
if (isCancellation && syntheticResponseHandlers.has(message.params.requestId)) {
const handler = syntheticResponseHandlers.get(message.params.requestId)!;
syntheticResponseHandlers.delete(message.params.requestId);
handler.reject(new Error('Request cancelled'));
return;
}
try {
const relatedRequestId = isCancellation ? message.params.requestId : undefined;
await target.transport.send(message, {relatedRequestId});
} catch (error) {
logger.error(`[proxy]: Error sending message to ${target.name}: ${error}`);
logger.error(error);
}
if (source.name === 'server' && isInitializeRequest(message)) {
resolve(message);
}
};
};
propagateMessage(server, client);
propagateMessage(client, server);
});
const addErrorHandler = (transport: NamedTransport) => {
transport.transport.onerror = async (error: Error) => {
if (verbose) {
logger.error(`[proxy]: Error from ${transport.name} transport: ${error.message}`);
}
onerror(transport.name, error);
};
};
addErrorHandler(client);
addErrorHandler(server);
server.transport.start();
client.transport.start();
return {
elicitInput: async (params: ElicitRequest['params'], options?: RequestOptions) => {
try {
if (verbose) {
logger.info(`[proxy]: elicitInput called with params: ${JSON.stringify(params)}, options: ${JSON.stringify(options)}`);
}
const init = await initializeRequest;
if (!init.params.capabilities.elicitation) {
if (verbose) {
logger.info(`[proxy]: Upstream client does not support elicitation, declining`);
}
return <ElicitResult>{action: 'decline'};
}
const response = await sendSynthetic(server, id => ({
jsonrpc: "2.0",
id,
...<ElicitRequest>{
method: "elicitation/create",
params,
},
}), options);
if (verbose) {
logger.info(`[proxy]: elicitInput response: ${JSON.stringify(response)}`);
}
return ElicitResultSchema.parse(JSONRPCResponseSchema.parse(response).result);
} catch (error: any) {
logger.error(`Error during elicitInput: ${error.message}`);
throw error;
}
},
close: async () => {
await Promise.all([
client.transport.close(),
server.transport.close(),
]);
},
}
}
export async function createStdioProxy(
inner: {command: string, args: string[]},
opts: {
logger: winston.Logger,
verbose: boolean,
onerror: (source: NamedTransport['name'], error: any) => void,
}) : Promise<{
elicitInput: (params: ElicitRequest['params'], options?: RequestOptions) => Promise<ElicitResult>,
close: () => Promise<void>,
}> {
return createProxy({
clientTransport: new StdioClientTransport({command: inner.command, args: inner.args}),
serverTransport: new StdioServerTransport(),
...opts,
});
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment