Skip to content

Instantly share code, notes, and snippets.

@ochafik
Created August 27, 2025 15:48
Show Gist options
  • Select an option

  • Save ochafik/f0a105300f1d0bc007b21769a07f9381 to your computer and use it in GitHub Desktop.

Select an option

Save ochafik/f0a105300f1d0bc007b21769a07f9381 to your computer and use it in GitHub Desktop.
MCP Proxy
import { afterEach, describe, it, expect } from "bun:test";
import winston from "winston";
import z from "zod";
import { Client } from "@modelcontextprotocol/sdk/client/index.js";
import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js";
import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js";
import { CallToolRequest, ClientCapabilities, ElicitRequest, ElicitRequestSchema, PingRequestSchema } from "@modelcontextprotocol/sdk/types.js";
import { createProxyServer } from "./proxy.js";
import { CleanupFunction } from "../shared/cleanup-utils.js";
import { getPort, registerStreamableHttpMcpServer, ServerFactory, setupExpressServer } from "../shared/mcp-utils.js";
const logger = winston.createLogger({
level: 'debug',
format: winston.format.combine(
winston.format.colorize(),
winston.format.timestamp(),
winston.format.errors({ stack: true }),
winston.format.simple(),
),
transports: [
new winston.transports.Console({}),
],
});
export type Config = {
port: number;
proxiedServerEndpoint?: URL;
mcpPath: string;
}
describe('Proxy server', () => {
const cleanups: CleanupFunction[] = [];
afterEach(async () => {
for (const cleanup of cleanups) {
try {
await cleanup();
} catch (error) {
console.error('Error during cleanup:', error);
}
}
cleanups.length = 0; // Clear the cleanups array
});
async function setupTargetServer(serverName: string, createServer: ServerFactory<McpServer>) {
const { close, server } = await setupExpressServer({
logger,
port: 0, // Use a random port
callback: async (app, addCleanup) => {
const { cleanup } = registerStreamableHttpMcpServer(serverName, app, {
logger,
path: '/mcp',
createServer,
});
addCleanup(cleanup);
},
});
cleanups.push(close);
const targetEndpoint = new URL(`http://localhost:${getPort(server)}/mcp`);
console.log(`Target server running: ${targetEndpoint}`);
return {targetEndpoint};
}
async function createClient(endpoint: URL, capabilities: ClientCapabilities, cb?: (client: Client) => Promise<void>) {
const client = new Client({
name: 'target-client',
version: '1.0.0',
}, {
capabilities,
});
client.setRequestHandler(PingRequestSchema, async (_, extra) => {
logger.info(`Client received MCP ping request from server: ${extra.requestId}`);
return {};
});
if (cb) {
await cb(client);
}
await client.connect(new StreamableHTTPClientTransport(endpoint));
return client;
}
const EchoInputSchema = z.object({
message: z.string(),
});
const EchoOutputSchema = z.object({
echo: z.string(),
});
async function setupEchoMcpServer() {
return await setupTargetServer('echo-server', async _initialize => {
const server = new McpServer({
name: 'echo-server',
version: '1.0.0',
}, {
instructions: "Echoes messages",
// capabilities: {
// elicitation: {},
// }
});
server.registerTool(
'echo',
{
inputSchema: EchoInputSchema.shape,
outputSchema: EchoOutputSchema.shape,
},
async ({ message }) => ({
// content: [{ type: 'text', text: `Echo: ${message}` }],
content: [],
structuredContent: {
echo: `Echo: ${message}`,
},
})
);
return server;
});
}
// async function setupClassifierServer(classifyMcpTool: (args: ClassifierInput) => Promise<ClassifierOutput>) {
// classifyMcpTool ??= async _inputs => {
// return {
// decision: 'allow',
// };
// };
// return setupTargetServer('classifier-server', async _initialize => {
// const server = new McpServer({
// name: 'classifier-server',
// version: '1.0.0',
// }, {
// capabilities: {
// },
// instructions: "Echoes messages",
// });
// server.registerTool(
// 'classify_mcp_tool',
// {
// inputSchema: ClassifierInputSchema.shape,
// outputSchema: ClassifierOutputSchema.shape,
// },
// async (args: ClassifierInput, _extra) => {
// return <ClassifierOutput>{
// content: [],
// structuredContent: await classifyMcpTool(args),
// };
// }
// );
// return server;
// });
// }
async function setupProxyServer(proxiedServerEndpoint: URL) {
const mcpPath = '/test-mcp';
const { close, netServer, elicitInput } = await createProxyServer(
{
type: 'http',
port: 0,
host: 'localhost',
path: mcpPath,
},
{
type: 'http',
endpoint: proxiedServerEndpoint,
},
logger,
);
cleanups.push(close);
if (!netServer) {
throw new Error('Proxy server did not create a net server');
}
const proxyEndpoint = new URL(`http://localhost:${getPort(netServer)}${mcpPath}`);
console.log(`Proxy server running: ${proxyEndpoint}`);
return { proxyEndpoint, netServer, elicitInput };
}
it('should proxy tool calls target server', async () => {
const { targetEndpoint } = await setupEchoMcpServer();
const { proxyEndpoint } = await setupProxyServer(targetEndpoint);
const targetClient = await createClient(targetEndpoint, {});
const proxyClient = await createClient(proxyEndpoint, {});
const targetTools = await targetClient.listTools();
const proxiedTools = await proxyClient.listTools();
expect(proxiedTools).toEqual(targetTools);
const params: CallToolRequest['params'] = {name: 'echo', arguments: { message: 'Hello, world!' }};
const targetResult = await targetClient.callTool(params);
const proxyResult = await proxyClient.callTool(params);
expect(proxyResult).toEqual(targetResult);
});
it('should send elicitations to client', async () => {
const { targetEndpoint } = await setupEchoMcpServer();
const { proxyEndpoint, elicitInput } = await setupProxyServer(targetEndpoint);
console.log(`Proxy endpoint: ${proxyEndpoint}`);
const elicitRequests: ElicitRequest[] = [];
const proxyClient = await createClient(proxyEndpoint, { elicitation: {} }, async client => {
client.setRequestHandler(ElicitRequestSchema, async (request, extra) => {
logger.info(`Received elicit request: ${JSON.stringify(request)}, extra: ${JSON.stringify(extra)}`);
elicitRequests.push(request);
return {
action: 'accept',
content: {
testField: 'yay',
}
}
});
});
const elicitParams: ElicitRequest['params'] = {
message: 'Test elicit input',
requestedSchema: {
'type': 'object',
'properties': {
'testField': {
'type': 'string',
},
},
required: ['testField'],
},
};
console.log(`Elicit params: ${JSON.stringify(elicitParams)}`);
const result = await elicitInput(elicitParams);
console.log(`Elicit result: ${JSON.stringify(result)}`);
expect(result as any).toBe({ testField: 'yay' });
expect(elicitRequests).toHaveLength(1);
expect(elicitRequests[0].params.message).toBe('Test elicit input');
expect(elicitRequests[0].params.requestedSchema).toEqual(elicitParams.requestedSchema);
});
// it('should call the classifier w/ tool calls', async () => {
// const classificationInputs: ClassifierInput[] = [];
// const proxiedServerEndpoint = await setupEchoMcpServer();
// const classifierEndpoint = await setupClassifierServer(async inputs => {
// classificationInputs.push(inputs);
// return {
// decision: 'allow',
// };
// });
// const { proxyEndpoint } = await setupProxyServer({proxiedServerEndpoint, classifierEndpoint});
// const targetClient = await createClient(proxiedServerEndpoint, { elicitation: false });
// const proxyClient = await createClient(proxyEndpoint, { elicitation: false });
// const params: CallToolRequest['params'] = {name: 'echo', arguments: { message: 'Hello, world!' }};
// const targetTools = await targetClient.listTools();
// const proxyTools = await proxyClient.listTools();
// expect(proxyTools).toEqual(targetTools);
// expect(proxyTools.tools).toHaveLength(1);
// expect(proxyTools.tools[0].name).toBe('echo');
// // expect(proxyTools.tools[0].inputSchema).toEqual(EchoInputSchema.shape);
// // expect(proxyTools.tools[0].outputSchema).toEqual(EchoOutputSchema.shape);
// const targetResult = await targetClient.callTool(params);
// const proxyResult = await proxyClient.callTool(params);
// expect(proxyResult).toEqual(targetResult);
// expect(classificationInputs).toHaveLength(1);
// expect(classificationInputs[0].request.name).toEqual('echo');
// expect(classificationInputs[0].request.arguments).toEqual({ message: 'Hello, world!' });
// // expect(classificationInputs[0]).toEqual({
// // name: 'echo',
// // tool: {
// // name: 'echo',
// // inputSchema: EchoInputSchema.shape,
// // outputSchema: EchoOutputSchema.shape,
// // },
// // arguments: { message: 'Hello, world!' },
// // });
// });
});
import net from "net";
import { Client } from "@modelcontextprotocol/sdk/client/index.js";
import { Server } from "@modelcontextprotocol/sdk/server/index.js";
import {
ListPromptsRequestSchema,
GetPromptRequestSchema,
ListToolsRequestSchema,
CallToolRequestSchema,
CallToolResultSchema,
CompatibilityCallToolResultSchema,
ListResourcesRequestSchema,
ListResourceTemplatesRequestSchema,
ReadResourceRequestSchema,
ElicitRequestSchema,
CreateMessageRequest,
PingRequestSchema,
LoggingLevelSchema,
SetLevelRequestSchema,
LoggingMessageNotificationSchema,
CompleteRequestSchema,
ResourceUpdatedNotificationSchema,
ResourceListChangedNotificationSchema,
PromptListChangedNotificationSchema,
CreateMessageResultSchema,
CreateMessageRequestSchema,
ListRootsRequestSchema,
RootsListChangedNotificationSchema,
JSONRPCMessageSchema,
} from "@modelcontextprotocol/sdk/types.js";
import { RequestOptions } from '@modelcontextprotocol/sdk/shared/protocol.js';
import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js';
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
import { ElicitRequest, ElicitResult, InitializeRequest, isJSONRPCRequest, JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js';
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
import { registerStreamableHttpMcpServer, SessionData, setupExpressServer } from "../shared/mcp-utils.js";
import winston from "winston";
import { ca } from "zod/v4/locales";
export type OuterMcpHttpConfig =
{
type: 'http',
host: string,
port: number,
path: string
};
export type InnerMcpHttpConfig =
{
type: 'http',
endpoint: URL,
};
export type InnerMcpStdioConfig =
{
type: 'stdio',
command: string,
args: string[],
};
export type OuterMcpStdioConfig =
{
type: 'stdio',
};
export type OuterMcpConfig = OuterMcpStdioConfig | OuterMcpHttpConfig;
export type InnerMcpConfig = InnerMcpStdioConfig | InnerMcpHttpConfig;
function readLine() {
return new Promise<string>((resolve) => {
let line = ''
const callback = () => {
let chunk: string;
while ((chunk = process.stdin.read(1)) !== null) {
const chunkStr = chunk.toString();
line += chunkStr;
if (chunkStr === '\n') {
if (line === 'conda activate py311\n') {
line = '';
continue; // Ignore this line, it's just noise from vscode.
}
process.stdin.removeListener('readable', callback);
resolve(line);
break;
}
}
};
process.stdin.on('readable', callback);
});
}
export async function createProxyServer(outer: OuterMcpConfig, inner: InnerMcpConfig, logger: winston.Logger) : Promise<{
netServer?: net.Server
elicitInput: (params: ElicitRequest['params'], options?: RequestOptions) => Promise<ElicitResult>,
close: () => Promise<void>,
}> {
async function createServer(initialize: InitializeRequest['params']): Promise<Server> {
const clientCapabilities = initialize.capabilities;
const client = new Client({
name: initialize.clientInfo.name,
title: initialize.clientInfo.title,
version: initialize.clientInfo.version,
}, {
capabilities: {
// Copy capabilities explicitly to avoid new ones creeping in.
sampling: clientCapabilities.sampling ? {} : undefined,
elicitation: clientCapabilities.elicitation ? {} : undefined, // Enable elicitation on proxy server if client supports it.
roots: clientCapabilities.roots ? {} : undefined,
},
});
try {
if (inner.type === 'stdio') {
await client.connect(new StdioClientTransport({command: inner.command, args: inner.args}));
} else {
await client.connect(new StreamableHTTPClientTransport(inner.endpoint));
}
} catch (error) {
logger.error('Error connecting to inner MCP server:');
logger.error(inner);
logger.error(error);
throw new Error(`Could not connect to inner MCP server: ${error}`);
}
const serverVersion = client.getServerVersion();
const serverCapabilities = client.getServerCapabilities();
const server = new Server({
name: serverVersion?.name ?? '?',
title: serverVersion?.title,
version: serverVersion?.version ?? '?',
}, {
capabilities: {
// Copy capabilities explicitly to avoid new ones creeping in.
completions: serverCapabilities?.completions,
elicitation: serverCapabilities?.elicitation,
logging: serverCapabilities?.logging ? {} : undefined,
prompts: serverCapabilities?.prompts ? {
listChanged: serverCapabilities.prompts.listChanged ? true : undefined,
} : undefined,
resources: serverCapabilities?.resources ? {
subscribe: serverCapabilities.resources.subscribe ? true : undefined,
listChanged: serverCapabilities.resources.listChanged ? true : undefined,
} : undefined,
tools: serverCapabilities?.tools ? {
listChanged: serverCapabilities.tools.listChanged ? true : undefined,
} : undefined,
},
instructions: client.getInstructions(),
});
// Wire ping handlers.
server.setRequestHandler(PingRequestSchema, async (_request, extra) => {
return client.ping(); // TODO: remap related ids: {relatedRequestId: extra.requestId});
});
client.setRequestHandler(PingRequestSchema, async (_request, _extra) => {
return server.ping();
});
if (serverCapabilities?.prompts) {
server.setRequestHandler(ListPromptsRequestSchema, (request, extra) => {
return client.listPrompts(request.params, extra);
});
server.setRequestHandler(GetPromptRequestSchema, (request, extra) => {
return client.getPrompt(request.params, extra);
});
if (serverCapabilities.prompts?.listChanged) {
client.setNotificationHandler(PromptListChangedNotificationSchema, _ => {
return server.sendPromptListChanged();
});
}
}
if (serverCapabilities?.tools) {
server.setRequestHandler(ListToolsRequestSchema, (request, extra) => {
return client.listTools(request.params, extra);
});
server.setRequestHandler(CallToolRequestSchema, (request, extra) => {
return client.callTool(request.params, CallToolResultSchema, extra);
});
}
if (serverCapabilities?.logging) {
server.setRequestHandler(SetLevelRequestSchema, (request, extra) => {
return client.setLoggingLevel(request.params.level, extra);
});
client.setNotificationHandler(LoggingMessageNotificationSchema, (notification) => {
return server.sendLoggingMessage(notification.params);
});
}
if (serverCapabilities?.completions) {
server.setRequestHandler(CompleteRequestSchema, (request, extra) => {
return client.complete(request.params, extra);
});
}
if (serverCapabilities?.resources) {
server.setRequestHandler(ListResourcesRequestSchema, (request, extra) => {
return client.listResources(request.params, extra);
});
server.setRequestHandler(ListResourceTemplatesRequestSchema, (request, extra) => {
return client.listResourceTemplates(request.params, extra);
});
server.setRequestHandler(ReadResourceRequestSchema, (request, extra) => {
return client.readResource(request.params, extra);
});
if (serverCapabilities.resources?.update) {
client.setNotificationHandler(ResourceUpdatedNotificationSchema, notification => {
return server.sendResourceUpdated(notification.params);
});
}
if (serverCapabilities.resources?.listChanged) {
client.setNotificationHandler(ResourceListChangedNotificationSchema, _ => {
return server.sendResourceListChanged();
});
}
}
if (clientCapabilities.elicitation) {
client.setRequestHandler(ElicitRequestSchema, async (request, extra) => {
return server.elicitInput(request.params, extra);
});
}
if (clientCapabilities.sampling) {
client.setRequestHandler(CreateMessageRequestSchema, async (request, extra) => {
return server.createMessage(request.params, extra);
});
}
if (clientCapabilities.roots) {
client.setRequestHandler(ListRootsRequestSchema, async (request, extra) => {
return server.listRoots(request.params, extra);
});
// if (clientCapabilities.roots?.listChanged) {
// client.setNotificationHandler(RootsListChangedNotificationSchema, _ => {
// return server.sendRootsListChanged();
// });
// }
}
return server;
}
return new Promise(async (resolve, reject) => {
try {
if (outer.type === 'stdio') {
const line = await readLine();
let message: JSONRPCMessage;
try {
message = JSONRPCMessageSchema.parse(JSON.parse(line));
} catch (error) {
reject(new Error(`Could not parse initial message: ${error}: ${line}`));
return;
}
if (!isJSONRPCRequest(message) || message.method !== 'initialize') {
reject(new Error(`Expected initialize message, got: ${JSON.stringify(message)}`));
return;
}
const initialize = message as any as InitializeRequest;
const server = await createServer(initialize.params);
const outerTransport = new StdioServerTransport();
await server.connect(outerTransport);
if (!outerTransport.onmessage) {
throw new Error('transport.onmessage should have been defined by the server');
}
outerTransport.onmessage(message);
resolve({
elicitInput: async (params: ElicitRequest['params'], options?: RequestOptions) => {
return server.elicitInput(params, options);
},
close: () => server.close(),
});
} else {
const transports = new Map<string, SessionData<Server>>();
const {server: netServer, close} = await setupExpressServer({
port: outer.port,
host: outer.host,
callback: async (app, addCleanup) => {
const {cleanup} = await registerStreamableHttpMcpServer('proxy', app, {
path: outer.path,
transports,
createServer,
logger,
});
addCleanup(cleanup);
},
logger,
});
resolve({
netServer,
elicitInput: async (params: ElicitRequest['params'], options?: RequestOptions) => {
const promises = Array.from(transports.values())
.filter(d => d.server.getClientCapabilities()?.elicitation)
.map(d => {
return d.server.elicitInput(params, options);
});
if (promises.length === 0) {
return {
action: 'cancel',
}
} else {
return await Promise.any(promises);
}
},
close,
});
}
} catch (error) {
reject(error);
}
});
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment