Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save Tiberriver256/00eab7df21e7ffa58f0c18d12d00bdc5 to your computer and use it in GitHub Desktop.
Save Tiberriver256/00eab7df21e7ffa58f0c18d12d00bdc5 to your computer and use it in GitHub Desktop.
Directory structure:
└── modelcontextprotocol-typescript-sdk/
├── README.md
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── SECURITY.md
├── eslint.config.mjs
├── jest.config.js
├── package.json
├── tsconfig.cjs.json
├── tsconfig.json
├── tsconfig.prod.json
├── .npmrc
├── src/
│ ├── cli.ts
│ ├── inMemory.test.ts
│ ├── inMemory.ts
│ ├── types.ts
│ ├── __mocks__/
│ │ └── pkce-challenge.ts
│ ├── client/
│ │ ├── auth.test.ts
│ │ ├── auth.ts
│ │ ├── index.test.ts
│ │ ├── index.ts
│ │ ├── sse.test.ts
│ │ ├── sse.ts
│ │ ├── stdio.test.ts
│ │ ├── stdio.ts
│ │ └── websocket.ts
│ ├── integration-tests/
│ │ └── process-cleanup.test.ts
│ ├── server/
│ │ ├── completable.test.ts
│ │ ├── completable.ts
│ │ ├── index.test.ts
│ │ ├── index.ts
│ │ ├── mcp.test.ts
│ │ ├── mcp.ts
│ │ ├── sse.ts
│ │ ├── stdio.test.ts
│ │ ├── stdio.ts
│ │ └── auth/
│ │ ├── clients.ts
│ │ ├── errors.ts
│ │ ├── provider.ts
│ │ ├── router.test.ts
│ │ ├── router.ts
│ │ ├── types.ts
│ │ ├── handlers/
│ │ │ ├── authorize.test.ts
│ │ │ ├── authorize.ts
│ │ │ ├── metadata.test.ts
│ │ │ ├── metadata.ts
│ │ │ ├── register.test.ts
│ │ │ ├── register.ts
│ │ │ ├── revoke.test.ts
│ │ │ ├── revoke.ts
│ │ │ ├── token.test.ts
│ │ │ └── token.ts
│ │ └── middleware/
│ │ ├── allowedMethods.test.ts
│ │ ├── allowedMethods.ts
│ │ ├── bearerAuth.test.ts
│ │ ├── bearerAuth.ts
│ │ ├── clientAuth.test.ts
│ │ └── clientAuth.ts
│ └── shared/
│ ├── auth.ts
│ ├── protocol.test.ts
│ ├── protocol.ts
│ ├── stdio.test.ts
│ ├── stdio.ts
│ ├── transport.ts
│ ├── uriTemplate.test.ts
│ └── uriTemplate.ts
└── .github/
└── workflows/
└── main.yml
================================================
File: README.md
================================================
# MCP TypeScript SDK ![NPM Version](https://img.shields.io/npm/v/%40modelcontextprotocol%2Fsdk) ![MIT licensed](https://img.shields.io/npm/l/%40modelcontextprotocol%2Fsdk)
## Table of Contents
- [Overview](#overview)
- [Installation](#installation)
- [Quickstart](#quickstart)
- [What is MCP?](#what-is-mcp)
- [Core Concepts](#core-concepts)
- [Server](#server)
- [Resources](#resources)
- [Tools](#tools)
- [Prompts](#prompts)
- [Running Your Server](#running-your-server)
- [stdio](#stdio)
- [HTTP with SSE](#http-with-sse)
- [Testing and Debugging](#testing-and-debugging)
- [Examples](#examples)
- [Echo Server](#echo-server)
- [SQLite Explorer](#sqlite-explorer)
- [Advanced Usage](#advanced-usage)
- [Low-Level Server](#low-level-server)
- [Writing MCP Clients](#writing-mcp-clients)
- [Server Capabilities](#server-capabilities)
## Overview
The Model Context Protocol allows applications to provide context for LLMs in a standardized way, separating the concerns of providing context from the actual LLM interaction. This TypeScript SDK implements the full MCP specification, making it easy to:
- Build MCP clients that can connect to any MCP server
- Create MCP servers that expose resources, prompts and tools
- Use standard transports like stdio and SSE
- Handle all MCP protocol messages and lifecycle events
## Installation
```bash
npm install @modelcontextprotocol/sdk
```
## Quick Start
Let's create a simple MCP server that exposes a calculator tool and some data:
```typescript
import { McpServer, ResourceTemplate } from "@modelcontextprotocol/sdk/server/mcp.js";
import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js";
import { z } from "zod";
// Create an MCP server
const server = new McpServer({
name: "Demo",
version: "1.0.0"
});
// Add an addition tool
server.tool("add",
{ a: z.number(), b: z.number() },
async ({ a, b }) => ({
content: [{ type: "text", text: String(a + b) }]
})
);
// Add a dynamic greeting resource
server.resource(
"greeting",
new ResourceTemplate("greeting://{name}", { list: undefined }),
async (uri, { name }) => ({
contents: [{
uri: uri.href,
text: `Hello, ${name}!`
}]
})
);
// Start receiving messages on stdin and sending messages on stdout
const transport = new StdioServerTransport();
await server.connect(transport);
```
## What is MCP?
The [Model Context Protocol (MCP)](https://modelcontextprotocol.io) lets you build servers that expose data and functionality to LLM applications in a secure, standardized way. Think of it like a web API, but specifically designed for LLM interactions. MCP servers can:
- Expose data through **Resources** (think of these sort of like GET endpoints; they are used to load information into the LLM's context)
- Provide functionality through **Tools** (sort of like POST endpoints; they are used to execute code or otherwise produce a side effect)
- Define interaction patterns through **Prompts** (reusable templates for LLM interactions)
- And more!
## Core Concepts
### Server
The McpServer is your core interface to the MCP protocol. It handles connection management, protocol compliance, and message routing:
```typescript
const server = new McpServer({
name: "My App",
version: "1.0.0"
});
```
### Resources
Resources are how you expose data to LLMs. They're similar to GET endpoints in a REST API - they provide data but shouldn't perform significant computation or have side effects:
```typescript
// Static resource
server.resource(
"config",
"config://app",
async (uri) => ({
contents: [{
uri: uri.href,
text: "App configuration here"
}]
})
);
// Dynamic resource with parameters
server.resource(
"user-profile",
new ResourceTemplate("users://{userId}/profile", { list: undefined }),
async (uri, { userId }) => ({
contents: [{
uri: uri.href,
text: `Profile data for user ${userId}`
}]
})
);
```
### Tools
Tools let LLMs take actions through your server. Unlike resources, tools are expected to perform computation and have side effects:
```typescript
// Simple tool with parameters
server.tool(
"calculate-bmi",
{
weightKg: z.number(),
heightM: z.number()
},
async ({ weightKg, heightM }) => ({
content: [{
type: "text",
text: String(weightKg / (heightM * heightM))
}]
})
);
// Async tool with external API call
server.tool(
"fetch-weather",
{ city: z.string() },
async ({ city }) => {
const response = await fetch(`https://api.weather.com/${city}`);
const data = await response.text();
return {
content: [{ type: "text", text: data }]
};
}
);
```
### Prompts
Prompts are reusable templates that help LLMs interact with your server effectively:
```typescript
server.prompt(
"review-code",
{ code: z.string() },
({ code }) => ({
messages: [{
role: "user",
content: {
type: "text",
text: `Please review this code:\n\n${code}`
}
}]
})
);
```
## Running Your Server
MCP servers in TypeScript need to be connected to a transport to communicate with clients. How you start the server depends on the choice of transport:
### stdio
For command-line tools and direct integrations:
```typescript
import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js";
import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js";
const server = new McpServer({
name: "example-server",
version: "1.0.0"
});
// ... set up server resources, tools, and prompts ...
const transport = new StdioServerTransport();
await server.connect(transport);
```
### HTTP with SSE
For remote servers, start a web server with a Server-Sent Events (SSE) endpoint, and a separate endpoint for the client to send its messages to:
```typescript
import express from "express";
import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js";
import { SSEServerTransport } from "@modelcontextprotocol/sdk/server/sse.js";
const server = new McpServer({
name: "example-server",
version: "1.0.0"
});
// ... set up server resources, tools, and prompts ...
const app = express();
app.get("/sse", async (req, res) => {
const transport = new SSEServerTransport("/messages", res);
await server.connect(transport);
});
app.post("/messages", async (req, res) => {
// Note: to support multiple simultaneous connections, these messages will
// need to be routed to a specific matching transport. (This logic isn't
// implemented here, for simplicity.)
await transport.handlePostMessage(req, res);
});
app.listen(3001);
```
### Testing and Debugging
To test your server, you can use the [MCP Inspector](https://github.com/modelcontextprotocol/inspector). See its README for more information.
## Examples
### Echo Server
A simple server demonstrating resources, tools, and prompts:
```typescript
import { McpServer, ResourceTemplate } from "@modelcontextprotocol/sdk/server/mcp.js";
import { z } from "zod";
const server = new McpServer({
name: "Echo",
version: "1.0.0"
});
server.resource(
"echo",
new ResourceTemplate("echo://{message}", { list: undefined }),
async (uri, { message }) => ({
contents: [{
uri: uri.href,
text: `Resource echo: ${message}`
}]
})
);
server.tool(
"echo",
{ message: z.string() },
async ({ message }) => ({
content: [{ type: "text", text: `Tool echo: ${message}` }]
})
);
server.prompt(
"echo",
{ message: z.string() },
({ message }) => ({
messages: [{
role: "user",
content: {
type: "text",
text: `Please process this message: ${message}`
}
}]
})
);
```
### SQLite Explorer
A more complex example showing database integration:
```typescript
import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js";
import sqlite3 from "sqlite3";
import { promisify } from "util";
import { z } from "zod";
const server = new McpServer({
name: "SQLite Explorer",
version: "1.0.0"
});
// Helper to create DB connection
const getDb = () => {
const db = new sqlite3.Database("database.db");
return {
all: promisify<string, any[]>(db.all.bind(db)),
close: promisify(db.close.bind(db))
};
};
server.resource(
"schema",
"schema://main",
async (uri) => {
const db = getDb();
try {
const tables = await db.all(
"SELECT sql FROM sqlite_master WHERE type='table'"
);
return {
contents: [{
uri: uri.href,
text: tables.map((t: {sql: string}) => t.sql).join("\n")
}]
};
} finally {
await db.close();
}
}
);
server.tool(
"query",
{ sql: z.string() },
async ({ sql }) => {
const db = getDb();
try {
const results = await db.all(sql);
return {
content: [{
type: "text",
text: JSON.stringify(results, null, 2)
}]
};
} catch (err: unknown) {
const error = err as Error;
return {
content: [{
type: "text",
text: `Error: ${error.message}`
}],
isError: true
};
} finally {
await db.close();
}
}
);
```
## Advanced Usage
### Low-Level Server
For more control, you can use the low-level Server class directly:
```typescript
import { Server } from "@modelcontextprotocol/sdk/server/index.js";
import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js";
import {
ListPromptsRequestSchema,
GetPromptRequestSchema
} from "@modelcontextprotocol/sdk/types.js";
const server = new Server(
{
name: "example-server",
version: "1.0.0"
},
{
capabilities: {
prompts: {}
}
}
);
server.setRequestHandler(ListPromptsRequestSchema, async () => {
return {
prompts: [{
name: "example-prompt",
description: "An example prompt template",
arguments: [{
name: "arg1",
description: "Example argument",
required: true
}]
}]
};
});
server.setRequestHandler(GetPromptRequestSchema, async (request) => {
if (request.params.name !== "example-prompt") {
throw new Error("Unknown prompt");
}
return {
description: "Example prompt",
messages: [{
role: "user",
content: {
type: "text",
text: "Example prompt text"
}
}]
};
});
const transport = new StdioServerTransport();
await server.connect(transport);
```
### Writing MCP Clients
The SDK provides a high-level client interface:
```typescript
import { Client } from "@modelcontextprotocol/sdk/client/index.js";
import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js";
const transport = new StdioClientTransport({
command: "node",
args: ["server.js"]
});
const client = new Client(
{
name: "example-client",
version: "1.0.0"
},
{
capabilities: {
prompts: {},
resources: {},
tools: {}
}
}
);
await client.connect(transport);
// List prompts
const prompts = await client.listPrompts();
// Get a prompt
const prompt = await client.getPrompt("example-prompt", {
arg1: "value"
});
// List resources
const resources = await client.listResources();
// Read a resource
const resource = await client.readResource("file:///example.txt");
// Call a tool
const result = await client.callTool({
name: "example-tool",
arguments: {
arg1: "value"
}
});
```
## Documentation
- [Model Context Protocol documentation](https://modelcontextprotocol.io)
- [MCP Specification](https://spec.modelcontextprotocol.io)
- [Example Servers](https://github.com/modelcontextprotocol/servers)
## Contributing
Issues and pull requests are welcome on GitHub at https://github.com/modelcontextprotocol/typescript-sdk.
## License
This project is licensed under the MIT License—see the [LICENSE](LICENSE) file for details.
================================================
File: CLAUDE.md
================================================
# MCP TypeScript SDK Guide
## Build & Test Commands
```
npm run build # Build ESM and CJS versions
npm run lint # Run ESLint
npm test # Run all tests
npx jest path/to/file.test.ts # Run specific test file
npx jest -t "test name" # Run tests matching pattern
```
## Code Style Guidelines
- **TypeScript**: Strict type checking, ES modules, explicit return types
- **Naming**: PascalCase for classes/types, camelCase for functions/variables
- **Files**: Lowercase with hyphens, test files with `.test.ts` suffix
- **Imports**: ES module style, include `.js` extension, group imports logically
- **Error Handling**: Use TypeScript's strict mode, explicit error checking in tests
- **Formatting**: 2-space indentation, semicolons required, single quotes preferred
- **Testing**: Co-locate tests with source files, use descriptive test names
- **Comments**: JSDoc for public APIs, inline comments for complex logic
## Project Structure
- `/src`: Source code with client, server, and shared modules
- Tests alongside source files with `.test.ts` suffix
- Node.js >= 18 required
================================================
File: CODE_OF_CONDUCT.md
================================================
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, religion, or sexual identity
and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
## Our Standards
Examples of behavior that contributes to a positive environment for our
community include:
* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
* Focusing on what is best not just for us as individuals, but for the
overall community
Examples of unacceptable behavior include:
* The use of sexualized language or imagery, and sexual attention or
advances of any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email
address, without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces, and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at
[email protected].
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the
reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series
of actions.
**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or
permanent ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
**Consequence**: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within
the community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.0, available at
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
Community Impact Guidelines were inspired by [Mozilla's code of conduct
enforcement ladder](https://github.com/mozilla/diversity).
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see the FAQ at
https://www.contributor-covenant.org/faq. Translations are available at
https://www.contributor-covenant.org/translations.
================================================
File: CONTRIBUTING.md
================================================
# Contributing to MCP TypeScript SDK
We welcome contributions to the Model Context Protocol TypeScript SDK! This document outlines the process for contributing to the project.
## Getting Started
1. Fork the repository
2. Clone your fork: `git clone https://github.com/YOUR-USERNAME/typescript-sdk.git`
3. Install dependencies: `npm install`
4. Build the project: `npm run build`
5. Run tests: `npm test`
## Development Process
1. Create a new branch for your changes
2. Make your changes
3. Run `npm run lint` to ensure code style compliance
4. Run `npm test` to verify all tests pass
5. Submit a pull request
## Pull Request Guidelines
- Follow the existing code style
- Include tests for new functionality
- Update documentation as needed
- Keep changes focused and atomic
- Provide a clear description of changes
## Running Examples
- Start the server: `npm run server`
- Run the client: `npm run client`
## Code of Conduct
This project follows our [Code of Conduct](CODE_OF_CONDUCT.md). Please review it before contributing.
## Reporting Issues
- Use the [GitHub issue tracker](https://github.com/modelcontextprotocol/typescript-sdk/issues)
- Search existing issues before creating a new one
- Provide clear reproduction steps
## Security Issues
Please review our [Security Policy](SECURITY.md) for reporting security vulnerabilities.
## License
By contributing, you agree that your contributions will be licensed under the MIT License.
================================================
File: LICENSE
================================================
MIT License
Copyright (c) 2024 Anthropic, PBC
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
File: SECURITY.md
================================================
# Security Policy
Thank you for helping us keep the SDKs and systems they interact with secure.
## Reporting Security Issues
This SDK is maintained by [Anthropic](https://www.anthropic.com/) as part of the Model Context Protocol project.
The security of our systems and user data is Anthropic’s top priority. We appreciate the work of security researchers acting in good faith in identifying and reporting potential vulnerabilities.
Our security program is managed on HackerOne and we ask that any validated vulnerability in this functionality be reported through their [submission form](https://hackerone.com/anthropic-vdp/reports/new?type=team&report_type=vulnerability).
## Vulnerability Disclosure Program
Our Vulnerability Program Guidelines are defined on our [HackerOne program page](https://hackerone.com/anthropic-vdp).
================================================
File: eslint.config.mjs
================================================
// @ts-check
import eslint from '@eslint/js';
import tseslint from 'typescript-eslint';
export default tseslint.config(
eslint.configs.recommended,
...tseslint.configs.recommended,
{
linterOptions: {
reportUnusedDisableDirectives: false,
},
rules: {
"@typescript-eslint/no-unused-vars": ["error",
{ "argsIgnorePattern": "^_" }
]
}
}
);
================================================
File: jest.config.js
================================================
import { createDefaultEsmPreset } from "ts-jest";
const defaultEsmPreset = createDefaultEsmPreset();
/** @type {import('ts-jest').JestConfigWithTsJest} **/
export default {
...defaultEsmPreset,
moduleNameMapper: {
"^(\\.{1,2}/.*)\\.js$": "$1",
"^pkce-challenge$": "<rootDir>/src/__mocks__/pkce-challenge.ts"
},
transformIgnorePatterns: [
"/node_modules/(?!eventsource)/"
],
testPathIgnorePatterns: ["/node_modules/", "/dist/"],
};
================================================
File: package.json
================================================
{
"name": "@modelcontextprotocol/sdk",
"version": "1.6.1",
"description": "Model Context Protocol implementation for TypeScript",
"license": "MIT",
"author": "Anthropic, PBC (https://anthropic.com)",
"homepage": "https://modelcontextprotocol.io",
"bugs": "https://github.com/modelcontextprotocol/typescript-sdk/issues",
"type": "module",
"repository": {
"type": "git",
"url": "git+https://github.com/modelcontextprotocol/typescript-sdk.git"
},
"engines": {
"node": ">=18"
},
"keywords": [
"modelcontextprotocol",
"mcp"
],
"exports": {
"./*": {
"import": "./dist/esm/*",
"require": "./dist/cjs/*"
}
},
"typesVersions": {
"*": {
"*": [
"./dist/esm/*"
]
}
},
"files": [
"dist"
],
"scripts": {
"build": "npm run build:esm && npm run build:cjs",
"build:esm": "tsc -p tsconfig.prod.json && echo '{\"type\": \"module\"}' > dist/esm/package.json",
"build:cjs": "tsc -p tsconfig.cjs.json && echo '{\"type\": \"commonjs\"}' > dist/cjs/package.json",
"prepack": "npm run build:esm && npm run build:cjs",
"lint": "eslint src/",
"test": "jest",
"start": "npm run server",
"server": "tsx watch --clear-screen=false src/cli.ts server",
"client": "tsx src/cli.ts client"
},
"dependencies": {
"content-type": "^1.0.5",
"cors": "^2.8.5",
"eventsource": "^3.0.2",
"express": "^5.0.1",
"express-rate-limit": "^7.5.0",
"pkce-challenge": "^4.1.0",
"raw-body": "^3.0.0",
"zod": "^3.23.8",
"zod-to-json-schema": "^3.24.1"
},
"devDependencies": {
"@eslint/js": "^9.8.0",
"@jest-mock/express": "^3.0.0",
"@types/content-type": "^1.1.8",
"@types/cors": "^2.8.17",
"@types/eslint__js": "^8.42.3",
"@types/eventsource": "^1.1.15",
"@types/express": "^5.0.0",
"@types/jest": "^29.5.12",
"@types/node": "^22.0.2",
"@types/supertest": "^6.0.2",
"@types/ws": "^8.5.12",
"eslint": "^9.8.0",
"jest": "^29.7.0",
"supertest": "^7.0.0",
"ts-jest": "^29.2.4",
"tsx": "^4.16.5",
"typescript": "^5.5.4",
"typescript-eslint": "^8.0.0",
"ws": "^8.18.0"
},
"resolutions": {
"strip-ansi": "6.0.1"
}
}
================================================
File: tsconfig.cjs.json
================================================
{
"extends": "./tsconfig.json",
"compilerOptions": {
"module": "commonjs",
"moduleResolution": "node",
"outDir": "./dist/cjs"
},
"exclude": ["**/*.test.ts", "src/__mocks__/**/*"]
}
================================================
File: tsconfig.json
================================================
{
"compilerOptions": {
"target": "es2018",
"module": "Node16",
"moduleResolution": "Node16",
"declaration": true,
"declarationMap": true,
"sourceMap": true,
"outDir": "./dist",
"strict": true,
"esModuleInterop": true,
"forceConsistentCasingInFileNames": true,
"resolveJsonModule": true,
"isolatedModules": true,
"skipLibCheck": true
},
"include": ["src/**/*"],
"exclude": ["node_modules", "dist"]
}
================================================
File: tsconfig.prod.json
================================================
{
"extends": "./tsconfig.json",
"compilerOptions": {
"outDir": "./dist/esm"
},
"exclude": ["**/*.test.ts", "src/__mocks__/**/*"]
}
================================================
File: .npmrc
================================================
registry = "https://registry.npmjs.org/"
================================================
File: src/cli.ts
================================================
import WebSocket from "ws";
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(global as any).WebSocket = WebSocket;
import express from "express";
import { Client } from "./client/index.js";
import { SSEClientTransport } from "./client/sse.js";
import { StdioClientTransport } from "./client/stdio.js";
import { WebSocketClientTransport } from "./client/websocket.js";
import { Server } from "./server/index.js";
import { SSEServerTransport } from "./server/sse.js";
import { StdioServerTransport } from "./server/stdio.js";
import { ListResourcesResultSchema } from "./types.js";
async function runClient(url_or_command: string, args: string[]) {
const client = new Client(
{
name: "mcp-typescript test client",
version: "0.1.0",
},
{
capabilities: {
sampling: {},
},
},
);
let clientTransport;
let url: URL | undefined = undefined;
try {
url = new URL(url_or_command);
} catch {
// Ignore
}
if (url?.protocol === "http:" || url?.protocol === "https:") {
clientTransport = new SSEClientTransport(new URL(url_or_command));
} else if (url?.protocol === "ws:" || url?.protocol === "wss:") {
clientTransport = new WebSocketClientTransport(new URL(url_or_command));
} else {
clientTransport = new StdioClientTransport({
command: url_or_command,
args,
});
}
console.log("Connected to server.");
await client.connect(clientTransport);
console.log("Initialized.");
await client.request({ method: "resources/list" }, ListResourcesResultSchema);
await client.close();
console.log("Closed.");
}
async function runServer(port: number | null) {
if (port !== null) {
const app = express();
let servers: Server[] = [];
app.get("/sse", async (req, res) => {
console.log("Got new SSE connection");
const transport = new SSEServerTransport("/message", res);
const server = new Server(
{
name: "mcp-typescript test server",
version: "0.1.0",
},
{
capabilities: {},
},
);
servers.push(server);
server.onclose = () => {
console.log("SSE connection closed");
servers = servers.filter((s) => s !== server);
};
await server.connect(transport);
});
app.post("/message", async (req, res) => {
console.log("Received message");
const sessionId = req.query.sessionId as string;
const transport = servers
.map((s) => s.transport as SSEServerTransport)
.find((t) => t.sessionId === sessionId);
if (!transport) {
res.status(404).send("Session not found");
return;
}
await transport.handlePostMessage(req, res);
});
app.listen(port, () => {
console.log(`Server running on http://localhost:${port}/sse`);
});
} else {
const server = new Server(
{
name: "mcp-typescript test server",
version: "0.1.0",
},
{
capabilities: {
prompts: {},
resources: {},
tools: {},
logging: {},
},
},
);
const transport = new StdioServerTransport();
await server.connect(transport);
console.log("Server running on stdio");
}
}
const args = process.argv.slice(2);
const command = args[0];
switch (command) {
case "client":
if (args.length < 2) {
console.error("Usage: client <server_url_or_command> [args...]");
process.exit(1);
}
runClient(args[1], args.slice(2)).catch((error) => {
console.error(error);
process.exit(1);
});
break;
case "server": {
const port = args[1] ? parseInt(args[1]) : null;
runServer(port).catch((error) => {
console.error(error);
process.exit(1);
});
break;
}
default:
console.error("Unrecognized command:", command);
}
================================================
File: src/inMemory.test.ts
================================================
import { InMemoryTransport } from "./inMemory.js";
import { JSONRPCMessage } from "./types.js";
describe("InMemoryTransport", () => {
let clientTransport: InMemoryTransport;
let serverTransport: InMemoryTransport;
beforeEach(() => {
[clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
});
test("should create linked pair", () => {
expect(clientTransport).toBeDefined();
expect(serverTransport).toBeDefined();
});
test("should start without error", async () => {
await expect(clientTransport.start()).resolves.not.toThrow();
await expect(serverTransport.start()).resolves.not.toThrow();
});
test("should send message from client to server", async () => {
const message: JSONRPCMessage = {
jsonrpc: "2.0",
method: "test",
id: 1,
};
let receivedMessage: JSONRPCMessage | undefined;
serverTransport.onmessage = (msg) => {
receivedMessage = msg;
};
await clientTransport.send(message);
expect(receivedMessage).toEqual(message);
});
test("should send message from server to client", async () => {
const message: JSONRPCMessage = {
jsonrpc: "2.0",
method: "test",
id: 1,
};
let receivedMessage: JSONRPCMessage | undefined;
clientTransport.onmessage = (msg) => {
receivedMessage = msg;
};
await serverTransport.send(message);
expect(receivedMessage).toEqual(message);
});
test("should handle close", async () => {
let clientClosed = false;
let serverClosed = false;
clientTransport.onclose = () => {
clientClosed = true;
};
serverTransport.onclose = () => {
serverClosed = true;
};
await clientTransport.close();
expect(clientClosed).toBe(true);
expect(serverClosed).toBe(true);
});
test("should throw error when sending after close", async () => {
await clientTransport.close();
await expect(
clientTransport.send({ jsonrpc: "2.0", method: "test", id: 1 }),
).rejects.toThrow("Not connected");
});
test("should queue messages sent before start", async () => {
const message: JSONRPCMessage = {
jsonrpc: "2.0",
method: "test",
id: 1,
};
let receivedMessage: JSONRPCMessage | undefined;
serverTransport.onmessage = (msg) => {
receivedMessage = msg;
};
await clientTransport.send(message);
await serverTransport.start();
expect(receivedMessage).toEqual(message);
});
});
================================================
File: src/inMemory.ts
================================================
import { Transport } from "./shared/transport.js";
import { JSONRPCMessage } from "./types.js";
/**
* In-memory transport for creating clients and servers that talk to each other within the same process.
*/
export class InMemoryTransport implements Transport {
private _otherTransport?: InMemoryTransport;
private _messageQueue: JSONRPCMessage[] = [];
onclose?: () => void;
onerror?: (error: Error) => void;
onmessage?: (message: JSONRPCMessage) => void;
/**
* Creates a pair of linked in-memory transports that can communicate with each other. One should be passed to a Client and one to a Server.
*/
static createLinkedPair(): [InMemoryTransport, InMemoryTransport] {
const clientTransport = new InMemoryTransport();
const serverTransport = new InMemoryTransport();
clientTransport._otherTransport = serverTransport;
serverTransport._otherTransport = clientTransport;
return [clientTransport, serverTransport];
}
async start(): Promise<void> {
// Process any messages that were queued before start was called
while (this._messageQueue.length > 0) {
const message = this._messageQueue.shift();
if (message) {
this.onmessage?.(message);
}
}
}
async close(): Promise<void> {
const other = this._otherTransport;
this._otherTransport = undefined;
await other?.close();
this.onclose?.();
}
async send(message: JSONRPCMessage): Promise<void> {
if (!this._otherTransport) {
throw new Error("Not connected");
}
if (this._otherTransport.onmessage) {
this._otherTransport.onmessage(message);
} else {
this._otherTransport._messageQueue.push(message);
}
}
}
================================================
File: src/types.ts
================================================
import { z, ZodTypeAny } from "zod";
export const LATEST_PROTOCOL_VERSION = "2024-11-05";
export const SUPPORTED_PROTOCOL_VERSIONS = [
LATEST_PROTOCOL_VERSION,
"2024-10-07",
];
/* JSON-RPC types */
export const JSONRPC_VERSION = "2.0";
/**
* A progress token, used to associate progress notifications with the original request.
*/
export const ProgressTokenSchema = z.union([z.string(), z.number().int()]);
/**
* An opaque token used to represent a cursor for pagination.
*/
export const CursorSchema = z.string();
const BaseRequestParamsSchema = z
.object({
_meta: z.optional(
z
.object({
/**
* If specified, the caller is requesting out-of-band progress notifications for this request (as represented by notifications/progress). The value of this parameter is an opaque token that will be attached to any subsequent notifications. The receiver is not obligated to provide these notifications.
*/
progressToken: z.optional(ProgressTokenSchema),
})
.passthrough(),
),
})
.passthrough();
export const RequestSchema = z.object({
method: z.string(),
params: z.optional(BaseRequestParamsSchema),
});
const BaseNotificationParamsSchema = z
.object({
/**
* This parameter name is reserved by MCP to allow clients and servers to attach additional metadata to their notifications.
*/
_meta: z.optional(z.object({}).passthrough()),
})
.passthrough();
export const NotificationSchema = z.object({
method: z.string(),
params: z.optional(BaseNotificationParamsSchema),
});
export const ResultSchema = z
.object({
/**
* This result property is reserved by the protocol to allow clients and servers to attach additional metadata to their responses.
*/
_meta: z.optional(z.object({}).passthrough()),
})
.passthrough();
/**
* A uniquely identifying ID for a request in JSON-RPC.
*/
export const RequestIdSchema = z.union([z.string(), z.number().int()]);
/**
* A request that expects a response.
*/
export const JSONRPCRequestSchema = z
.object({
jsonrpc: z.literal(JSONRPC_VERSION),
id: RequestIdSchema,
})
.merge(RequestSchema)
.strict();
/**
* A notification which does not expect a response.
*/
export const JSONRPCNotificationSchema = z
.object({
jsonrpc: z.literal(JSONRPC_VERSION),
})
.merge(NotificationSchema)
.strict();
/**
* A successful (non-error) response to a request.
*/
export const JSONRPCResponseSchema = z
.object({
jsonrpc: z.literal(JSONRPC_VERSION),
id: RequestIdSchema,
result: ResultSchema,
})
.strict();
/**
* Error codes defined by the JSON-RPC specification.
*/
export enum ErrorCode {
// SDK error codes
ConnectionClosed = -32000,
RequestTimeout = -32001,
// Standard JSON-RPC error codes
ParseError = -32700,
InvalidRequest = -32600,
MethodNotFound = -32601,
InvalidParams = -32602,
InternalError = -32603,
}
/**
* A response to a request that indicates an error occurred.
*/
export const JSONRPCErrorSchema = z
.object({
jsonrpc: z.literal(JSONRPC_VERSION),
id: RequestIdSchema,
error: z.object({
/**
* The error type that occurred.
*/
code: z.number().int(),
/**
* A short description of the error. The message SHOULD be limited to a concise single sentence.
*/
message: z.string(),
/**
* Additional information about the error. The value of this member is defined by the sender (e.g. detailed error information, nested errors etc.).
*/
data: z.optional(z.unknown()),
}),
})
.strict();
export const JSONRPCMessageSchema = z.union([
JSONRPCRequestSchema,
JSONRPCNotificationSchema,
JSONRPCResponseSchema,
JSONRPCErrorSchema,
]);
/* Empty result */
/**
* A response that indicates success but carries no data.
*/
export const EmptyResultSchema = ResultSchema.strict();
/* Cancellation */
/**
* This notification can be sent by either side to indicate that it is cancelling a previously-issued request.
*
* The request SHOULD still be in-flight, but due to communication latency, it is always possible that this notification MAY arrive after the request has already finished.
*
* This notification indicates that the result will be unused, so any associated processing SHOULD cease.
*
* A client MUST NOT attempt to cancel its `initialize` request.
*/
export const CancelledNotificationSchema = NotificationSchema.extend({
method: z.literal("notifications/cancelled"),
params: BaseNotificationParamsSchema.extend({
/**
* The ID of the request to cancel.
*
* This MUST correspond to the ID of a request previously issued in the same direction.
*/
requestId: RequestIdSchema,
/**
* An optional string describing the reason for the cancellation. This MAY be logged or presented to the user.
*/
reason: z.string().optional(),
}),
});
/* Initialization */
/**
* Describes the name and version of an MCP implementation.
*/
export const ImplementationSchema = z
.object({
name: z.string(),
version: z.string(),
})
.passthrough();
/**
* Capabilities a client may support. Known capabilities are defined here, in this schema, but this is not a closed set: any client can define its own, additional capabilities.
*/
export const ClientCapabilitiesSchema = z
.object({
/**
* Experimental, non-standard capabilities that the client supports.
*/
experimental: z.optional(z.object({}).passthrough()),
/**
* Present if the client supports sampling from an LLM.
*/
sampling: z.optional(z.object({}).passthrough()),
/**
* Present if the client supports listing roots.
*/
roots: z.optional(
z
.object({
/**
* Whether the client supports issuing notifications for changes to the roots list.
*/
listChanged: z.optional(z.boolean()),
})
.passthrough(),
),
})
.passthrough();
/**
* This request is sent from the client to the server when it first connects, asking it to begin initialization.
*/
export const InitializeRequestSchema = RequestSchema.extend({
method: z.literal("initialize"),
params: BaseRequestParamsSchema.extend({
/**
* The latest version of the Model Context Protocol that the client supports. The client MAY decide to support older versions as well.
*/
protocolVersion: z.string(),
capabilities: ClientCapabilitiesSchema,
clientInfo: ImplementationSchema,
}),
});
/**
* Capabilities that a server may support. Known capabilities are defined here, in this schema, but this is not a closed set: any server can define its own, additional capabilities.
*/
export const ServerCapabilitiesSchema = z
.object({
/**
* Experimental, non-standard capabilities that the server supports.
*/
experimental: z.optional(z.object({}).passthrough()),
/**
* Present if the server supports sending log messages to the client.
*/
logging: z.optional(z.object({}).passthrough()),
/**
* Present if the server offers any prompt templates.
*/
prompts: z.optional(
z
.object({
/**
* Whether this server supports issuing notifications for changes to the prompt list.
*/
listChanged: z.optional(z.boolean()),
})
.passthrough(),
),
/**
* Present if the server offers any resources to read.
*/
resources: z.optional(
z
.object({
/**
* Whether this server supports clients subscribing to resource updates.
*/
subscribe: z.optional(z.boolean()),
/**
* Whether this server supports issuing notifications for changes to the resource list.
*/
listChanged: z.optional(z.boolean()),
})
.passthrough(),
),
/**
* Present if the server offers any tools to call.
*/
tools: z.optional(
z
.object({
/**
* Whether this server supports issuing notifications for changes to the tool list.
*/
listChanged: z.optional(z.boolean()),
})
.passthrough(),
),
})
.passthrough();
/**
* After receiving an initialize request from the client, the server sends this response.
*/
export const InitializeResultSchema = ResultSchema.extend({
/**
* The version of the Model Context Protocol that the server wants to use. This may not match the version that the client requested. If the client cannot support this version, it MUST disconnect.
*/
protocolVersion: z.string(),
capabilities: ServerCapabilitiesSchema,
serverInfo: ImplementationSchema,
/**
* Instructions describing how to use the server and its features.
*
* This can be used by clients to improve the LLM's understanding of available tools, resources, etc. It can be thought of like a "hint" to the model. For example, this information MAY be added to the system prompt.
*/
instructions: z.optional(z.string()),
});
/**
* This notification is sent from the client to the server after initialization has finished.
*/
export const InitializedNotificationSchema = NotificationSchema.extend({
method: z.literal("notifications/initialized"),
});
/* Ping */
/**
* A ping, issued by either the server or the client, to check that the other party is still alive. The receiver must promptly respond, or else may be disconnected.
*/
export const PingRequestSchema = RequestSchema.extend({
method: z.literal("ping"),
});
/* Progress notifications */
export const ProgressSchema = z
.object({
/**
* The progress thus far. This should increase every time progress is made, even if the total is unknown.
*/
progress: z.number(),
/**
* Total number of items to process (or total progress required), if known.
*/
total: z.optional(z.number()),
})
.passthrough();
/**
* An out-of-band notification used to inform the receiver of a progress update for a long-running request.
*/
export const ProgressNotificationSchema = NotificationSchema.extend({
method: z.literal("notifications/progress"),
params: BaseNotificationParamsSchema.merge(ProgressSchema).extend({
/**
* The progress token which was given in the initial request, used to associate this notification with the request that is proceeding.
*/
progressToken: ProgressTokenSchema,
}),
});
/* Pagination */
export const PaginatedRequestSchema = RequestSchema.extend({
params: BaseRequestParamsSchema.extend({
/**
* An opaque token representing the current pagination position.
* If provided, the server should return results starting after this cursor.
*/
cursor: z.optional(CursorSchema),
}).optional(),
});
export const PaginatedResultSchema = ResultSchema.extend({
/**
* An opaque token representing the pagination position after the last returned result.
* If present, there may be more results available.
*/
nextCursor: z.optional(CursorSchema),
});
/* Resources */
/**
* The contents of a specific resource or sub-resource.
*/
export const ResourceContentsSchema = z
.object({
/**
* The URI of this resource.
*/
uri: z.string(),
/**
* The MIME type of this resource, if known.
*/
mimeType: z.optional(z.string()),
})
.passthrough();
export const TextResourceContentsSchema = ResourceContentsSchema.extend({
/**
* The text of the item. This must only be set if the item can actually be represented as text (not binary data).
*/
text: z.string(),
});
export const BlobResourceContentsSchema = ResourceContentsSchema.extend({
/**
* A base64-encoded string representing the binary data of the item.
*/
blob: z.string().base64(),
});
/**
* A known resource that the server is capable of reading.
*/
export const ResourceSchema = z
.object({
/**
* The URI of this resource.
*/
uri: z.string(),
/**
* A human-readable name for this resource.
*
* This can be used by clients to populate UI elements.
*/
name: z.string(),
/**
* A description of what this resource represents.
*
* This can be used by clients to improve the LLM's understanding of available resources. It can be thought of like a "hint" to the model.
*/
description: z.optional(z.string()),
/**
* The MIME type of this resource, if known.
*/
mimeType: z.optional(z.string()),
})
.passthrough();
/**
* A template description for resources available on the server.
*/
export const ResourceTemplateSchema = z
.object({
/**
* A URI template (according to RFC 6570) that can be used to construct resource URIs.
*/
uriTemplate: z.string(),
/**
* A human-readable name for the type of resource this template refers to.
*
* This can be used by clients to populate UI elements.
*/
name: z.string(),
/**
* A description of what this template is for.
*
* This can be used by clients to improve the LLM's understanding of available resources. It can be thought of like a "hint" to the model.
*/
description: z.optional(z.string()),
/**
* The MIME type for all resources that match this template. This should only be included if all resources matching this template have the same type.
*/
mimeType: z.optional(z.string()),
})
.passthrough();
/**
* Sent from the client to request a list of resources the server has.
*/
export const ListResourcesRequestSchema = PaginatedRequestSchema.extend({
method: z.literal("resources/list"),
});
/**
* The server's response to a resources/list request from the client.
*/
export const ListResourcesResultSchema = PaginatedResultSchema.extend({
resources: z.array(ResourceSchema),
});
/**
* Sent from the client to request a list of resource templates the server has.
*/
export const ListResourceTemplatesRequestSchema = PaginatedRequestSchema.extend(
{
method: z.literal("resources/templates/list"),
},
);
/**
* The server's response to a resources/templates/list request from the client.
*/
export const ListResourceTemplatesResultSchema = PaginatedResultSchema.extend({
resourceTemplates: z.array(ResourceTemplateSchema),
});
/**
* Sent from the client to the server, to read a specific resource URI.
*/
export const ReadResourceRequestSchema = RequestSchema.extend({
method: z.literal("resources/read"),
params: BaseRequestParamsSchema.extend({
/**
* The URI of the resource to read. The URI can use any protocol; it is up to the server how to interpret it.
*/
uri: z.string(),
}),
});
/**
* The server's response to a resources/read request from the client.
*/
export const ReadResourceResultSchema = ResultSchema.extend({
contents: z.array(
z.union([TextResourceContentsSchema, BlobResourceContentsSchema]),
),
});
/**
* An optional notification from the server to the client, informing it that the list of resources it can read from has changed. This may be issued by servers without any previous subscription from the client.
*/
export const ResourceListChangedNotificationSchema = NotificationSchema.extend({
method: z.literal("notifications/resources/list_changed"),
});
/**
* Sent from the client to request resources/updated notifications from the server whenever a particular resource changes.
*/
export const SubscribeRequestSchema = RequestSchema.extend({
method: z.literal("resources/subscribe"),
params: BaseRequestParamsSchema.extend({
/**
* The URI of the resource to subscribe to. The URI can use any protocol; it is up to the server how to interpret it.
*/
uri: z.string(),
}),
});
/**
* Sent from the client to request cancellation of resources/updated notifications from the server. This should follow a previous resources/subscribe request.
*/
export const UnsubscribeRequestSchema = RequestSchema.extend({
method: z.literal("resources/unsubscribe"),
params: BaseRequestParamsSchema.extend({
/**
* The URI of the resource to unsubscribe from.
*/
uri: z.string(),
}),
});
/**
* A notification from the server to the client, informing it that a resource has changed and may need to be read again. This should only be sent if the client previously sent a resources/subscribe request.
*/
export const ResourceUpdatedNotificationSchema = NotificationSchema.extend({
method: z.literal("notifications/resources/updated"),
params: BaseNotificationParamsSchema.extend({
/**
* The URI of the resource that has been updated. This might be a sub-resource of the one that the client actually subscribed to.
*/
uri: z.string(),
}),
});
/* Prompts */
/**
* Describes an argument that a prompt can accept.
*/
export const PromptArgumentSchema = z
.object({
/**
* The name of the argument.
*/
name: z.string(),
/**
* A human-readable description of the argument.
*/
description: z.optional(z.string()),
/**
* Whether this argument must be provided.
*/
required: z.optional(z.boolean()),
})
.passthrough();
/**
* A prompt or prompt template that the server offers.
*/
export const PromptSchema = z
.object({
/**
* The name of the prompt or prompt template.
*/
name: z.string(),
/**
* An optional description of what this prompt provides
*/
description: z.optional(z.string()),
/**
* A list of arguments to use for templating the prompt.
*/
arguments: z.optional(z.array(PromptArgumentSchema)),
})
.passthrough();
/**
* Sent from the client to request a list of prompts and prompt templates the server has.
*/
export const ListPromptsRequestSchema = PaginatedRequestSchema.extend({
method: z.literal("prompts/list"),
});
/**
* The server's response to a prompts/list request from the client.
*/
export const ListPromptsResultSchema = PaginatedResultSchema.extend({
prompts: z.array(PromptSchema),
});
/**
* Used by the client to get a prompt provided by the server.
*/
export const GetPromptRequestSchema = RequestSchema.extend({
method: z.literal("prompts/get"),
params: BaseRequestParamsSchema.extend({
/**
* The name of the prompt or prompt template.
*/
name: z.string(),
/**
* Arguments to use for templating the prompt.
*/
arguments: z.optional(z.record(z.string())),
}),
});
/**
* Text provided to or from an LLM.
*/
export const TextContentSchema = z
.object({
type: z.literal("text"),
/**
* The text content of the message.
*/
text: z.string(),
})
.passthrough();
/**
* An image provided to or from an LLM.
*/
export const ImageContentSchema = z
.object({
type: z.literal("image"),
/**
* The base64-encoded image data.
*/
data: z.string().base64(),
/**
* The MIME type of the image. Different providers may support different image types.
*/
mimeType: z.string(),
})
.passthrough();
/**
* The contents of a resource, embedded into a prompt or tool call result.
*/
export const EmbeddedResourceSchema = z
.object({
type: z.literal("resource"),
resource: z.union([TextResourceContentsSchema, BlobResourceContentsSchema]),
})
.passthrough();
/**
* Describes a message returned as part of a prompt.
*/
export const PromptMessageSchema = z
.object({
role: z.enum(["user", "assistant"]),
content: z.union([
TextContentSchema,
ImageContentSchema,
EmbeddedResourceSchema,
]),
})
.passthrough();
/**
* The server's response to a prompts/get request from the client.
*/
export const GetPromptResultSchema = ResultSchema.extend({
/**
* An optional description for the prompt.
*/
description: z.optional(z.string()),
messages: z.array(PromptMessageSchema),
});
/**
* An optional notification from the server to the client, informing it that the list of prompts it offers has changed. This may be issued by servers without any previous subscription from the client.
*/
export const PromptListChangedNotificationSchema = NotificationSchema.extend({
method: z.literal("notifications/prompts/list_changed"),
});
/* Tools */
/**
* Definition for a tool the client can call.
*/
export const ToolSchema = z
.object({
/**
* The name of the tool.
*/
name: z.string(),
/**
* A human-readable description of the tool.
*/
description: z.optional(z.string()),
/**
* A JSON Schema object defining the expected parameters for the tool.
*/
inputSchema: z
.object({
type: z.literal("object"),
properties: z.optional(z.object({}).passthrough()),
})
.passthrough(),
})
.passthrough();
/**
* Sent from the client to request a list of tools the server has.
*/
export const ListToolsRequestSchema = PaginatedRequestSchema.extend({
method: z.literal("tools/list"),
});
/**
* The server's response to a tools/list request from the client.
*/
export const ListToolsResultSchema = PaginatedResultSchema.extend({
tools: z.array(ToolSchema),
});
/**
* The server's response to a tool call.
*/
export const CallToolResultSchema = ResultSchema.extend({
content: z.array(
z.union([TextContentSchema, ImageContentSchema, EmbeddedResourceSchema]),
),
isError: z.boolean().default(false).optional(),
});
/**
* CallToolResultSchema extended with backwards compatibility to protocol version 2024-10-07.
*/
export const CompatibilityCallToolResultSchema = CallToolResultSchema.or(
ResultSchema.extend({
toolResult: z.unknown(),
}),
);
/**
* Used by the client to invoke a tool provided by the server.
*/
export const CallToolRequestSchema = RequestSchema.extend({
method: z.literal("tools/call"),
params: BaseRequestParamsSchema.extend({
name: z.string(),
arguments: z.optional(z.record(z.unknown())),
}),
});
/**
* An optional notification from the server to the client, informing it that the list of tools it offers has changed. This may be issued by servers without any previous subscription from the client.
*/
export const ToolListChangedNotificationSchema = NotificationSchema.extend({
method: z.literal("notifications/tools/list_changed"),
});
/* Logging */
/**
* The severity of a log message.
*/
export const LoggingLevelSchema = z.enum([
"debug",
"info",
"notice",
"warning",
"error",
"critical",
"alert",
"emergency",
]);
/**
* A request from the client to the server, to enable or adjust logging.
*/
export const SetLevelRequestSchema = RequestSchema.extend({
method: z.literal("logging/setLevel"),
params: BaseRequestParamsSchema.extend({
/**
* The level of logging that the client wants to receive from the server. The server should send all logs at this level and higher (i.e., more severe) to the client as notifications/logging/message.
*/
level: LoggingLevelSchema,
}),
});
/**
* Notification of a log message passed from server to client. If no logging/setLevel request has been sent from the client, the server MAY decide which messages to send automatically.
*/
export const LoggingMessageNotificationSchema = NotificationSchema.extend({
method: z.literal("notifications/message"),
params: BaseNotificationParamsSchema.extend({
/**
* The severity of this log message.
*/
level: LoggingLevelSchema,
/**
* An optional name of the logger issuing this message.
*/
logger: z.optional(z.string()),
/**
* The data to be logged, such as a string message or an object. Any JSON serializable type is allowed here.
*/
data: z.unknown(),
}),
});
/* Sampling */
/**
* Hints to use for model selection.
*/
export const ModelHintSchema = z
.object({
/**
* A hint for a model name.
*/
name: z.string().optional(),
})
.passthrough();
/**
* The server's preferences for model selection, requested of the client during sampling.
*/
export const ModelPreferencesSchema = z
.object({
/**
* Optional hints to use for model selection.
*/
hints: z.optional(z.array(ModelHintSchema)),
/**
* How much to prioritize cost when selecting a model.
*/
costPriority: z.optional(z.number().min(0).max(1)),
/**
* How much to prioritize sampling speed (latency) when selecting a model.
*/
speedPriority: z.optional(z.number().min(0).max(1)),
/**
* How much to prioritize intelligence and capabilities when selecting a model.
*/
intelligencePriority: z.optional(z.number().min(0).max(1)),
})
.passthrough();
/**
* Describes a message issued to or received from an LLM API.
*/
export const SamplingMessageSchema = z
.object({
role: z.enum(["user", "assistant"]),
content: z.union([TextContentSchema, ImageContentSchema]),
})
.passthrough();
/**
* A request from the server to sample an LLM via the client. The client has full discretion over which model to select. The client should also inform the user before beginning sampling, to allow them to inspect the request (human in the loop) and decide whether to approve it.
*/
export const CreateMessageRequestSchema = RequestSchema.extend({
method: z.literal("sampling/createMessage"),
params: BaseRequestParamsSchema.extend({
messages: z.array(SamplingMessageSchema),
/**
* An optional system prompt the server wants to use for sampling. The client MAY modify or omit this prompt.
*/
systemPrompt: z.optional(z.string()),
/**
* A request to include context from one or more MCP servers (including the caller), to be attached to the prompt. The client MAY ignore this request.
*/
includeContext: z.optional(z.enum(["none", "thisServer", "allServers"])),
temperature: z.optional(z.number()),
/**
* The maximum number of tokens to sample, as requested by the server. The client MAY choose to sample fewer tokens than requested.
*/
maxTokens: z.number().int(),
stopSequences: z.optional(z.array(z.string())),
/**
* Optional metadata to pass through to the LLM provider. The format of this metadata is provider-specific.
*/
metadata: z.optional(z.object({}).passthrough()),
/**
* The server's preferences for which model to select.
*/
modelPreferences: z.optional(ModelPreferencesSchema),
}),
});
/**
* The client's response to a sampling/create_message request from the server. The client should inform the user before returning the sampled message, to allow them to inspect the response (human in the loop) and decide whether to allow the server to see it.
*/
export const CreateMessageResultSchema = ResultSchema.extend({
/**
* The name of the model that generated the message.
*/
model: z.string(),
/**
* The reason why sampling stopped.
*/
stopReason: z.optional(
z.enum(["endTurn", "stopSequence", "maxTokens"]).or(z.string()),
),
role: z.enum(["user", "assistant"]),
content: z.discriminatedUnion("type", [
TextContentSchema,
ImageContentSchema,
]),
});
/* Autocomplete */
/**
* A reference to a resource or resource template definition.
*/
export const ResourceReferenceSchema = z
.object({
type: z.literal("ref/resource"),
/**
* The URI or URI template of the resource.
*/
uri: z.string(),
})
.passthrough();
/**
* Identifies a prompt.
*/
export const PromptReferenceSchema = z
.object({
type: z.literal("ref/prompt"),
/**
* The name of the prompt or prompt template
*/
name: z.string(),
})
.passthrough();
/**
* A request from the client to the server, to ask for completion options.
*/
export const CompleteRequestSchema = RequestSchema.extend({
method: z.literal("completion/complete"),
params: BaseRequestParamsSchema.extend({
ref: z.union([PromptReferenceSchema, ResourceReferenceSchema]),
/**
* The argument's information
*/
argument: z
.object({
/**
* The name of the argument
*/
name: z.string(),
/**
* The value of the argument to use for completion matching.
*/
value: z.string(),
})
.passthrough(),
}),
});
/**
* The server's response to a completion/complete request
*/
export const CompleteResultSchema = ResultSchema.extend({
completion: z
.object({
/**
* An array of completion values. Must not exceed 100 items.
*/
values: z.array(z.string()).max(100),
/**
* The total number of completion options available. This can exceed the number of values actually sent in the response.
*/
total: z.optional(z.number().int()),
/**
* Indicates whether there are additional completion options beyond those provided in the current response, even if the exact total is unknown.
*/
hasMore: z.optional(z.boolean()),
})
.passthrough(),
});
/* Roots */
/**
* Represents a root directory or file that the server can operate on.
*/
export const RootSchema = z
.object({
/**
* The URI identifying the root. This *must* start with file:// for now.
*/
uri: z.string().startsWith("file://"),
/**
* An optional name for the root.
*/
name: z.optional(z.string()),
})
.passthrough();
/**
* Sent from the server to request a list of root URIs from the client.
*/
export const ListRootsRequestSchema = RequestSchema.extend({
method: z.literal("roots/list"),
});
/**
* The client's response to a roots/list request from the server.
*/
export const ListRootsResultSchema = ResultSchema.extend({
roots: z.array(RootSchema),
});
/**
* A notification from the client to the server, informing it that the list of roots has changed.
*/
export const RootsListChangedNotificationSchema = NotificationSchema.extend({
method: z.literal("notifications/roots/list_changed"),
});
/* Client messages */
export const ClientRequestSchema = z.union([
PingRequestSchema,
InitializeRequestSchema,
CompleteRequestSchema,
SetLevelRequestSchema,
GetPromptRequestSchema,
ListPromptsRequestSchema,
ListResourcesRequestSchema,
ListResourceTemplatesRequestSchema,
ReadResourceRequestSchema,
SubscribeRequestSchema,
UnsubscribeRequestSchema,
CallToolRequestSchema,
ListToolsRequestSchema,
]);
export const ClientNotificationSchema = z.union([
CancelledNotificationSchema,
ProgressNotificationSchema,
InitializedNotificationSchema,
RootsListChangedNotificationSchema,
]);
export const ClientResultSchema = z.union([
EmptyResultSchema,
CreateMessageResultSchema,
ListRootsResultSchema,
]);
/* Server messages */
export const ServerRequestSchema = z.union([
PingRequestSchema,
CreateMessageRequestSchema,
ListRootsRequestSchema,
]);
export const ServerNotificationSchema = z.union([
CancelledNotificationSchema,
ProgressNotificationSchema,
LoggingMessageNotificationSchema,
ResourceUpdatedNotificationSchema,
ResourceListChangedNotificationSchema,
ToolListChangedNotificationSchema,
PromptListChangedNotificationSchema,
]);
export const ServerResultSchema = z.union([
EmptyResultSchema,
InitializeResultSchema,
CompleteResultSchema,
GetPromptResultSchema,
ListPromptsResultSchema,
ListResourcesResultSchema,
ListResourceTemplatesResultSchema,
ReadResourceResultSchema,
CallToolResultSchema,
ListToolsResultSchema,
]);
export class McpError extends Error {
constructor(
public readonly code: number,
message: string,
public readonly data?: unknown,
) {
super(`MCP error ${code}: ${message}`);
this.name = "McpError";
}
}
type Primitive = string | number | boolean | bigint | null | undefined;
type Flatten<T> = T extends Primitive
? T
: T extends Array<infer U>
? Array<Flatten<U>>
: T extends Set<infer U>
? Set<Flatten<U>>
: T extends Map<infer K, infer V>
? Map<Flatten<K>, Flatten<V>>
: T extends object
? { [K in keyof T]: Flatten<T[K]> }
: T;
type Infer<Schema extends ZodTypeAny> = Flatten<z.infer<Schema>>;
/* JSON-RPC types */
export type ProgressToken = Infer<typeof ProgressTokenSchema>;
export type Cursor = Infer<typeof CursorSchema>;
export type Request = Infer<typeof RequestSchema>;
export type Notification = Infer<typeof NotificationSchema>;
export type Result = Infer<typeof ResultSchema>;
export type RequestId = Infer<typeof RequestIdSchema>;
export type JSONRPCRequest = Infer<typeof JSONRPCRequestSchema>;
export type JSONRPCNotification = Infer<typeof JSONRPCNotificationSchema>;
export type JSONRPCResponse = Infer<typeof JSONRPCResponseSchema>;
export type JSONRPCError = Infer<typeof JSONRPCErrorSchema>;
export type JSONRPCMessage = Infer<typeof JSONRPCMessageSchema>;
/* Empty result */
export type EmptyResult = Infer<typeof EmptyResultSchema>;
/* Cancellation */
export type CancelledNotification = Infer<typeof CancelledNotificationSchema>;
/* Initialization */
export type Implementation = Infer<typeof ImplementationSchema>;
export type ClientCapabilities = Infer<typeof ClientCapabilitiesSchema>;
export type InitializeRequest = Infer<typeof InitializeRequestSchema>;
export type ServerCapabilities = Infer<typeof ServerCapabilitiesSchema>;
export type InitializeResult = Infer<typeof InitializeResultSchema>;
export type InitializedNotification = Infer<typeof InitializedNotificationSchema>;
/* Ping */
export type PingRequest = Infer<typeof PingRequestSchema>;
/* Progress notifications */
export type Progress = Infer<typeof ProgressSchema>;
export type ProgressNotification = Infer<typeof ProgressNotificationSchema>;
/* Pagination */
export type PaginatedRequest = Infer<typeof PaginatedRequestSchema>;
export type PaginatedResult = Infer<typeof PaginatedResultSchema>;
/* Resources */
export type ResourceContents = Infer<typeof ResourceContentsSchema>;
export type TextResourceContents = Infer<typeof TextResourceContentsSchema>;
export type BlobResourceContents = Infer<typeof BlobResourceContentsSchema>;
export type Resource = Infer<typeof ResourceSchema>;
export type ResourceTemplate = Infer<typeof ResourceTemplateSchema>;
export type ListResourcesRequest = Infer<typeof ListResourcesRequestSchema>;
export type ListResourcesResult = Infer<typeof ListResourcesResultSchema>;
export type ListResourceTemplatesRequest = Infer<typeof ListResourceTemplatesRequestSchema>;
export type ListResourceTemplatesResult = Infer<typeof ListResourceTemplatesResultSchema>;
export type ReadResourceRequest = Infer<typeof ReadResourceRequestSchema>;
export type ReadResourceResult = Infer<typeof ReadResourceResultSchema>;
export type ResourceListChangedNotification = Infer<typeof ResourceListChangedNotificationSchema>;
export type SubscribeRequest = Infer<typeof SubscribeRequestSchema>;
export type UnsubscribeRequest = Infer<typeof UnsubscribeRequestSchema>;
export type ResourceUpdatedNotification = Infer<typeof ResourceUpdatedNotificationSchema>;
/* Prompts */
export type PromptArgument = Infer<typeof PromptArgumentSchema>;
export type Prompt = Infer<typeof PromptSchema>;
export type ListPromptsRequest = Infer<typeof ListPromptsRequestSchema>;
export type ListPromptsResult = Infer<typeof ListPromptsResultSchema>;
export type GetPromptRequest = Infer<typeof GetPromptRequestSchema>;
export type TextContent = Infer<typeof TextContentSchema>;
export type ImageContent = Infer<typeof ImageContentSchema>;
export type EmbeddedResource = Infer<typeof EmbeddedResourceSchema>;
export type PromptMessage = Infer<typeof PromptMessageSchema>;
export type GetPromptResult = Infer<typeof GetPromptResultSchema>;
export type PromptListChangedNotification = Infer<typeof PromptListChangedNotificationSchema>;
/* Tools */
export type Tool = Infer<typeof ToolSchema>;
export type ListToolsRequest = Infer<typeof ListToolsRequestSchema>;
export type ListToolsResult = Infer<typeof ListToolsResultSchema>;
export type CallToolResult = Infer<typeof CallToolResultSchema>;
export type CompatibilityCallToolResult = Infer<typeof CompatibilityCallToolResultSchema>;
export type CallToolRequest = Infer<typeof CallToolRequestSchema>;
export type ToolListChangedNotification = Infer<typeof ToolListChangedNotificationSchema>;
/* Logging */
export type LoggingLevel = Infer<typeof LoggingLevelSchema>;
export type SetLevelRequest = Infer<typeof SetLevelRequestSchema>;
export type LoggingMessageNotification = Infer<typeof LoggingMessageNotificationSchema>;
/* Sampling */
export type SamplingMessage = Infer<typeof SamplingMessageSchema>;
export type CreateMessageRequest = Infer<typeof CreateMessageRequestSchema>;
export type CreateMessageResult = Infer<typeof CreateMessageResultSchema>;
/* Autocomplete */
export type ResourceReference = Infer<typeof ResourceReferenceSchema>;
export type PromptReference = Infer<typeof PromptReferenceSchema>;
export type CompleteRequest = Infer<typeof CompleteRequestSchema>;
export type CompleteResult = Infer<typeof CompleteResultSchema>;
/* Roots */
export type Root = Infer<typeof RootSchema>;
export type ListRootsRequest = Infer<typeof ListRootsRequestSchema>;
export type ListRootsResult = Infer<typeof ListRootsResultSchema>;
export type RootsListChangedNotification = Infer<typeof RootsListChangedNotificationSchema>;
/* Client messages */
export type ClientRequest = Infer<typeof ClientRequestSchema>;
export type ClientNotification = Infer<typeof ClientNotificationSchema>;
export type ClientResult = Infer<typeof ClientResultSchema>;
/* Server messages */
export type ServerRequest = Infer<typeof ServerRequestSchema>;
export type ServerNotification = Infer<typeof ServerNotificationSchema>;
export type ServerResult = Infer<typeof ServerResultSchema>;
================================================
File: src/__mocks__/pkce-challenge.ts
================================================
export default function pkceChallenge() {
return {
code_verifier: "test_verifier",
code_challenge: "test_challenge",
};
}
================================================
File: src/client/auth.test.ts
================================================
import {
discoverOAuthMetadata,
startAuthorization,
exchangeAuthorization,
refreshAuthorization,
registerClient,
} from "./auth.js";
// Mock fetch globally
const mockFetch = jest.fn();
global.fetch = mockFetch;
describe("OAuth Authorization", () => {
beforeEach(() => {
mockFetch.mockReset();
});
describe("discoverOAuthMetadata", () => {
const validMetadata = {
issuer: "https://auth.example.com",
authorization_endpoint: "https://auth.example.com/authorize",
token_endpoint: "https://auth.example.com/token",
registration_endpoint: "https://auth.example.com/register",
response_types_supported: ["code"],
code_challenge_methods_supported: ["S256"],
};
it("returns metadata when discovery succeeds", async () => {
mockFetch.mockResolvedValueOnce({
ok: true,
status: 200,
json: async () => validMetadata,
});
const metadata = await discoverOAuthMetadata("https://auth.example.com");
expect(metadata).toEqual(validMetadata);
const calls = mockFetch.mock.calls;
expect(calls.length).toBe(1);
const [url, options] = calls[0];
expect(url.toString()).toBe("https://auth.example.com/.well-known/oauth-authorization-server");
expect(options.headers).toEqual({
"MCP-Protocol-Version": "2024-11-05"
});
});
it("returns metadata when first fetch fails but second without MCP header succeeds", async () => {
// Set up a counter to control behavior
let callCount = 0;
// Mock implementation that changes behavior based on call count
mockFetch.mockImplementation((_url, _options) => {
callCount++;
if (callCount === 1) {
// First call with MCP header - fail with TypeError (simulating CORS error)
// We need to use TypeError specifically because that's what the implementation checks for
return Promise.reject(new TypeError("Network error"));
} else {
// Second call without header - succeed
return Promise.resolve({
ok: true,
status: 200,
json: async () => validMetadata
});
}
});
// Should succeed with the second call
const metadata = await discoverOAuthMetadata("https://auth.example.com");
expect(metadata).toEqual(validMetadata);
// Verify both calls were made
expect(mockFetch).toHaveBeenCalledTimes(2);
// Verify first call had MCP header
expect(mockFetch.mock.calls[0][1]?.headers).toHaveProperty("MCP-Protocol-Version");
});
it("throws an error when all fetch attempts fail", async () => {
// Set up a counter to control behavior
let callCount = 0;
// Mock implementation that changes behavior based on call count
mockFetch.mockImplementation((_url, _options) => {
callCount++;
if (callCount === 1) {
// First call - fail with TypeError
return Promise.reject(new TypeError("First failure"));
} else {
// Second call - fail with different error
return Promise.reject(new Error("Second failure"));
}
});
// Should fail with the second error
await expect(discoverOAuthMetadata("https://auth.example.com"))
.rejects.toThrow("Second failure");
// Verify both calls were made
expect(mockFetch).toHaveBeenCalledTimes(2);
});
it("returns undefined when discovery endpoint returns 404", async () => {
mockFetch.mockResolvedValueOnce({
ok: false,
status: 404,
});
const metadata = await discoverOAuthMetadata("https://auth.example.com");
expect(metadata).toBeUndefined();
});
it("throws on non-404 errors", async () => {
mockFetch.mockResolvedValueOnce({
ok: false,
status: 500,
});
await expect(
discoverOAuthMetadata("https://auth.example.com")
).rejects.toThrow("HTTP 500");
});
it("validates metadata schema", async () => {
mockFetch.mockResolvedValueOnce({
ok: true,
status: 200,
json: async () => ({
// Missing required fields
issuer: "https://auth.example.com",
}),
});
await expect(
discoverOAuthMetadata("https://auth.example.com")
).rejects.toThrow();
});
});
describe("startAuthorization", () => {
const validMetadata = {
issuer: "https://auth.example.com",
authorization_endpoint: "https://auth.example.com/auth",
token_endpoint: "https://auth.example.com/tkn",
response_types_supported: ["code"],
code_challenge_methods_supported: ["S256"],
};
const validClientInfo = {
client_id: "client123",
client_secret: "secret123",
redirect_uris: ["http://localhost:3000/callback"],
client_name: "Test Client",
};
it("generates authorization URL with PKCE challenge", async () => {
const { authorizationUrl, codeVerifier } = await startAuthorization(
"https://auth.example.com",
{
clientInformation: validClientInfo,
redirectUrl: "http://localhost:3000/callback",
}
);
expect(authorizationUrl.toString()).toMatch(
/^https:\/\/auth\.example\.com\/authorize\?/
);
expect(authorizationUrl.searchParams.get("response_type")).toBe("code");
expect(authorizationUrl.searchParams.get("code_challenge")).toBe("test_challenge");
expect(authorizationUrl.searchParams.get("code_challenge_method")).toBe(
"S256"
);
expect(authorizationUrl.searchParams.get("redirect_uri")).toBe(
"http://localhost:3000/callback"
);
expect(codeVerifier).toBe("test_verifier");
});
it("uses metadata authorization_endpoint when provided", async () => {
const { authorizationUrl } = await startAuthorization(
"https://auth.example.com",
{
metadata: validMetadata,
clientInformation: validClientInfo,
redirectUrl: "http://localhost:3000/callback",
}
);
expect(authorizationUrl.toString()).toMatch(
/^https:\/\/auth\.example\.com\/auth\?/
);
});
it("validates response type support", async () => {
const metadata = {
...validMetadata,
response_types_supported: ["token"], // Does not support 'code'
};
await expect(
startAuthorization("https://auth.example.com", {
metadata,
clientInformation: validClientInfo,
redirectUrl: "http://localhost:3000/callback",
})
).rejects.toThrow(/does not support response type/);
});
it("validates PKCE support", async () => {
const metadata = {
...validMetadata,
response_types_supported: ["code"],
code_challenge_methods_supported: ["plain"], // Does not support 'S256'
};
await expect(
startAuthorization("https://auth.example.com", {
metadata,
clientInformation: validClientInfo,
redirectUrl: "http://localhost:3000/callback",
})
).rejects.toThrow(/does not support code challenge method/);
});
});
describe("exchangeAuthorization", () => {
const validTokens = {
access_token: "access123",
token_type: "Bearer",
expires_in: 3600,
refresh_token: "refresh123",
};
const validClientInfo = {
client_id: "client123",
client_secret: "secret123",
redirect_uris: ["http://localhost:3000/callback"],
client_name: "Test Client",
};
it("exchanges code for tokens", async () => {
mockFetch.mockResolvedValueOnce({
ok: true,
status: 200,
json: async () => validTokens,
});
const tokens = await exchangeAuthorization("https://auth.example.com", {
clientInformation: validClientInfo,
authorizationCode: "code123",
codeVerifier: "verifier123",
});
expect(tokens).toEqual(validTokens);
expect(mockFetch).toHaveBeenCalledWith(
expect.objectContaining({
href: "https://auth.example.com/token",
}),
expect.objectContaining({
method: "POST",
headers: {
"Content-Type": "application/x-www-form-urlencoded",
},
})
);
const body = mockFetch.mock.calls[0][1].body as URLSearchParams;
expect(body.get("grant_type")).toBe("authorization_code");
expect(body.get("code")).toBe("code123");
expect(body.get("code_verifier")).toBe("verifier123");
expect(body.get("client_id")).toBe("client123");
expect(body.get("client_secret")).toBe("secret123");
});
it("validates token response schema", async () => {
mockFetch.mockResolvedValueOnce({
ok: true,
status: 200,
json: async () => ({
// Missing required fields
access_token: "access123",
}),
});
await expect(
exchangeAuthorization("https://auth.example.com", {
clientInformation: validClientInfo,
authorizationCode: "code123",
codeVerifier: "verifier123",
})
).rejects.toThrow();
});
it("throws on error response", async () => {
mockFetch.mockResolvedValueOnce({
ok: false,
status: 400,
});
await expect(
exchangeAuthorization("https://auth.example.com", {
clientInformation: validClientInfo,
authorizationCode: "code123",
codeVerifier: "verifier123",
})
).rejects.toThrow("Token exchange failed");
});
});
describe("refreshAuthorization", () => {
const validTokens = {
access_token: "newaccess123",
token_type: "Bearer",
expires_in: 3600,
refresh_token: "newrefresh123",
};
const validClientInfo = {
client_id: "client123",
client_secret: "secret123",
redirect_uris: ["http://localhost:3000/callback"],
client_name: "Test Client",
};
it("exchanges refresh token for new tokens", async () => {
mockFetch.mockResolvedValueOnce({
ok: true,
status: 200,
json: async () => validTokens,
});
const tokens = await refreshAuthorization("https://auth.example.com", {
clientInformation: validClientInfo,
refreshToken: "refresh123",
});
expect(tokens).toEqual(validTokens);
expect(mockFetch).toHaveBeenCalledWith(
expect.objectContaining({
href: "https://auth.example.com/token",
}),
expect.objectContaining({
method: "POST",
headers: {
"Content-Type": "application/x-www-form-urlencoded",
},
})
);
const body = mockFetch.mock.calls[0][1].body as URLSearchParams;
expect(body.get("grant_type")).toBe("refresh_token");
expect(body.get("refresh_token")).toBe("refresh123");
expect(body.get("client_id")).toBe("client123");
expect(body.get("client_secret")).toBe("secret123");
});
it("validates token response schema", async () => {
mockFetch.mockResolvedValueOnce({
ok: true,
status: 200,
json: async () => ({
// Missing required fields
access_token: "newaccess123",
}),
});
await expect(
refreshAuthorization("https://auth.example.com", {
clientInformation: validClientInfo,
refreshToken: "refresh123",
})
).rejects.toThrow();
});
it("throws on error response", async () => {
mockFetch.mockResolvedValueOnce({
ok: false,
status: 400,
});
await expect(
refreshAuthorization("https://auth.example.com", {
clientInformation: validClientInfo,
refreshToken: "refresh123",
})
).rejects.toThrow("Token refresh failed");
});
});
describe("registerClient", () => {
const validClientMetadata = {
redirect_uris: ["http://localhost:3000/callback"],
client_name: "Test Client",
};
const validClientInfo = {
client_id: "client123",
client_secret: "secret123",
client_id_issued_at: 1612137600,
client_secret_expires_at: 1612224000,
...validClientMetadata,
};
it("registers client and returns client information", async () => {
mockFetch.mockResolvedValueOnce({
ok: true,
status: 200,
json: async () => validClientInfo,
});
const clientInfo = await registerClient("https://auth.example.com", {
clientMetadata: validClientMetadata,
});
expect(clientInfo).toEqual(validClientInfo);
expect(mockFetch).toHaveBeenCalledWith(
expect.objectContaining({
href: "https://auth.example.com/register",
}),
expect.objectContaining({
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(validClientMetadata),
})
);
});
it("validates client information response schema", async () => {
mockFetch.mockResolvedValueOnce({
ok: true,
status: 200,
json: async () => ({
// Missing required fields
client_secret: "secret123",
}),
});
await expect(
registerClient("https://auth.example.com", {
clientMetadata: validClientMetadata,
})
).rejects.toThrow();
});
it("throws when registration endpoint not available in metadata", async () => {
const metadata = {
issuer: "https://auth.example.com",
authorization_endpoint: "https://auth.example.com/authorize",
token_endpoint: "https://auth.example.com/token",
response_types_supported: ["code"],
};
await expect(
registerClient("https://auth.example.com", {
metadata,
clientMetadata: validClientMetadata,
})
).rejects.toThrow(/does not support dynamic client registration/);
});
it("throws on error response", async () => {
mockFetch.mockResolvedValueOnce({
ok: false,
status: 400,
});
await expect(
registerClient("https://auth.example.com", {
clientMetadata: validClientMetadata,
})
).rejects.toThrow("Dynamic client registration failed");
});
});
});
================================================
File: src/client/auth.ts
================================================
import pkceChallenge from "pkce-challenge";
import { LATEST_PROTOCOL_VERSION } from "../types.js";
import type { OAuthClientMetadata, OAuthClientInformation, OAuthTokens, OAuthMetadata, OAuthClientInformationFull } from "../shared/auth.js";
import { OAuthClientInformationFullSchema, OAuthMetadataSchema, OAuthTokensSchema } from "../shared/auth.js";
/**
* Implements an end-to-end OAuth client to be used with one MCP server.
*
* This client relies upon a concept of an authorized "session," the exact
* meaning of which is application-defined. Tokens, authorization codes, and
* code verifiers should not cross different sessions.
*/
export interface OAuthClientProvider {
/**
* The URL to redirect the user agent to after authorization.
*/
get redirectUrl(): string | URL;
/**
* Metadata about this OAuth client.
*/
get clientMetadata(): OAuthClientMetadata;
/**
* Loads information about this OAuth client, as registered already with the
* server, or returns `undefined` if the client is not registered with the
* server.
*/
clientInformation(): OAuthClientInformation | undefined | Promise<OAuthClientInformation | undefined>;
/**
* If implemented, this permits the OAuth client to dynamically register with
* the server. Client information saved this way should later be read via
* `clientInformation()`.
*
* This method is not required to be implemented if client information is
* statically known (e.g., pre-registered).
*/
saveClientInformation?(clientInformation: OAuthClientInformationFull): void | Promise<void>;
/**
* Loads any existing OAuth tokens for the current session, or returns
* `undefined` if there are no saved tokens.
*/
tokens(): OAuthTokens | undefined | Promise<OAuthTokens | undefined>;
/**
* Stores new OAuth tokens for the current session, after a successful
* authorization.
*/
saveTokens(tokens: OAuthTokens): void | Promise<void>;
/**
* Invoked to redirect the user agent to the given URL to begin the authorization flow.
*/
redirectToAuthorization(authorizationUrl: URL): void | Promise<void>;
/**
* Saves a PKCE code verifier for the current session, before redirecting to
* the authorization flow.
*/
saveCodeVerifier(codeVerifier: string): void | Promise<void>;
/**
* Loads the PKCE code verifier for the current session, necessary to validate
* the authorization result.
*/
codeVerifier(): string | Promise<string>;
}
export type AuthResult = "AUTHORIZED" | "REDIRECT";
export class UnauthorizedError extends Error {
constructor(message?: string) {
super(message ?? "Unauthorized");
}
}
/**
* Orchestrates the full auth flow with a server.
*
* This can be used as a single entry point for all authorization functionality,
* instead of linking together the other lower-level functions in this module.
*/
export async function auth(
provider: OAuthClientProvider,
{ serverUrl, authorizationCode }: { serverUrl: string | URL, authorizationCode?: string }): Promise<AuthResult> {
const metadata = await discoverOAuthMetadata(serverUrl);
// Handle client registration if needed
let clientInformation = await Promise.resolve(provider.clientInformation());
if (!clientInformation) {
if (authorizationCode !== undefined) {
throw new Error("Existing OAuth client information is required when exchanging an authorization code");
}
if (!provider.saveClientInformation) {
throw new Error("OAuth client information must be saveable for dynamic registration");
}
const fullInformation = await registerClient(serverUrl, {
metadata,
clientMetadata: provider.clientMetadata,
});
await provider.saveClientInformation(fullInformation);
clientInformation = fullInformation;
}
// Exchange authorization code for tokens
if (authorizationCode !== undefined) {
const codeVerifier = await provider.codeVerifier();
const tokens = await exchangeAuthorization(serverUrl, {
metadata,
clientInformation,
authorizationCode,
codeVerifier,
});
await provider.saveTokens(tokens);
return "AUTHORIZED";
}
const tokens = await provider.tokens();
// Handle token refresh or new authorization
if (tokens?.refresh_token) {
try {
// Attempt to refresh the token
const newTokens = await refreshAuthorization(serverUrl, {
metadata,
clientInformation,
refreshToken: tokens.refresh_token,
});
await provider.saveTokens(newTokens);
return "AUTHORIZED";
} catch (error) {
console.error("Could not refresh OAuth tokens:", error);
}
}
// Start new authorization flow
const { authorizationUrl, codeVerifier } = await startAuthorization(serverUrl, {
metadata,
clientInformation,
redirectUrl: provider.redirectUrl
});
await provider.saveCodeVerifier(codeVerifier);
await provider.redirectToAuthorization(authorizationUrl);
return "REDIRECT";
}
/**
* Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata.
*
* If the server returns a 404 for the well-known endpoint, this function will
* return `undefined`. Any other errors will be thrown as exceptions.
*/
export async function discoverOAuthMetadata(
serverUrl: string | URL,
opts?: { protocolVersion?: string },
): Promise<OAuthMetadata | undefined> {
const url = new URL("/.well-known/oauth-authorization-server", serverUrl);
let response: Response;
try {
response = await fetch(url, {
headers: {
"MCP-Protocol-Version": opts?.protocolVersion ?? LATEST_PROTOCOL_VERSION
}
});
} catch (error) {
// CORS errors come back as TypeError
if (error instanceof TypeError) {
response = await fetch(url);
} else {
throw error;
}
}
if (response.status === 404) {
return undefined;
}
if (!response.ok) {
throw new Error(
`HTTP ${response.status} trying to load well-known OAuth metadata`,
);
}
return OAuthMetadataSchema.parse(await response.json());
}
/**
* Begins the authorization flow with the given server, by generating a PKCE challenge and constructing the authorization URL.
*/
export async function startAuthorization(
serverUrl: string | URL,
{
metadata,
clientInformation,
redirectUrl,
}: {
metadata?: OAuthMetadata;
clientInformation: OAuthClientInformation;
redirectUrl: string | URL;
},
): Promise<{ authorizationUrl: URL; codeVerifier: string }> {
const responseType = "code";
const codeChallengeMethod = "S256";
let authorizationUrl: URL;
if (metadata) {
authorizationUrl = new URL(metadata.authorization_endpoint);
if (!metadata.response_types_supported.includes(responseType)) {
throw new Error(
`Incompatible auth server: does not support response type ${responseType}`,
);
}
if (
!metadata.code_challenge_methods_supported ||
!metadata.code_challenge_methods_supported.includes(codeChallengeMethod)
) {
throw new Error(
`Incompatible auth server: does not support code challenge method ${codeChallengeMethod}`,
);
}
} else {
authorizationUrl = new URL("/authorize", serverUrl);
}
// Generate PKCE challenge
const challenge = await pkceChallenge();
const codeVerifier = challenge.code_verifier;
const codeChallenge = challenge.code_challenge;
authorizationUrl.searchParams.set("response_type", responseType);
authorizationUrl.searchParams.set("client_id", clientInformation.client_id);
authorizationUrl.searchParams.set("code_challenge", codeChallenge);
authorizationUrl.searchParams.set(
"code_challenge_method",
codeChallengeMethod,
);
authorizationUrl.searchParams.set("redirect_uri", String(redirectUrl));
return { authorizationUrl, codeVerifier };
}
/**
* Exchanges an authorization code for an access token with the given server.
*/
export async function exchangeAuthorization(
serverUrl: string | URL,
{
metadata,
clientInformation,
authorizationCode,
codeVerifier,
}: {
metadata?: OAuthMetadata;
clientInformation: OAuthClientInformation;
authorizationCode: string;
codeVerifier: string;
},
): Promise<OAuthTokens> {
const grantType = "authorization_code";
let tokenUrl: URL;
if (metadata) {
tokenUrl = new URL(metadata.token_endpoint);
if (
metadata.grant_types_supported &&
!metadata.grant_types_supported.includes(grantType)
) {
throw new Error(
`Incompatible auth server: does not support grant type ${grantType}`,
);
}
} else {
tokenUrl = new URL("/token", serverUrl);
}
// Exchange code for tokens
const params = new URLSearchParams({
grant_type: grantType,
client_id: clientInformation.client_id,
code: authorizationCode,
code_verifier: codeVerifier,
});
if (clientInformation.client_secret) {
params.set("client_secret", clientInformation.client_secret);
}
const response = await fetch(tokenUrl, {
method: "POST",
headers: {
"Content-Type": "application/x-www-form-urlencoded",
},
body: params,
});
if (!response.ok) {
throw new Error(`Token exchange failed: HTTP ${response.status}`);
}
return OAuthTokensSchema.parse(await response.json());
}
/**
* Exchange a refresh token for an updated access token.
*/
export async function refreshAuthorization(
serverUrl: string | URL,
{
metadata,
clientInformation,
refreshToken,
}: {
metadata?: OAuthMetadata;
clientInformation: OAuthClientInformation;
refreshToken: string;
},
): Promise<OAuthTokens> {
const grantType = "refresh_token";
let tokenUrl: URL;
if (metadata) {
tokenUrl = new URL(metadata.token_endpoint);
if (
metadata.grant_types_supported &&
!metadata.grant_types_supported.includes(grantType)
) {
throw new Error(
`Incompatible auth server: does not support grant type ${grantType}`,
);
}
} else {
tokenUrl = new URL("/token", serverUrl);
}
// Exchange refresh token
const params = new URLSearchParams({
grant_type: grantType,
client_id: clientInformation.client_id,
refresh_token: refreshToken,
});
if (clientInformation.client_secret) {
params.set("client_secret", clientInformation.client_secret);
}
const response = await fetch(tokenUrl, {
method: "POST",
headers: {
"Content-Type": "application/x-www-form-urlencoded",
},
body: params,
});
if (!response.ok) {
throw new Error(`Token refresh failed: HTTP ${response.status}`);
}
return OAuthTokensSchema.parse(await response.json());
}
/**
* Performs OAuth 2.0 Dynamic Client Registration according to RFC 7591.
*/
export async function registerClient(
serverUrl: string | URL,
{
metadata,
clientMetadata,
}: {
metadata?: OAuthMetadata;
clientMetadata: OAuthClientMetadata;
},
): Promise<OAuthClientInformationFull> {
let registrationUrl: URL;
if (metadata) {
if (!metadata.registration_endpoint) {
throw new Error("Incompatible auth server: does not support dynamic client registration");
}
registrationUrl = new URL(metadata.registration_endpoint);
} else {
registrationUrl = new URL("/register", serverUrl);
}
const response = await fetch(registrationUrl, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(clientMetadata),
});
if (!response.ok) {
throw new Error(`Dynamic client registration failed: HTTP ${response.status}`);
}
return OAuthClientInformationFullSchema.parse(await response.json());
}
================================================
File: src/client/index.test.ts
================================================
/* eslint-disable @typescript-eslint/no-unused-vars */
/* eslint-disable no-constant-binary-expression */
/* eslint-disable @typescript-eslint/no-unused-expressions */
import { Client } from "./index.js";
import { z } from "zod";
import {
RequestSchema,
NotificationSchema,
ResultSchema,
LATEST_PROTOCOL_VERSION,
SUPPORTED_PROTOCOL_VERSIONS,
InitializeRequestSchema,
ListResourcesRequestSchema,
ListToolsRequestSchema,
CreateMessageRequestSchema,
ListRootsRequestSchema,
ErrorCode,
} from "../types.js";
import { Transport } from "../shared/transport.js";
import { Server } from "../server/index.js";
import { InMemoryTransport } from "../inMemory.js";
test("should initialize with matching protocol version", async () => {
const clientTransport: Transport = {
start: jest.fn().mockResolvedValue(undefined),
close: jest.fn().mockResolvedValue(undefined),
send: jest.fn().mockImplementation((message) => {
if (message.method === "initialize") {
clientTransport.onmessage?.({
jsonrpc: "2.0",
id: message.id,
result: {
protocolVersion: LATEST_PROTOCOL_VERSION,
capabilities: {},
serverInfo: {
name: "test",
version: "1.0",
},
instructions: "test instructions",
},
});
}
return Promise.resolve();
}),
};
const client = new Client(
{
name: "test client",
version: "1.0",
},
{
capabilities: {
sampling: {},
},
},
);
await client.connect(clientTransport);
// Should have sent initialize with latest version
expect(clientTransport.send).toHaveBeenCalledWith(
expect.objectContaining({
method: "initialize",
params: expect.objectContaining({
protocolVersion: LATEST_PROTOCOL_VERSION,
}),
}),
);
// Should have the instructions returned
expect(client.getInstructions()).toEqual("test instructions");
});
test("should initialize with supported older protocol version", async () => {
const OLD_VERSION = SUPPORTED_PROTOCOL_VERSIONS[1];
const clientTransport: Transport = {
start: jest.fn().mockResolvedValue(undefined),
close: jest.fn().mockResolvedValue(undefined),
send: jest.fn().mockImplementation((message) => {
if (message.method === "initialize") {
clientTransport.onmessage?.({
jsonrpc: "2.0",
id: message.id,
result: {
protocolVersion: OLD_VERSION,
capabilities: {},
serverInfo: {
name: "test",
version: "1.0",
},
},
});
}
return Promise.resolve();
}),
};
const client = new Client(
{
name: "test client",
version: "1.0",
},
{
capabilities: {
sampling: {},
},
},
);
await client.connect(clientTransport);
// Connection should succeed with the older version
expect(client.getServerVersion()).toEqual({
name: "test",
version: "1.0",
});
// Expect no instructions
expect(client.getInstructions()).toBeUndefined();
});
test("should reject unsupported protocol version", async () => {
const clientTransport: Transport = {
start: jest.fn().mockResolvedValue(undefined),
close: jest.fn().mockResolvedValue(undefined),
send: jest.fn().mockImplementation((message) => {
if (message.method === "initialize") {
clientTransport.onmessage?.({
jsonrpc: "2.0",
id: message.id,
result: {
protocolVersion: "invalid-version",
capabilities: {},
serverInfo: {
name: "test",
version: "1.0",
},
},
});
}
return Promise.resolve();
}),
};
const client = new Client(
{
name: "test client",
version: "1.0",
},
{
capabilities: {
sampling: {},
},
},
);
await expect(client.connect(clientTransport)).rejects.toThrow(
"Server's protocol version is not supported: invalid-version",
);
expect(clientTransport.close).toHaveBeenCalled();
});
test("should respect server capabilities", async () => {
const server = new Server(
{
name: "test server",
version: "1.0",
},
{
capabilities: {
resources: {},
tools: {},
},
},
);
server.setRequestHandler(InitializeRequestSchema, (_request) => ({
protocolVersion: LATEST_PROTOCOL_VERSION,
capabilities: {
resources: {},
tools: {},
},
serverInfo: {
name: "test",
version: "1.0",
},
}));
server.setRequestHandler(ListResourcesRequestSchema, () => ({
resources: [],
}));
server.setRequestHandler(ListToolsRequestSchema, () => ({
tools: [],
}));
const [clientTransport, serverTransport] =
InMemoryTransport.createLinkedPair();
const client = new Client(
{
name: "test client",
version: "1.0",
},
{
capabilities: {
sampling: {},
},
enforceStrictCapabilities: true,
},
);
await Promise.all([
client.connect(clientTransport),
server.connect(serverTransport),
]);
// Server supports resources and tools, but not prompts
expect(client.getServerCapabilities()).toEqual({
resources: {},
tools: {},
});
// These should work
await expect(client.listResources()).resolves.not.toThrow();
await expect(client.listTools()).resolves.not.toThrow();
// This should throw because prompts are not supported
await expect(client.listPrompts()).rejects.toThrow(
"Server does not support prompts",
);
});
test("should respect client notification capabilities", async () => {
const server = new Server(
{
name: "test server",
version: "1.0",
},
{
capabilities: {},
},
);
const client = new Client(
{
name: "test client",
version: "1.0",
},
{
capabilities: {
roots: {
listChanged: true,
},
},
},
);
const [clientTransport, serverTransport] =
InMemoryTransport.createLinkedPair();
await Promise.all([
client.connect(clientTransport),
server.connect(serverTransport),
]);
// This should work because the client has the roots.listChanged capability
await expect(client.sendRootsListChanged()).resolves.not.toThrow();
// Create a new client without the roots.listChanged capability
const clientWithoutCapability = new Client(
{
name: "test client without capability",
version: "1.0",
},
{
capabilities: {},
enforceStrictCapabilities: true,
},
);
await clientWithoutCapability.connect(clientTransport);
// This should throw because the client doesn't have the roots.listChanged capability
await expect(clientWithoutCapability.sendRootsListChanged()).rejects.toThrow(
/^Client does not support/,
);
});
test("should respect server notification capabilities", async () => {
const server = new Server(
{
name: "test server",
version: "1.0",
},
{
capabilities: {
logging: {},
resources: {
listChanged: true,
},
},
},
);
const client = new Client(
{
name: "test client",
version: "1.0",
},
{
capabilities: {},
},
);
const [clientTransport, serverTransport] =
InMemoryTransport.createLinkedPair();
await Promise.all([
client.connect(clientTransport),
server.connect(serverTransport),
]);
// These should work because the server has the corresponding capabilities
await expect(
server.sendLoggingMessage({ level: "info", data: "Test" }),
).resolves.not.toThrow();
await expect(server.sendResourceListChanged()).resolves.not.toThrow();
// This should throw because the server doesn't have the tools capability
await expect(server.sendToolListChanged()).rejects.toThrow(
"Server does not support notifying of tool list changes",
);
});
test("should only allow setRequestHandler for declared capabilities", () => {
const client = new Client(
{
name: "test client",
version: "1.0",
},
{
capabilities: {
sampling: {},
},
},
);
// This should work because sampling is a declared capability
expect(() => {
client.setRequestHandler(CreateMessageRequestSchema, () => ({
model: "test-model",
role: "assistant",
content: {
type: "text",
text: "Test response",
},
}));
}).not.toThrow();
// This should throw because roots listing is not a declared capability
expect(() => {
client.setRequestHandler(ListRootsRequestSchema, () => ({}));
}).toThrow("Client does not support roots capability");
});
/*
Test that custom request/notification/result schemas can be used with the Client class.
*/
test("should typecheck", () => {
const GetWeatherRequestSchema = RequestSchema.extend({
method: z.literal("weather/get"),
params: z.object({
city: z.string(),
}),
});
const GetForecastRequestSchema = RequestSchema.extend({
method: z.literal("weather/forecast"),
params: z.object({
city: z.string(),
days: z.number(),
}),
});
const WeatherForecastNotificationSchema = NotificationSchema.extend({
method: z.literal("weather/alert"),
params: z.object({
severity: z.enum(["warning", "watch"]),
message: z.string(),
}),
});
const WeatherRequestSchema = GetWeatherRequestSchema.or(
GetForecastRequestSchema,
);
const WeatherNotificationSchema = WeatherForecastNotificationSchema;
const WeatherResultSchema = ResultSchema.extend({
temperature: z.number(),
conditions: z.string(),
});
type WeatherRequest = z.infer<typeof WeatherRequestSchema>;
type WeatherNotification = z.infer<typeof WeatherNotificationSchema>;
type WeatherResult = z.infer<typeof WeatherResultSchema>;
// Create a typed Client for weather data
const weatherClient = new Client<
WeatherRequest,
WeatherNotification,
WeatherResult
>(
{
name: "WeatherClient",
version: "1.0.0",
},
{
capabilities: {
sampling: {},
},
},
);
// Typecheck that only valid weather requests/notifications/results are allowed
false &&
weatherClient.request(
{
method: "weather/get",
params: {
city: "Seattle",
},
},
WeatherResultSchema,
);
false &&
weatherClient.notification({
method: "weather/alert",
params: {
severity: "warning",
message: "Storm approaching",
},
});
});
test("should handle client cancelling a request", async () => {
const server = new Server(
{
name: "test server",
version: "1.0",
},
{
capabilities: {
resources: {},
},
},
);
// Set up server to delay responding to listResources
server.setRequestHandler(
ListResourcesRequestSchema,
async (request, extra) => {
await new Promise((resolve) => setTimeout(resolve, 1000));
return {
resources: [],
};
},
);
const [clientTransport, serverTransport] =
InMemoryTransport.createLinkedPair();
const client = new Client(
{
name: "test client",
version: "1.0",
},
{
capabilities: {},
},
);
await Promise.all([
client.connect(clientTransport),
server.connect(serverTransport),
]);
// Set up abort controller
const controller = new AbortController();
// Issue request but cancel it immediately
const listResourcesPromise = client.listResources(undefined, {
signal: controller.signal,
});
controller.abort("Cancelled by test");
// Request should be rejected
await expect(listResourcesPromise).rejects.toBe("Cancelled by test");
});
test("should handle request timeout", async () => {
const server = new Server(
{
name: "test server",
version: "1.0",
},
{
capabilities: {
resources: {},
},
},
);
// Set up server with a delayed response
server.setRequestHandler(
ListResourcesRequestSchema,
async (_request, extra) => {
const timer = new Promise((resolve) => {
const timeout = setTimeout(resolve, 100);
extra.signal.addEventListener("abort", () => clearTimeout(timeout));
});
await timer;
return {
resources: [],
};
},
);
const [clientTransport, serverTransport] =
InMemoryTransport.createLinkedPair();
const client = new Client(
{
name: "test client",
version: "1.0",
},
{
capabilities: {},
},
);
await Promise.all([
client.connect(clientTransport),
server.connect(serverTransport),
]);
// Request with 0 msec timeout should fail immediately
await expect(
client.listResources(undefined, { timeout: 0 }),
).rejects.toMatchObject({
code: ErrorCode.RequestTimeout,
});
});
================================================
File: src/client/index.ts
================================================
import {
mergeCapabilities,
Protocol,
ProtocolOptions,
RequestOptions,
} from "../shared/protocol.js";
import { Transport } from "../shared/transport.js";
import {
CallToolRequest,
CallToolResultSchema,
ClientCapabilities,
ClientNotification,
ClientRequest,
ClientResult,
CompatibilityCallToolResultSchema,
CompleteRequest,
CompleteResultSchema,
EmptyResultSchema,
GetPromptRequest,
GetPromptResultSchema,
Implementation,
InitializeResultSchema,
LATEST_PROTOCOL_VERSION,
ListPromptsRequest,
ListPromptsResultSchema,
ListResourcesRequest,
ListResourcesResultSchema,
ListResourceTemplatesRequest,
ListResourceTemplatesResultSchema,
ListToolsRequest,
ListToolsResultSchema,
LoggingLevel,
Notification,
ReadResourceRequest,
ReadResourceResultSchema,
Request,
Result,
ServerCapabilities,
SubscribeRequest,
SUPPORTED_PROTOCOL_VERSIONS,
UnsubscribeRequest,
} from "../types.js";
export type ClientOptions = ProtocolOptions & {
/**
* Capabilities to advertise as being supported by this client.
*/
capabilities?: ClientCapabilities;
};
/**
* An MCP client on top of a pluggable transport.
*
* The client will automatically begin the initialization flow with the server when connect() is called.
*
* To use with custom types, extend the base Request/Notification/Result types and pass them as type parameters:
*
* ```typescript
* // Custom schemas
* const CustomRequestSchema = RequestSchema.extend({...})
* const CustomNotificationSchema = NotificationSchema.extend({...})
* const CustomResultSchema = ResultSchema.extend({...})
*
* // Type aliases
* type CustomRequest = z.infer<typeof CustomRequestSchema>
* type CustomNotification = z.infer<typeof CustomNotificationSchema>
* type CustomResult = z.infer<typeof CustomResultSchema>
*
* // Create typed client
* const client = new Client<CustomRequest, CustomNotification, CustomResult>({
* name: "CustomClient",
* version: "1.0.0"
* })
* ```
*/
export class Client<
RequestT extends Request = Request,
NotificationT extends Notification = Notification,
ResultT extends Result = Result,
> extends Protocol<
ClientRequest | RequestT,
ClientNotification | NotificationT,
ClientResult | ResultT
> {
private _serverCapabilities?: ServerCapabilities;
private _serverVersion?: Implementation;
private _capabilities: ClientCapabilities;
private _instructions?: string;
/**
* Initializes this client with the given name and version information.
*/
constructor(
private _clientInfo: Implementation,
options?: ClientOptions,
) {
super(options);
this._capabilities = options?.capabilities ?? {};
}
/**
* Registers new capabilities. This can only be called before connecting to a transport.
*
* The new capabilities will be merged with any existing capabilities previously given (e.g., at initialization).
*/
public registerCapabilities(capabilities: ClientCapabilities): void {
if (this.transport) {
throw new Error(
"Cannot register capabilities after connecting to transport",
);
}
this._capabilities = mergeCapabilities(this._capabilities, capabilities);
}
protected assertCapability(
capability: keyof ServerCapabilities,
method: string,
): void {
if (!this._serverCapabilities?.[capability]) {
throw new Error(
`Server does not support ${capability} (required for ${method})`,
);
}
}
override async connect(transport: Transport): Promise<void> {
await super.connect(transport);
try {
const result = await this.request(
{
method: "initialize",
params: {
protocolVersion: LATEST_PROTOCOL_VERSION,
capabilities: this._capabilities,
clientInfo: this._clientInfo,
},
},
InitializeResultSchema,
);
if (result === undefined) {
throw new Error(`Server sent invalid initialize result: ${result}`);
}
if (!SUPPORTED_PROTOCOL_VERSIONS.includes(result.protocolVersion)) {
throw new Error(
`Server's protocol version is not supported: ${result.protocolVersion}`,
);
}
this._serverCapabilities = result.capabilities;
this._serverVersion = result.serverInfo;
this._instructions = result.instructions;
await this.notification({
method: "notifications/initialized",
});
} catch (error) {
// Disconnect if initialization fails.
void this.close();
throw error;
}
}
/**
* After initialization has completed, this will be populated with the server's reported capabilities.
*/
getServerCapabilities(): ServerCapabilities | undefined {
return this._serverCapabilities;
}
/**
* After initialization has completed, this will be populated with information about the server's name and version.
*/
getServerVersion(): Implementation | undefined {
return this._serverVersion;
}
/**
* After initialization has completed, this may be populated with information about the server's instructions.
*/
getInstructions(): string | undefined {
return this._instructions;
}
protected assertCapabilityForMethod(method: RequestT["method"]): void {
switch (method as ClientRequest["method"]) {
case "logging/setLevel":
if (!this._serverCapabilities?.logging) {
throw new Error(
`Server does not support logging (required for ${method})`,
);
}
break;
case "prompts/get":
case "prompts/list":
if (!this._serverCapabilities?.prompts) {
throw new Error(
`Server does not support prompts (required for ${method})`,
);
}
break;
case "resources/list":
case "resources/templates/list":
case "resources/read":
case "resources/subscribe":
case "resources/unsubscribe":
if (!this._serverCapabilities?.resources) {
throw new Error(
`Server does not support resources (required for ${method})`,
);
}
if (
method === "resources/subscribe" &&
!this._serverCapabilities.resources.subscribe
) {
throw new Error(
`Server does not support resource subscriptions (required for ${method})`,
);
}
break;
case "tools/call":
case "tools/list":
if (!this._serverCapabilities?.tools) {
throw new Error(
`Server does not support tools (required for ${method})`,
);
}
break;
case "completion/complete":
if (!this._serverCapabilities?.prompts) {
throw new Error(
`Server does not support prompts (required for ${method})`,
);
}
break;
case "initialize":
// No specific capability required for initialize
break;
case "ping":
// No specific capability required for ping
break;
}
}
protected assertNotificationCapability(
method: NotificationT["method"],
): void {
switch (method as ClientNotification["method"]) {
case "notifications/roots/list_changed":
if (!this._capabilities.roots?.listChanged) {
throw new Error(
`Client does not support roots list changed notifications (required for ${method})`,
);
}
break;
case "notifications/initialized":
// No specific capability required for initialized
break;
case "notifications/cancelled":
// Cancellation notifications are always allowed
break;
case "notifications/progress":
// Progress notifications are always allowed
break;
}
}
protected assertRequestHandlerCapability(method: string): void {
switch (method) {
case "sampling/createMessage":
if (!this._capabilities.sampling) {
throw new Error(
`Client does not support sampling capability (required for ${method})`,
);
}
break;
case "roots/list":
if (!this._capabilities.roots) {
throw new Error(
`Client does not support roots capability (required for ${method})`,
);
}
break;
case "ping":
// No specific capability required for ping
break;
}
}
async ping(options?: RequestOptions) {
return this.request({ method: "ping" }, EmptyResultSchema, options);
}
async complete(params: CompleteRequest["params"], options?: RequestOptions) {
return this.request(
{ method: "completion/complete", params },
CompleteResultSchema,
options,
);
}
async setLoggingLevel(level: LoggingLevel, options?: RequestOptions) {
return this.request(
{ method: "logging/setLevel", params: { level } },
EmptyResultSchema,
options,
);
}
async getPrompt(
params: GetPromptRequest["params"],
options?: RequestOptions,
) {
return this.request(
{ method: "prompts/get", params },
GetPromptResultSchema,
options,
);
}
async listPrompts(
params?: ListPromptsRequest["params"],
options?: RequestOptions,
) {
return this.request(
{ method: "prompts/list", params },
ListPromptsResultSchema,
options,
);
}
async listResources(
params?: ListResourcesRequest["params"],
options?: RequestOptions,
) {
return this.request(
{ method: "resources/list", params },
ListResourcesResultSchema,
options,
);
}
async listResourceTemplates(
params?: ListResourceTemplatesRequest["params"],
options?: RequestOptions,
) {
return this.request(
{ method: "resources/templates/list", params },
ListResourceTemplatesResultSchema,
options,
);
}
async readResource(
params: ReadResourceRequest["params"],
options?: RequestOptions,
) {
return this.request(
{ method: "resources/read", params },
ReadResourceResultSchema,
options,
);
}
async subscribeResource(
params: SubscribeRequest["params"],
options?: RequestOptions,
) {
return this.request(
{ method: "resources/subscribe", params },
EmptyResultSchema,
options,
);
}
async unsubscribeResource(
params: UnsubscribeRequest["params"],
options?: RequestOptions,
) {
return this.request(
{ method: "resources/unsubscribe", params },
EmptyResultSchema,
options,
);
}
async callTool(
params: CallToolRequest["params"],
resultSchema:
| typeof CallToolResultSchema
| typeof CompatibilityCallToolResultSchema = CallToolResultSchema,
options?: RequestOptions,
) {
return this.request(
{ method: "tools/call", params },
resultSchema,
options,
);
}
async listTools(
params?: ListToolsRequest["params"],
options?: RequestOptions,
) {
return this.request(
{ method: "tools/list", params },
ListToolsResultSchema,
options,
);
}
async sendRootsListChanged() {
return this.notification({ method: "notifications/roots/list_changed" });
}
}
================================================
File: src/client/sse.test.ts
================================================
import { createServer, type IncomingMessage, type Server } from "http";
import { AddressInfo } from "net";
import { JSONRPCMessage } from "../types.js";
import { SSEClientTransport } from "./sse.js";
import { OAuthClientProvider, UnauthorizedError } from "./auth.js";
import { OAuthTokens } from "../shared/auth.js";
describe("SSEClientTransport", () => {
let server: Server;
let transport: SSEClientTransport;
let baseUrl: URL;
let lastServerRequest: IncomingMessage;
let sendServerMessage: ((message: string) => void) | null = null;
beforeEach((done) => {
// Reset state
lastServerRequest = null as unknown as IncomingMessage;
sendServerMessage = null;
// Create a test server that will receive the EventSource connection
server = createServer((req, res) => {
lastServerRequest = req;
// Send SSE headers
res.writeHead(200, {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
Connection: "keep-alive",
});
// Send the endpoint event
res.write("event: endpoint\n");
res.write(`data: ${baseUrl.href}\n\n`);
// Store reference to send function for tests
sendServerMessage = (message: string) => {
res.write(`data: ${message}\n\n`);
};
// Handle request body for POST endpoints
if (req.method === "POST") {
let body = "";
req.on("data", (chunk) => {
body += chunk;
});
req.on("end", () => {
(req as IncomingMessage & { body: string }).body = body;
res.end();
});
}
});
// Start server on random port
server.listen(0, "127.0.0.1", () => {
const addr = server.address() as AddressInfo;
baseUrl = new URL(`http://127.0.0.1:${addr.port}`);
done();
});
});
afterEach(async () => {
await transport.close();
await server.close();
jest.clearAllMocks();
});
describe("connection handling", () => {
it("establishes SSE connection and receives endpoint", async () => {
transport = new SSEClientTransport(baseUrl);
await transport.start();
expect(lastServerRequest.headers.accept).toBe("text/event-stream");
expect(lastServerRequest.method).toBe("GET");
});
it("rejects if server returns non-200 status", async () => {
// Create a server that returns 403
await server.close();
server = createServer((req, res) => {
res.writeHead(403);
res.end();
});
await new Promise<void>((resolve) => {
server.listen(0, "127.0.0.1", () => {
const addr = server.address() as AddressInfo;
baseUrl = new URL(`http://127.0.0.1:${addr.port}`);
resolve();
});
});
transport = new SSEClientTransport(baseUrl);
await expect(transport.start()).rejects.toThrow();
});
it("closes EventSource connection on close()", async () => {
transport = new SSEClientTransport(baseUrl);
await transport.start();
const closePromise = new Promise((resolve) => {
lastServerRequest.on("close", resolve);
});
await transport.close();
await closePromise;
});
});
describe("message handling", () => {
it("receives and parses JSON-RPC messages", async () => {
const receivedMessages: JSONRPCMessage[] = [];
transport = new SSEClientTransport(baseUrl);
transport.onmessage = (msg) => receivedMessages.push(msg);
await transport.start();
const testMessage: JSONRPCMessage = {
jsonrpc: "2.0",
id: "test-1",
method: "test",
params: { foo: "bar" },
};
sendServerMessage!(JSON.stringify(testMessage));
// Wait for message processing
await new Promise((resolve) => setTimeout(resolve, 50));
expect(receivedMessages).toHaveLength(1);
expect(receivedMessages[0]).toEqual(testMessage);
});
it("handles malformed JSON messages", async () => {
const errors: Error[] = [];
transport = new SSEClientTransport(baseUrl);
transport.onerror = (err) => errors.push(err);
await transport.start();
sendServerMessage!("invalid json");
// Wait for message processing
await new Promise((resolve) => setTimeout(resolve, 50));
expect(errors).toHaveLength(1);
expect(errors[0].message).toMatch(/JSON/);
});
it("handles messages via POST requests", async () => {
transport = new SSEClientTransport(baseUrl);
await transport.start();
const testMessage: JSONRPCMessage = {
jsonrpc: "2.0",
id: "test-1",
method: "test",
params: { foo: "bar" },
};
await transport.send(testMessage);
// Wait for request processing
await new Promise((resolve) => setTimeout(resolve, 50));
expect(lastServerRequest.method).toBe("POST");
expect(lastServerRequest.headers["content-type"]).toBe(
"application/json",
);
expect(
JSON.parse(
(lastServerRequest as IncomingMessage & { body: string }).body,
),
).toEqual(testMessage);
});
it("handles POST request failures", async () => {
// Create a server that returns 500 for POST
await server.close();
server = createServer((req, res) => {
if (req.method === "GET") {
res.writeHead(200, {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
Connection: "keep-alive",
});
res.write("event: endpoint\n");
res.write(`data: ${baseUrl.href}\n\n`);
} else {
res.writeHead(500);
res.end("Internal error");
}
});
await new Promise<void>((resolve) => {
server.listen(0, "127.0.0.1", () => {
const addr = server.address() as AddressInfo;
baseUrl = new URL(`http://127.0.0.1:${addr.port}`);
resolve();
});
});
transport = new SSEClientTransport(baseUrl);
await transport.start();
const testMessage: JSONRPCMessage = {
jsonrpc: "2.0",
id: "test-1",
method: "test",
params: {},
};
await expect(transport.send(testMessage)).rejects.toThrow(/500/);
});
});
describe("header handling", () => {
it("uses custom fetch implementation from EventSourceInit to add auth headers", async () => {
const authToken = "Bearer test-token";
// Create a fetch wrapper that adds auth header
const fetchWithAuth = (url: string | URL, init?: RequestInit) => {
const headers = new Headers(init?.headers);
headers.set("Authorization", authToken);
return fetch(url.toString(), { ...init, headers });
};
transport = new SSEClientTransport(baseUrl, {
eventSourceInit: {
fetch: fetchWithAuth,
},
});
await transport.start();
// Verify the auth header was received by the server
expect(lastServerRequest.headers.authorization).toBe(authToken);
});
it("passes custom headers to fetch requests", async () => {
const customHeaders = {
Authorization: "Bearer test-token",
"X-Custom-Header": "custom-value",
};
transport = new SSEClientTransport(baseUrl, {
requestInit: {
headers: customHeaders,
},
});
await transport.start();
// Store original fetch
const originalFetch = global.fetch;
try {
// Mock fetch for the message sending test
global.fetch = jest.fn().mockResolvedValue({
ok: true,
});
const message: JSONRPCMessage = {
jsonrpc: "2.0",
id: "1",
method: "test",
params: {},
};
await transport.send(message);
// Verify fetch was called with correct headers
expect(global.fetch).toHaveBeenCalledWith(
expect.any(URL),
expect.objectContaining({
headers: expect.any(Headers),
}),
);
const calledHeaders = (global.fetch as jest.Mock).mock.calls[0][1]
.headers;
expect(calledHeaders.get("Authorization")).toBe(
customHeaders.Authorization,
);
expect(calledHeaders.get("X-Custom-Header")).toBe(
customHeaders["X-Custom-Header"],
);
expect(calledHeaders.get("content-type")).toBe("application/json");
} finally {
// Restore original fetch
global.fetch = originalFetch;
}
});
});
describe("auth handling", () => {
let mockAuthProvider: jest.Mocked<OAuthClientProvider>;
beforeEach(() => {
mockAuthProvider = {
get redirectUrl() { return "http://localhost/callback"; },
get clientMetadata() { return { redirect_uris: ["http://localhost/callback"] }; },
clientInformation: jest.fn(() => ({ client_id: "test-client-id", client_secret: "test-client-secret" })),
tokens: jest.fn(),
saveTokens: jest.fn(),
redirectToAuthorization: jest.fn(),
saveCodeVerifier: jest.fn(),
codeVerifier: jest.fn(),
};
});
it("attaches auth header from provider on SSE connection", async () => {
mockAuthProvider.tokens.mockResolvedValue({
access_token: "test-token",
token_type: "Bearer"
});
transport = new SSEClientTransport(baseUrl, {
authProvider: mockAuthProvider,
});
await transport.start();
expect(lastServerRequest.headers.authorization).toBe("Bearer test-token");
expect(mockAuthProvider.tokens).toHaveBeenCalled();
});
it("attaches auth header from provider on POST requests", async () => {
mockAuthProvider.tokens.mockResolvedValue({
access_token: "test-token",
token_type: "Bearer"
});
transport = new SSEClientTransport(baseUrl, {
authProvider: mockAuthProvider,
});
await transport.start();
const message: JSONRPCMessage = {
jsonrpc: "2.0",
id: "1",
method: "test",
params: {},
};
await transport.send(message);
expect(lastServerRequest.headers.authorization).toBe("Bearer test-token");
expect(mockAuthProvider.tokens).toHaveBeenCalled();
});
it("attempts auth flow on 401 during SSE connection", async () => {
// Create server that returns 401s
await server.close();
server = createServer((req, res) => {
lastServerRequest = req;
if (req.url !== "/") {
res.writeHead(404).end();
} else {
res.writeHead(401).end();
}
});
await new Promise<void>(resolve => {
server.listen(0, "127.0.0.1", () => {
const addr = server.address() as AddressInfo;
baseUrl = new URL(`http://127.0.0.1:${addr.port}`);
resolve();
});
});
transport = new SSEClientTransport(baseUrl, {
authProvider: mockAuthProvider,
});
await expect(() => transport.start()).rejects.toThrow(UnauthorizedError);
expect(mockAuthProvider.redirectToAuthorization.mock.calls).toHaveLength(1);
});
it("attempts auth flow on 401 during POST request", async () => {
// Create server that accepts SSE but returns 401 on POST
await server.close();
server = createServer((req, res) => {
lastServerRequest = req;
switch (req.method) {
case "GET":
if (req.url !== "/") {
res.writeHead(404).end();
return;
}
res.writeHead(200, {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
Connection: "keep-alive",
});
res.write("event: endpoint\n");
res.write(`data: ${baseUrl.href}\n\n`);
break;
case "POST":
res.writeHead(401);
res.end();
break;
}
});
await new Promise<void>(resolve => {
server.listen(0, "127.0.0.1", () => {
const addr = server.address() as AddressInfo;
baseUrl = new URL(`http://127.0.0.1:${addr.port}`);
resolve();
});
});
transport = new SSEClientTransport(baseUrl, {
authProvider: mockAuthProvider,
});
await transport.start();
const message: JSONRPCMessage = {
jsonrpc: "2.0",
id: "1",
method: "test",
params: {},
};
await expect(() => transport.send(message)).rejects.toThrow(UnauthorizedError);
expect(mockAuthProvider.redirectToAuthorization.mock.calls).toHaveLength(1);
});
it("respects custom headers when using auth provider", async () => {
mockAuthProvider.tokens.mockResolvedValue({
access_token: "test-token",
token_type: "Bearer"
});
const customHeaders = {
"X-Custom-Header": "custom-value",
};
transport = new SSEClientTransport(baseUrl, {
authProvider: mockAuthProvider,
requestInit: {
headers: customHeaders,
},
});
await transport.start();
const message: JSONRPCMessage = {
jsonrpc: "2.0",
id: "1",
method: "test",
params: {},
};
await transport.send(message);
expect(lastServerRequest.headers.authorization).toBe("Bearer test-token");
expect(lastServerRequest.headers["x-custom-header"]).toBe("custom-value");
});
it("refreshes expired token during SSE connection", async () => {
// Mock tokens() to return expired token until saveTokens is called
let currentTokens: OAuthTokens = {
access_token: "expired-token",
token_type: "Bearer",
refresh_token: "refresh-token"
};
mockAuthProvider.tokens.mockImplementation(() => currentTokens);
mockAuthProvider.saveTokens.mockImplementation((tokens) => {
currentTokens = tokens;
});
// Create server that returns 401 for expired token, then accepts new token
await server.close();
let connectionAttempts = 0;
server = createServer((req, res) => {
lastServerRequest = req;
if (req.url === "/token" && req.method === "POST") {
// Handle token refresh request
let body = "";
req.on("data", chunk => { body += chunk; });
req.on("end", () => {
const params = new URLSearchParams(body);
if (params.get("grant_type") === "refresh_token" &&
params.get("refresh_token") === "refresh-token" &&
params.get("client_id") === "test-client-id" &&
params.get("client_secret") === "test-client-secret") {
res.writeHead(200, { "Content-Type": "application/json" });
res.end(JSON.stringify({
access_token: "new-token",
token_type: "Bearer",
refresh_token: "new-refresh-token"
}));
} else {
res.writeHead(400).end();
}
});
return;
}
if (req.url !== "/") {
res.writeHead(404).end();
return;
}
const auth = req.headers.authorization;
if (auth === "Bearer expired-token") {
res.writeHead(401).end();
return;
}
if (auth === "Bearer new-token") {
res.writeHead(200, {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
Connection: "keep-alive",
});
res.write("event: endpoint\n");
res.write(`data: ${baseUrl.href}\n\n`);
connectionAttempts++;
return;
}
res.writeHead(401).end();
});
await new Promise<void>(resolve => {
server.listen(0, "127.0.0.1", () => {
const addr = server.address() as AddressInfo;
baseUrl = new URL(`http://127.0.0.1:${addr.port}`);
resolve();
});
});
transport = new SSEClientTransport(baseUrl, {
authProvider: mockAuthProvider,
});
await transport.start();
expect(mockAuthProvider.saveTokens).toHaveBeenCalledWith({
access_token: "new-token",
token_type: "Bearer",
refresh_token: "new-refresh-token"
});
expect(connectionAttempts).toBe(1);
expect(lastServerRequest.headers.authorization).toBe("Bearer new-token");
});
it("refreshes expired token during POST request", async () => {
// Mock tokens() to return expired token until saveTokens is called
let currentTokens: OAuthTokens = {
access_token: "expired-token",
token_type: "Bearer",
refresh_token: "refresh-token"
};
mockAuthProvider.tokens.mockImplementation(() => currentTokens);
mockAuthProvider.saveTokens.mockImplementation((tokens) => {
currentTokens = tokens;
});
// Create server that accepts SSE but returns 401 on POST with expired token
await server.close();
let postAttempts = 0;
server = createServer((req, res) => {
lastServerRequest = req;
if (req.url === "/token" && req.method === "POST") {
// Handle token refresh request
let body = "";
req.on("data", chunk => { body += chunk; });
req.on("end", () => {
const params = new URLSearchParams(body);
if (params.get("grant_type") === "refresh_token" &&
params.get("refresh_token") === "refresh-token" &&
params.get("client_id") === "test-client-id" &&
params.get("client_secret") === "test-client-secret") {
res.writeHead(200, { "Content-Type": "application/json" });
res.end(JSON.stringify({
access_token: "new-token",
token_type: "Bearer",
refresh_token: "new-refresh-token"
}));
} else {
res.writeHead(400).end();
}
});
return;
}
switch (req.method) {
case "GET":
if (req.url !== "/") {
res.writeHead(404).end();
return;
}
res.writeHead(200, {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
Connection: "keep-alive",
});
res.write("event: endpoint\n");
res.write(`data: ${baseUrl.href}\n\n`);
break;
case "POST": {
if (req.url !== "/") {
res.writeHead(404).end();
return;
}
const auth = req.headers.authorization;
if (auth === "Bearer expired-token") {
res.writeHead(401).end();
return;
}
if (auth === "Bearer new-token") {
res.writeHead(200).end();
postAttempts++;
return;
}
res.writeHead(401).end();
break;
}
}
});
await new Promise<void>(resolve => {
server.listen(0, "127.0.0.1", () => {
const addr = server.address() as AddressInfo;
baseUrl = new URL(`http://127.0.0.1:${addr.port}`);
resolve();
});
});
transport = new SSEClientTransport(baseUrl, {
authProvider: mockAuthProvider,
});
await transport.start();
const message: JSONRPCMessage = {
jsonrpc: "2.0",
id: "1",
method: "test",
params: {},
};
await transport.send(message);
expect(mockAuthProvider.saveTokens).toHaveBeenCalledWith({
access_token: "new-token",
token_type: "Bearer",
refresh_token: "new-refresh-token"
});
expect(postAttempts).toBe(1);
expect(lastServerRequest.headers.authorization).toBe("Bearer new-token");
});
it("redirects to authorization if refresh token flow fails", async () => {
// Mock tokens() to return expired token until saveTokens is called
let currentTokens: OAuthTokens = {
access_token: "expired-token",
token_type: "Bearer",
refresh_token: "refresh-token"
};
mockAuthProvider.tokens.mockImplementation(() => currentTokens);
mockAuthProvider.saveTokens.mockImplementation((tokens) => {
currentTokens = tokens;
});
// Create server that returns 401 for all tokens
await server.close();
server = createServer((req, res) => {
lastServerRequest = req;
if (req.url === "/token" && req.method === "POST") {
// Handle token refresh request - always fail
res.writeHead(400).end();
return;
}
if (req.url !== "/") {
res.writeHead(404).end();
return;
}
res.writeHead(401).end();
});
await new Promise<void>(resolve => {
server.listen(0, "127.0.0.1", () => {
const addr = server.address() as AddressInfo;
baseUrl = new URL(`http://127.0.0.1:${addr.port}`);
resolve();
});
});
transport = new SSEClientTransport(baseUrl, {
authProvider: mockAuthProvider,
});
await expect(() => transport.start()).rejects.toThrow(UnauthorizedError);
expect(mockAuthProvider.redirectToAuthorization).toHaveBeenCalled();
});
});
});
================================================
File: src/client/sse.ts
================================================
import { EventSource, type ErrorEvent, type EventSourceInit } from "eventsource";
import { Transport } from "../shared/transport.js";
import { JSONRPCMessage, JSONRPCMessageSchema } from "../types.js";
import { auth, AuthResult, OAuthClientProvider, UnauthorizedError } from "./auth.js";
export class SseError extends Error {
constructor(
public readonly code: number | undefined,
message: string | undefined,
public readonly event: ErrorEvent,
) {
super(`SSE error: ${message}`);
}
}
/**
* Configuration options for the `SSEClientTransport`.
*/
export type SSEClientTransportOptions = {
/**
* An OAuth client provider to use for authentication.
*
* When an `authProvider` is specified and the SSE connection is started:
* 1. The connection is attempted with any existing access token from the `authProvider`.
* 2. If the access token has expired, the `authProvider` is used to refresh the token.
* 3. If token refresh fails or no access token exists, and auth is required, `OAuthClientProvider.redirectToAuthorization` is called, and an `UnauthorizedError` will be thrown from `connect`/`start`.
*
* After the user has finished authorizing via their user agent, and is redirected back to the MCP client application, call `SSEClientTransport.finishAuth` with the authorization code before retrying the connection.
*
* If an `authProvider` is not provided, and auth is required, an `UnauthorizedError` will be thrown.
*
* `UnauthorizedError` might also be thrown when sending any message over the SSE transport, indicating that the session has expired, and needs to be re-authed and reconnected.
*/
authProvider?: OAuthClientProvider;
/**
* Customizes the initial SSE request to the server (the request that begins the stream).
*
* NOTE: Setting this property will prevent an `Authorization` header from
* being automatically attached to the SSE request, if an `authProvider` is
* also given. This can be worked around by setting the `Authorization` header
* manually.
*/
eventSourceInit?: EventSourceInit;
/**
* Customizes recurring POST requests to the server.
*/
requestInit?: RequestInit;
};
/**
* Client transport for SSE: this will connect to a server using Server-Sent Events for receiving
* messages and make separate POST requests for sending messages.
*/
export class SSEClientTransport implements Transport {
private _eventSource?: EventSource;
private _endpoint?: URL;
private _abortController?: AbortController;
private _url: URL;
private _eventSourceInit?: EventSourceInit;
private _requestInit?: RequestInit;
private _authProvider?: OAuthClientProvider;
onclose?: () => void;
onerror?: (error: Error) => void;
onmessage?: (message: JSONRPCMessage) => void;
constructor(
url: URL,
opts?: SSEClientTransportOptions,
) {
this._url = url;
this._eventSourceInit = opts?.eventSourceInit;
this._requestInit = opts?.requestInit;
this._authProvider = opts?.authProvider;
}
private async _authThenStart(): Promise<void> {
if (!this._authProvider) {
throw new UnauthorizedError("No auth provider");
}
let result: AuthResult;
try {
result = await auth(this._authProvider, { serverUrl: this._url });
} catch (error) {
this.onerror?.(error as Error);
throw error;
}
if (result !== "AUTHORIZED") {
throw new UnauthorizedError();
}
return await this._startOrAuth();
}
private async _commonHeaders(): Promise<HeadersInit> {
const headers: HeadersInit = {};
if (this._authProvider) {
const tokens = await this._authProvider.tokens();
if (tokens) {
headers["Authorization"] = `Bearer ${tokens.access_token}`;
}
}
return headers;
}
private _startOrAuth(): Promise<void> {
return new Promise((resolve, reject) => {
this._eventSource = new EventSource(
this._url.href,
this._eventSourceInit ?? {
fetch: (url, init) => this._commonHeaders().then((headers) => fetch(url, {
...init,
headers: {
...headers,
Accept: "text/event-stream"
}
})),
},
);
this._abortController = new AbortController();
this._eventSource.onerror = (event) => {
if (event.code === 401 && this._authProvider) {
this._authThenStart().then(resolve, reject);
return;
}
const error = new SseError(event.code, event.message, event);
reject(error);
this.onerror?.(error);
};
this._eventSource.onopen = () => {
// The connection is open, but we need to wait for the endpoint to be received.
};
this._eventSource.addEventListener("endpoint", (event: Event) => {
const messageEvent = event as MessageEvent;
try {
this._endpoint = new URL(messageEvent.data, this._url);
if (this._endpoint.origin !== this._url.origin) {
throw new Error(
`Endpoint origin does not match connection origin: ${this._endpoint.origin}`,
);
}
} catch (error) {
reject(error);
this.onerror?.(error as Error);
void this.close();
return;
}
resolve();
});
this._eventSource.onmessage = (event: Event) => {
const messageEvent = event as MessageEvent;
let message: JSONRPCMessage;
try {
message = JSONRPCMessageSchema.parse(JSON.parse(messageEvent.data));
} catch (error) {
this.onerror?.(error as Error);
return;
}
this.onmessage?.(message);
};
});
}
async start() {
if (this._eventSource) {
throw new Error(
"SSEClientTransport already started! If using Client class, note that connect() calls start() automatically.",
);
}
return await this._startOrAuth();
}
/**
* Call this method after the user has finished authorizing via their user agent and is redirected back to the MCP client application. This will exchange the authorization code for an access token, enabling the next connection attempt to successfully auth.
*/
async finishAuth(authorizationCode: string): Promise<void> {
if (!this._authProvider) {
throw new UnauthorizedError("No auth provider");
}
const result = await auth(this._authProvider, { serverUrl: this._url, authorizationCode });
if (result !== "AUTHORIZED") {
throw new UnauthorizedError("Failed to authorize");
}
}
async close(): Promise<void> {
this._abortController?.abort();
this._eventSource?.close();
this.onclose?.();
}
async send(message: JSONRPCMessage): Promise<void> {
if (!this._endpoint) {
throw new Error("Not connected");
}
try {
const commonHeaders = await this._commonHeaders();
const headers = new Headers({ ...commonHeaders, ...this._requestInit?.headers });
headers.set("content-type", "application/json");
const init = {
...this._requestInit,
method: "POST",
headers,
body: JSON.stringify(message),
signal: this._abortController?.signal,
};
const response = await fetch(this._endpoint, init);
if (!response.ok) {
if (response.status === 401 && this._authProvider) {
const result = await auth(this._authProvider, { serverUrl: this._url });
if (result !== "AUTHORIZED") {
throw new UnauthorizedError();
}
// Purposely _not_ awaited, so we don't call onerror twice
return this.send(message);
}
const text = await response.text().catch(() => null);
throw new Error(
`Error POSTing to endpoint (HTTP ${response.status}): ${text}`,
);
}
} catch (error) {
this.onerror?.(error as Error);
throw error;
}
}
}
================================================
File: src/client/stdio.test.ts
================================================
import { JSONRPCMessage } from "../types.js";
import { StdioClientTransport, StdioServerParameters } from "./stdio.js";
const serverParameters: StdioServerParameters = {
command: "/usr/bin/tee",
};
test("should start then close cleanly", async () => {
const client = new StdioClientTransport(serverParameters);
client.onerror = (error) => {
throw error;
};
let didClose = false;
client.onclose = () => {
didClose = true;
};
await client.start();
expect(didClose).toBeFalsy();
await client.close();
expect(didClose).toBeTruthy();
});
test("should read messages", async () => {
const client = new StdioClientTransport(serverParameters);
client.onerror = (error) => {
throw error;
};
const messages: JSONRPCMessage[] = [
{
jsonrpc: "2.0",
id: 1,
method: "ping",
},
{
jsonrpc: "2.0",
method: "notifications/initialized",
},
];
const readMessages: JSONRPCMessage[] = [];
const finished = new Promise<void>((resolve) => {
client.onmessage = (message) => {
readMessages.push(message);
if (JSON.stringify(message) === JSON.stringify(messages[1])) {
resolve();
}
};
});
await client.start();
await client.send(messages[0]);
await client.send(messages[1]);
await finished;
expect(readMessages).toEqual(messages);
await client.close();
});
================================================
File: src/client/stdio.ts
================================================
import { ChildProcess, IOType, spawn } from "node:child_process";
import process from "node:process";
import { Stream } from "node:stream";
import { ReadBuffer, serializeMessage } from "../shared/stdio.js";
import { Transport } from "../shared/transport.js";
import { JSONRPCMessage } from "../types.js";
export type StdioServerParameters = {
/**
* The executable to run to start the server.
*/
command: string;
/**
* Command line arguments to pass to the executable.
*/
args?: string[];
/**
* The environment to use when spawning the process.
*
* If not specified, the result of getDefaultEnvironment() will be used.
*/
env?: Record<string, string>;
/**
* How to handle stderr of the child process. This matches the semantics of Node's `child_process.spawn`.
*
* The default is "inherit", meaning messages to stderr will be printed to the parent process's stderr.
*/
stderr?: IOType | Stream | number;
/**
* The working directory to use when spawning the process.
*
* If not specified, the current working directory will be inherited.
*/
cwd?: string;
};
/**
* Environment variables to inherit by default, if an environment is not explicitly given.
*/
export const DEFAULT_INHERITED_ENV_VARS =
process.platform === "win32"
? [
"APPDATA",
"HOMEDRIVE",
"HOMEPATH",
"LOCALAPPDATA",
"PATH",
"PROCESSOR_ARCHITECTURE",
"SYSTEMDRIVE",
"SYSTEMROOT",
"TEMP",
"USERNAME",
"USERPROFILE",
]
: /* list inspired by the default env inheritance of sudo */
["HOME", "LOGNAME", "PATH", "SHELL", "TERM", "USER"];
/**
* Returns a default environment object including only environment variables deemed safe to inherit.
*/
export function getDefaultEnvironment(): Record<string, string> {
const env: Record<string, string> = {};
for (const key of DEFAULT_INHERITED_ENV_VARS) {
const value = process.env[key];
if (value === undefined) {
continue;
}
if (value.startsWith("()")) {
// Skip functions, which are a security risk.
continue;
}
env[key] = value;
}
return env;
}
/**
* Client transport for stdio: this will connect to a server by spawning a process and communicating with it over stdin/stdout.
*
* This transport is only available in Node.js environments.
*/
export class StdioClientTransport implements Transport {
private _process?: ChildProcess;
private _abortController: AbortController = new AbortController();
private _readBuffer: ReadBuffer = new ReadBuffer();
private _serverParams: StdioServerParameters;
onclose?: () => void;
onerror?: (error: Error) => void;
onmessage?: (message: JSONRPCMessage) => void;
constructor(server: StdioServerParameters) {
this._serverParams = server;
}
/**
* Starts the server process and prepares to communicate with it.
*/
async start(): Promise<void> {
if (this._process) {
throw new Error(
"StdioClientTransport already started! If using Client class, note that connect() calls start() automatically."
);
}
return new Promise((resolve, reject) => {
this._process = spawn(
this._serverParams.command,
this._serverParams.args ?? [],
{
env: this._serverParams.env ?? getDefaultEnvironment(),
stdio: ["pipe", "pipe", this._serverParams.stderr ?? "inherit"],
shell: false,
signal: this._abortController.signal,
windowsHide: process.platform === "win32" && isElectron(),
cwd: this._serverParams.cwd,
}
);
this._process.on("error", (error) => {
if (error.name === "AbortError") {
// Expected when close() is called.
this.onclose?.();
return;
}
reject(error);
this.onerror?.(error);
});
this._process.on("spawn", () => {
resolve();
});
this._process.on("close", (_code) => {
this._process = undefined;
this.onclose?.();
});
this._process.stdin?.on("error", (error) => {
this.onerror?.(error);
});
this._process.stdout?.on("data", (chunk) => {
this._readBuffer.append(chunk);
this.processReadBuffer();
});
this._process.stdout?.on("error", (error) => {
this.onerror?.(error);
});
});
}
/**
* The stderr stream of the child process, if `StdioServerParameters.stderr` was set to "pipe" or "overlapped".
*
* This is only available after the process has been started.
*/
get stderr(): Stream | null {
return this._process?.stderr ?? null;
}
private processReadBuffer() {
while (true) {
try {
const message = this._readBuffer.readMessage();
if (message === null) {
break;
}
this.onmessage?.(message);
} catch (error) {
this.onerror?.(error as Error);
}
}
}
async close(): Promise<void> {
this._abortController.abort();
this._process = undefined;
this._readBuffer.clear();
}
send(message: JSONRPCMessage): Promise<void> {
return new Promise((resolve) => {
if (!this._process?.stdin) {
throw new Error("Not connected");
}
const json = serializeMessage(message);
if (this._process.stdin.write(json)) {
resolve();
} else {
this._process.stdin.once("drain", resolve);
}
});
}
}
function isElectron() {
return "type" in process;
}
================================================
File: src/client/websocket.ts
================================================
import { Transport } from "../shared/transport.js";
import { JSONRPCMessage, JSONRPCMessageSchema } from "../types.js";
const SUBPROTOCOL = "mcp";
/**
* Client transport for WebSocket: this will connect to a server over the WebSocket protocol.
*/
export class WebSocketClientTransport implements Transport {
private _socket?: WebSocket;
private _url: URL;
onclose?: () => void;
onerror?: (error: Error) => void;
onmessage?: (message: JSONRPCMessage) => void;
constructor(url: URL) {
this._url = url;
}
start(): Promise<void> {
if (this._socket) {
throw new Error(
"WebSocketClientTransport already started! If using Client class, note that connect() calls start() automatically.",
);
}
return new Promise((resolve, reject) => {
this._socket = new WebSocket(this._url, SUBPROTOCOL);
this._socket.onerror = (event) => {
const error =
"error" in event
? (event.error as Error)
: new Error(`WebSocket error: ${JSON.stringify(event)}`);
reject(error);
this.onerror?.(error);
};
this._socket.onopen = () => {
resolve();
};
this._socket.onclose = () => {
this.onclose?.();
};
this._socket.onmessage = (event: MessageEvent) => {
let message: JSONRPCMessage;
try {
message = JSONRPCMessageSchema.parse(JSON.parse(event.data));
} catch (error) {
this.onerror?.(error as Error);
return;
}
this.onmessage?.(message);
};
});
}
async close(): Promise<void> {
this._socket?.close();
}
send(message: JSONRPCMessage): Promise<void> {
return new Promise((resolve, reject) => {
if (!this._socket) {
reject(new Error("Not connected"));
return;
}
this._socket?.send(JSON.stringify(message));
resolve();
});
}
}
================================================
File: src/integration-tests/process-cleanup.test.ts
================================================
import { Server } from "../server/index.js";
import { StdioServerTransport } from "../server/stdio.js";
describe("Process cleanup", () => {
jest.setTimeout(5000); // 5 second timeout
it("should exit cleanly after closing transport", async () => {
const server = new Server(
{
name: "test-server",
version: "1.0.0",
},
{
capabilities: {},
}
);
const transport = new StdioServerTransport();
await server.connect(transport);
// Close the transport
await transport.close();
// If we reach here without hanging, the test passes
// The test runner will fail if the process hangs
expect(true).toBe(true);
});
});
================================================
File: src/server/completable.test.ts
================================================
import { z } from "zod";
import { completable } from "./completable.js";
describe("completable", () => {
it("preserves types and values of underlying schema", () => {
const baseSchema = z.string();
const schema = completable(baseSchema, () => []);
expect(schema.parse("test")).toBe("test");
expect(() => schema.parse(123)).toThrow();
});
it("provides access to completion function", async () => {
const completions = ["foo", "bar", "baz"];
const schema = completable(z.string(), () => completions);
expect(await schema._def.complete("")).toEqual(completions);
});
it("allows async completion functions", async () => {
const completions = ["foo", "bar", "baz"];
const schema = completable(z.string(), async () => completions);
expect(await schema._def.complete("")).toEqual(completions);
});
it("passes current value to completion function", async () => {
const schema = completable(z.string(), (value) => [value + "!"]);
expect(await schema._def.complete("test")).toEqual(["test!"]);
});
it("works with number schemas", async () => {
const schema = completable(z.number(), () => [1, 2, 3]);
expect(schema.parse(1)).toBe(1);
expect(await schema._def.complete(0)).toEqual([1, 2, 3]);
});
it("preserves schema description", () => {
const desc = "test description";
const schema = completable(z.string().describe(desc), () => []);
expect(schema.description).toBe(desc);
});
});
================================================
File: src/server/completable.ts
================================================
import {
ZodTypeAny,
ZodTypeDef,
ZodType,
ParseInput,
ParseReturnType,
RawCreateParams,
ZodErrorMap,
ProcessedCreateParams,
} from "zod";
export enum McpZodTypeKind {
Completable = "McpCompletable",
}
export type CompleteCallback<T extends ZodTypeAny = ZodTypeAny> = (
value: T["_input"],
) => T["_input"][] | Promise<T["_input"][]>;
export interface CompletableDef<T extends ZodTypeAny = ZodTypeAny>
extends ZodTypeDef {
type: T;
complete: CompleteCallback<T>;
typeName: McpZodTypeKind.Completable;
}
export class Completable<T extends ZodTypeAny> extends ZodType<
T["_output"],
CompletableDef<T>,
T["_input"]
> {
_parse(input: ParseInput): ParseReturnType<this["_output"]> {
const { ctx } = this._processInputParams(input);
const data = ctx.data;
return this._def.type._parse({
data,
path: ctx.path,
parent: ctx,
});
}
unwrap() {
return this._def.type;
}
static create = <T extends ZodTypeAny>(
type: T,
params: RawCreateParams & {
complete: CompleteCallback<T>;
},
): Completable<T> => {
return new Completable({
type,
typeName: McpZodTypeKind.Completable,
complete: params.complete,
...processCreateParams(params),
});
};
}
/**
* Wraps a Zod type to provide autocompletion capabilities. Useful for, e.g., prompt arguments in MCP.
*/
export function completable<T extends ZodTypeAny>(
schema: T,
complete: CompleteCallback<T>,
): Completable<T> {
return Completable.create(schema, { ...schema._def, complete });
}
// Not sure why this isn't exported from Zod:
// https://github.com/colinhacks/zod/blob/f7ad26147ba291cb3fb257545972a8e00e767470/src/types.ts#L130
function processCreateParams(params: RawCreateParams): ProcessedCreateParams {
if (!params) return {};
const { errorMap, invalid_type_error, required_error, description } = params;
if (errorMap && (invalid_type_error || required_error)) {
throw new Error(
`Can't use "invalid_type_error" or "required_error" in conjunction with custom error map.`,
);
}
if (errorMap) return { errorMap: errorMap, description };
const customMap: ZodErrorMap = (iss, ctx) => {
const { message } = params;
if (iss.code === "invalid_enum_value") {
return { message: message ?? ctx.defaultError };
}
if (typeof ctx.data === "undefined") {
return { message: message ?? required_error ?? ctx.defaultError };
}
if (iss.code !== "invalid_type") return { message: ctx.defaultError };
return { message: message ?? invalid_type_error ?? ctx.defaultError };
};
return { errorMap: customMap, description };
}
================================================
File: src/server/index.test.ts
================================================
/* eslint-disable @typescript-eslint/no-unused-vars */
/* eslint-disable no-constant-binary-expression */
/* eslint-disable @typescript-eslint/no-unused-expressions */
import { Server } from "./index.js";
import { z } from "zod";
import {
RequestSchema,
NotificationSchema,
ResultSchema,
LATEST_PROTOCOL_VERSION,
SUPPORTED_PROTOCOL_VERSIONS,
CreateMessageRequestSchema,
ListPromptsRequestSchema,
ListResourcesRequestSchema,
ListToolsRequestSchema,
SetLevelRequestSchema,
ErrorCode,
} from "../types.js";
import { Transport } from "../shared/transport.js";
import { InMemoryTransport } from "../inMemory.js";
import { Client } from "../client/index.js";
test("should accept latest protocol version", async () => {
let sendPromiseResolve: (value: unknown) => void;
const sendPromise = new Promise((resolve) => {
sendPromiseResolve = resolve;
});
const serverTransport: Transport = {
start: jest.fn().mockResolvedValue(undefined),
close: jest.fn().mockResolvedValue(undefined),
send: jest.fn().mockImplementation((message) => {
if (message.id === 1 && message.result) {
expect(message.result).toEqual({
protocolVersion: LATEST_PROTOCOL_VERSION,
capabilities: expect.any(Object),
serverInfo: {
name: "test server",
version: "1.0",
},
instructions: "Test instructions",
});
sendPromiseResolve(undefined);
}
return Promise.resolve();
}),
};
const server = new Server(
{
name: "test server",
version: "1.0",
},
{
capabilities: {
prompts: {},
resources: {},
tools: {},
logging: {},
},
instructions: "Test instructions",
},
);
await server.connect(serverTransport);
// Simulate initialize request with latest version
serverTransport.onmessage?.({
jsonrpc: "2.0",
id: 1,
method: "initialize",
params: {
protocolVersion: LATEST_PROTOCOL_VERSION,
capabilities: {},
clientInfo: {
name: "test client",
version: "1.0",
},
},
});
await expect(sendPromise).resolves.toBeUndefined();
});
test("should accept supported older protocol version", async () => {
const OLD_VERSION = SUPPORTED_PROTOCOL_VERSIONS[1];
let sendPromiseResolve: (value: unknown) => void;
const sendPromise = new Promise((resolve) => {
sendPromiseResolve = resolve;
});
const serverTransport: Transport = {
start: jest.fn().mockResolvedValue(undefined),
close: jest.fn().mockResolvedValue(undefined),
send: jest.fn().mockImplementation((message) => {
if (message.id === 1 && message.result) {
expect(message.result).toEqual({
protocolVersion: OLD_VERSION,
capabilities: expect.any(Object),
serverInfo: {
name: "test server",
version: "1.0",
},
});
sendPromiseResolve(undefined);
}
return Promise.resolve();
}),
};
const server = new Server(
{
name: "test server",
version: "1.0",
},
{
capabilities: {
prompts: {},
resources: {},
tools: {},
logging: {},
},
},
);
await server.connect(serverTransport);
// Simulate initialize request with older version
serverTransport.onmessage?.({
jsonrpc: "2.0",
id: 1,
method: "initialize",
params: {
protocolVersion: OLD_VERSION,
capabilities: {},
clientInfo: {
name: "test client",
version: "1.0",
},
},
});
await expect(sendPromise).resolves.toBeUndefined();
});
test("should handle unsupported protocol version", async () => {
let sendPromiseResolve: (value: unknown) => void;
const sendPromise = new Promise((resolve) => {
sendPromiseResolve = resolve;
});
const serverTransport: Transport = {
start: jest.fn().mockResolvedValue(undefined),
close: jest.fn().mockResolvedValue(undefined),
send: jest.fn().mockImplementation((message) => {
if (message.id === 1 && message.result) {
expect(message.result).toEqual({
protocolVersion: LATEST_PROTOCOL_VERSION,
capabilities: expect.any(Object),
serverInfo: {
name: "test server",
version: "1.0",
},
});
sendPromiseResolve(undefined);
}
return Promise.resolve();
}),
};
const server = new Server(
{
name: "test server",
version: "1.0",
},
{
capabilities: {
prompts: {},
resources: {},
tools: {},
logging: {},
},
},
);
await server.connect(serverTransport);
// Simulate initialize request with unsupported version
serverTransport.onmessage?.({
jsonrpc: "2.0",
id: 1,
method: "initialize",
params: {
protocolVersion: "invalid-version",
capabilities: {},
clientInfo: {
name: "test client",
version: "1.0",
},
},
});
await expect(sendPromise).resolves.toBeUndefined();
});
test("should respect client capabilities", async () => {
const server = new Server(
{
name: "test server",
version: "1.0",
},
{
capabilities: {
prompts: {},
resources: {},
tools: {},
logging: {},
},
enforceStrictCapabilities: true,
},
);
const client = new Client(
{
name: "test client",
version: "1.0",
},
{
capabilities: {
sampling: {},
},
},
);
// Implement request handler for sampling/createMessage
client.setRequestHandler(CreateMessageRequestSchema, async (request) => {
// Mock implementation of createMessage
return {
model: "test-model",
role: "assistant",
content: {
type: "text",
text: "This is a test response",
},
};
});
const [clientTransport, serverTransport] =
InMemoryTransport.createLinkedPair();
await Promise.all([
client.connect(clientTransport),
server.connect(serverTransport),
]);
expect(server.getClientCapabilities()).toEqual({ sampling: {} });
// This should work because sampling is supported by the client
await expect(
server.createMessage({
messages: [],
maxTokens: 10,
}),
).resolves.not.toThrow();
// This should still throw because roots are not supported by the client
await expect(server.listRoots()).rejects.toThrow(/^Client does not support/);
});
test("should respect server notification capabilities", async () => {
const server = new Server(
{
name: "test server",
version: "1.0",
},
{
capabilities: {
logging: {},
},
enforceStrictCapabilities: true,
},
);
const [clientTransport, serverTransport] =
InMemoryTransport.createLinkedPair();
await server.connect(serverTransport);
// This should work because logging is supported by the server
await expect(
server.sendLoggingMessage({
level: "info",
data: "Test log message",
}),
).resolves.not.toThrow();
// This should throw because resource notificaitons are not supported by the server
await expect(
server.sendResourceUpdated({ uri: "test://resource" }),
).rejects.toThrow(/^Server does not support/);
});
test("should only allow setRequestHandler for declared capabilities", () => {
const server = new Server(
{
name: "test server",
version: "1.0",
},
{
capabilities: {
prompts: {},
resources: {},
},
},
);
// These should work because the capabilities are declared
expect(() => {
server.setRequestHandler(ListPromptsRequestSchema, () => ({ prompts: [] }));
}).not.toThrow();
expect(() => {
server.setRequestHandler(ListResourcesRequestSchema, () => ({
resources: [],
}));
}).not.toThrow();
// These should throw because the capabilities are not declared
expect(() => {
server.setRequestHandler(ListToolsRequestSchema, () => ({ tools: [] }));
}).toThrow(/^Server does not support tools/);
expect(() => {
server.setRequestHandler(SetLevelRequestSchema, () => ({}));
}).toThrow(/^Server does not support logging/);
});
/*
Test that custom request/notification/result schemas can be used with the Server class.
*/
test("should typecheck", () => {
const GetWeatherRequestSchema = RequestSchema.extend({
method: z.literal("weather/get"),
params: z.object({
city: z.string(),
}),
});
const GetForecastRequestSchema = RequestSchema.extend({
method: z.literal("weather/forecast"),
params: z.object({
city: z.string(),
days: z.number(),
}),
});
const WeatherForecastNotificationSchema = NotificationSchema.extend({
method: z.literal("weather/alert"),
params: z.object({
severity: z.enum(["warning", "watch"]),
message: z.string(),
}),
});
const WeatherRequestSchema = GetWeatherRequestSchema.or(
GetForecastRequestSchema,
);
const WeatherNotificationSchema = WeatherForecastNotificationSchema;
const WeatherResultSchema = ResultSchema.extend({
temperature: z.number(),
conditions: z.string(),
});
type WeatherRequest = z.infer<typeof WeatherRequestSchema>;
type WeatherNotification = z.infer<typeof WeatherNotificationSchema>;
type WeatherResult = z.infer<typeof WeatherResultSchema>;
// Create a typed Server for weather data
const weatherServer = new Server<
WeatherRequest,
WeatherNotification,
WeatherResult
>(
{
name: "WeatherServer",
version: "1.0.0",
},
{
capabilities: {
prompts: {},
resources: {},
tools: {},
logging: {},
},
},
);
// Typecheck that only valid weather requests/notifications/results are allowed
weatherServer.setRequestHandler(GetWeatherRequestSchema, (request) => {
return {
temperature: 72,
conditions: "sunny",
};
});
weatherServer.setNotificationHandler(
WeatherForecastNotificationSchema,
(notification) => {
console.log(`Weather alert: ${notification.params.message}`);
},
);
});
test("should handle server cancelling a request", async () => {
const server = new Server(
{
name: "test server",
version: "1.0",
},
{
capabilities: {
sampling: {},
},
},
);
const client = new Client(
{
name: "test client",
version: "1.0",
},
{
capabilities: {
sampling: {},
},
},
);
// Set up client to delay responding to createMessage
client.setRequestHandler(
CreateMessageRequestSchema,
async (_request, extra) => {
await new Promise((resolve) => setTimeout(resolve, 1000));
return {
model: "test",
role: "assistant",
content: {
type: "text",
text: "Test response",
},
};
},
);
const [clientTransport, serverTransport] =
InMemoryTransport.createLinkedPair();
await Promise.all([
client.connect(clientTransport),
server.connect(serverTransport),
]);
// Set up abort controller
const controller = new AbortController();
// Issue request but cancel it immediately
const createMessagePromise = server.createMessage(
{
messages: [],
maxTokens: 10,
},
{
signal: controller.signal,
},
);
controller.abort("Cancelled by test");
// Request should be rejected
await expect(createMessagePromise).rejects.toBe("Cancelled by test");
});
test("should handle request timeout", async () => {
const server = new Server(
{
name: "test server",
version: "1.0",
},
{
capabilities: {
sampling: {},
},
},
);
// Set up client that delays responses
const client = new Client(
{
name: "test client",
version: "1.0",
},
{
capabilities: {
sampling: {},
},
},
);
client.setRequestHandler(
CreateMessageRequestSchema,
async (_request, extra) => {
await new Promise((resolve, reject) => {
const timeout = setTimeout(resolve, 100);
extra.signal.addEventListener("abort", () => {
clearTimeout(timeout);
reject(extra.signal.reason);
});
});
return {
model: "test",
role: "assistant",
content: {
type: "text",
text: "Test response",
},
};
},
);
const [clientTransport, serverTransport] =
InMemoryTransport.createLinkedPair();
await Promise.all([
client.connect(clientTransport),
server.connect(serverTransport),
]);
// Request with 0 msec timeout should fail immediately
await expect(
server.createMessage(
{
messages: [],
maxTokens: 10,
},
{ timeout: 0 },
),
).rejects.toMatchObject({
code: ErrorCode.RequestTimeout,
});
});
================================================
File: src/server/index.ts
================================================
import {
mergeCapabilities,
Protocol,
ProtocolOptions,
RequestOptions,
} from "../shared/protocol.js";
import {
ClientCapabilities,
CreateMessageRequest,
CreateMessageResultSchema,
EmptyResultSchema,
Implementation,
InitializedNotificationSchema,
InitializeRequest,
InitializeRequestSchema,
InitializeResult,
LATEST_PROTOCOL_VERSION,
ListRootsRequest,
ListRootsResultSchema,
LoggingMessageNotification,
Notification,
Request,
ResourceUpdatedNotification,
Result,
ServerCapabilities,
ServerNotification,
ServerRequest,
ServerResult,
SUPPORTED_PROTOCOL_VERSIONS,
} from "../types.js";
export type ServerOptions = ProtocolOptions & {
/**
* Capabilities to advertise as being supported by this server.
*/
capabilities?: ServerCapabilities;
/**
* Optional instructions describing how to use the server and its features.
*/
instructions?: string;
};
/**
* An MCP server on top of a pluggable transport.
*
* This server will automatically respond to the initialization flow as initiated from the client.
*
* To use with custom types, extend the base Request/Notification/Result types and pass them as type parameters:
*
* ```typescript
* // Custom schemas
* const CustomRequestSchema = RequestSchema.extend({...})
* const CustomNotificationSchema = NotificationSchema.extend({...})
* const CustomResultSchema = ResultSchema.extend({...})
*
* // Type aliases
* type CustomRequest = z.infer<typeof CustomRequestSchema>
* type CustomNotification = z.infer<typeof CustomNotificationSchema>
* type CustomResult = z.infer<typeof CustomResultSchema>
*
* // Create typed server
* const server = new Server<CustomRequest, CustomNotification, CustomResult>({
* name: "CustomServer",
* version: "1.0.0"
* })
* ```
*/
export class Server<
RequestT extends Request = Request,
NotificationT extends Notification = Notification,
ResultT extends Result = Result,
> extends Protocol<
ServerRequest | RequestT,
ServerNotification | NotificationT,
ServerResult | ResultT
> {
private _clientCapabilities?: ClientCapabilities;
private _clientVersion?: Implementation;
private _capabilities: ServerCapabilities;
private _instructions?: string;
/**
* Callback for when initialization has fully completed (i.e., the client has sent an `initialized` notification).
*/
oninitialized?: () => void;
/**
* Initializes this server with the given name and version information.
*/
constructor(
private _serverInfo: Implementation,
options?: ServerOptions,
) {
super(options);
this._capabilities = options?.capabilities ?? {};
this._instructions = options?.instructions;
this.setRequestHandler(InitializeRequestSchema, (request) =>
this._oninitialize(request),
);
this.setNotificationHandler(InitializedNotificationSchema, () =>
this.oninitialized?.(),
);
}
/**
* Registers new capabilities. This can only be called before connecting to a transport.
*
* The new capabilities will be merged with any existing capabilities previously given (e.g., at initialization).
*/
public registerCapabilities(capabilities: ServerCapabilities): void {
if (this.transport) {
throw new Error(
"Cannot register capabilities after connecting to transport",
);
}
this._capabilities = mergeCapabilities(this._capabilities, capabilities);
}
protected assertCapabilityForMethod(method: RequestT["method"]): void {
switch (method as ServerRequest["method"]) {
case "sampling/createMessage":
if (!this._clientCapabilities?.sampling) {
throw new Error(
`Client does not support sampling (required for ${method})`,
);
}
break;
case "roots/list":
if (!this._clientCapabilities?.roots) {
throw new Error(
`Client does not support listing roots (required for ${method})`,
);
}
break;
case "ping":
// No specific capability required for ping
break;
}
}
protected assertNotificationCapability(
method: (ServerNotification | NotificationT)["method"],
): void {
switch (method as ServerNotification["method"]) {
case "notifications/message":
if (!this._capabilities.logging) {
throw new Error(
`Server does not support logging (required for ${method})`,
);
}
break;
case "notifications/resources/updated":
case "notifications/resources/list_changed":
if (!this._capabilities.resources) {
throw new Error(
`Server does not support notifying about resources (required for ${method})`,
);
}
break;
case "notifications/tools/list_changed":
if (!this._capabilities.tools) {
throw new Error(
`Server does not support notifying of tool list changes (required for ${method})`,
);
}
break;
case "notifications/prompts/list_changed":
if (!this._capabilities.prompts) {
throw new Error(
`Server does not support notifying of prompt list changes (required for ${method})`,
);
}
break;
case "notifications/cancelled":
// Cancellation notifications are always allowed
break;
case "notifications/progress":
// Progress notifications are always allowed
break;
}
}
protected assertRequestHandlerCapability(method: string): void {
switch (method) {
case "sampling/createMessage":
if (!this._capabilities.sampling) {
throw new Error(
`Server does not support sampling (required for ${method})`,
);
}
break;
case "logging/setLevel":
if (!this._capabilities.logging) {
throw new Error(
`Server does not support logging (required for ${method})`,
);
}
break;
case "prompts/get":
case "prompts/list":
if (!this._capabilities.prompts) {
throw new Error(
`Server does not support prompts (required for ${method})`,
);
}
break;
case "resources/list":
case "resources/templates/list":
case "resources/read":
if (!this._capabilities.resources) {
throw new Error(
`Server does not support resources (required for ${method})`,
);
}
break;
case "tools/call":
case "tools/list":
if (!this._capabilities.tools) {
throw new Error(
`Server does not support tools (required for ${method})`,
);
}
break;
case "ping":
case "initialize":
// No specific capability required for these methods
break;
}
}
private async _oninitialize(
request: InitializeRequest,
): Promise<InitializeResult> {
const requestedVersion = request.params.protocolVersion;
this._clientCapabilities = request.params.capabilities;
this._clientVersion = request.params.clientInfo;
return {
protocolVersion: SUPPORTED_PROTOCOL_VERSIONS.includes(requestedVersion)
? requestedVersion
: LATEST_PROTOCOL_VERSION,
capabilities: this.getCapabilities(),
serverInfo: this._serverInfo,
...(this._instructions && { instructions: this._instructions }),
};
}
/**
* After initialization has completed, this will be populated with the client's reported capabilities.
*/
getClientCapabilities(): ClientCapabilities | undefined {
return this._clientCapabilities;
}
/**
* After initialization has completed, this will be populated with information about the client's name and version.
*/
getClientVersion(): Implementation | undefined {
return this._clientVersion;
}
private getCapabilities(): ServerCapabilities {
return this._capabilities;
}
async ping() {
return this.request({ method: "ping" }, EmptyResultSchema);
}
async createMessage(
params: CreateMessageRequest["params"],
options?: RequestOptions,
) {
return this.request(
{ method: "sampling/createMessage", params },
CreateMessageResultSchema,
options,
);
}
async listRoots(
params?: ListRootsRequest["params"],
options?: RequestOptions,
) {
return this.request(
{ method: "roots/list", params },
ListRootsResultSchema,
options,
);
}
async sendLoggingMessage(params: LoggingMessageNotification["params"]) {
return this.notification({ method: "notifications/message", params });
}
async sendResourceUpdated(params: ResourceUpdatedNotification["params"]) {
return this.notification({
method: "notifications/resources/updated",
params,
});
}
async sendResourceListChanged() {
return this.notification({
method: "notifications/resources/list_changed",
});
}
async sendToolListChanged() {
return this.notification({ method: "notifications/tools/list_changed" });
}
async sendPromptListChanged() {
return this.notification({ method: "notifications/prompts/list_changed" });
}
}
================================================
File: src/server/mcp.test.ts
================================================
import { McpServer } from "./mcp.js";
import { Client } from "../client/index.js";
import { InMemoryTransport } from "../inMemory.js";
import { z } from "zod";
import {
ListToolsResultSchema,
CallToolResultSchema,
ListResourcesResultSchema,
ListResourceTemplatesResultSchema,
ReadResourceResultSchema,
ListPromptsResultSchema,
GetPromptResultSchema,
CompleteResultSchema,
} from "../types.js";
import { ResourceTemplate } from "./mcp.js";
import { completable } from "./completable.js";
import { UriTemplate } from "../shared/uriTemplate.js";
describe("McpServer", () => {
test("should expose underlying Server instance", () => {
const mcpServer = new McpServer({
name: "test server",
version: "1.0",
});
expect(mcpServer.server).toBeDefined();
});
test("should allow sending notifications via Server", async () => {
const mcpServer = new McpServer(
{
name: "test server",
version: "1.0",
},
{ capabilities: { logging: {} } },
);
const client = new Client({
name: "test client",
version: "1.0",
});
const [clientTransport, serverTransport] =
InMemoryTransport.createLinkedPair();
await Promise.all([
client.connect(clientTransport),
mcpServer.server.connect(serverTransport),
]);
// This should work because we're using the underlying server
await expect(
mcpServer.server.sendLoggingMessage({
level: "info",
data: "Test log message",
}),
).resolves.not.toThrow();
});
});
describe("ResourceTemplate", () => {
test("should create ResourceTemplate with string pattern", () => {
const template = new ResourceTemplate("test://{category}/{id}", {
list: undefined,
});
expect(template.uriTemplate.toString()).toBe("test://{category}/{id}");
expect(template.listCallback).toBeUndefined();
});
test("should create ResourceTemplate with UriTemplate", () => {
const uriTemplate = new UriTemplate("test://{category}/{id}");
const template = new ResourceTemplate(uriTemplate, { list: undefined });
expect(template.uriTemplate).toBe(uriTemplate);
expect(template.listCallback).toBeUndefined();
});
test("should create ResourceTemplate with list callback", async () => {
const list = jest.fn().mockResolvedValue({
resources: [{ name: "Test", uri: "test://example" }],
});
const template = new ResourceTemplate("test://{id}", { list });
expect(template.listCallback).toBe(list);
const abortController = new AbortController();
const result = await template.listCallback?.({
signal: abortController.signal,
});
expect(result?.resources).toHaveLength(1);
expect(list).toHaveBeenCalled();
});
});
describe("tool()", () => {
test("should register zero-argument tool", async () => {
const mcpServer = new McpServer({
name: "test server",
version: "1.0",
});
const client = new Client({
name: "test client",
version: "1.0",
});
mcpServer.tool("test", async () => ({
content: [
{
type: "text",
text: "Test response",
},
],
}));
const [clientTransport, serverTransport] =
InMemoryTransport.createLinkedPair();
await Promise.all([
client.connect(clientTransport),
mcpServer.server.connect(serverTransport),
]);
const result = await client.request(
{
method: "tools/list",
},
ListToolsResultSchema,
);
expect(result.tools).toHaveLength(1);
expect(result.tools[0].name).toBe("test");
expect(result.tools[0].inputSchema).toEqual({
type: "object",
});
});
test("should register tool with args schema", async () => {
const mcpServer = new McpServer({
name: "test server",
version: "1.0",
});
const client = new Client({
name: "test client",
version: "1.0",
});
mcpServer.tool(
"test",
{
name: z.string(),
value: z.number(),
},
async ({ name, value }) => ({
content: [
{
type: "text",
text: `${name}: ${value}`,
},
],
}),
);
const [clientTransport, serverTransport] =
InMemoryTransport.createLinkedPair();
await Promise.all([
client.connect(clientTransport),
mcpServer.server.connect(serverTransport),
]);
const result = await client.request(
{
method: "tools/list",
},
ListToolsResultSchema,
);
expect(result.tools).toHaveLength(1);
expect(result.tools[0].name).toBe("test");
expect(result.tools[0].inputSchema).toMatchObject({
type: "object",
properties: {
name: { type: "string" },
value: { type: "number" },
},
});
});
test("should register tool with description", async () => {
const mcpServer = new McpServer({
name: "test server",
version: "1.0",
});
const client = new Client({
name: "test client",
version: "1.0",
});
mcpServer.tool("test", "Test description", async () => ({
content: [
{
type: "text",
text: "Test response",
},
],
}));
const [clientTransport, serverTransport] =
InMemoryTransport.createLinkedPair();
await Promise.all([
client.connect(clientTransport),
mcpServer.server.connect(serverTransport),
]);
const result = await client.request(
{
method: "tools/list",
},
ListToolsResultSchema,
);
expect(result.tools).toHaveLength(1);
expect(result.tools[0].name).toBe("test");
expect(result.tools[0].description).toBe("Test description");
});
test("should validate tool args", async () => {
const mcpServer = new McpServer({
name: "test server",
version: "1.0",
});
const client = new Client(
{
name: "test client",
version: "1.0",
},
{
capabilities: {
tools: {},
},
},
);
mcpServer.tool(
"test",
{
name: z.string(),
value: z.number(),
},
async ({ name, value }) => ({
content: [
{
type: "text",
text: `${name}: ${value}`,
},
],
}),
);
const [clientTransport, serverTransport] =
InMemoryTransport.createLinkedPair();
await Promise.all([
client.connect(clientTransport),
mcpServer.server.connect(serverTransport),
]);
await expect(
client.request(
{
method: "tools/call",
params: {
name: "test",
arguments: {
name: "test",
value: "not a number",
},
},
},
CallToolResultSchema,
),
).rejects.toThrow(/Invalid arguments/);
});
test("should prevent duplicate tool registration", () => {
const mcpServer = new McpServer({
name: "test server",
version: "1.0",
});
mcpServer.tool("test", async () => ({
content: [
{
type: "text",
text: "Test response",
},
],
}));
expect(() => {
mcpServer.tool("test", async () => ({
content: [
{
type: "text",
text: "Test response 2",
},
],
}));
}).toThrow(/already registered/);
});
test("should allow registering multiple tools", () => {
const mcpServer = new McpServer({
name: "test server",
version: "1.0",
});
// This should succeed
mcpServer.tool("tool1", () => ({ content: [] }));
// This should also succeed and not throw about request handlers
mcpServer.tool("tool2", () => ({ content: [] }));
});
test("should allow client to call server tools", async () => {
const mcpServer = new McpServer({
name: "test server",
version: "1.0",
});
const client = new Client(
{
name: "test client",
version: "1.0",
},
{
capabilities: {
tools: {},
},
},
);
mcpServer.tool(
"test",
"Test tool",
{
input: z.string(),
},
async ({ input }) => ({
content: [
{
type: "text",
text: `Processed: ${input}`,
},
],
}),
);
const [clientTransport, serverTransport] =
InMemoryTransport.createLinkedPair();
await Promise.all([
client.connect(clientTransport),
mcpServer.server.connect(serverTransport),
]);
const result = await client.request(
{
method: "tools/call",
params: {
name: "test",
arguments: {
input: "hello",
},
},
},
CallToolResultSchema,
);
expect(result.content).toEqual([
{
type: "text",
text: "Processed: hello",
},
]);
});
test("should handle server tool errors gracefully", async () => {
const mcpServer = new McpServer({
name: "test server",
version: "1.0",
});
const client = new Client(
{
name: "test client",
version: "1.0",
},
{
capabilities: {
tools: {},
},
},
);
mcpServer.tool("error-test", async () => {
throw new Error("Tool execution failed");
});
const [clientTransport, serverTransport] =
InMemoryTransport.createLinkedPair();
await Promise.all([
client.connect(clientTransport),
mcpServer.server.connect(serverTransport),
]);
const result = await client.request(
{
method: "tools/call",
params: {
name: "error-test",
},
},
CallToolResultSchema,
);
expect(result.isError).toBe(true);
expect(result.content).toEqual([
{
type: "text",
text: "Tool execution failed",
},
]);
});
test("should throw McpError for invalid tool name", async () => {
const mcpServer = new McpServer({
name: "test server",
version: "1.0",
});
const client = new Client(
{
name: "test client",
version: "1.0",
},
{
capabilities: {
tools: {},
},
},
);
mcpServer.tool("test-tool", async () => ({
content: [
{
type: "text",
text: "Test response",
},
],
}));
const [clientTransport, serverTransport] =
InMemoryTransport.createLinkedPair();
await Promise.all([
client.connect(clientTransport),
mcpServer.server.connect(serverTransport),
]);
await expect(
client.request(
{
method: "tools/call",
params: {
name: "nonexistent-tool",
},
},
CallToolResultSchema,
),
).rejects.toThrow(/Tool nonexistent-tool not found/);
});
});
describe("resource()", () => {
test("should register resource with uri and readCallback", async () => {
const mcpServer = new McpServer({
name: "test server",
version: "1.0",
});
const client = new Client({
name: "test client",
version: "1.0",
});
mcpServer.resource("test", "test://resource", async () => ({
contents: [
{
uri: "test://resource",
text: "Test content",
},
],
}));
const [clientTransport, serverTransport] =
InMemoryTransport.createLinkedPair();
await Promise.all([
client.connect(clientTransport),
mcpServer.server.connect(serverTransport),
]);
const result = await client.request(
{
method: "resources/list",
},
ListResourcesResultSchema,
);
expect(result.resources).toHaveLength(1);
expect(result.resources[0].name).toBe("test");
expect(result.resources[0].uri).toBe("test://resource");
});
test("should register resource with metadata", async () => {
const mcpServer = new McpServer({
name: "test server",
version: "1.0",
});
const client = new Client({
name: "test client",
version: "1.0",
});
mcpServer.resource(
"test",
"test://resource",
{
description: "Test resource",
mimeType: "text/plain",
},
async () => ({
contents: [
{
uri: "test://resource",
text: "Test content",
},
],
}),
);
const [clientTransport, serverTransport] =
InMemoryTransport.createLinkedPair();
await Promise.all([
client.connect(clientTransport),
mcpServer.server.connect(serverTransport),
]);
const result = await client.request(
{
method: "resources/list",
},
ListResourcesResultSchema,
);
expect(result.resources).toHaveLength(1);
expect(result.resources[0].description).toBe("Test resource");
expect(result.resources[0].mimeType).toBe("text/plain");
});
test("should register resource template", async () => {
const mcpServer = new McpServer({
name: "test server",
version: "1.0",
});
const client = new Client({
name: "test client",
version: "1.0",
});
mcpServer.resource(
"test",
new ResourceTemplate("test://resource/{id}", { list: undefined }),
async () => ({
contents: [
{
uri: "test://resource/123",
text: "Test content",
},
],
}),
);
const [clientTransport, serverTransport] =
InMemoryTransport.createLinkedPair();
await Promise.all([
client.connect(clientTransport),
mcpServer.server.connect(serverTransport),
]);
const result = await client.request(
{
method: "resources/templates/list",
},
ListResourceTemplatesResultSchema,
);
expect(result.resourceTemplates).toHaveLength(1);
expect(result.resourceTemplates[0].name).toBe("test");
expect(result.resourceTemplates[0].uriTemplate).toBe(
"test://resource/{id}",
);
});
test("should register resource template with listCallback", async () => {
const mcpServer = new McpServer({
name: "test server",
version: "1.0",
});
const client = new Client({
name: "test client",
version: "1.0",
});
mcpServer.resource(
"test",
new ResourceTemplate("test://resource/{id}", {
list: async () => ({
resources: [
{
name: "Resource 1",
uri: "test://resource/1",
},
{
name: "Resource 2",
uri: "test://resource/2",
},
],
}),
}),
async (uri) => ({
contents: [
{
uri: uri.href,
text: "Test content",
},
],
}),
);
const [clientTransport, serverTransport] =
InMemoryTransport.createLinkedPair();
await Promise.all([
client.connect(clientTransport),
mcpServer.server.connect(serverTransport),
]);
const result = await client.request(
{
method: "resources/list",
},
ListResourcesResultSchema,
);
expect(result.resources).toHaveLength(2);
expect(result.resources[0].name).toBe("Resource 1");
expect(result.resources[0].uri).toBe("test://resource/1");
expect(result.resources[1].name).toBe("Resource 2");
expect(result.resources[1].uri).toBe("test://resource/2");
});
test("should pass template variables to readCallback", async () => {
const mcpServer = new McpServer({
name: "test server",
version: "1.0",
});
const client = new Client({
name: "test client",
version: "1.0",
});
mcpServer.resource(
"test",
new ResourceTemplate("test://resource/{category}/{id}", {
list: undefined,
}),
async (uri, { category, id }) => ({
contents: [
{
uri: uri.href,
text: `Category: ${category}, ID: ${id}`,
},
],
}),
);
const [clientTransport, serverTransport] =
InMemoryTransport.createLinkedPair();
await Promise.all([
client.connect(clientTransport),
mcpServer.server.connect(serverTransport),
]);
const result = await client.request(
{
method: "resources/read",
params: {
uri: "test://resource/books/123",
},
},
ReadResourceResultSchema,
);
expect(result.contents[0].text).toBe("Category: books, ID: 123");
});
test("should prevent duplicate resource registration", () => {
const mcpServer = new McpServer({
name: "test server",
version: "1.0",
});
mcpServer.resource("test", "test://resource", async () => ({
contents: [
{
uri: "test://resource",
text: "Test content",
},
],
}));
expect(() => {
mcpServer.resource("test2", "test://resource", async () => ({
contents: [
{
uri: "test://resource",
text: "Test content 2",
},
],
}));
}).toThrow(/already registered/);
});
test("should allow registering multiple resources", () => {
const mcpServer = new McpServer({
name: "test server",
version: "1.0",
});
// This should succeed
mcpServer.resource("resource1", "test://resource1", async () => ({
contents: [
{
uri: "test://resource1",
text: "Test content 1",
},
],
}));
// This should also succeed and not throw about request handlers
mcpServer.resource("resource2", "test://resource2", async () => ({
contents: [
{
uri: "test://resource2",
text: "Test content 2",
},
],
}));
});
test("should prevent duplicate resource template registration", () => {
const mcpServer = new McpServer({
name: "test server",
version: "1.0",
});
mcpServer.resource(
"test",
new ResourceTemplate("test://resource/{id}", { list: undefined }),
async () => ({
contents: [
{
uri: "test://resource/123",
text: "Test content",
},
],
}),
);
expect(() => {
mcpServer.resource(
"test",
new ResourceTemplate("test://resource/{id}", { list: undefined }),
async () => ({
contents: [
{
uri: "test://resource/123",
text: "Test content 2",
},
],
}),
);
}).toThrow(/already registered/);
});
test("should handle resource read errors gracefully", async () => {
const mcpServer = new McpServer({
name: "test server",
version: "1.0",
});
const client = new Client({
name: "test client",
version: "1.0",
});
mcpServer.resource("error-test", "test://error", async () => {
throw new Error("Resource read failed");
});
const [clientTransport, serverTransport] =
InMemoryTransport.createLinkedPair();
await Promise.all([
client.connect(clientTransport),
mcpServer.server.connect(serverTransport),
]);
await expect(
client.request(
{
method: "resources/read",
params: {
uri: "test://error",
},
},
ReadResourceResultSchema,
),
).rejects.toThrow(/Resource read failed/);
});
test("should throw McpError for invalid resource URI", async () => {
const mcpServer = new McpServer({
name: "test server",
version: "1.0",
});
const client = new Client({
name: "test client",
version: "1.0",
});
mcpServer.resource("test", "test://resource", async () => ({
contents: [
{
uri: "test://resource",
text: "Test content",
},
],
}));
const [clientTransport, serverTransport] =
InMemoryTransport.createLinkedPair();
await Promise.all([
client.connect(clientTransport),
mcpServer.server.connect(serverTransport),
]);
await expect(
client.request(
{
method: "resources/read",
params: {
uri: "test://nonexistent",
},
},
ReadResourceResultSchema,
),
).rejects.toThrow(/Resource test:\/\/nonexistent not found/);
});
test("should support completion of resource template parameters", async () => {
const mcpServer = new McpServer({
name: "test server",
version: "1.0",
});
const client = new Client(
{
name: "test client",
version: "1.0",
},
{
capabilities: {
resources: {},
},
},
);
mcpServer.resource(
"test",
new ResourceTemplate("test://resource/{category}", {
list: undefined,
complete: {
category: () => ["books", "movies", "music"],
},
}),
async () => ({
contents: [
{
uri: "test://resource/test",
text: "Test content",
},
],
}),
);
const [clientTransport, serverTransport] =
InMemoryTransport.createLinkedPair();
await Promise.all([
client.connect(clientTransport),
mcpServer.server.connect(serverTransport),
]);
const result = await client.request(
{
method: "completion/complete",
params: {
ref: {
type: "ref/resource",
uri: "test://resource/{category}",
},
argument: {
name: "category",
value: "",
},
},
},
CompleteResultSchema,
);
expect(result.completion.values).toEqual(["books", "movies", "music"]);
expect(result.completion.total).toBe(3);
});
test("should support filtered completion of resource template parameters", async () => {
const mcpServer = new McpServer({
name: "test server",
version: "1.0",
});
const client = new Client(
{
name: "test client",
version: "1.0",
},
{
capabilities: {
resources: {},
},
},
);
mcpServer.resource(
"test",
new ResourceTemplate("test://resource/{category}", {
list: undefined,
complete: {
category: (test: string) =>
["books", "movies", "music"].filter((value) =>
value.startsWith(test),
),
},
}),
async () => ({
contents: [
{
uri: "test://resource/test",
text: "Test content",
},
],
}),
);
const [clientTransport, serverTransport] =
InMemoryTransport.createLinkedPair();
await Promise.all([
client.connect(clientTransport),
mcpServer.server.connect(serverTransport),
]);
const result = await client.request(
{
method: "completion/complete",
params: {
ref: {
type: "ref/resource",
uri: "test://resource/{category}",
},
argument: {
name: "category",
value: "m",
},
},
},
CompleteResultSchema,
);
expect(result.completion.values).toEqual(["movies", "music"]);
expect(result.completion.total).toBe(2);
});
});
describe("prompt()", () => {
test("should register zero-argument prompt", async () => {
const mcpServer = new McpServer({
name: "test server",
version: "1.0",
});
const client = new Client({
name: "test client",
version: "1.0",
});
mcpServer.prompt("test", async () => ({
messages: [
{
role: "assistant",
content: {
type: "text",
text: "Test response",
},
},
],
}));
const [clientTransport, serverTransport] =
InMemoryTransport.createLinkedPair();
await Promise.all([
client.connect(clientTransport),
mcpServer.server.connect(serverTransport),
]);
const result = await client.request(
{
method: "prompts/list",
},
ListPromptsResultSchema,
);
expect(result.prompts).toHaveLength(1);
expect(result.prompts[0].name).toBe("test");
expect(result.prompts[0].arguments).toBeUndefined();
});
test("should register prompt with args schema", async () => {
const mcpServer = new McpServer({
name: "test server",
version: "1.0",
});
const client = new Client({
name: "test client",
version: "1.0",
});
mcpServer.prompt(
"test",
{
name: z.string(),
value: z.string(),
},
async ({ name, value }) => ({
messages: [
{
role: "assistant",
content: {
type: "text",
text: `${name}: ${value}`,
},
},
],
}),
);
const [clientTransport, serverTransport] =
InMemoryTransport.createLinkedPair();
await Promise.all([
client.connect(clientTransport),
mcpServer.server.connect(serverTransport),
]);
const result = await client.request(
{
method: "prompts/list",
},
ListPromptsResultSchema,
);
expect(result.prompts).toHaveLength(1);
expect(result.prompts[0].name).toBe("test");
expect(result.prompts[0].arguments).toEqual([
{ name: "name", required: true },
{ name: "value", required: true },
]);
});
test("should register prompt with description", async () => {
const mcpServer = new McpServer({
name: "test server",
version: "1.0",
});
const client = new Client({
name: "test client",
version: "1.0",
});
mcpServer.prompt("test", "Test description", async () => ({
messages: [
{
role: "assistant",
content: {
type: "text",
text: "Test response",
},
},
],
}));
const [clientTransport, serverTransport] =
InMemoryTransport.createLinkedPair();
await Promise.all([
client.connect(clientTransport),
mcpServer.server.connect(serverTransport),
]);
const result = await client.request(
{
method: "prompts/list",
},
ListPromptsResultSchema,
);
expect(result.prompts).toHaveLength(1);
expect(result.prompts[0].name).toBe("test");
expect(result.prompts[0].description).toBe("Test description");
});
test("should validate prompt args", async () => {
const mcpServer = new McpServer({
name: "test server",
version: "1.0",
});
const client = new Client(
{
name: "test client",
version: "1.0",
},
{
capabilities: {
prompts: {},
},
},
);
mcpServer.prompt(
"test",
{
name: z.string(),
value: z.string().min(3),
},
async ({ name, value }) => ({
messages: [
{
role: "assistant",
content: {
type: "text",
text: `${name}: ${value}`,
},
},
],
}),
);
const [clientTransport, serverTransport] =
InMemoryTransport.createLinkedPair();
await Promise.all([
client.connect(clientTransport),
mcpServer.server.connect(serverTransport),
]);
await expect(
client.request(
{
method: "prompts/get",
params: {
name: "test",
arguments: {
name: "test",
value: "ab", // Too short
},
},
},
GetPromptResultSchema,
),
).rejects.toThrow(/Invalid arguments/);
});
test("should prevent duplicate prompt registration", () => {
const mcpServer = new McpServer({
name: "test server",
version: "1.0",
});
mcpServer.prompt("test", async () => ({
messages: [
{
role: "assistant",
content: {
type: "text",
text: "Test response",
},
},
],
}));
expect(() => {
mcpServer.prompt("test", async () => ({
messages: [
{
role: "assistant",
content: {
type: "text",
text: "Test response 2",
},
},
],
}));
}).toThrow(/already registered/);
});
test("should allow registering multiple prompts", () => {
const mcpServer = new McpServer({
name: "test server",
version: "1.0",
});
// This should succeed
mcpServer.prompt("prompt1", async () => ({
messages: [
{
role: "assistant",
content: {
type: "text",
text: "Test response 1",
},
},
],
}));
// This should also succeed and not throw about request handlers
mcpServer.prompt("prompt2", async () => ({
messages: [
{
role: "assistant",
content: {
type: "text",
text: "Test response 2",
},
},
],
}));
});
test("should allow registering prompts with arguments", () => {
const mcpServer = new McpServer({
name: "test server",
version: "1.0",
});
// This should succeed
mcpServer.prompt(
"echo",
{ message: z.string() },
({ message }) => ({
messages: [{
role: "user",
content: {
type: "text",
text: `Please process this message: ${message}`
}
}]
})
);
});
test("should allow registering both resources and prompts with completion handlers", () => {
const mcpServer = new McpServer({
name: "test server",
version: "1.0",
});
// Register a resource with completion
mcpServer.resource(
"test",
new ResourceTemplate("test://resource/{category}", {
list: undefined,
complete: {
category: () => ["books", "movies", "music"],
},
}),
async () => ({
contents: [
{
uri: "test://resource/test",
text: "Test content",
},
],
}),
);
// Register a prompt with completion
mcpServer.prompt(
"echo",
{ message: completable(z.string(), () => ["hello", "world"]) },
({ message }) => ({
messages: [{
role: "user",
content: {
type: "text",
text: `Please process this message: ${message}`
}
}]
})
);
});
test("should throw McpError for invalid prompt name", async () => {
const mcpServer = new McpServer({
name: "test server",
version: "1.0",
});
const client = new Client(
{
name: "test client",
version: "1.0",
},
{
capabilities: {
prompts: {},
},
},
);
mcpServer.prompt("test-prompt", async () => ({
messages: [
{
role: "assistant",
content: {
type: "text",
text: "Test response",
},
},
],
}));
const [clientTransport, serverTransport] =
InMemoryTransport.createLinkedPair();
await Promise.all([
client.connect(clientTransport),
mcpServer.server.connect(serverTransport),
]);
await expect(
client.request(
{
method: "prompts/get",
params: {
name: "nonexistent-prompt",
},
},
GetPromptResultSchema,
),
).rejects.toThrow(/Prompt nonexistent-prompt not found/);
});
test("should support completion of prompt arguments", async () => {
const mcpServer = new McpServer({
name: "test server",
version: "1.0",
});
const client = new Client(
{
name: "test client",
version: "1.0",
},
{
capabilities: {
prompts: {},
},
},
);
mcpServer.prompt(
"test-prompt",
{
name: completable(z.string(), () => ["Alice", "Bob", "Charlie"]),
},
async ({ name }) => ({
messages: [
{
role: "assistant",
content: {
type: "text",
text: `Hello ${name}`,
},
},
],
}),
);
const [clientTransport, serverTransport] =
InMemoryTransport.createLinkedPair();
await Promise.all([
client.connect(clientTransport),
mcpServer.server.connect(serverTransport),
]);
const result = await client.request(
{
method: "completion/complete",
params: {
ref: {
type: "ref/prompt",
name: "test-prompt",
},
argument: {
name: "name",
value: "",
},
},
},
CompleteResultSchema,
);
expect(result.completion.values).toEqual(["Alice", "Bob", "Charlie"]);
expect(result.completion.total).toBe(3);
});
test("should support filtered completion of prompt arguments", async () => {
const mcpServer = new McpServer({
name: "test server",
version: "1.0",
});
const client = new Client(
{
name: "test client",
version: "1.0",
},
{
capabilities: {
prompts: {},
},
},
);
mcpServer.prompt(
"test-prompt",
{
name: completable(z.string(), (test) =>
["Alice", "Bob", "Charlie"].filter((value) => value.startsWith(test)),
),
},
async ({ name }) => ({
messages: [
{
role: "assistant",
content: {
type: "text",
text: `Hello ${name}`,
},
},
],
}),
);
const [clientTransport, serverTransport] =
InMemoryTransport.createLinkedPair();
await Promise.all([
client.connect(clientTransport),
mcpServer.server.connect(serverTransport),
]);
const result = await client.request(
{
method: "completion/complete",
params: {
ref: {
type: "ref/prompt",
name: "test-prompt",
},
argument: {
name: "name",
value: "A",
},
},
},
CompleteResultSchema,
);
expect(result.completion.values).toEqual(["Alice"]);
expect(result.completion.total).toBe(1);
});
});
================================================
File: src/server/mcp.ts
================================================
import { Server, ServerOptions } from "./index.js";
import { zodToJsonSchema } from "zod-to-json-schema";
import {
z,
ZodRawShape,
ZodObject,
ZodString,
AnyZodObject,
ZodTypeAny,
ZodType,
ZodTypeDef,
ZodOptional,
} from "zod";
import {
Implementation,
Tool,
ListToolsResult,
CallToolResult,
McpError,
ErrorCode,
CompleteRequest,
CompleteResult,
PromptReference,
ResourceReference,
Resource,
ListResourcesResult,
ListResourceTemplatesRequestSchema,
ReadResourceRequestSchema,
ListToolsRequestSchema,
CallToolRequestSchema,
ListResourcesRequestSchema,
ListPromptsRequestSchema,
GetPromptRequestSchema,
CompleteRequestSchema,
ListPromptsResult,
Prompt,
PromptArgument,
GetPromptResult,
ReadResourceResult,
} from "../types.js";
import { Completable, CompletableDef } from "./completable.js";
import { UriTemplate, Variables } from "../shared/uriTemplate.js";
import { RequestHandlerExtra } from "../shared/protocol.js";
import { Transport } from "../shared/transport.js";
/**
* High-level MCP server that provides a simpler API for working with resources, tools, and prompts.
* For advanced usage (like sending notifications or setting custom request handlers), use the underlying
* Server instance available via the `server` property.
*/
export class McpServer {
/**
* The underlying Server instance, useful for advanced operations like sending notifications.
*/
public readonly server: Server;
private _registeredResources: { [uri: string]: RegisteredResource } = {};
private _registeredResourceTemplates: {
[name: string]: RegisteredResourceTemplate;
} = {};
private _registeredTools: { [name: string]: RegisteredTool } = {};
private _registeredPrompts: { [name: string]: RegisteredPrompt } = {};
constructor(serverInfo: Implementation, options?: ServerOptions) {
this.server = new Server(serverInfo, options);
}
/**
* Attaches to the given transport, starts it, and starts listening for messages.
*
* The `server` object assumes ownership of the Transport, replacing any callbacks that have already been set, and expects that it is the only user of the Transport instance going forward.
*/
async connect(transport: Transport): Promise<void> {
return await this.server.connect(transport);
}
/**
* Closes the connection.
*/
async close(): Promise<void> {
await this.server.close();
}
private _toolHandlersInitialized = false;
private setToolRequestHandlers() {
if (this._toolHandlersInitialized) {
return;
}
this.server.assertCanSetRequestHandler(
ListToolsRequestSchema.shape.method.value,
);
this.server.assertCanSetRequestHandler(
CallToolRequestSchema.shape.method.value,
);
this.server.registerCapabilities({
tools: {},
});
this.server.setRequestHandler(
ListToolsRequestSchema,
(): ListToolsResult => ({
tools: Object.entries(this._registeredTools).map(
([name, tool]): Tool => {
return {
name,
description: tool.description,
inputSchema: tool.inputSchema
? (zodToJsonSchema(tool.inputSchema, {
strictUnions: true,
}) as Tool["inputSchema"])
: EMPTY_OBJECT_JSON_SCHEMA,
};
},
),
}),
);
this.server.setRequestHandler(
CallToolRequestSchema,
async (request, extra): Promise<CallToolResult> => {
const tool = this._registeredTools[request.params.name];
if (!tool) {
throw new McpError(
ErrorCode.InvalidParams,
`Tool ${request.params.name} not found`,
);
}
if (tool.inputSchema) {
const parseResult = await tool.inputSchema.safeParseAsync(
request.params.arguments,
);
if (!parseResult.success) {
throw new McpError(
ErrorCode.InvalidParams,
`Invalid arguments for tool ${request.params.name}: ${parseResult.error.message}`,
);
}
const args = parseResult.data;
const cb = tool.callback as ToolCallback<ZodRawShape>;
try {
return await Promise.resolve(cb(args, extra));
} catch (error) {
return {
content: [
{
type: "text",
text: error instanceof Error ? error.message : String(error),
},
],
isError: true,
};
}
} else {
const cb = tool.callback as ToolCallback<undefined>;
try {
return await Promise.resolve(cb(extra));
} catch (error) {
return {
content: [
{
type: "text",
text: error instanceof Error ? error.message : String(error),
},
],
isError: true,
};
}
}
},
);
this._toolHandlersInitialized = true;
}
private _completionHandlerInitialized = false;
private setCompletionRequestHandler() {
if (this._completionHandlerInitialized) {
return;
}
this.server.assertCanSetRequestHandler(
CompleteRequestSchema.shape.method.value,
);
this.server.setRequestHandler(
CompleteRequestSchema,
async (request): Promise<CompleteResult> => {
switch (request.params.ref.type) {
case "ref/prompt":
return this.handlePromptCompletion(request, request.params.ref);
case "ref/resource":
return this.handleResourceCompletion(request, request.params.ref);
default:
throw new McpError(
ErrorCode.InvalidParams,
`Invalid completion reference: ${request.params.ref}`,
);
}
},
);
this._completionHandlerInitialized = true;
}
private async handlePromptCompletion(
request: CompleteRequest,
ref: PromptReference,
): Promise<CompleteResult> {
const prompt = this._registeredPrompts[ref.name];
if (!prompt) {
throw new McpError(
ErrorCode.InvalidParams,
`Prompt ${request.params.ref.name} not found`,
);
}
if (!prompt.argsSchema) {
return EMPTY_COMPLETION_RESULT;
}
const field = prompt.argsSchema.shape[request.params.argument.name];
if (!(field instanceof Completable)) {
return EMPTY_COMPLETION_RESULT;
}
const def: CompletableDef<ZodString> = field._def;
const suggestions = await def.complete(request.params.argument.value);
return createCompletionResult(suggestions);
}
private async handleResourceCompletion(
request: CompleteRequest,
ref: ResourceReference,
): Promise<CompleteResult> {
const template = Object.values(this._registeredResourceTemplates).find(
(t) => t.resourceTemplate.uriTemplate.toString() === ref.uri,
);
if (!template) {
if (this._registeredResources[ref.uri]) {
// Attempting to autocomplete a fixed resource URI is not an error in the spec (but probably should be).
return EMPTY_COMPLETION_RESULT;
}
throw new McpError(
ErrorCode.InvalidParams,
`Resource template ${request.params.ref.uri} not found`,
);
}
const completer = template.resourceTemplate.completeCallback(
request.params.argument.name,
);
if (!completer) {
return EMPTY_COMPLETION_RESULT;
}
const suggestions = await completer(request.params.argument.value);
return createCompletionResult(suggestions);
}
private _resourceHandlersInitialized = false;
private setResourceRequestHandlers() {
if (this._resourceHandlersInitialized) {
return;
}
this.server.assertCanSetRequestHandler(
ListResourcesRequestSchema.shape.method.value,
);
this.server.assertCanSetRequestHandler(
ListResourceTemplatesRequestSchema.shape.method.value,
);
this.server.assertCanSetRequestHandler(
ReadResourceRequestSchema.shape.method.value,
);
this.server.registerCapabilities({
resources: {},
});
this.server.setRequestHandler(
ListResourcesRequestSchema,
async (request, extra) => {
const resources = Object.entries(this._registeredResources).map(
([uri, resource]) => ({
uri,
name: resource.name,
...resource.metadata,
}),
);
const templateResources: Resource[] = [];
for (const template of Object.values(
this._registeredResourceTemplates,
)) {
if (!template.resourceTemplate.listCallback) {
continue;
}
const result = await template.resourceTemplate.listCallback(extra);
for (const resource of result.resources) {
templateResources.push({
...resource,
...template.metadata,
});
}
}
return { resources: [...resources, ...templateResources] };
},
);
this.server.setRequestHandler(
ListResourceTemplatesRequestSchema,
async () => {
const resourceTemplates = Object.entries(
this._registeredResourceTemplates,
).map(([name, template]) => ({
name,
uriTemplate: template.resourceTemplate.uriTemplate.toString(),
...template.metadata,
}));
return { resourceTemplates };
},
);
this.server.setRequestHandler(
ReadResourceRequestSchema,
async (request, extra) => {
const uri = new URL(request.params.uri);
// First check for exact resource match
const resource = this._registeredResources[uri.toString()];
if (resource) {
return resource.readCallback(uri, extra);
}
// Then check templates
for (const template of Object.values(
this._registeredResourceTemplates,
)) {
const variables = template.resourceTemplate.uriTemplate.match(
uri.toString(),
);
if (variables) {
return template.readCallback(uri, variables, extra);
}
}
throw new McpError(
ErrorCode.InvalidParams,
`Resource ${uri} not found`,
);
},
);
this.setCompletionRequestHandler();
this._resourceHandlersInitialized = true;
}
private _promptHandlersInitialized = false;
private setPromptRequestHandlers() {
if (this._promptHandlersInitialized) {
return;
}
this.server.assertCanSetRequestHandler(
ListPromptsRequestSchema.shape.method.value,
);
this.server.assertCanSetRequestHandler(
GetPromptRequestSchema.shape.method.value,
);
this.server.registerCapabilities({
prompts: {},
});
this.server.setRequestHandler(
ListPromptsRequestSchema,
(): ListPromptsResult => ({
prompts: Object.entries(this._registeredPrompts).map(
([name, prompt]): Prompt => {
return {
name,
description: prompt.description,
arguments: prompt.argsSchema
? promptArgumentsFromSchema(prompt.argsSchema)
: undefined,
};
},
),
}),
);
this.server.setRequestHandler(
GetPromptRequestSchema,
async (request, extra): Promise<GetPromptResult> => {
const prompt = this._registeredPrompts[request.params.name];
if (!prompt) {
throw new McpError(
ErrorCode.InvalidParams,
`Prompt ${request.params.name} not found`,
);
}
if (prompt.argsSchema) {
const parseResult = await prompt.argsSchema.safeParseAsync(
request.params.arguments,
);
if (!parseResult.success) {
throw new McpError(
ErrorCode.InvalidParams,
`Invalid arguments for prompt ${request.params.name}: ${parseResult.error.message}`,
);
}
const args = parseResult.data;
const cb = prompt.callback as PromptCallback<PromptArgsRawShape>;
return await Promise.resolve(cb(args, extra));
} else {
const cb = prompt.callback as PromptCallback<undefined>;
return await Promise.resolve(cb(extra));
}
},
);
this.setCompletionRequestHandler();
this._promptHandlersInitialized = true;
}
/**
* Registers a resource `name` at a fixed URI, which will use the given callback to respond to read requests.
*/
resource(name: string, uri: string, readCallback: ReadResourceCallback): void;
/**
* Registers a resource `name` at a fixed URI with metadata, which will use the given callback to respond to read requests.
*/
resource(
name: string,
uri: string,
metadata: ResourceMetadata,
readCallback: ReadResourceCallback,
): void;
/**
* Registers a resource `name` with a template pattern, which will use the given callback to respond to read requests.
*/
resource(
name: string,
template: ResourceTemplate,
readCallback: ReadResourceTemplateCallback,
): void;
/**
* Registers a resource `name` with a template pattern and metadata, which will use the given callback to respond to read requests.
*/
resource(
name: string,
template: ResourceTemplate,
metadata: ResourceMetadata,
readCallback: ReadResourceTemplateCallback,
): void;
resource(
name: string,
uriOrTemplate: string | ResourceTemplate,
...rest: unknown[]
): void {
let metadata: ResourceMetadata | undefined;
if (typeof rest[0] === "object") {
metadata = rest.shift() as ResourceMetadata;
}
const readCallback = rest[0] as
| ReadResourceCallback
| ReadResourceTemplateCallback;
if (typeof uriOrTemplate === "string") {
if (this._registeredResources[uriOrTemplate]) {
throw new Error(`Resource ${uriOrTemplate} is already registered`);
}
this._registeredResources[uriOrTemplate] = {
name,
metadata,
readCallback: readCallback as ReadResourceCallback,
};
} else {
if (this._registeredResourceTemplates[name]) {
throw new Error(`Resource template ${name} is already registered`);
}
this._registeredResourceTemplates[name] = {
resourceTemplate: uriOrTemplate,
metadata,
readCallback: readCallback as ReadResourceTemplateCallback,
};
}
this.setResourceRequestHandlers();
}
/**
* Registers a zero-argument tool `name`, which will run the given function when the client calls it.
*/
tool(name: string, cb: ToolCallback): void;
/**
* Registers a zero-argument tool `name` (with a description) which will run the given function when the client calls it.
*/
tool(name: string, description: string, cb: ToolCallback): void;
/**
* Registers a tool `name` accepting the given arguments, which must be an object containing named properties associated with Zod schemas. When the client calls it, the function will be run with the parsed and validated arguments.
*/
tool<Args extends ZodRawShape>(
name: string,
paramsSchema: Args,
cb: ToolCallback<Args>,
): void;
/**
* Registers a tool `name` (with a description) accepting the given arguments, which must be an object containing named properties associated with Zod schemas. When the client calls it, the function will be run with the parsed and validated arguments.
*/
tool<Args extends ZodRawShape>(
name: string,
description: string,
paramsSchema: Args,
cb: ToolCallback<Args>,
): void;
tool(name: string, ...rest: unknown[]): void {
if (this._registeredTools[name]) {
throw new Error(`Tool ${name} is already registered`);
}
let description: string | undefined;
if (typeof rest[0] === "string") {
description = rest.shift() as string;
}
let paramsSchema: ZodRawShape | undefined;
if (rest.length > 1) {
paramsSchema = rest.shift() as ZodRawShape;
}
const cb = rest[0] as ToolCallback<ZodRawShape | undefined>;
this._registeredTools[name] = {
description,
inputSchema:
paramsSchema === undefined ? undefined : z.object(paramsSchema),
callback: cb,
};
this.setToolRequestHandlers();
}
/**
* Registers a zero-argument prompt `name`, which will run the given function when the client calls it.
*/
prompt(name: string, cb: PromptCallback): void;
/**
* Registers a zero-argument prompt `name` (with a description) which will run the given function when the client calls it.
*/
prompt(name: string, description: string, cb: PromptCallback): void;
/**
* Registers a prompt `name` accepting the given arguments, which must be an object containing named properties associated with Zod schemas. When the client calls it, the function will be run with the parsed and validated arguments.
*/
prompt<Args extends PromptArgsRawShape>(
name: string,
argsSchema: Args,
cb: PromptCallback<Args>,
): void;
/**
* Registers a prompt `name` (with a description) accepting the given arguments, which must be an object containing named properties associated with Zod schemas. When the client calls it, the function will be run with the parsed and validated arguments.
*/
prompt<Args extends PromptArgsRawShape>(
name: string,
description: string,
argsSchema: Args,
cb: PromptCallback<Args>,
): void;
prompt(name: string, ...rest: unknown[]): void {
if (this._registeredPrompts[name]) {
throw new Error(`Prompt ${name} is already registered`);
}
let description: string | undefined;
if (typeof rest[0] === "string") {
description = rest.shift() as string;
}
let argsSchema: PromptArgsRawShape | undefined;
if (rest.length > 1) {
argsSchema = rest.shift() as PromptArgsRawShape;
}
const cb = rest[0] as PromptCallback<PromptArgsRawShape | undefined>;
this._registeredPrompts[name] = {
description,
argsSchema: argsSchema === undefined ? undefined : z.object(argsSchema),
callback: cb,
};
this.setPromptRequestHandlers();
}
}
/**
* A callback to complete one variable within a resource template's URI template.
*/
export type CompleteResourceTemplateCallback = (
value: string,
) => string[] | Promise<string[]>;
/**
* A resource template combines a URI pattern with optional functionality to enumerate
* all resources matching that pattern.
*/
export class ResourceTemplate {
private _uriTemplate: UriTemplate;
constructor(
uriTemplate: string | UriTemplate,
private _callbacks: {
/**
* A callback to list all resources matching this template. This is required to specified, even if `undefined`, to avoid accidentally forgetting resource listing.
*/
list: ListResourcesCallback | undefined;
/**
* An optional callback to autocomplete variables within the URI template. Useful for clients and users to discover possible values.
*/
complete?: {
[variable: string]: CompleteResourceTemplateCallback;
};
},
) {
this._uriTemplate =
typeof uriTemplate === "string"
? new UriTemplate(uriTemplate)
: uriTemplate;
}
/**
* Gets the URI template pattern.
*/
get uriTemplate(): UriTemplate {
return this._uriTemplate;
}
/**
* Gets the list callback, if one was provided.
*/
get listCallback(): ListResourcesCallback | undefined {
return this._callbacks.list;
}
/**
* Gets the callback for completing a specific URI template variable, if one was provided.
*/
completeCallback(
variable: string,
): CompleteResourceTemplateCallback | undefined {
return this._callbacks.complete?.[variable];
}
}
/**
* Callback for a tool handler registered with Server.tool().
*
* Parameters will include tool arguments, if applicable, as well as other request handler context.
*/
export type ToolCallback<Args extends undefined | ZodRawShape = undefined> =
Args extends ZodRawShape
? (
args: z.objectOutputType<Args, ZodTypeAny>,
extra: RequestHandlerExtra,
) => CallToolResult | Promise<CallToolResult>
: (extra: RequestHandlerExtra) => CallToolResult | Promise<CallToolResult>;
type RegisteredTool = {
description?: string;
inputSchema?: AnyZodObject;
callback: ToolCallback<undefined | ZodRawShape>;
};
const EMPTY_OBJECT_JSON_SCHEMA = {
type: "object" as const,
};
/**
* Additional, optional information for annotating a resource.
*/
export type ResourceMetadata = Omit<Resource, "uri" | "name">;
/**
* Callback to list all resources matching a given template.
*/
export type ListResourcesCallback = (
extra: RequestHandlerExtra,
) => ListResourcesResult | Promise<ListResourcesResult>;
/**
* Callback to read a resource at a given URI.
*/
export type ReadResourceCallback = (
uri: URL,
extra: RequestHandlerExtra,
) => ReadResourceResult | Promise<ReadResourceResult>;
type RegisteredResource = {
name: string;
metadata?: ResourceMetadata;
readCallback: ReadResourceCallback;
};
/**
* Callback to read a resource at a given URI, following a filled-in URI template.
*/
export type ReadResourceTemplateCallback = (
uri: URL,
variables: Variables,
extra: RequestHandlerExtra,
) => ReadResourceResult | Promise<ReadResourceResult>;
type RegisteredResourceTemplate = {
resourceTemplate: ResourceTemplate;
metadata?: ResourceMetadata;
readCallback: ReadResourceTemplateCallback;
};
type PromptArgsRawShape = {
[k: string]:
| ZodType<string, ZodTypeDef, string>
| ZodOptional<ZodType<string, ZodTypeDef, string>>;
};
export type PromptCallback<
Args extends undefined | PromptArgsRawShape = undefined,
> = Args extends PromptArgsRawShape
? (
args: z.objectOutputType<Args, ZodTypeAny>,
extra: RequestHandlerExtra,
) => GetPromptResult | Promise<GetPromptResult>
: (extra: RequestHandlerExtra) => GetPromptResult | Promise<GetPromptResult>;
type RegisteredPrompt = {
description?: string;
argsSchema?: ZodObject<PromptArgsRawShape>;
callback: PromptCallback<undefined | PromptArgsRawShape>;
};
function promptArgumentsFromSchema(
schema: ZodObject<PromptArgsRawShape>,
): PromptArgument[] {
return Object.entries(schema.shape).map(
([name, field]): PromptArgument => ({
name,
description: field.description,
required: !field.isOptional(),
}),
);
}
function createCompletionResult(suggestions: string[]): CompleteResult {
return {
completion: {
values: suggestions.slice(0, 100),
total: suggestions.length,
hasMore: suggestions.length > 100,
},
};
}
const EMPTY_COMPLETION_RESULT: CompleteResult = {
completion: {
values: [],
hasMore: false,
},
};
================================================
File: src/server/sse.ts
================================================
import { randomUUID } from "node:crypto";
import { IncomingMessage, ServerResponse } from "node:http";
import { Transport } from "../shared/transport.js";
import { JSONRPCMessage, JSONRPCMessageSchema } from "../types.js";
import getRawBody from "raw-body";
import contentType from "content-type";
const MAXIMUM_MESSAGE_SIZE = "4mb";
/**
* Server transport for SSE: this will send messages over an SSE connection and receive messages from HTTP POST requests.
*
* This transport is only available in Node.js environments.
*/
export class SSEServerTransport implements Transport {
private _sseResponse?: ServerResponse;
private _sessionId: string;
onclose?: () => void;
onerror?: (error: Error) => void;
onmessage?: (message: JSONRPCMessage) => void;
/**
* Creates a new SSE server transport, which will direct the client to POST messages to the relative or absolute URL identified by `_endpoint`.
*/
constructor(
private _endpoint: string,
private res: ServerResponse,
) {
this._sessionId = randomUUID();
}
/**
* Handles the initial SSE connection request.
*
* This should be called when a GET request is made to establish the SSE stream.
*/
async start(): Promise<void> {
if (this._sseResponse) {
throw new Error(
"SSEServerTransport already started! If using Server class, note that connect() calls start() automatically.",
);
}
this.res.writeHead(200, {
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
Connection: "keep-alive",
});
// Send the endpoint event
this.res.write(
`event: endpoint\ndata: ${encodeURI(this._endpoint)}?sessionId=${this._sessionId}\n\n`,
);
this._sseResponse = this.res;
this.res.on("close", () => {
this._sseResponse = undefined;
this.onclose?.();
});
}
/**
* Handles incoming POST messages.
*
* This should be called when a POST request is made to send a message to the server.
*/
async handlePostMessage(
req: IncomingMessage,
res: ServerResponse,
parsedBody?: unknown,
): Promise<void> {
if (!this._sseResponse) {
const message = "SSE connection not established";
res.writeHead(500).end(message);
throw new Error(message);
}
let body: string | unknown;
try {
const ct = contentType.parse(req.headers["content-type"] ?? "");
if (ct.type !== "application/json") {
throw new Error(`Unsupported content-type: ${ct}`);
}
body = parsedBody ?? await getRawBody(req, {
limit: MAXIMUM_MESSAGE_SIZE,
encoding: ct.parameters.charset ?? "utf-8",
});
} catch (error) {
res.writeHead(400).end(String(error));
this.onerror?.(error as Error);
return;
}
try {
await this.handleMessage(typeof body === 'string' ? JSON.parse(body) : body);
} catch {
res.writeHead(400).end(`Invalid message: ${body}`);
return;
}
res.writeHead(202).end("Accepted");
}
/**
* Handle a client message, regardless of how it arrived. This can be used to inform the server of messages that arrive via a means different than HTTP POST.
*/
async handleMessage(message: unknown): Promise<void> {
let parsedMessage: JSONRPCMessage;
try {
parsedMessage = JSONRPCMessageSchema.parse(message);
} catch (error) {
this.onerror?.(error as Error);
throw error;
}
this.onmessage?.(parsedMessage);
}
async close(): Promise<void> {
this._sseResponse?.end();
this._sseResponse = undefined;
this.onclose?.();
}
async send(message: JSONRPCMessage): Promise<void> {
if (!this._sseResponse) {
throw new Error("Not connected");
}
this._sseResponse.write(
`event: message\ndata: ${JSON.stringify(message)}\n\n`,
);
}
/**
* Returns the session ID for this transport.
*
* This can be used to route incoming POST requests.
*/
get sessionId(): string {
return this._sessionId;
}
}
================================================
File: src/server/stdio.test.ts
================================================
import { Readable, Writable } from "node:stream";
import { ReadBuffer, serializeMessage } from "../shared/stdio.js";
import { JSONRPCMessage } from "../types.js";
import { StdioServerTransport } from "./stdio.js";
let input: Readable;
let outputBuffer: ReadBuffer;
let output: Writable;
beforeEach(() => {
input = new Readable({
// We'll use input.push() instead.
read: () => {},
});
outputBuffer = new ReadBuffer();
output = new Writable({
write(chunk, encoding, callback) {
outputBuffer.append(chunk);
callback();
},
});
});
test("should start then close cleanly", async () => {
const server = new StdioServerTransport(input, output);
server.onerror = (error) => {
throw error;
};
let didClose = false;
server.onclose = () => {
didClose = true;
};
await server.start();
expect(didClose).toBeFalsy();
await server.close();
expect(didClose).toBeTruthy();
});
test("should not read until started", async () => {
const server = new StdioServerTransport(input, output);
server.onerror = (error) => {
throw error;
};
let didRead = false;
const readMessage = new Promise((resolve) => {
server.onmessage = (message) => {
didRead = true;
resolve(message);
};
});
const message: JSONRPCMessage = {
jsonrpc: "2.0",
id: 1,
method: "ping",
};
input.push(serializeMessage(message));
expect(didRead).toBeFalsy();
await server.start();
expect(await readMessage).toEqual(message);
});
test("should read multiple messages", async () => {
const server = new StdioServerTransport(input, output);
server.onerror = (error) => {
throw error;
};
const messages: JSONRPCMessage[] = [
{
jsonrpc: "2.0",
id: 1,
method: "ping",
},
{
jsonrpc: "2.0",
method: "notifications/initialized",
},
];
const readMessages: JSONRPCMessage[] = [];
const finished = new Promise<void>((resolve) => {
server.onmessage = (message) => {
readMessages.push(message);
if (JSON.stringify(message) === JSON.stringify(messages[1])) {
resolve();
}
};
});
input.push(serializeMessage(messages[0]));
input.push(serializeMessage(messages[1]));
await server.start();
await finished;
expect(readMessages).toEqual(messages);
});
================================================
File: src/server/stdio.ts
================================================
import process from "node:process";
import { Readable, Writable } from "node:stream";
import { ReadBuffer, serializeMessage } from "../shared/stdio.js";
import { JSONRPCMessage } from "../types.js";
import { Transport } from "../shared/transport.js";
/**
* Server transport for stdio: this communicates with a MCP client by reading from the current process' stdin and writing to stdout.
*
* This transport is only available in Node.js environments.
*/
export class StdioServerTransport implements Transport {
private _readBuffer: ReadBuffer = new ReadBuffer();
private _started = false;
constructor(
private _stdin: Readable = process.stdin,
private _stdout: Writable = process.stdout,
) {}
onclose?: () => void;
onerror?: (error: Error) => void;
onmessage?: (message: JSONRPCMessage) => void;
// Arrow functions to bind `this` properly, while maintaining function identity.
_ondata = (chunk: Buffer) => {
this._readBuffer.append(chunk);
this.processReadBuffer();
};
_onerror = (error: Error) => {
this.onerror?.(error);
};
/**
* Starts listening for messages on stdin.
*/
async start(): Promise<void> {
if (this._started) {
throw new Error(
"StdioServerTransport already started! If using Server class, note that connect() calls start() automatically.",
);
}
this._started = true;
this._stdin.on("data", this._ondata);
this._stdin.on("error", this._onerror);
}
private processReadBuffer() {
while (true) {
try {
const message = this._readBuffer.readMessage();
if (message === null) {
break;
}
this.onmessage?.(message);
} catch (error) {
this.onerror?.(error as Error);
}
}
}
async close(): Promise<void> {
// Remove our event listeners first
this._stdin.off("data", this._ondata);
this._stdin.off("error", this._onerror);
// Check if we were the only data listener
const remainingDataListeners = this._stdin.listenerCount('data');
if (remainingDataListeners === 0) {
// Only pause stdin if we were the only listener
// This prevents interfering with other parts of the application that might be using stdin
this._stdin.pause();
}
// Clear the buffer and notify closure
this._readBuffer.clear();
this.onclose?.();
}
send(message: JSONRPCMessage): Promise<void> {
return new Promise((resolve) => {
const json = serializeMessage(message);
if (this._stdout.write(json)) {
resolve();
} else {
this._stdout.once("drain", resolve);
}
});
}
}
================================================
File: src/server/auth/clients.ts
================================================
import { OAuthClientInformationFull } from "../../shared/auth.js";
/**
* Stores information about registered OAuth clients for this server.
*/
export interface OAuthRegisteredClientsStore {
/**
* Returns information about a registered client, based on its ID.
*/
getClient(clientId: string): OAuthClientInformationFull | undefined | Promise<OAuthClientInformationFull | undefined>;
/**
* Registers a new client with the server. The client ID and secret will be automatically generated by the library. A modified version of the client information can be returned to reflect specific values enforced by the server.
*
* NOTE: Implementations should NOT delete expired client secrets in-place. Auth middleware provided by this library will automatically check the `client_secret_expires_at` field and reject requests with expired secrets. Any custom logic for authenticating clients should check the `client_secret_expires_at` field as well.
*
* If unimplemented, dynamic client registration is unsupported.
*/
registerClient?(client: OAuthClientInformationFull): OAuthClientInformationFull | Promise<OAuthClientInformationFull>;
}
================================================
File: src/server/auth/errors.ts
================================================
import { OAuthErrorResponse } from "../../shared/auth.js";
/**
* Base class for all OAuth errors
*/
export class OAuthError extends Error {
constructor(
public readonly errorCode: string,
message: string,
public readonly errorUri?: string
) {
super(message);
this.name = this.constructor.name;
}
/**
* Converts the error to a standard OAuth error response object
*/
toResponseObject(): OAuthErrorResponse {
const response: OAuthErrorResponse = {
error: this.errorCode,
error_description: this.message
};
if (this.errorUri) {
response.error_uri = this.errorUri;
}
return response;
}
}
/**
* Invalid request error - The request is missing a required parameter,
* includes an invalid parameter value, includes a parameter more than once,
* or is otherwise malformed.
*/
export class InvalidRequestError extends OAuthError {
constructor(message: string, errorUri?: string) {
super("invalid_request", message, errorUri);
}
}
/**
* Invalid client error - Client authentication failed (e.g., unknown client, no client
* authentication included, or unsupported authentication method).
*/
export class InvalidClientError extends OAuthError {
constructor(message: string, errorUri?: string) {
super("invalid_client", message, errorUri);
}
}
/**
* Invalid grant error - The provided authorization grant or refresh token is
* invalid, expired, revoked, does not match the redirection URI used in the
* authorization request, or was issued to another client.
*/
export class InvalidGrantError extends OAuthError {
constructor(message: string, errorUri?: string) {
super("invalid_grant", message, errorUri);
}
}
/**
* Unauthorized client error - The authenticated client is not authorized to use
* this authorization grant type.
*/
export class UnauthorizedClientError extends OAuthError {
constructor(message: string, errorUri?: string) {
super("unauthorized_client", message, errorUri);
}
}
/**
* Unsupported grant type error - The authorization grant type is not supported
* by the authorization server.
*/
export class UnsupportedGrantTypeError extends OAuthError {
constructor(message: string, errorUri?: string) {
super("unsupported_grant_type", message, errorUri);
}
}
/**
* Invalid scope error - The requested scope is invalid, unknown, malformed, or
* exceeds the scope granted by the resource owner.
*/
export class InvalidScopeError extends OAuthError {
constructor(message: string, errorUri?: string) {
super("invalid_scope", message, errorUri);
}
}
/**
* Access denied error - The resource owner or authorization server denied the request.
*/
export class AccessDeniedError extends OAuthError {
constructor(message: string, errorUri?: string) {
super("access_denied", message, errorUri);
}
}
/**
* Server error - The authorization server encountered an unexpected condition
* that prevented it from fulfilling the request.
*/
export class ServerError extends OAuthError {
constructor(message: string, errorUri?: string) {
super("server_error", message, errorUri);
}
}
/**
* Temporarily unavailable error - The authorization server is currently unable to
* handle the request due to a temporary overloading or maintenance of the server.
*/
export class TemporarilyUnavailableError extends OAuthError {
constructor(message: string, errorUri?: string) {
super("temporarily_unavailable", message, errorUri);
}
}
/**
* Unsupported response type error - The authorization server does not support
* obtaining an authorization code using this method.
*/
export class UnsupportedResponseTypeError extends OAuthError {
constructor(message: string, errorUri?: string) {
super("unsupported_response_type", message, errorUri);
}
}
/**
* Unsupported token type error - The authorization server does not support
* the requested token type.
*/
export class UnsupportedTokenTypeError extends OAuthError {
constructor(message: string, errorUri?: string) {
super("unsupported_token_type", message, errorUri);
}
}
/**
* Invalid token error - The access token provided is expired, revoked, malformed,
* or invalid for other reasons.
*/
export class InvalidTokenError extends OAuthError {
constructor(message: string, errorUri?: string) {
super("invalid_token", message, errorUri);
}
}
/**
* Method not allowed error - The HTTP method used is not allowed for this endpoint.
* (Custom, non-standard error)
*/
export class MethodNotAllowedError extends OAuthError {
constructor(message: string, errorUri?: string) {
super("method_not_allowed", message, errorUri);
}
}
/**
* Too many requests error - Rate limit exceeded.
* (Custom, non-standard error based on RFC 6585)
*/
export class TooManyRequestsError extends OAuthError {
constructor(message: string, errorUri?: string) {
super("too_many_requests", message, errorUri);
}
}
/**
* Invalid client metadata error - The client metadata is invalid.
* (Custom error for dynamic client registration - RFC 7591)
*/
export class InvalidClientMetadataError extends OAuthError {
constructor(message: string, errorUri?: string) {
super("invalid_client_metadata", message, errorUri);
}
}
/**
* Insufficient scope error - The request requires higher privileges than provided by the access token.
*/
export class InsufficientScopeError extends OAuthError {
constructor(message: string, errorUri?: string) {
super("insufficient_scope", message, errorUri);
}
}
================================================
File: src/server/auth/provider.ts
================================================
import { Response } from "express";
import { OAuthRegisteredClientsStore } from "./clients.js";
import { OAuthClientInformationFull, OAuthTokenRevocationRequest, OAuthTokens } from "../../shared/auth.js";
import { AuthInfo } from "./types.js";
export type AuthorizationParams = {
state?: string;
scopes?: string[];
codeChallenge: string;
redirectUri: string;
};
/**
* Implements an end-to-end OAuth server.
*/
export interface OAuthServerProvider {
/**
* A store used to read information about registered OAuth clients.
*/
get clientsStore(): OAuthRegisteredClientsStore;
/**
* Begins the authorization flow, which can either be implemented by this server itself or via redirection to a separate authorization server.
*
* This server must eventually issue a redirect with an authorization response or an error response to the given redirect URI. Per OAuth 2.1:
* - In the successful case, the redirect MUST include the `code` and `state` (if present) query parameters.
* - In the error case, the redirect MUST include the `error` query parameter, and MAY include an optional `error_description` query parameter.
*/
authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise<void>;
/**
* Returns the `codeChallenge` that was used when the indicated authorization began.
*/
challengeForAuthorizationCode(client: OAuthClientInformationFull, authorizationCode: string): Promise<string>;
/**
* Exchanges an authorization code for an access token.
*/
exchangeAuthorizationCode(client: OAuthClientInformationFull, authorizationCode: string): Promise<OAuthTokens>;
/**
* Exchanges a refresh token for an access token.
*/
exchangeRefreshToken(client: OAuthClientInformationFull, refreshToken: string, scopes?: string[]): Promise<OAuthTokens>;
/**
* Verifies an access token and returns information about it.
*/
verifyAccessToken(token: string): Promise<AuthInfo>;
/**
* Revokes an access or refresh token. If unimplemented, token revocation is not supported (not recommended).
*
* If the given token is invalid or already revoked, this method should do nothing.
*/
revokeToken?(client: OAuthClientInformationFull, request: OAuthTokenRevocationRequest): Promise<void>;
}
================================================
File: src/server/auth/router.test.ts
================================================
import { mcpAuthRouter, AuthRouterOptions } from './router.js';
import { OAuthServerProvider, AuthorizationParams } from './provider.js';
import { OAuthRegisteredClientsStore } from './clients.js';
import { OAuthClientInformationFull, OAuthTokenRevocationRequest, OAuthTokens } from '../../shared/auth.js';
import express, { Response } from 'express';
import supertest from 'supertest';
import { AuthInfo } from './types.js';
import { InvalidTokenError } from './errors.js';
describe('MCP Auth Router', () => {
// Setup mock provider with full capabilities
const mockClientStore: OAuthRegisteredClientsStore = {
async getClient(clientId: string): Promise<OAuthClientInformationFull | undefined> {
if (clientId === 'valid-client') {
return {
client_id: 'valid-client',
client_secret: 'valid-secret',
redirect_uris: ['https://example.com/callback']
};
}
return undefined;
},
async registerClient(client: OAuthClientInformationFull): Promise<OAuthClientInformationFull> {
return client;
}
};
const mockProvider: OAuthServerProvider = {
clientsStore: mockClientStore,
async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise<void> {
const redirectUrl = new URL(params.redirectUri);
redirectUrl.searchParams.set('code', 'mock_auth_code');
if (params.state) {
redirectUrl.searchParams.set('state', params.state);
}
res.redirect(302, redirectUrl.toString());
},
async challengeForAuthorizationCode(): Promise<string> {
return 'mock_challenge';
},
async exchangeAuthorizationCode(): Promise<OAuthTokens> {
return {
access_token: 'mock_access_token',
token_type: 'bearer',
expires_in: 3600,
refresh_token: 'mock_refresh_token'
};
},
async exchangeRefreshToken(): Promise<OAuthTokens> {
return {
access_token: 'new_mock_access_token',
token_type: 'bearer',
expires_in: 3600,
refresh_token: 'new_mock_refresh_token'
};
},
async verifyAccessToken(token: string): Promise<AuthInfo> {
if (token === 'valid_token') {
return {
token,
clientId: 'valid-client',
scopes: ['read', 'write'],
expiresAt: Date.now() / 1000 + 3600
};
}
throw new InvalidTokenError('Token is invalid or expired');
},
async revokeToken(_client: OAuthClientInformationFull, _request: OAuthTokenRevocationRequest): Promise<void> {
// Success - do nothing in mock
}
};
// Provider without registration and revocation
const mockProviderMinimal: OAuthServerProvider = {
clientsStore: {
async getClient(clientId: string): Promise<OAuthClientInformationFull | undefined> {
if (clientId === 'valid-client') {
return {
client_id: 'valid-client',
client_secret: 'valid-secret',
redirect_uris: ['https://example.com/callback']
};
}
return undefined;
}
},
async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise<void> {
const redirectUrl = new URL(params.redirectUri);
redirectUrl.searchParams.set('code', 'mock_auth_code');
if (params.state) {
redirectUrl.searchParams.set('state', params.state);
}
res.redirect(302, redirectUrl.toString());
},
async challengeForAuthorizationCode(): Promise<string> {
return 'mock_challenge';
},
async exchangeAuthorizationCode(): Promise<OAuthTokens> {
return {
access_token: 'mock_access_token',
token_type: 'bearer',
expires_in: 3600,
refresh_token: 'mock_refresh_token'
};
},
async exchangeRefreshToken(): Promise<OAuthTokens> {
return {
access_token: 'new_mock_access_token',
token_type: 'bearer',
expires_in: 3600,
refresh_token: 'new_mock_refresh_token'
};
},
async verifyAccessToken(token: string): Promise<AuthInfo> {
if (token === 'valid_token') {
return {
token,
clientId: 'valid-client',
scopes: ['read'],
expiresAt: Date.now() / 1000 + 3600
};
}
throw new InvalidTokenError('Token is invalid or expired');
}
};
describe('Router creation', () => {
it('throws error for non-HTTPS issuer URL', () => {
const options: AuthRouterOptions = {
provider: mockProvider,
issuerUrl: new URL('http://auth.example.com')
};
expect(() => mcpAuthRouter(options)).toThrow('Issuer URL must be HTTPS');
});
it('allows localhost HTTP for development', () => {
const options: AuthRouterOptions = {
provider: mockProvider,
issuerUrl: new URL('http://localhost:3000')
};
expect(() => mcpAuthRouter(options)).not.toThrow();
});
it('throws error for issuer URL with fragment', () => {
const options: AuthRouterOptions = {
provider: mockProvider,
issuerUrl: new URL('https://auth.example.com#fragment')
};
expect(() => mcpAuthRouter(options)).toThrow('Issuer URL must not have a fragment');
});
it('throws error for issuer URL with query string', () => {
const options: AuthRouterOptions = {
provider: mockProvider,
issuerUrl: new URL('https://auth.example.com?param=value')
};
expect(() => mcpAuthRouter(options)).toThrow('Issuer URL must not have a query string');
});
it('successfully creates router with valid options', () => {
const options: AuthRouterOptions = {
provider: mockProvider,
issuerUrl: new URL('https://auth.example.com')
};
expect(() => mcpAuthRouter(options)).not.toThrow();
});
});
describe('Metadata endpoint', () => {
let app: express.Express;
beforeEach(() => {
// Setup full-featured router
app = express();
const options: AuthRouterOptions = {
provider: mockProvider,
issuerUrl: new URL('https://auth.example.com'),
serviceDocumentationUrl: new URL('https://docs.example.com')
};
app.use(mcpAuthRouter(options));
});
it('returns complete metadata for full-featured router', async () => {
const response = await supertest(app)
.get('/.well-known/oauth-authorization-server');
expect(response.status).toBe(200);
// Verify essential fields
expect(response.body.issuer).toBe('https://auth.example.com/');
expect(response.body.authorization_endpoint).toBe('https://auth.example.com/authorize');
expect(response.body.token_endpoint).toBe('https://auth.example.com/token');
expect(response.body.registration_endpoint).toBe('https://auth.example.com/register');
expect(response.body.revocation_endpoint).toBe('https://auth.example.com/revoke');
// Verify supported features
expect(response.body.response_types_supported).toEqual(['code']);
expect(response.body.grant_types_supported).toEqual(['authorization_code', 'refresh_token']);
expect(response.body.code_challenge_methods_supported).toEqual(['S256']);
expect(response.body.token_endpoint_auth_methods_supported).toEqual(['client_secret_post']);
expect(response.body.revocation_endpoint_auth_methods_supported).toEqual(['client_secret_post']);
// Verify optional fields
expect(response.body.service_documentation).toBe('https://docs.example.com/');
});
it('returns minimal metadata for minimal router', async () => {
// Setup minimal router
const minimalApp = express();
const options: AuthRouterOptions = {
provider: mockProviderMinimal,
issuerUrl: new URL('https://auth.example.com')
};
minimalApp.use(mcpAuthRouter(options));
const response = await supertest(minimalApp)
.get('/.well-known/oauth-authorization-server');
expect(response.status).toBe(200);
// Verify essential endpoints
expect(response.body.issuer).toBe('https://auth.example.com/');
expect(response.body.authorization_endpoint).toBe('https://auth.example.com/authorize');
expect(response.body.token_endpoint).toBe('https://auth.example.com/token');
// Verify missing optional endpoints
expect(response.body.registration_endpoint).toBeUndefined();
expect(response.body.revocation_endpoint).toBeUndefined();
expect(response.body.revocation_endpoint_auth_methods_supported).toBeUndefined();
expect(response.body.service_documentation).toBeUndefined();
});
});
describe('Endpoint routing', () => {
let app: express.Express;
beforeEach(() => {
// Setup full-featured router
app = express();
const options: AuthRouterOptions = {
provider: mockProvider,
issuerUrl: new URL('https://auth.example.com')
};
app.use(mcpAuthRouter(options));
});
it('routes to authorization endpoint', async () => {
const response = await supertest(app)
.get('/authorize')
.query({
client_id: 'valid-client',
response_type: 'code',
code_challenge: 'challenge123',
code_challenge_method: 'S256'
});
expect(response.status).toBe(302);
const location = new URL(response.header.location);
expect(location.searchParams.has('code')).toBe(true);
});
it('routes to token endpoint', async () => {
// Setup verifyChallenge mock for token handler
jest.mock('pkce-challenge', () => ({
verifyChallenge: jest.fn().mockResolvedValue(true)
}));
const response = await supertest(app)
.post('/token')
.type('form')
.send({
client_id: 'valid-client',
client_secret: 'valid-secret',
grant_type: 'authorization_code',
code: 'valid_code',
code_verifier: 'valid_verifier'
});
// The request will fail in testing due to mocking limitations,
// but we can verify the route was matched
expect(response.status).not.toBe(404);
});
it('routes to registration endpoint', async () => {
const response = await supertest(app)
.post('/register')
.send({
redirect_uris: ['https://example.com/callback']
});
// The request will fail in testing due to mocking limitations,
// but we can verify the route was matched
expect(response.status).not.toBe(404);
});
it('routes to revocation endpoint', async () => {
const response = await supertest(app)
.post('/revoke')
.type('form')
.send({
client_id: 'valid-client',
client_secret: 'valid-secret',
token: 'token_to_revoke'
});
// The request will fail in testing due to mocking limitations,
// but we can verify the route was matched
expect(response.status).not.toBe(404);
});
it('excludes endpoints for unsupported features', async () => {
// Setup minimal router
const minimalApp = express();
const options: AuthRouterOptions = {
provider: mockProviderMinimal,
issuerUrl: new URL('https://auth.example.com')
};
minimalApp.use(mcpAuthRouter(options));
// Registration should not be available
const regResponse = await supertest(minimalApp)
.post('/register')
.send({
redirect_uris: ['https://example.com/callback']
});
expect(regResponse.status).toBe(404);
// Revocation should not be available
const revokeResponse = await supertest(minimalApp)
.post('/revoke')
.send({
client_id: 'valid-client',
client_secret: 'valid-secret',
token: 'token_to_revoke'
});
expect(revokeResponse.status).toBe(404);
});
});
});
================================================
File: src/server/auth/router.ts
================================================
import express, { RequestHandler } from "express";
import { clientRegistrationHandler, ClientRegistrationHandlerOptions } from "./handlers/register.js";
import { tokenHandler, TokenHandlerOptions } from "./handlers/token.js";
import { authorizationHandler, AuthorizationHandlerOptions } from "./handlers/authorize.js";
import { revocationHandler, RevocationHandlerOptions } from "./handlers/revoke.js";
import { metadataHandler } from "./handlers/metadata.js";
import { OAuthServerProvider } from "./provider.js";
export type AuthRouterOptions = {
/**
* A provider implementing the actual authorization logic for this router.
*/
provider: OAuthServerProvider;
/**
* The authorization server's issuer identifier, which is a URL that uses the "https" scheme and has no query or fragment components.
*/
issuerUrl: URL;
/**
* An optional URL of a page containing human-readable information that developers might want or need to know when using the authorization server.
*/
serviceDocumentationUrl?: URL;
// Individual options per route
authorizationOptions?: Omit<AuthorizationHandlerOptions, "provider">;
clientRegistrationOptions?: Omit<ClientRegistrationHandlerOptions, "clientsStore">;
revocationOptions?: Omit<RevocationHandlerOptions, "provider">;
tokenOptions?: Omit<TokenHandlerOptions, "provider">;
};
/**
* Installs standard MCP authorization endpoints, including dynamic client registration and token revocation (if supported). Also advertises standard authorization server metadata, for easier discovery of supported configurations by clients.
*
* By default, rate limiting is applied to all endpoints to prevent abuse.
*
* This router MUST be installed at the application root, like so:
*
* const app = express();
* app.use(mcpAuthRouter(...));
*/
export function mcpAuthRouter(options: AuthRouterOptions): RequestHandler {
const issuer = options.issuerUrl;
// Technically RFC 8414 does not permit a localhost HTTPS exemption, but this will be necessary for ease of testing
if (issuer.protocol !== "https:" && issuer.hostname !== "localhost" && issuer.hostname !== "127.0.0.1") {
throw new Error("Issuer URL must be HTTPS");
}
if (issuer.hash) {
throw new Error("Issuer URL must not have a fragment");
}
if (issuer.search) {
throw new Error("Issuer URL must not have a query string");
}
const authorization_endpoint = "/authorize";
const token_endpoint = "/token";
const registration_endpoint = options.provider.clientsStore.registerClient ? "/register" : undefined;
const revocation_endpoint = options.provider.revokeToken ? "/revoke" : undefined;
const metadata = {
issuer: issuer.href,
service_documentation: options.serviceDocumentationUrl?.href,
authorization_endpoint: new URL(authorization_endpoint, issuer).href,
response_types_supported: ["code"],
code_challenge_methods_supported: ["S256"],
token_endpoint: new URL(token_endpoint, issuer).href,
token_endpoint_auth_methods_supported: ["client_secret_post"],
grant_types_supported: ["authorization_code", "refresh_token"],
revocation_endpoint: revocation_endpoint ? new URL(revocation_endpoint, issuer).href : undefined,
revocation_endpoint_auth_methods_supported: revocation_endpoint ? ["client_secret_post"] : undefined,
registration_endpoint: registration_endpoint ? new URL(registration_endpoint, issuer).href : undefined,
};
const router = express.Router();
router.use(
authorization_endpoint,
authorizationHandler({ provider: options.provider, ...options.authorizationOptions })
);
router.use(
token_endpoint,
tokenHandler({ provider: options.provider, ...options.tokenOptions })
);
router.use("/.well-known/oauth-authorization-server", metadataHandler(metadata));
if (registration_endpoint) {
router.use(
registration_endpoint,
clientRegistrationHandler({
clientsStore: options.provider.clientsStore,
...options,
})
);
}
if (revocation_endpoint) {
router.use(
revocation_endpoint,
revocationHandler({ provider: options.provider, ...options.revocationOptions })
);
}
return router;
}
================================================
File: src/server/auth/types.ts
================================================
/**
* Information about a validated access token, provided to request handlers.
*/
export interface AuthInfo {
/**
* The access token.
*/
token: string;
/**
* The client ID associated with this token.
*/
clientId: string;
/**
* Scopes associated with this token.
*/
scopes: string[];
/**
* When the token expires (in seconds since epoch).
*/
expiresAt?: number;
}
================================================
File: src/server/auth/handlers/authorize.test.ts
================================================
import { authorizationHandler, AuthorizationHandlerOptions } from './authorize.js';
import { OAuthServerProvider, AuthorizationParams } from '../provider.js';
import { OAuthRegisteredClientsStore } from '../clients.js';
import { OAuthClientInformationFull, OAuthTokens } from '../../../shared/auth.js';
import express, { Response } from 'express';
import supertest from 'supertest';
import { AuthInfo } from '../types.js';
import { InvalidTokenError } from '../errors.js';
describe('Authorization Handler', () => {
// Mock client data
const validClient: OAuthClientInformationFull = {
client_id: 'valid-client',
client_secret: 'valid-secret',
redirect_uris: ['https://example.com/callback'],
scope: 'profile email'
};
const multiRedirectClient: OAuthClientInformationFull = {
client_id: 'multi-redirect-client',
client_secret: 'valid-secret',
redirect_uris: [
'https://example.com/callback1',
'https://example.com/callback2'
],
scope: 'profile email'
};
// Mock client store
const mockClientStore: OAuthRegisteredClientsStore = {
async getClient(clientId: string): Promise<OAuthClientInformationFull | undefined> {
if (clientId === 'valid-client') {
return validClient;
} else if (clientId === 'multi-redirect-client') {
return multiRedirectClient;
}
return undefined;
}
};
// Mock provider
const mockProvider: OAuthServerProvider = {
clientsStore: mockClientStore,
async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise<void> {
// Mock implementation - redirects to redirectUri with code and state
const redirectUrl = new URL(params.redirectUri);
redirectUrl.searchParams.set('code', 'mock_auth_code');
if (params.state) {
redirectUrl.searchParams.set('state', params.state);
}
res.redirect(302, redirectUrl.toString());
},
async challengeForAuthorizationCode(): Promise<string> {
return 'mock_challenge';
},
async exchangeAuthorizationCode(): Promise<OAuthTokens> {
return {
access_token: 'mock_access_token',
token_type: 'bearer',
expires_in: 3600,
refresh_token: 'mock_refresh_token'
};
},
async exchangeRefreshToken(): Promise<OAuthTokens> {
return {
access_token: 'new_mock_access_token',
token_type: 'bearer',
expires_in: 3600,
refresh_token: 'new_mock_refresh_token'
};
},
async verifyAccessToken(token: string): Promise<AuthInfo> {
if (token === 'valid_token') {
return {
token,
clientId: 'valid-client',
scopes: ['read', 'write'],
expiresAt: Date.now() / 1000 + 3600
};
}
throw new InvalidTokenError('Token is invalid or expired');
},
async revokeToken(): Promise<void> {
// Do nothing in mock
}
};
// Setup express app with handler
let app: express.Express;
let options: AuthorizationHandlerOptions;
beforeEach(() => {
app = express();
options = { provider: mockProvider };
const handler = authorizationHandler(options);
app.use('/authorize', handler);
});
describe('HTTP method validation', () => {
it('rejects non-GET/POST methods', async () => {
const response = await supertest(app)
.put('/authorize')
.query({ client_id: 'valid-client' });
expect(response.status).toBe(405); // Method not allowed response from handler
});
});
describe('Client validation', () => {
it('requires client_id parameter', async () => {
const response = await supertest(app)
.get('/authorize');
expect(response.status).toBe(400);
expect(response.text).toContain('client_id');
});
it('validates that client exists', async () => {
const response = await supertest(app)
.get('/authorize')
.query({ client_id: 'nonexistent-client' });
expect(response.status).toBe(400);
});
});
describe('Redirect URI validation', () => {
it('uses the only redirect_uri if client has just one and none provided', async () => {
const response = await supertest(app)
.get('/authorize')
.query({
client_id: 'valid-client',
response_type: 'code',
code_challenge: 'challenge123',
code_challenge_method: 'S256'
});
expect(response.status).toBe(302);
const location = new URL(response.header.location);
expect(location.origin + location.pathname).toBe('https://example.com/callback');
});
it('requires redirect_uri if client has multiple', async () => {
const response = await supertest(app)
.get('/authorize')
.query({
client_id: 'multi-redirect-client',
response_type: 'code',
code_challenge: 'challenge123',
code_challenge_method: 'S256'
});
expect(response.status).toBe(400);
});
it('validates redirect_uri against client registered URIs', async () => {
const response = await supertest(app)
.get('/authorize')
.query({
client_id: 'valid-client',
redirect_uri: 'https://malicious.com/callback',
response_type: 'code',
code_challenge: 'challenge123',
code_challenge_method: 'S256'
});
expect(response.status).toBe(400);
});
it('accepts valid redirect_uri that client registered with', async () => {
const response = await supertest(app)
.get('/authorize')
.query({
client_id: 'valid-client',
redirect_uri: 'https://example.com/callback',
response_type: 'code',
code_challenge: 'challenge123',
code_challenge_method: 'S256'
});
expect(response.status).toBe(302);
const location = new URL(response.header.location);
expect(location.origin + location.pathname).toBe('https://example.com/callback');
});
});
describe('Authorization request validation', () => {
it('requires response_type=code', async () => {
const response = await supertest(app)
.get('/authorize')
.query({
client_id: 'valid-client',
redirect_uri: 'https://example.com/callback',
response_type: 'token', // invalid - we only support code flow
code_challenge: 'challenge123',
code_challenge_method: 'S256'
});
expect(response.status).toBe(302);
const location = new URL(response.header.location);
expect(location.searchParams.get('error')).toBe('invalid_request');
});
it('requires code_challenge parameter', async () => {
const response = await supertest(app)
.get('/authorize')
.query({
client_id: 'valid-client',
redirect_uri: 'https://example.com/callback',
response_type: 'code',
code_challenge_method: 'S256'
// Missing code_challenge
});
expect(response.status).toBe(302);
const location = new URL(response.header.location);
expect(location.searchParams.get('error')).toBe('invalid_request');
});
it('requires code_challenge_method=S256', async () => {
const response = await supertest(app)
.get('/authorize')
.query({
client_id: 'valid-client',
redirect_uri: 'https://example.com/callback',
response_type: 'code',
code_challenge: 'challenge123',
code_challenge_method: 'plain' // Only S256 is supported
});
expect(response.status).toBe(302);
const location = new URL(response.header.location);
expect(location.searchParams.get('error')).toBe('invalid_request');
});
});
describe('Scope validation', () => {
it('validates requested scopes against client registered scopes', async () => {
const response = await supertest(app)
.get('/authorize')
.query({
client_id: 'valid-client',
redirect_uri: 'https://example.com/callback',
response_type: 'code',
code_challenge: 'challenge123',
code_challenge_method: 'S256',
scope: 'profile email admin' // 'admin' not in client scopes
});
expect(response.status).toBe(302);
const location = new URL(response.header.location);
expect(location.searchParams.get('error')).toBe('invalid_scope');
});
it('accepts valid scopes subset', async () => {
const response = await supertest(app)
.get('/authorize')
.query({
client_id: 'valid-client',
redirect_uri: 'https://example.com/callback',
response_type: 'code',
code_challenge: 'challenge123',
code_challenge_method: 'S256',
scope: 'profile' // subset of client scopes
});
expect(response.status).toBe(302);
const location = new URL(response.header.location);
expect(location.searchParams.has('code')).toBe(true);
});
});
describe('Successful authorization', () => {
it('handles successful authorization with all parameters', async () => {
const response = await supertest(app)
.get('/authorize')
.query({
client_id: 'valid-client',
redirect_uri: 'https://example.com/callback',
response_type: 'code',
code_challenge: 'challenge123',
code_challenge_method: 'S256',
scope: 'profile email',
state: 'xyz789'
});
expect(response.status).toBe(302);
const location = new URL(response.header.location);
expect(location.origin + location.pathname).toBe('https://example.com/callback');
expect(location.searchParams.get('code')).toBe('mock_auth_code');
expect(location.searchParams.get('state')).toBe('xyz789');
});
it('preserves state parameter in response', async () => {
const response = await supertest(app)
.get('/authorize')
.query({
client_id: 'valid-client',
redirect_uri: 'https://example.com/callback',
response_type: 'code',
code_challenge: 'challenge123',
code_challenge_method: 'S256',
state: 'state-value-123'
});
expect(response.status).toBe(302);
const location = new URL(response.header.location);
expect(location.searchParams.get('state')).toBe('state-value-123');
});
it('handles POST requests the same as GET', async () => {
const response = await supertest(app)
.post('/authorize')
.type('form')
.send({
client_id: 'valid-client',
response_type: 'code',
code_challenge: 'challenge123',
code_challenge_method: 'S256'
});
expect(response.status).toBe(302);
const location = new URL(response.header.location);
expect(location.searchParams.has('code')).toBe(true);
});
});
});
================================================
File: src/server/auth/handlers/authorize.ts
================================================
import { RequestHandler } from "express";
import { z } from "zod";
import express from "express";
import { OAuthServerProvider } from "../provider.js";
import { rateLimit, Options as RateLimitOptions } from "express-rate-limit";
import { allowedMethods } from "../middleware/allowedMethods.js";
import {
InvalidRequestError,
InvalidClientError,
InvalidScopeError,
ServerError,
TooManyRequestsError,
OAuthError
} from "../errors.js";
export type AuthorizationHandlerOptions = {
provider: OAuthServerProvider;
/**
* Rate limiting configuration for the authorization endpoint.
* Set to false to disable rate limiting for this endpoint.
*/
rateLimit?: Partial<RateLimitOptions> | false;
};
// Parameters that must be validated in order to issue redirects.
const ClientAuthorizationParamsSchema = z.object({
client_id: z.string(),
redirect_uri: z.string().optional().refine((value) => value === undefined || URL.canParse(value), { message: "redirect_uri must be a valid URL" }),
});
// Parameters that must be validated for a successful authorization request. Failure can be reported to the redirect URI.
const RequestAuthorizationParamsSchema = z.object({
response_type: z.literal("code"),
code_challenge: z.string(),
code_challenge_method: z.literal("S256"),
scope: z.string().optional(),
state: z.string().optional(),
});
export function authorizationHandler({ provider, rateLimit: rateLimitConfig }: AuthorizationHandlerOptions): RequestHandler {
// Create a router to apply middleware
const router = express.Router();
router.use(allowedMethods(["GET", "POST"]));
router.use(express.urlencoded({ extended: false }));
// Apply rate limiting unless explicitly disabled
if (rateLimitConfig !== false) {
router.use(rateLimit({
windowMs: 15 * 60 * 1000, // 15 minutes
max: 100, // 100 requests per windowMs
standardHeaders: true,
legacyHeaders: false,
message: new TooManyRequestsError('You have exceeded the rate limit for authorization requests').toResponseObject(),
...rateLimitConfig
}));
}
router.all("/", async (req, res) => {
res.setHeader('Cache-Control', 'no-store');
// In the authorization flow, errors are split into two categories:
// 1. Pre-redirect errors (direct response with 400)
// 2. Post-redirect errors (redirect with error parameters)
// Phase 1: Validate client_id and redirect_uri. Any errors here must be direct responses.
let client_id, redirect_uri, client;
try {
const result = ClientAuthorizationParamsSchema.safeParse(req.method === 'POST' ? req.body : req.query);
if (!result.success) {
throw new InvalidRequestError(result.error.message);
}
client_id = result.data.client_id;
redirect_uri = result.data.redirect_uri;
client = await provider.clientsStore.getClient(client_id);
if (!client) {
throw new InvalidClientError("Invalid client_id");
}
if (redirect_uri !== undefined) {
if (!client.redirect_uris.includes(redirect_uri)) {
throw new InvalidRequestError("Unregistered redirect_uri");
}
} else if (client.redirect_uris.length === 1) {
redirect_uri = client.redirect_uris[0];
} else {
throw new InvalidRequestError("redirect_uri must be specified when client has multiple registered URIs");
}
} catch (error) {
// Pre-redirect errors - return direct response
//
// These don't need to be JSON encoded, as they'll be displayed in a user
// agent, but OTOH they all represent exceptional situations (arguably,
// "programmer error"), so presenting a nice HTML page doesn't help the
// user anyway.
if (error instanceof OAuthError) {
const status = error instanceof ServerError ? 500 : 400;
res.status(status).json(error.toResponseObject());
} else {
console.error("Unexpected error looking up client:", error);
const serverError = new ServerError("Internal Server Error");
res.status(500).json(serverError.toResponseObject());
}
return;
}
// Phase 2: Validate other parameters. Any errors here should go into redirect responses.
let state;
try {
// Parse and validate authorization parameters
const parseResult = RequestAuthorizationParamsSchema.safeParse(req.method === 'POST' ? req.body : req.query);
if (!parseResult.success) {
throw new InvalidRequestError(parseResult.error.message);
}
const { scope, code_challenge } = parseResult.data;
state = parseResult.data.state;
// Validate scopes
let requestedScopes: string[] = [];
if (scope !== undefined) {
requestedScopes = scope.split(" ");
const allowedScopes = new Set(client.scope?.split(" "));
// Check each requested scope against allowed scopes
for (const scope of requestedScopes) {
if (!allowedScopes.has(scope)) {
throw new InvalidScopeError(`Client was not registered with scope ${scope}`);
}
}
}
// All validation passed, proceed with authorization
await provider.authorize(client, {
state,
scopes: requestedScopes,
redirectUri: redirect_uri,
codeChallenge: code_challenge,
}, res);
} catch (error) {
// Post-redirect errors - redirect with error parameters
if (error instanceof OAuthError) {
res.redirect(302, createErrorRedirect(redirect_uri, error, state));
} else {
console.error("Unexpected error during authorization:", error);
const serverError = new ServerError("Internal Server Error");
res.redirect(302, createErrorRedirect(redirect_uri, serverError, state));
}
}
});
return router;
}
/**
* Helper function to create redirect URL with error parameters
*/
function createErrorRedirect(redirectUri: string, error: OAuthError, state?: string): string {
const errorUrl = new URL(redirectUri);
errorUrl.searchParams.set("error", error.errorCode);
errorUrl.searchParams.set("error_description", error.message);
if (error.errorUri) {
errorUrl.searchParams.set("error_uri", error.errorUri);
}
if (state) {
errorUrl.searchParams.set("state", state);
}
return errorUrl.href;
}
================================================
File: src/server/auth/handlers/metadata.test.ts
================================================
import { metadataHandler } from './metadata.js';
import { OAuthMetadata } from '../../../shared/auth.js';
import express from 'express';
import supertest from 'supertest';
describe('Metadata Handler', () => {
const exampleMetadata: OAuthMetadata = {
issuer: 'https://auth.example.com',
authorization_endpoint: 'https://auth.example.com/authorize',
token_endpoint: 'https://auth.example.com/token',
registration_endpoint: 'https://auth.example.com/register',
revocation_endpoint: 'https://auth.example.com/revoke',
scopes_supported: ['profile', 'email'],
response_types_supported: ['code'],
grant_types_supported: ['authorization_code', 'refresh_token'],
token_endpoint_auth_methods_supported: ['client_secret_basic'],
code_challenge_methods_supported: ['S256']
};
let app: express.Express;
beforeEach(() => {
// Setup express app with metadata handler
app = express();
app.use('/.well-known/oauth-authorization-server', metadataHandler(exampleMetadata));
});
it('requires GET method', async () => {
const response = await supertest(app)
.post('/.well-known/oauth-authorization-server')
.send({});
expect(response.status).toBe(405);
expect(response.headers.allow).toBe('GET');
expect(response.body).toEqual({
error: "method_not_allowed",
error_description: "The method POST is not allowed for this endpoint"
});
});
it('returns the metadata object', async () => {
const response = await supertest(app)
.get('/.well-known/oauth-authorization-server');
expect(response.status).toBe(200);
expect(response.body).toEqual(exampleMetadata);
});
it('includes CORS headers in response', async () => {
const response = await supertest(app)
.get('/.well-known/oauth-authorization-server')
.set('Origin', 'https://example.com');
expect(response.header['access-control-allow-origin']).toBe('*');
});
it('supports OPTIONS preflight requests', async () => {
const response = await supertest(app)
.options('/.well-known/oauth-authorization-server')
.set('Origin', 'https://example.com')
.set('Access-Control-Request-Method', 'GET');
expect(response.status).toBe(204);
expect(response.header['access-control-allow-origin']).toBe('*');
});
it('works with minimal metadata', async () => {
// Setup a new express app with minimal metadata
const minimalApp = express();
const minimalMetadata: OAuthMetadata = {
issuer: 'https://auth.example.com',
authorization_endpoint: 'https://auth.example.com/authorize',
token_endpoint: 'https://auth.example.com/token',
response_types_supported: ['code']
};
minimalApp.use('/.well-known/oauth-authorization-server', metadataHandler(minimalMetadata));
const response = await supertest(minimalApp)
.get('/.well-known/oauth-authorization-server');
expect(response.status).toBe(200);
expect(response.body).toEqual(minimalMetadata);
});
});
================================================
File: src/server/auth/handlers/metadata.ts
================================================
import express, { RequestHandler } from "express";
import { OAuthMetadata } from "../../../shared/auth.js";
import cors from 'cors';
import { allowedMethods } from "../middleware/allowedMethods.js";
export function metadataHandler(metadata: OAuthMetadata): RequestHandler {
// Nested router so we can configure middleware and restrict HTTP method
const router = express.Router();
// Configure CORS to allow any origin, to make accessible to web-based MCP clients
router.use(cors());
router.use(allowedMethods(['GET']));
router.get("/", (req, res) => {
res.status(200).json(metadata);
});
return router;
}
================================================
File: src/server/auth/handlers/register.test.ts
================================================
import { clientRegistrationHandler, ClientRegistrationHandlerOptions } from './register.js';
import { OAuthRegisteredClientsStore } from '../clients.js';
import { OAuthClientInformationFull, OAuthClientMetadata } from '../../../shared/auth.js';
import express from 'express';
import supertest from 'supertest';
describe('Client Registration Handler', () => {
// Mock client store with registration support
const mockClientStoreWithRegistration: OAuthRegisteredClientsStore = {
async getClient(_clientId: string): Promise<OAuthClientInformationFull | undefined> {
return undefined;
},
async registerClient(client: OAuthClientInformationFull): Promise<OAuthClientInformationFull> {
// Return the client info as-is in the mock
return client;
}
};
// Mock client store without registration support
const mockClientStoreWithoutRegistration: OAuthRegisteredClientsStore = {
async getClient(_clientId: string): Promise<OAuthClientInformationFull | undefined> {
return undefined;
}
// No registerClient method
};
describe('Handler creation', () => {
it('throws error if client store does not support registration', () => {
const options: ClientRegistrationHandlerOptions = {
clientsStore: mockClientStoreWithoutRegistration
};
expect(() => clientRegistrationHandler(options)).toThrow('does not support registering clients');
});
it('creates handler if client store supports registration', () => {
const options: ClientRegistrationHandlerOptions = {
clientsStore: mockClientStoreWithRegistration
};
expect(() => clientRegistrationHandler(options)).not.toThrow();
});
});
describe('Request handling', () => {
let app: express.Express;
let spyRegisterClient: jest.SpyInstance;
beforeEach(() => {
// Setup express app with registration handler
app = express();
const options: ClientRegistrationHandlerOptions = {
clientsStore: mockClientStoreWithRegistration,
clientSecretExpirySeconds: 86400 // 1 day for testing
};
app.use('/register', clientRegistrationHandler(options));
// Spy on the registerClient method
spyRegisterClient = jest.spyOn(mockClientStoreWithRegistration, 'registerClient');
});
afterEach(() => {
spyRegisterClient.mockRestore();
});
it('requires POST method', async () => {
const response = await supertest(app)
.get('/register')
.send({
redirect_uris: ['https://example.com/callback']
});
expect(response.status).toBe(405);
expect(response.headers.allow).toBe('POST');
expect(response.body).toEqual({
error: "method_not_allowed",
error_description: "The method GET is not allowed for this endpoint"
});
expect(spyRegisterClient).not.toHaveBeenCalled();
});
it('validates required client metadata', async () => {
const response = await supertest(app)
.post('/register')
.send({
// Missing redirect_uris (required)
client_name: 'Test Client'
});
expect(response.status).toBe(400);
expect(response.body.error).toBe('invalid_client_metadata');
expect(spyRegisterClient).not.toHaveBeenCalled();
});
it('validates redirect URIs format', async () => {
const response = await supertest(app)
.post('/register')
.send({
redirect_uris: ['invalid-url'] // Invalid URL format
});
expect(response.status).toBe(400);
expect(response.body.error).toBe('invalid_client_metadata');
expect(response.body.error_description).toContain('redirect_uris');
expect(spyRegisterClient).not.toHaveBeenCalled();
});
it('successfully registers client with minimal metadata', async () => {
const clientMetadata: OAuthClientMetadata = {
redirect_uris: ['https://example.com/callback']
};
const response = await supertest(app)
.post('/register')
.send(clientMetadata);
expect(response.status).toBe(201);
// Verify the generated client information
expect(response.body.client_id).toBeDefined();
expect(response.body.client_secret).toBeDefined();
expect(response.body.client_id_issued_at).toBeDefined();
expect(response.body.client_secret_expires_at).toBeDefined();
expect(response.body.redirect_uris).toEqual(['https://example.com/callback']);
// Verify client was registered
expect(spyRegisterClient).toHaveBeenCalledTimes(1);
});
it('sets client_secret to undefined for token_endpoint_auth_method=none', async () => {
const clientMetadata: OAuthClientMetadata = {
redirect_uris: ['https://example.com/callback'],
token_endpoint_auth_method: 'none'
};
const response = await supertest(app)
.post('/register')
.send(clientMetadata);
expect(response.status).toBe(201);
expect(response.body.client_secret).toBeUndefined();
expect(response.body.client_secret_expires_at).toBeUndefined();
});
it('sets client_secret_expires_at for public clients only', async () => {
// Test for public client (token_endpoint_auth_method not 'none')
const publicClientMetadata: OAuthClientMetadata = {
redirect_uris: ['https://example.com/callback'],
token_endpoint_auth_method: 'client_secret_basic'
};
const publicResponse = await supertest(app)
.post('/register')
.send(publicClientMetadata);
expect(publicResponse.status).toBe(201);
expect(publicResponse.body.client_secret).toBeDefined();
expect(publicResponse.body.client_secret_expires_at).toBeDefined();
// Test for non-public client (token_endpoint_auth_method is 'none')
const nonPublicClientMetadata: OAuthClientMetadata = {
redirect_uris: ['https://example.com/callback'],
token_endpoint_auth_method: 'none'
};
const nonPublicResponse = await supertest(app)
.post('/register')
.send(nonPublicClientMetadata);
expect(nonPublicResponse.status).toBe(201);
expect(nonPublicResponse.body.client_secret).toBeUndefined();
expect(nonPublicResponse.body.client_secret_expires_at).toBeUndefined();
});
it('sets expiry based on clientSecretExpirySeconds', async () => {
// Create handler with custom expiry time
const customApp = express();
const options: ClientRegistrationHandlerOptions = {
clientsStore: mockClientStoreWithRegistration,
clientSecretExpirySeconds: 3600 // 1 hour
};
customApp.use('/register', clientRegistrationHandler(options));
const response = await supertest(customApp)
.post('/register')
.send({
redirect_uris: ['https://example.com/callback']
});
expect(response.status).toBe(201);
// Verify the expiration time (~1 hour from now)
const issuedAt = response.body.client_id_issued_at;
const expiresAt = response.body.client_secret_expires_at;
expect(expiresAt - issuedAt).toBe(3600);
});
it('sets no expiry when clientSecretExpirySeconds=0', async () => {
// Create handler with no expiry
const customApp = express();
const options: ClientRegistrationHandlerOptions = {
clientsStore: mockClientStoreWithRegistration,
clientSecretExpirySeconds: 0 // No expiry
};
customApp.use('/register', clientRegistrationHandler(options));
const response = await supertest(customApp)
.post('/register')
.send({
redirect_uris: ['https://example.com/callback']
});
expect(response.status).toBe(201);
expect(response.body.client_secret_expires_at).toBe(0);
});
it('handles client with all metadata fields', async () => {
const fullClientMetadata: OAuthClientMetadata = {
redirect_uris: ['https://example.com/callback'],
token_endpoint_auth_method: 'client_secret_basic',
grant_types: ['authorization_code', 'refresh_token'],
response_types: ['code'],
client_name: 'Test Client',
client_uri: 'https://example.com',
logo_uri: 'https://example.com/logo.png',
scope: 'profile email',
contacts: ['[email protected]'],
tos_uri: 'https://example.com/tos',
policy_uri: 'https://example.com/privacy',
jwks_uri: 'https://example.com/jwks',
software_id: 'test-software',
software_version: '1.0.0'
};
const response = await supertest(app)
.post('/register')
.send(fullClientMetadata);
expect(response.status).toBe(201);
// Verify all metadata was preserved
Object.entries(fullClientMetadata).forEach(([key, value]) => {
expect(response.body[key]).toEqual(value);
});
});
it('includes CORS headers in response', async () => {
const response = await supertest(app)
.post('/register')
.set('Origin', 'https://example.com')
.send({
redirect_uris: ['https://example.com/callback']
});
expect(response.header['access-control-allow-origin']).toBe('*');
});
});
});
================================================
File: src/server/auth/handlers/register.ts
================================================
import express, { RequestHandler } from "express";
import { OAuthClientInformationFull, OAuthClientMetadataSchema } from "../../../shared/auth.js";
import crypto from 'node:crypto';
import cors from 'cors';
import { OAuthRegisteredClientsStore } from "../clients.js";
import { rateLimit, Options as RateLimitOptions } from "express-rate-limit";
import { allowedMethods } from "../middleware/allowedMethods.js";
import {
InvalidClientMetadataError,
ServerError,
TooManyRequestsError,
OAuthError
} from "../errors.js";
export type ClientRegistrationHandlerOptions = {
/**
* A store used to save information about dynamically registered OAuth clients.
*/
clientsStore: OAuthRegisteredClientsStore;
/**
* The number of seconds after which to expire issued client secrets, or 0 to prevent expiration of client secrets (not recommended).
*
* If not set, defaults to 30 days.
*/
clientSecretExpirySeconds?: number;
/**
* Rate limiting configuration for the client registration endpoint.
* Set to false to disable rate limiting for this endpoint.
* Registration endpoints are particularly sensitive to abuse and should be rate limited.
*/
rateLimit?: Partial<RateLimitOptions> | false;
};
const DEFAULT_CLIENT_SECRET_EXPIRY_SECONDS = 30 * 24 * 60 * 60; // 30 days
export function clientRegistrationHandler({
clientsStore,
clientSecretExpirySeconds = DEFAULT_CLIENT_SECRET_EXPIRY_SECONDS,
rateLimit: rateLimitConfig
}: ClientRegistrationHandlerOptions): RequestHandler {
if (!clientsStore.registerClient) {
throw new Error("Client registration store does not support registering clients");
}
// Nested router so we can configure middleware and restrict HTTP method
const router = express.Router();
// Configure CORS to allow any origin, to make accessible to web-based MCP clients
router.use(cors());
router.use(allowedMethods(["POST"]));
router.use(express.json());
// Apply rate limiting unless explicitly disabled - stricter limits for registration
if (rateLimitConfig !== false) {
router.use(rateLimit({
windowMs: 60 * 60 * 1000, // 1 hour
max: 20, // 20 requests per hour - stricter as registration is sensitive
standardHeaders: true,
legacyHeaders: false,
message: new TooManyRequestsError('You have exceeded the rate limit for client registration requests').toResponseObject(),
...rateLimitConfig
}));
}
router.post("/", async (req, res) => {
res.setHeader('Cache-Control', 'no-store');
try {
const parseResult = OAuthClientMetadataSchema.safeParse(req.body);
if (!parseResult.success) {
throw new InvalidClientMetadataError(parseResult.error.message);
}
const clientMetadata = parseResult.data;
const isPublicClient = clientMetadata.token_endpoint_auth_method === 'none'
// Generate client credentials
const clientId = crypto.randomUUID();
const clientSecret = isPublicClient
? undefined
: crypto.randomBytes(32).toString('hex');
const clientIdIssuedAt = Math.floor(Date.now() / 1000);
// Calculate client secret expiry time
const clientsDoExpire = clientSecretExpirySeconds > 0
const secretExpiryTime = clientsDoExpire ? clientIdIssuedAt + clientSecretExpirySeconds : 0
const clientSecretExpiresAt = isPublicClient ? undefined : secretExpiryTime
let clientInfo: OAuthClientInformationFull = {
...clientMetadata,
client_id: clientId,
client_secret: clientSecret,
client_id_issued_at: clientIdIssuedAt,
client_secret_expires_at: clientSecretExpiresAt,
};
clientInfo = await clientsStore.registerClient!(clientInfo);
res.status(201).json(clientInfo);
} catch (error) {
if (error instanceof OAuthError) {
const status = error instanceof ServerError ? 500 : 400;
res.status(status).json(error.toResponseObject());
} else {
console.error("Unexpected error registering client:", error);
const serverError = new ServerError("Internal Server Error");
res.status(500).json(serverError.toResponseObject());
}
}
});
return router;
}
================================================
File: src/server/auth/handlers/revoke.test.ts
================================================
import { revocationHandler, RevocationHandlerOptions } from './revoke.js';
import { OAuthServerProvider, AuthorizationParams } from '../provider.js';
import { OAuthRegisteredClientsStore } from '../clients.js';
import { OAuthClientInformationFull, OAuthTokenRevocationRequest, OAuthTokens } from '../../../shared/auth.js';
import express, { Response } from 'express';
import supertest from 'supertest';
import { AuthInfo } from '../types.js';
import { InvalidTokenError } from '../errors.js';
describe('Revocation Handler', () => {
// Mock client data
const validClient: OAuthClientInformationFull = {
client_id: 'valid-client',
client_secret: 'valid-secret',
redirect_uris: ['https://example.com/callback']
};
// Mock client store
const mockClientStore: OAuthRegisteredClientsStore = {
async getClient(clientId: string): Promise<OAuthClientInformationFull | undefined> {
if (clientId === 'valid-client') {
return validClient;
}
return undefined;
}
};
// Mock provider with revocation capability
const mockProviderWithRevocation: OAuthServerProvider = {
clientsStore: mockClientStore,
async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise<void> {
res.redirect('https://example.com/callback?code=mock_auth_code');
},
async challengeForAuthorizationCode(): Promise<string> {
return 'mock_challenge';
},
async exchangeAuthorizationCode(): Promise<OAuthTokens> {
return {
access_token: 'mock_access_token',
token_type: 'bearer',
expires_in: 3600,
refresh_token: 'mock_refresh_token'
};
},
async exchangeRefreshToken(): Promise<OAuthTokens> {
return {
access_token: 'new_mock_access_token',
token_type: 'bearer',
expires_in: 3600,
refresh_token: 'new_mock_refresh_token'
};
},
async verifyAccessToken(token: string): Promise<AuthInfo> {
if (token === 'valid_token') {
return {
token,
clientId: 'valid-client',
scopes: ['read', 'write'],
expiresAt: Date.now() / 1000 + 3600
};
}
throw new InvalidTokenError('Token is invalid or expired');
},
async revokeToken(_client: OAuthClientInformationFull, _request: OAuthTokenRevocationRequest): Promise<void> {
// Success - do nothing in mock
}
};
// Mock provider without revocation capability
const mockProviderWithoutRevocation: OAuthServerProvider = {
clientsStore: mockClientStore,
async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise<void> {
res.redirect('https://example.com/callback?code=mock_auth_code');
},
async challengeForAuthorizationCode(): Promise<string> {
return 'mock_challenge';
},
async exchangeAuthorizationCode(): Promise<OAuthTokens> {
return {
access_token: 'mock_access_token',
token_type: 'bearer',
expires_in: 3600,
refresh_token: 'mock_refresh_token'
};
},
async exchangeRefreshToken(): Promise<OAuthTokens> {
return {
access_token: 'new_mock_access_token',
token_type: 'bearer',
expires_in: 3600,
refresh_token: 'new_mock_refresh_token'
};
},
async verifyAccessToken(token: string): Promise<AuthInfo> {
if (token === 'valid_token') {
return {
token,
clientId: 'valid-client',
scopes: ['read', 'write'],
expiresAt: Date.now() / 1000 + 3600
};
}
throw new InvalidTokenError('Token is invalid or expired');
}
// No revokeToken method
};
describe('Handler creation', () => {
it('throws error if provider does not support token revocation', () => {
const options: RevocationHandlerOptions = { provider: mockProviderWithoutRevocation };
expect(() => revocationHandler(options)).toThrow('does not support revoking tokens');
});
it('creates handler if provider supports token revocation', () => {
const options: RevocationHandlerOptions = { provider: mockProviderWithRevocation };
expect(() => revocationHandler(options)).not.toThrow();
});
});
describe('Request handling', () => {
let app: express.Express;
let spyRevokeToken: jest.SpyInstance;
beforeEach(() => {
// Setup express app with revocation handler
app = express();
const options: RevocationHandlerOptions = { provider: mockProviderWithRevocation };
app.use('/revoke', revocationHandler(options));
// Spy on the revokeToken method
spyRevokeToken = jest.spyOn(mockProviderWithRevocation, 'revokeToken');
});
afterEach(() => {
spyRevokeToken.mockRestore();
});
it('requires POST method', async () => {
const response = await supertest(app)
.get('/revoke')
.send({
client_id: 'valid-client',
client_secret: 'valid-secret',
token: 'token_to_revoke'
});
expect(response.status).toBe(405);
expect(response.headers.allow).toBe('POST');
expect(response.body).toEqual({
error: "method_not_allowed",
error_description: "The method GET is not allowed for this endpoint"
});
expect(spyRevokeToken).not.toHaveBeenCalled();
});
it('requires token parameter', async () => {
const response = await supertest(app)
.post('/revoke')
.type('form')
.send({
client_id: 'valid-client',
client_secret: 'valid-secret'
// Missing token
});
expect(response.status).toBe(400);
expect(response.body.error).toBe('invalid_request');
expect(spyRevokeToken).not.toHaveBeenCalled();
});
it('authenticates client before revoking token', async () => {
const response = await supertest(app)
.post('/revoke')
.type('form')
.send({
client_id: 'invalid-client',
client_secret: 'wrong-secret',
token: 'token_to_revoke'
});
expect(response.status).toBe(400);
expect(response.body.error).toBe('invalid_client');
expect(spyRevokeToken).not.toHaveBeenCalled();
});
it('successfully revokes token', async () => {
const response = await supertest(app)
.post('/revoke')
.type('form')
.send({
client_id: 'valid-client',
client_secret: 'valid-secret',
token: 'token_to_revoke'
});
expect(response.status).toBe(200);
expect(response.body).toEqual({}); // Empty response on success
expect(spyRevokeToken).toHaveBeenCalledTimes(1);
expect(spyRevokeToken).toHaveBeenCalledWith(validClient, {
token: 'token_to_revoke'
});
});
it('accepts optional token_type_hint', async () => {
const response = await supertest(app)
.post('/revoke')
.type('form')
.send({
client_id: 'valid-client',
client_secret: 'valid-secret',
token: 'token_to_revoke',
token_type_hint: 'refresh_token'
});
expect(response.status).toBe(200);
expect(spyRevokeToken).toHaveBeenCalledWith(validClient, {
token: 'token_to_revoke',
token_type_hint: 'refresh_token'
});
});
it('includes CORS headers in response', async () => {
const response = await supertest(app)
.post('/revoke')
.type('form')
.set('Origin', 'https://example.com')
.send({
client_id: 'valid-client',
client_secret: 'valid-secret',
token: 'token_to_revoke'
});
expect(response.header['access-control-allow-origin']).toBe('*');
});
});
});
================================================
File: src/server/auth/handlers/revoke.ts
================================================
import { OAuthServerProvider } from "../provider.js";
import express, { RequestHandler } from "express";
import cors from "cors";
import { authenticateClient } from "../middleware/clientAuth.js";
import { OAuthTokenRevocationRequestSchema } from "../../../shared/auth.js";
import { rateLimit, Options as RateLimitOptions } from "express-rate-limit";
import { allowedMethods } from "../middleware/allowedMethods.js";
import {
InvalidRequestError,
ServerError,
TooManyRequestsError,
OAuthError
} from "../errors.js";
export type RevocationHandlerOptions = {
provider: OAuthServerProvider;
/**
* Rate limiting configuration for the token revocation endpoint.
* Set to false to disable rate limiting for this endpoint.
*/
rateLimit?: Partial<RateLimitOptions> | false;
};
export function revocationHandler({ provider, rateLimit: rateLimitConfig }: RevocationHandlerOptions): RequestHandler {
if (!provider.revokeToken) {
throw new Error("Auth provider does not support revoking tokens");
}
// Nested router so we can configure middleware and restrict HTTP method
const router = express.Router();
// Configure CORS to allow any origin, to make accessible to web-based MCP clients
router.use(cors());
router.use(allowedMethods(["POST"]));
router.use(express.urlencoded({ extended: false }));
// Apply rate limiting unless explicitly disabled
if (rateLimitConfig !== false) {
router.use(rateLimit({
windowMs: 15 * 60 * 1000, // 15 minutes
max: 50, // 50 requests per windowMs
standardHeaders: true,
legacyHeaders: false,
message: new TooManyRequestsError('You have exceeded the rate limit for token revocation requests').toResponseObject(),
...rateLimitConfig
}));
}
// Authenticate and extract client details
router.use(authenticateClient({ clientsStore: provider.clientsStore }));
router.post("/", async (req, res) => {
res.setHeader('Cache-Control', 'no-store');
try {
const parseResult = OAuthTokenRevocationRequestSchema.safeParse(req.body);
if (!parseResult.success) {
throw new InvalidRequestError(parseResult.error.message);
}
const client = req.client;
if (!client) {
// This should never happen
console.error("Missing client information after authentication");
throw new ServerError("Internal Server Error");
}
await provider.revokeToken!(client, parseResult.data);
res.status(200).json({});
} catch (error) {
if (error instanceof OAuthError) {
const status = error instanceof ServerError ? 500 : 400;
res.status(status).json(error.toResponseObject());
} else {
console.error("Unexpected error revoking token:", error);
const serverError = new ServerError("Internal Server Error");
res.status(500).json(serverError.toResponseObject());
}
}
});
return router;
}
================================================
File: src/server/auth/handlers/token.test.ts
================================================
import { tokenHandler, TokenHandlerOptions } from './token.js';
import { OAuthServerProvider, AuthorizationParams } from '../provider.js';
import { OAuthRegisteredClientsStore } from '../clients.js';
import { OAuthClientInformationFull, OAuthTokenRevocationRequest, OAuthTokens } from '../../../shared/auth.js';
import express, { Response } from 'express';
import supertest from 'supertest';
import * as pkceChallenge from 'pkce-challenge';
import { InvalidGrantError, InvalidTokenError } from '../errors.js';
import { AuthInfo } from '../types.js';
// Mock pkce-challenge
jest.mock('pkce-challenge', () => ({
verifyChallenge: jest.fn().mockImplementation(async (verifier, challenge) => {
return verifier === 'valid_verifier' && challenge === 'mock_challenge';
})
}));
describe('Token Handler', () => {
// Mock client data
const validClient: OAuthClientInformationFull = {
client_id: 'valid-client',
client_secret: 'valid-secret',
redirect_uris: ['https://example.com/callback']
};
// Mock client store
const mockClientStore: OAuthRegisteredClientsStore = {
async getClient(clientId: string): Promise<OAuthClientInformationFull | undefined> {
if (clientId === 'valid-client') {
return validClient;
}
return undefined;
}
};
// Mock provider
let mockProvider: OAuthServerProvider;
let app: express.Express;
beforeEach(() => {
// Create fresh mocks for each test
mockProvider = {
clientsStore: mockClientStore,
async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise<void> {
res.redirect('https://example.com/callback?code=mock_auth_code');
},
async challengeForAuthorizationCode(client: OAuthClientInformationFull, authorizationCode: string): Promise<string> {
if (authorizationCode === 'valid_code') {
return 'mock_challenge';
} else if (authorizationCode === 'expired_code') {
throw new InvalidGrantError('The authorization code has expired');
}
throw new InvalidGrantError('The authorization code is invalid');
},
async exchangeAuthorizationCode(client: OAuthClientInformationFull, authorizationCode: string): Promise<OAuthTokens> {
if (authorizationCode === 'valid_code') {
return {
access_token: 'mock_access_token',
token_type: 'bearer',
expires_in: 3600,
refresh_token: 'mock_refresh_token'
};
}
throw new InvalidGrantError('The authorization code is invalid or has expired');
},
async exchangeRefreshToken(client: OAuthClientInformationFull, refreshToken: string, scopes?: string[]): Promise<OAuthTokens> {
if (refreshToken === 'valid_refresh_token') {
const response: OAuthTokens = {
access_token: 'new_mock_access_token',
token_type: 'bearer',
expires_in: 3600,
refresh_token: 'new_mock_refresh_token'
};
if (scopes) {
response.scope = scopes.join(' ');
}
return response;
}
throw new InvalidGrantError('The refresh token is invalid or has expired');
},
async verifyAccessToken(token: string): Promise<AuthInfo> {
if (token === 'valid_token') {
return {
token,
clientId: 'valid-client',
scopes: ['read', 'write'],
expiresAt: Date.now() / 1000 + 3600
};
}
throw new InvalidTokenError('Token is invalid or expired');
},
async revokeToken(_client: OAuthClientInformationFull, _request: OAuthTokenRevocationRequest): Promise<void> {
// Do nothing in mock
}
};
// Mock PKCE verification
(pkceChallenge.verifyChallenge as jest.Mock).mockImplementation(
async (verifier: string, challenge: string) => {
return verifier === 'valid_verifier' && challenge === 'mock_challenge';
}
);
// Setup express app with token handler
app = express();
const options: TokenHandlerOptions = { provider: mockProvider };
app.use('/token', tokenHandler(options));
});
describe('Basic request validation', () => {
it('requires POST method', async () => {
const response = await supertest(app)
.get('/token')
.send({
client_id: 'valid-client',
client_secret: 'valid-secret',
grant_type: 'authorization_code'
});
expect(response.status).toBe(405);
expect(response.headers.allow).toBe('POST');
expect(response.body).toEqual({
error: "method_not_allowed",
error_description: "The method GET is not allowed for this endpoint"
});
});
it('requires grant_type parameter', async () => {
const response = await supertest(app)
.post('/token')
.type('form')
.send({
client_id: 'valid-client',
client_secret: 'valid-secret'
// Missing grant_type
});
expect(response.status).toBe(400);
expect(response.body.error).toBe('invalid_request');
});
it('rejects unsupported grant types', async () => {
const response = await supertest(app)
.post('/token')
.type('form')
.send({
client_id: 'valid-client',
client_secret: 'valid-secret',
grant_type: 'password' // Unsupported grant type
});
expect(response.status).toBe(400);
expect(response.body.error).toBe('unsupported_grant_type');
});
});
describe('Client authentication', () => {
it('requires valid client credentials', async () => {
const response = await supertest(app)
.post('/token')
.type('form')
.send({
client_id: 'invalid-client',
client_secret: 'wrong-secret',
grant_type: 'authorization_code'
});
expect(response.status).toBe(400);
expect(response.body.error).toBe('invalid_client');
});
it('accepts valid client credentials', async () => {
const response = await supertest(app)
.post('/token')
.type('form')
.send({
client_id: 'valid-client',
client_secret: 'valid-secret',
grant_type: 'authorization_code',
code: 'valid_code',
code_verifier: 'valid_verifier'
});
expect(response.status).toBe(200);
});
});
describe('Authorization code grant', () => {
it('requires code parameter', async () => {
const response = await supertest(app)
.post('/token')
.type('form')
.send({
client_id: 'valid-client',
client_secret: 'valid-secret',
grant_type: 'authorization_code',
// Missing code
code_verifier: 'valid_verifier'
});
expect(response.status).toBe(400);
expect(response.body.error).toBe('invalid_request');
});
it('requires code_verifier parameter', async () => {
const response = await supertest(app)
.post('/token')
.type('form')
.send({
client_id: 'valid-client',
client_secret: 'valid-secret',
grant_type: 'authorization_code',
code: 'valid_code'
// Missing code_verifier
});
expect(response.status).toBe(400);
expect(response.body.error).toBe('invalid_request');
});
it('verifies code_verifier against challenge', async () => {
// Setup invalid verifier
(pkceChallenge.verifyChallenge as jest.Mock).mockResolvedValueOnce(false);
const response = await supertest(app)
.post('/token')
.type('form')
.send({
client_id: 'valid-client',
client_secret: 'valid-secret',
grant_type: 'authorization_code',
code: 'valid_code',
code_verifier: 'invalid_verifier'
});
expect(response.status).toBe(400);
expect(response.body.error).toBe('invalid_grant');
expect(response.body.error_description).toContain('code_verifier');
});
it('rejects expired or invalid authorization codes', async () => {
const response = await supertest(app)
.post('/token')
.type('form')
.send({
client_id: 'valid-client',
client_secret: 'valid-secret',
grant_type: 'authorization_code',
code: 'expired_code',
code_verifier: 'valid_verifier'
});
expect(response.status).toBe(400);
expect(response.body.error).toBe('invalid_grant');
});
it('returns tokens for valid code exchange', async () => {
const response = await supertest(app)
.post('/token')
.type('form')
.send({
client_id: 'valid-client',
client_secret: 'valid-secret',
grant_type: 'authorization_code',
code: 'valid_code',
code_verifier: 'valid_verifier'
});
expect(response.status).toBe(200);
expect(response.body.access_token).toBe('mock_access_token');
expect(response.body.token_type).toBe('bearer');
expect(response.body.expires_in).toBe(3600);
expect(response.body.refresh_token).toBe('mock_refresh_token');
});
});
describe('Refresh token grant', () => {
it('requires refresh_token parameter', async () => {
const response = await supertest(app)
.post('/token')
.type('form')
.send({
client_id: 'valid-client',
client_secret: 'valid-secret',
grant_type: 'refresh_token'
// Missing refresh_token
});
expect(response.status).toBe(400);
expect(response.body.error).toBe('invalid_request');
});
it('rejects invalid refresh tokens', async () => {
const response = await supertest(app)
.post('/token')
.type('form')
.send({
client_id: 'valid-client',
client_secret: 'valid-secret',
grant_type: 'refresh_token',
refresh_token: 'invalid_refresh_token'
});
expect(response.status).toBe(400);
expect(response.body.error).toBe('invalid_grant');
});
it('returns new tokens for valid refresh token', async () => {
const response = await supertest(app)
.post('/token')
.type('form')
.send({
client_id: 'valid-client',
client_secret: 'valid-secret',
grant_type: 'refresh_token',
refresh_token: 'valid_refresh_token'
});
expect(response.status).toBe(200);
expect(response.body.access_token).toBe('new_mock_access_token');
expect(response.body.token_type).toBe('bearer');
expect(response.body.expires_in).toBe(3600);
expect(response.body.refresh_token).toBe('new_mock_refresh_token');
});
it('respects requested scopes on refresh', async () => {
const response = await supertest(app)
.post('/token')
.type('form')
.send({
client_id: 'valid-client',
client_secret: 'valid-secret',
grant_type: 'refresh_token',
refresh_token: 'valid_refresh_token',
scope: 'profile email'
});
expect(response.status).toBe(200);
expect(response.body.scope).toBe('profile email');
});
});
describe('CORS support', () => {
it('includes CORS headers in response', async () => {
const response = await supertest(app)
.post('/token')
.type('form')
.set('Origin', 'https://example.com')
.send({
client_id: 'valid-client',
client_secret: 'valid-secret',
grant_type: 'authorization_code',
code: 'valid_code',
code_verifier: 'valid_verifier'
});
expect(response.header['access-control-allow-origin']).toBe('*');
});
});
});
================================================
File: src/server/auth/handlers/token.ts
================================================
import { z } from "zod";
import express, { RequestHandler } from "express";
import { OAuthServerProvider } from "../provider.js";
import cors from "cors";
import { verifyChallenge } from "pkce-challenge";
import { authenticateClient } from "../middleware/clientAuth.js";
import { rateLimit, Options as RateLimitOptions } from "express-rate-limit";
import { allowedMethods } from "../middleware/allowedMethods.js";
import {
InvalidRequestError,
InvalidGrantError,
UnsupportedGrantTypeError,
ServerError,
TooManyRequestsError,
OAuthError
} from "../errors.js";
export type TokenHandlerOptions = {
provider: OAuthServerProvider;
/**
* Rate limiting configuration for the token endpoint.
* Set to false to disable rate limiting for this endpoint.
*/
rateLimit?: Partial<RateLimitOptions> | false;
};
const TokenRequestSchema = z.object({
grant_type: z.string(),
});
const AuthorizationCodeGrantSchema = z.object({
code: z.string(),
code_verifier: z.string(),
});
const RefreshTokenGrantSchema = z.object({
refresh_token: z.string(),
scope: z.string().optional(),
});
export function tokenHandler({ provider, rateLimit: rateLimitConfig }: TokenHandlerOptions): RequestHandler {
// Nested router so we can configure middleware and restrict HTTP method
const router = express.Router();
// Configure CORS to allow any origin, to make accessible to web-based MCP clients
router.use(cors());
router.use(allowedMethods(["POST"]));
router.use(express.urlencoded({ extended: false }));
// Apply rate limiting unless explicitly disabled
if (rateLimitConfig !== false) {
router.use(rateLimit({
windowMs: 15 * 60 * 1000, // 15 minutes
max: 50, // 50 requests per windowMs
standardHeaders: true,
legacyHeaders: false,
message: new TooManyRequestsError('You have exceeded the rate limit for token requests').toResponseObject(),
...rateLimitConfig
}));
}
// Authenticate and extract client details
router.use(authenticateClient({ clientsStore: provider.clientsStore }));
router.post("/", async (req, res) => {
res.setHeader('Cache-Control', 'no-store');
try {
const parseResult = TokenRequestSchema.safeParse(req.body);
if (!parseResult.success) {
throw new InvalidRequestError(parseResult.error.message);
}
const { grant_type } = parseResult.data;
const client = req.client;
if (!client) {
// This should never happen
console.error("Missing client information after authentication");
throw new ServerError("Internal Server Error");
}
switch (grant_type) {
case "authorization_code": {
const parseResult = AuthorizationCodeGrantSchema.safeParse(req.body);
if (!parseResult.success) {
throw new InvalidRequestError(parseResult.error.message);
}
const { code, code_verifier } = parseResult.data;
// Verify PKCE challenge
const codeChallenge = await provider.challengeForAuthorizationCode(client, code);
if (!(await verifyChallenge(code_verifier, codeChallenge))) {
throw new InvalidGrantError("code_verifier does not match the challenge");
}
const tokens = await provider.exchangeAuthorizationCode(client, code);
res.status(200).json(tokens);
break;
}
case "refresh_token": {
const parseResult = RefreshTokenGrantSchema.safeParse(req.body);
if (!parseResult.success) {
throw new InvalidRequestError(parseResult.error.message);
}
const { refresh_token, scope } = parseResult.data;
const scopes = scope?.split(" ");
const tokens = await provider.exchangeRefreshToken(client, refresh_token, scopes);
res.status(200).json(tokens);
break;
}
// Not supported right now
//case "client_credentials":
default:
throw new UnsupportedGrantTypeError(
"The grant type is not supported by this authorization server."
);
}
} catch (error) {
if (error instanceof OAuthError) {
const status = error instanceof ServerError ? 500 : 400;
res.status(status).json(error.toResponseObject());
} else {
console.error("Unexpected error exchanging token:", error);
const serverError = new ServerError("Internal Server Error");
res.status(500).json(serverError.toResponseObject());
}
}
});
return router;
}
================================================
File: src/server/auth/middleware/allowedMethods.test.ts
================================================
import { allowedMethods } from "./allowedMethods.js";
import express, { Request, Response } from "express";
import request from "supertest";
describe("allowedMethods", () => {
let app: express.Express;
beforeEach(() => {
app = express();
// Set up a test router with a GET handler and 405 middleware
const router = express.Router();
router.get("/test", (req, res) => {
res.status(200).send("GET success");
});
// Add method not allowed middleware for all other methods
router.all("/test", allowedMethods(["GET"]));
app.use(router);
});
test("allows specified HTTP method", async () => {
const response = await request(app).get("/test");
expect(response.status).toBe(200);
expect(response.text).toBe("GET success");
});
test("returns 405 for unspecified HTTP methods", async () => {
const methods = ["post", "put", "delete", "patch"];
for (const method of methods) {
// @ts-expect-error - dynamic method call
const response = await request(app)[method]("/test");
expect(response.status).toBe(405);
expect(response.body).toEqual({
error: "method_not_allowed",
error_description: `The method ${method.toUpperCase()} is not allowed for this endpoint`
});
}
});
test("includes Allow header with specified methods", async () => {
const response = await request(app).post("/test");
expect(response.headers.allow).toBe("GET");
});
test("works with multiple allowed methods", async () => {
const multiMethodApp = express();
const router = express.Router();
router.get("/multi", (req: Request, res: Response) => {
res.status(200).send("GET");
});
router.post("/multi", (req: Request, res: Response) => {
res.status(200).send("POST");
});
router.all("/multi", allowedMethods(["GET", "POST"]));
multiMethodApp.use(router);
// Allowed methods should work
const getResponse = await request(multiMethodApp).get("/multi");
expect(getResponse.status).toBe(200);
const postResponse = await request(multiMethodApp).post("/multi");
expect(postResponse.status).toBe(200);
// Unallowed methods should return 405
const putResponse = await request(multiMethodApp).put("/multi");
expect(putResponse.status).toBe(405);
expect(putResponse.headers.allow).toBe("GET, POST");
});
});
================================================
File: src/server/auth/middleware/allowedMethods.ts
================================================
import { RequestHandler } from "express";
import { MethodNotAllowedError } from "../errors.js";
/**
* Middleware to handle unsupported HTTP methods with a 405 Method Not Allowed response.
*
* @param allowedMethods Array of allowed HTTP methods for this endpoint (e.g., ['GET', 'POST'])
* @returns Express middleware that returns a 405 error if method not in allowed list
*/
export function allowedMethods(allowedMethods: string[]): RequestHandler {
return (req, res, next) => {
if (allowedMethods.includes(req.method)) {
next();
return;
}
const error = new MethodNotAllowedError(`The method ${req.method} is not allowed for this endpoint`);
res.status(405)
.set('Allow', allowedMethods.join(', '))
.json(error.toResponseObject());
};
}
================================================
File: src/server/auth/middleware/bearerAuth.test.ts
================================================
import { Request, Response } from "express";
import { requireBearerAuth } from "./bearerAuth.js";
import { AuthInfo } from "../types.js";
import { InsufficientScopeError, InvalidTokenError, OAuthError, ServerError } from "../errors.js";
import { OAuthServerProvider } from "../provider.js";
import { OAuthRegisteredClientsStore } from "../clients.js";
// Mock provider
const mockVerifyAccessToken = jest.fn();
const mockProvider: OAuthServerProvider = {
clientsStore: {} as OAuthRegisteredClientsStore,
authorize: jest.fn(),
challengeForAuthorizationCode: jest.fn(),
exchangeAuthorizationCode: jest.fn(),
exchangeRefreshToken: jest.fn(),
verifyAccessToken: mockVerifyAccessToken,
};
describe("requireBearerAuth middleware", () => {
let mockRequest: Partial<Request>;
let mockResponse: Partial<Response>;
let nextFunction: jest.Mock;
beforeEach(() => {
mockRequest = {
headers: {},
};
mockResponse = {
status: jest.fn().mockReturnThis(),
json: jest.fn(),
set: jest.fn().mockReturnThis(),
};
nextFunction = jest.fn();
jest.clearAllMocks();
});
it("should call next when token is valid", async () => {
const validAuthInfo: AuthInfo = {
token: "valid-token",
clientId: "client-123",
scopes: ["read", "write"],
};
mockVerifyAccessToken.mockResolvedValue(validAuthInfo);
mockRequest.headers = {
authorization: "Bearer valid-token",
};
const middleware = requireBearerAuth({ provider: mockProvider });
await middleware(mockRequest as Request, mockResponse as Response, nextFunction);
expect(mockVerifyAccessToken).toHaveBeenCalledWith("valid-token");
expect(mockRequest.auth).toEqual(validAuthInfo);
expect(nextFunction).toHaveBeenCalled();
expect(mockResponse.status).not.toHaveBeenCalled();
expect(mockResponse.json).not.toHaveBeenCalled();
});
it("should reject expired tokens", async () => {
const expiredAuthInfo: AuthInfo = {
token: "expired-token",
clientId: "client-123",
scopes: ["read", "write"],
expiresAt: Math.floor(Date.now() / 1000) - 100, // Token expired 100 seconds ago
};
mockVerifyAccessToken.mockResolvedValue(expiredAuthInfo);
mockRequest.headers = {
authorization: "Bearer expired-token",
};
const middleware = requireBearerAuth({ provider: mockProvider });
await middleware(mockRequest as Request, mockResponse as Response, nextFunction);
expect(mockVerifyAccessToken).toHaveBeenCalledWith("expired-token");
expect(mockResponse.status).toHaveBeenCalledWith(401);
expect(mockResponse.set).toHaveBeenCalledWith(
"WWW-Authenticate",
expect.stringContaining('Bearer error="invalid_token"')
);
expect(mockResponse.json).toHaveBeenCalledWith(
expect.objectContaining({ error: "invalid_token", error_description: "Token has expired" })
);
expect(nextFunction).not.toHaveBeenCalled();
});
it("should accept non-expired tokens", async () => {
const nonExpiredAuthInfo: AuthInfo = {
token: "valid-token",
clientId: "client-123",
scopes: ["read", "write"],
expiresAt: Math.floor(Date.now() / 1000) + 3600, // Token expires in an hour
};
mockVerifyAccessToken.mockResolvedValue(nonExpiredAuthInfo);
mockRequest.headers = {
authorization: "Bearer valid-token",
};
const middleware = requireBearerAuth({ provider: mockProvider });
await middleware(mockRequest as Request, mockResponse as Response, nextFunction);
expect(mockVerifyAccessToken).toHaveBeenCalledWith("valid-token");
expect(mockRequest.auth).toEqual(nonExpiredAuthInfo);
expect(nextFunction).toHaveBeenCalled();
expect(mockResponse.status).not.toHaveBeenCalled();
expect(mockResponse.json).not.toHaveBeenCalled();
});
it("should require specific scopes when configured", async () => {
const authInfo: AuthInfo = {
token: "valid-token",
clientId: "client-123",
scopes: ["read"],
};
mockVerifyAccessToken.mockResolvedValue(authInfo);
mockRequest.headers = {
authorization: "Bearer valid-token",
};
const middleware = requireBearerAuth({
provider: mockProvider,
requiredScopes: ["read", "write"]
});
await middleware(mockRequest as Request, mockResponse as Response, nextFunction);
expect(mockVerifyAccessToken).toHaveBeenCalledWith("valid-token");
expect(mockResponse.status).toHaveBeenCalledWith(403);
expect(mockResponse.set).toHaveBeenCalledWith(
"WWW-Authenticate",
expect.stringContaining('Bearer error="insufficient_scope"')
);
expect(mockResponse.json).toHaveBeenCalledWith(
expect.objectContaining({ error: "insufficient_scope", error_description: "Insufficient scope" })
);
expect(nextFunction).not.toHaveBeenCalled();
});
it("should accept token with all required scopes", async () => {
const authInfo: AuthInfo = {
token: "valid-token",
clientId: "client-123",
scopes: ["read", "write", "admin"],
};
mockVerifyAccessToken.mockResolvedValue(authInfo);
mockRequest.headers = {
authorization: "Bearer valid-token",
};
const middleware = requireBearerAuth({
provider: mockProvider,
requiredScopes: ["read", "write"]
});
await middleware(mockRequest as Request, mockResponse as Response, nextFunction);
expect(mockVerifyAccessToken).toHaveBeenCalledWith("valid-token");
expect(mockRequest.auth).toEqual(authInfo);
expect(nextFunction).toHaveBeenCalled();
expect(mockResponse.status).not.toHaveBeenCalled();
expect(mockResponse.json).not.toHaveBeenCalled();
});
it("should return 401 when no Authorization header is present", async () => {
const middleware = requireBearerAuth({ provider: mockProvider });
await middleware(mockRequest as Request, mockResponse as Response, nextFunction);
expect(mockVerifyAccessToken).not.toHaveBeenCalled();
expect(mockResponse.status).toHaveBeenCalledWith(401);
expect(mockResponse.set).toHaveBeenCalledWith(
"WWW-Authenticate",
expect.stringContaining('Bearer error="invalid_token"')
);
expect(mockResponse.json).toHaveBeenCalledWith(
expect.objectContaining({ error: "invalid_token", error_description: "Missing Authorization header" })
);
expect(nextFunction).not.toHaveBeenCalled();
});
it("should return 401 when Authorization header format is invalid", async () => {
mockRequest.headers = {
authorization: "InvalidFormat",
};
const middleware = requireBearerAuth({ provider: mockProvider });
await middleware(mockRequest as Request, mockResponse as Response, nextFunction);
expect(mockVerifyAccessToken).not.toHaveBeenCalled();
expect(mockResponse.status).toHaveBeenCalledWith(401);
expect(mockResponse.set).toHaveBeenCalledWith(
"WWW-Authenticate",
expect.stringContaining('Bearer error="invalid_token"')
);
expect(mockResponse.json).toHaveBeenCalledWith(
expect.objectContaining({
error: "invalid_token",
error_description: "Invalid Authorization header format, expected 'Bearer TOKEN'"
})
);
expect(nextFunction).not.toHaveBeenCalled();
});
it("should return 401 when token verification fails with InvalidTokenError", async () => {
mockRequest.headers = {
authorization: "Bearer invalid-token",
};
mockVerifyAccessToken.mockRejectedValue(new InvalidTokenError("Token expired"));
const middleware = requireBearerAuth({ provider: mockProvider });
await middleware(mockRequest as Request, mockResponse as Response, nextFunction);
expect(mockVerifyAccessToken).toHaveBeenCalledWith("invalid-token");
expect(mockResponse.status).toHaveBeenCalledWith(401);
expect(mockResponse.set).toHaveBeenCalledWith(
"WWW-Authenticate",
expect.stringContaining('Bearer error="invalid_token"')
);
expect(mockResponse.json).toHaveBeenCalledWith(
expect.objectContaining({ error: "invalid_token", error_description: "Token expired" })
);
expect(nextFunction).not.toHaveBeenCalled();
});
it("should return 403 when access token has insufficient scopes", async () => {
mockRequest.headers = {
authorization: "Bearer valid-token",
};
mockVerifyAccessToken.mockRejectedValue(new InsufficientScopeError("Required scopes: read, write"));
const middleware = requireBearerAuth({ provider: mockProvider });
await middleware(mockRequest as Request, mockResponse as Response, nextFunction);
expect(mockVerifyAccessToken).toHaveBeenCalledWith("valid-token");
expect(mockResponse.status).toHaveBeenCalledWith(403);
expect(mockResponse.set).toHaveBeenCalledWith(
"WWW-Authenticate",
expect.stringContaining('Bearer error="insufficient_scope"')
);
expect(mockResponse.json).toHaveBeenCalledWith(
expect.objectContaining({ error: "insufficient_scope", error_description: "Required scopes: read, write" })
);
expect(nextFunction).not.toHaveBeenCalled();
});
it("should return 500 when a ServerError occurs", async () => {
mockRequest.headers = {
authorization: "Bearer valid-token",
};
mockVerifyAccessToken.mockRejectedValue(new ServerError("Internal server issue"));
const middleware = requireBearerAuth({ provider: mockProvider });
await middleware(mockRequest as Request, mockResponse as Response, nextFunction);
expect(mockVerifyAccessToken).toHaveBeenCalledWith("valid-token");
expect(mockResponse.status).toHaveBeenCalledWith(500);
expect(mockResponse.json).toHaveBeenCalledWith(
expect.objectContaining({ error: "server_error", error_description: "Internal server issue" })
);
expect(nextFunction).not.toHaveBeenCalled();
});
it("should return 400 for generic OAuthError", async () => {
mockRequest.headers = {
authorization: "Bearer valid-token",
};
mockVerifyAccessToken.mockRejectedValue(new OAuthError("custom_error", "Some OAuth error"));
const middleware = requireBearerAuth({ provider: mockProvider });
await middleware(mockRequest as Request, mockResponse as Response, nextFunction);
expect(mockVerifyAccessToken).toHaveBeenCalledWith("valid-token");
expect(mockResponse.status).toHaveBeenCalledWith(400);
expect(mockResponse.json).toHaveBeenCalledWith(
expect.objectContaining({ error: "custom_error", error_description: "Some OAuth error" })
);
expect(nextFunction).not.toHaveBeenCalled();
});
it("should return 500 when unexpected error occurs", async () => {
mockRequest.headers = {
authorization: "Bearer valid-token",
};
mockVerifyAccessToken.mockRejectedValue(new Error("Unexpected error"));
const middleware = requireBearerAuth({ provider: mockProvider });
await middleware(mockRequest as Request, mockResponse as Response, nextFunction);
expect(mockVerifyAccessToken).toHaveBeenCalledWith("valid-token");
expect(mockResponse.status).toHaveBeenCalledWith(500);
expect(mockResponse.json).toHaveBeenCalledWith(
expect.objectContaining({ error: "server_error", error_description: "Internal Server Error" })
);
expect(nextFunction).not.toHaveBeenCalled();
});
});
================================================
File: src/server/auth/middleware/bearerAuth.ts
================================================
import { RequestHandler } from "express";
import { InsufficientScopeError, InvalidTokenError, OAuthError, ServerError } from "../errors.js";
import { OAuthServerProvider } from "../provider.js";
import { AuthInfo } from "../types.js";
export type BearerAuthMiddlewareOptions = {
/**
* A provider used to verify tokens.
*/
provider: OAuthServerProvider;
/**
* Optional scopes that the token must have.
*/
requiredScopes?: string[];
};
declare module "express-serve-static-core" {
interface Request {
/**
* Information about the validated access token, if the `requireBearerAuth` middleware was used.
*/
auth?: AuthInfo;
}
}
/**
* Middleware that requires a valid Bearer token in the Authorization header.
*
* This will validate the token with the auth provider and add the resulting auth info to the request object.
*/
export function requireBearerAuth({ provider, requiredScopes = [] }: BearerAuthMiddlewareOptions): RequestHandler {
return async (req, res, next) => {
try {
const authHeader = req.headers.authorization;
if (!authHeader) {
throw new InvalidTokenError("Missing Authorization header");
}
const [type, token] = authHeader.split(' ');
if (type.toLowerCase() !== 'bearer' || !token) {
throw new InvalidTokenError("Invalid Authorization header format, expected 'Bearer TOKEN'");
}
const authInfo = await provider.verifyAccessToken(token);
// Check if token has the required scopes (if any)
if (requiredScopes.length > 0) {
const hasAllScopes = requiredScopes.every(scope =>
authInfo.scopes.includes(scope)
);
if (!hasAllScopes) {
throw new InsufficientScopeError("Insufficient scope");
}
}
// Check if the token is expired
if (!!authInfo.expiresAt && authInfo.expiresAt < Date.now() / 1000) {
throw new InvalidTokenError("Token has expired");
}
req.auth = authInfo;
next();
} catch (error) {
if (error instanceof InvalidTokenError) {
res.set("WWW-Authenticate", `Bearer error="${error.errorCode}", error_description="${error.message}"`);
res.status(401).json(error.toResponseObject());
} else if (error instanceof InsufficientScopeError) {
res.set("WWW-Authenticate", `Bearer error="${error.errorCode}", error_description="${error.message}"`);
res.status(403).json(error.toResponseObject());
} else if (error instanceof ServerError) {
res.status(500).json(error.toResponseObject());
} else if (error instanceof OAuthError) {
res.status(400).json(error.toResponseObject());
} else {
console.error("Unexpected error authenticating bearer token:", error);
const serverError = new ServerError("Internal Server Error");
res.status(500).json(serverError.toResponseObject());
}
}
};
}
================================================
File: src/server/auth/middleware/clientAuth.test.ts
================================================
import { authenticateClient, ClientAuthenticationMiddlewareOptions } from './clientAuth.js';
import { OAuthRegisteredClientsStore } from '../clients.js';
import { OAuthClientInformationFull } from '../../../shared/auth.js';
import express from 'express';
import supertest from 'supertest';
describe('clientAuth middleware', () => {
// Mock client store
const mockClientStore: OAuthRegisteredClientsStore = {
async getClient(clientId: string): Promise<OAuthClientInformationFull | undefined> {
if (clientId === 'valid-client') {
return {
client_id: 'valid-client',
client_secret: 'valid-secret',
redirect_uris: ['https://example.com/callback']
};
} else if (clientId === 'expired-client') {
// Client with no secret
return {
client_id: 'expired-client',
redirect_uris: ['https://example.com/callback']
};
} else if (clientId === 'client-with-expired-secret') {
// Client with an expired secret
return {
client_id: 'client-with-expired-secret',
client_secret: 'expired-secret',
client_secret_expires_at: Math.floor(Date.now() / 1000) - 3600, // Expired 1 hour ago
redirect_uris: ['https://example.com/callback']
};
}
return undefined;
}
};
// Setup Express app with middleware
let app: express.Express;
let options: ClientAuthenticationMiddlewareOptions;
beforeEach(() => {
app = express();
app.use(express.json());
options = {
clientsStore: mockClientStore
};
// Setup route with client auth
app.post('/protected', authenticateClient(options), (req, res) => {
res.status(200).json({ success: true, client: req.client });
});
});
it('authenticates valid client credentials', async () => {
const response = await supertest(app)
.post('/protected')
.send({
client_id: 'valid-client',
client_secret: 'valid-secret'
});
expect(response.status).toBe(200);
expect(response.body.success).toBe(true);
expect(response.body.client.client_id).toBe('valid-client');
});
it('rejects invalid client_id', async () => {
const response = await supertest(app)
.post('/protected')
.send({
client_id: 'non-existent-client',
client_secret: 'some-secret'
});
expect(response.status).toBe(400);
expect(response.body.error).toBe('invalid_client');
expect(response.body.error_description).toBe('Invalid client_id');
});
it('rejects invalid client_secret', async () => {
const response = await supertest(app)
.post('/protected')
.send({
client_id: 'valid-client',
client_secret: 'wrong-secret'
});
expect(response.status).toBe(400);
expect(response.body.error).toBe('invalid_client');
expect(response.body.error_description).toBe('Invalid client_secret');
});
it('rejects missing client_id', async () => {
const response = await supertest(app)
.post('/protected')
.send({
client_secret: 'valid-secret'
});
expect(response.status).toBe(400);
expect(response.body.error).toBe('invalid_request');
});
it('allows missing client_secret if client has none', async () => {
const response = await supertest(app)
.post('/protected')
.send({
client_id: 'expired-client'
});
// Since the client has no secret, this should pass without providing one
expect(response.status).toBe(200);
});
it('rejects request when client secret has expired', async () => {
const response = await supertest(app)
.post('/protected')
.send({
client_id: 'client-with-expired-secret',
client_secret: 'expired-secret'
});
expect(response.status).toBe(400);
expect(response.body.error).toBe('invalid_client');
expect(response.body.error_description).toBe('Client secret has expired');
});
it('handles malformed request body', async () => {
const response = await supertest(app)
.post('/protected')
.send('not-json-format');
expect(response.status).toBe(400);
});
// Testing request with extra fields to ensure they're ignored
it('ignores extra fields in request', async () => {
const response = await supertest(app)
.post('/protected')
.send({
client_id: 'valid-client',
client_secret: 'valid-secret',
extra_field: 'should be ignored'
});
expect(response.status).toBe(200);
});
});
================================================
File: src/server/auth/middleware/clientAuth.ts
================================================
import { z } from "zod";
import { RequestHandler } from "express";
import { OAuthRegisteredClientsStore } from "../clients.js";
import { OAuthClientInformationFull } from "../../../shared/auth.js";
import { InvalidRequestError, InvalidClientError, ServerError, OAuthError } from "../errors.js";
export type ClientAuthenticationMiddlewareOptions = {
/**
* A store used to read information about registered OAuth clients.
*/
clientsStore: OAuthRegisteredClientsStore;
}
const ClientAuthenticatedRequestSchema = z.object({
client_id: z.string(),
client_secret: z.string().optional(),
});
declare module "express-serve-static-core" {
interface Request {
/**
* The authenticated client for this request, if the `authenticateClient` middleware was used.
*/
client?: OAuthClientInformationFull;
}
}
export function authenticateClient({ clientsStore }: ClientAuthenticationMiddlewareOptions): RequestHandler {
return async (req, res, next) => {
try {
const result = ClientAuthenticatedRequestSchema.safeParse(req.body);
if (!result.success) {
throw new InvalidRequestError(String(result.error));
}
const { client_id, client_secret } = result.data;
const client = await clientsStore.getClient(client_id);
if (!client) {
throw new InvalidClientError("Invalid client_id");
}
// If client has a secret, validate it
if (client.client_secret) {
// Check if client_secret is required but not provided
if (!client_secret) {
throw new InvalidClientError("Client secret is required");
}
// Check if client_secret matches
if (client.client_secret !== client_secret) {
throw new InvalidClientError("Invalid client_secret");
}
// Check if client_secret has expired
if (client.client_secret_expires_at && client.client_secret_expires_at < Math.floor(Date.now() / 1000)) {
throw new InvalidClientError("Client secret has expired");
}
}
req.client = client;
next();
} catch (error) {
if (error instanceof OAuthError) {
const status = error instanceof ServerError ? 500 : 400;
res.status(status).json(error.toResponseObject());
} else {
console.error("Unexpected error authenticating client:", error);
const serverError = new ServerError("Internal Server Error");
res.status(500).json(serverError.toResponseObject());
}
}
}
}
================================================
File: src/shared/auth.ts
================================================
import { z } from "zod";
/**
* RFC 8414 OAuth 2.0 Authorization Server Metadata
*/
export const OAuthMetadataSchema = z
.object({
issuer: z.string(),
authorization_endpoint: z.string(),
token_endpoint: z.string(),
registration_endpoint: z.string().optional(),
scopes_supported: z.array(z.string()).optional(),
response_types_supported: z.array(z.string()),
response_modes_supported: z.array(z.string()).optional(),
grant_types_supported: z.array(z.string()).optional(),
token_endpoint_auth_methods_supported: z.array(z.string()).optional(),
token_endpoint_auth_signing_alg_values_supported: z
.array(z.string())
.optional(),
service_documentation: z.string().optional(),
revocation_endpoint: z.string().optional(),
revocation_endpoint_auth_methods_supported: z.array(z.string()).optional(),
revocation_endpoint_auth_signing_alg_values_supported: z
.array(z.string())
.optional(),
introspection_endpoint: z.string().optional(),
introspection_endpoint_auth_methods_supported: z
.array(z.string())
.optional(),
introspection_endpoint_auth_signing_alg_values_supported: z
.array(z.string())
.optional(),
code_challenge_methods_supported: z.array(z.string()).optional(),
})
.passthrough();
/**
* OAuth 2.1 token response
*/
export const OAuthTokensSchema = z
.object({
access_token: z.string(),
token_type: z.string(),
expires_in: z.number().optional(),
scope: z.string().optional(),
refresh_token: z.string().optional(),
})
.strip();
/**
* OAuth 2.1 error response
*/
export const OAuthErrorResponseSchema = z
.object({
error: z.string(),
error_description: z.string().optional(),
error_uri: z.string().optional(),
});
/**
* RFC 7591 OAuth 2.0 Dynamic Client Registration metadata
*/
export const OAuthClientMetadataSchema = z.object({
redirect_uris: z.array(z.string()).refine((uris) => uris.every((uri) => URL.canParse(uri)), { message: "redirect_uris must contain valid URLs" }),
token_endpoint_auth_method: z.string().optional(),
grant_types: z.array(z.string()).optional(),
response_types: z.array(z.string()).optional(),
client_name: z.string().optional(),
client_uri: z.string().optional(),
logo_uri: z.string().optional(),
scope: z.string().optional(),
contacts: z.array(z.string()).optional(),
tos_uri: z.string().optional(),
policy_uri: z.string().optional(),
jwks_uri: z.string().optional(),
jwks: z.any().optional(),
software_id: z.string().optional(),
software_version: z.string().optional(),
}).strip();
/**
* RFC 7591 OAuth 2.0 Dynamic Client Registration client information
*/
export const OAuthClientInformationSchema = z.object({
client_id: z.string(),
client_secret: z.string().optional(),
client_id_issued_at: z.number().optional(),
client_secret_expires_at: z.number().optional(),
}).strip();
/**
* RFC 7591 OAuth 2.0 Dynamic Client Registration full response (client information plus metadata)
*/
export const OAuthClientInformationFullSchema = OAuthClientMetadataSchema.merge(OAuthClientInformationSchema);
/**
* RFC 7591 OAuth 2.0 Dynamic Client Registration error response
*/
export const OAuthClientRegistrationErrorSchema = z.object({
error: z.string(),
error_description: z.string().optional(),
}).strip();
/**
* RFC 7009 OAuth 2.0 Token Revocation request
*/
export const OAuthTokenRevocationRequestSchema = z.object({
token: z.string(),
token_type_hint: z.string().optional(),
}).strip();
export type OAuthMetadata = z.infer<typeof OAuthMetadataSchema>;
export type OAuthTokens = z.infer<typeof OAuthTokensSchema>;
export type OAuthErrorResponse = z.infer<typeof OAuthErrorResponseSchema>;
export type OAuthClientMetadata = z.infer<typeof OAuthClientMetadataSchema>;
export type OAuthClientInformation = z.infer<typeof OAuthClientInformationSchema>;
export type OAuthClientInformationFull = z.infer<typeof OAuthClientInformationFullSchema>;
export type OAuthClientRegistrationError = z.infer<typeof OAuthClientRegistrationErrorSchema>;
export type OAuthTokenRevocationRequest = z.infer<typeof OAuthTokenRevocationRequestSchema>;
================================================
File: src/shared/protocol.test.ts
================================================
import { ZodType, z } from "zod";
import {
ClientCapabilities,
ErrorCode,
McpError,
Notification,
Request,
Result,
ServerCapabilities,
} from "../types.js";
import { Protocol, mergeCapabilities } from "./protocol.js";
import { Transport } from "./transport.js";
// Mock Transport class
class MockTransport implements Transport {
onclose?: () => void;
onerror?: (error: Error) => void;
onmessage?: (message: unknown) => void;
async start(): Promise<void> {}
async close(): Promise<void> {
this.onclose?.();
}
async send(_message: unknown): Promise<void> {}
}
describe("protocol tests", () => {
let protocol: Protocol<Request, Notification, Result>;
let transport: MockTransport;
beforeEach(() => {
transport = new MockTransport();
protocol = new (class extends Protocol<Request, Notification, Result> {
protected assertCapabilityForMethod(): void {}
protected assertNotificationCapability(): void {}
protected assertRequestHandlerCapability(): void {}
})();
});
test("should throw a timeout error if the request exceeds the timeout", async () => {
await protocol.connect(transport);
const request = { method: "example", params: {} };
try {
const mockSchema: ZodType<{ result: string }> = z.object({
result: z.string(),
});
await protocol.request(request, mockSchema, {
timeout: 0,
});
} catch (error) {
expect(error).toBeInstanceOf(McpError);
if (error instanceof McpError) {
expect(error.code).toBe(ErrorCode.RequestTimeout);
}
}
});
test("should invoke onclose when the connection is closed", async () => {
const oncloseMock = jest.fn();
protocol.onclose = oncloseMock;
await protocol.connect(transport);
await transport.close();
expect(oncloseMock).toHaveBeenCalled();
});
describe("progress notification timeout behavior", () => {
beforeEach(() => {
jest.useFakeTimers();
});
afterEach(() => {
jest.useRealTimers();
});
test("should reset timeout when progress notification is received", async () => {
await protocol.connect(transport);
const request = { method: "example", params: {} };
const mockSchema: ZodType<{ result: string }> = z.object({
result: z.string(),
});
const onProgressMock = jest.fn();
const requestPromise = protocol.request(request, mockSchema, {
timeout: 1000,
resetTimeoutOnProgress: true,
onprogress: onProgressMock,
});
jest.advanceTimersByTime(800);
if (transport.onmessage) {
transport.onmessage({
jsonrpc: "2.0",
method: "notifications/progress",
params: {
progressToken: 0,
progress: 50,
total: 100,
},
});
}
await Promise.resolve();
expect(onProgressMock).toHaveBeenCalledWith({
progress: 50,
total: 100,
});
jest.advanceTimersByTime(800);
if (transport.onmessage) {
transport.onmessage({
jsonrpc: "2.0",
id: 0,
result: { result: "success" },
});
}
await Promise.resolve();
await expect(requestPromise).resolves.toEqual({ result: "success" });
});
test("should respect maxTotalTimeout", async () => {
await protocol.connect(transport);
const request = { method: "example", params: {} };
const mockSchema: ZodType<{ result: string }> = z.object({
result: z.string(),
});
const onProgressMock = jest.fn();
const requestPromise = protocol.request(request, mockSchema, {
timeout: 1000,
maxTotalTimeout: 150,
resetTimeoutOnProgress: true,
onprogress: onProgressMock,
});
// First progress notification should work
jest.advanceTimersByTime(80);
if (transport.onmessage) {
transport.onmessage({
jsonrpc: "2.0",
method: "notifications/progress",
params: {
progressToken: 0,
progress: 50,
total: 100,
},
});
}
await Promise.resolve();
expect(onProgressMock).toHaveBeenCalledWith({
progress: 50,
total: 100,
});
jest.advanceTimersByTime(80);
if (transport.onmessage) {
transport.onmessage({
jsonrpc: "2.0",
method: "notifications/progress",
params: {
progressToken: 0,
progress: 75,
total: 100,
},
});
}
await expect(requestPromise).rejects.toThrow("Maximum total timeout exceeded");
expect(onProgressMock).toHaveBeenCalledTimes(1);
});
test("should timeout if no progress received within timeout period", async () => {
await protocol.connect(transport);
const request = { method: "example", params: {} };
const mockSchema: ZodType<{ result: string }> = z.object({
result: z.string(),
});
const requestPromise = protocol.request(request, mockSchema, {
timeout: 100,
resetTimeoutOnProgress: true,
});
jest.advanceTimersByTime(101);
await expect(requestPromise).rejects.toThrow("Request timed out");
});
test("should handle multiple progress notifications correctly", async () => {
await protocol.connect(transport);
const request = { method: "example", params: {} };
const mockSchema: ZodType<{ result: string }> = z.object({
result: z.string(),
});
const onProgressMock = jest.fn();
const requestPromise = protocol.request(request, mockSchema, {
timeout: 1000,
resetTimeoutOnProgress: true,
onprogress: onProgressMock,
});
// Simulate multiple progress updates
for (let i = 1; i <= 3; i++) {
jest.advanceTimersByTime(800);
if (transport.onmessage) {
transport.onmessage({
jsonrpc: "2.0",
method: "notifications/progress",
params: {
progressToken: 0,
progress: i * 25,
total: 100,
},
});
}
await Promise.resolve();
expect(onProgressMock).toHaveBeenNthCalledWith(i, {
progress: i * 25,
total: 100,
});
}
if (transport.onmessage) {
transport.onmessage({
jsonrpc: "2.0",
id: 0,
result: { result: "success" },
});
}
await Promise.resolve();
await expect(requestPromise).resolves.toEqual({ result: "success" });
});
});
});
describe("mergeCapabilities", () => {
it("should merge client capabilities", () => {
const base: ClientCapabilities = {
sampling: {},
roots: {
listChanged: true,
},
};
const additional: ClientCapabilities = {
experimental: {
feature: true,
},
roots: {
newProp: true,
},
};
const merged = mergeCapabilities(base, additional);
expect(merged).toEqual({
sampling: {},
roots: {
listChanged: true,
newProp: true,
},
experimental: {
feature: true,
},
});
});
it("should merge server capabilities", () => {
const base: ServerCapabilities = {
logging: {},
prompts: {
listChanged: true,
},
};
const additional: ServerCapabilities = {
resources: {
subscribe: true,
},
prompts: {
newProp: true,
},
};
const merged = mergeCapabilities(base, additional);
expect(merged).toEqual({
logging: {},
prompts: {
listChanged: true,
newProp: true,
},
resources: {
subscribe: true,
},
});
});
it("should override existing values with additional values", () => {
const base: ServerCapabilities = {
prompts: {
listChanged: false,
},
};
const additional: ServerCapabilities = {
prompts: {
listChanged: true,
},
};
const merged = mergeCapabilities(base, additional);
expect(merged.prompts!.listChanged).toBe(true);
});
it("should handle empty objects", () => {
const base = {};
const additional = {};
const merged = mergeCapabilities(base, additional);
expect(merged).toEqual({});
});
});
================================================
File: src/shared/protocol.ts
================================================
import { ZodLiteral, ZodObject, ZodType, z } from "zod";
import {
CancelledNotificationSchema,
ClientCapabilities,
ErrorCode,
JSONRPCError,
JSONRPCNotification,
JSONRPCRequest,
JSONRPCResponse,
McpError,
Notification,
PingRequestSchema,
Progress,
ProgressNotification,
ProgressNotificationSchema,
Request,
RequestId,
Result,
ServerCapabilities,
} from "../types.js";
import { Transport } from "./transport.js";
/**
* Callback for progress notifications.
*/
export type ProgressCallback = (progress: Progress) => void;
/**
* Additional initialization options.
*/
export type ProtocolOptions = {
/**
* Whether to restrict emitted requests to only those that the remote side has indicated that they can handle, through their advertised capabilities.
*
* Note that this DOES NOT affect checking of _local_ side capabilities, as it is considered a logic error to mis-specify those.
*
* Currently this defaults to false, for backwards compatibility with SDK versions that did not advertise capabilities correctly. In future, this will default to true.
*/
enforceStrictCapabilities?: boolean;
};
/**
* The default request timeout, in miliseconds.
*/
export const DEFAULT_REQUEST_TIMEOUT_MSEC = 60000;
/**
* Options that can be given per request.
*/
export type RequestOptions = {
/**
* If set, requests progress notifications from the remote end (if supported). When progress notifications are received, this callback will be invoked.
*/
onprogress?: ProgressCallback;
/**
* Can be used to cancel an in-flight request. This will cause an AbortError to be raised from request().
*/
signal?: AbortSignal;
/**
* A timeout (in milliseconds) for this request. If exceeded, an McpError with code `RequestTimeout` will be raised from request().
*
* If not specified, `DEFAULT_REQUEST_TIMEOUT_MSEC` will be used as the timeout.
*/
timeout?: number;
/**
* If true, receiving a progress notification will reset the request timeout.
* This is useful for long-running operations that send periodic progress updates.
* Default: false
*/
resetTimeoutOnProgress?: boolean;
/**
* Maximum total time (in milliseconds) to wait for a response.
* If exceeded, an McpError with code `RequestTimeout` will be raised, regardless of progress notifications.
* If not specified, there is no maximum total timeout.
*/
maxTotalTimeout?: number;
};
/**
* Extra data given to request handlers.
*/
export type RequestHandlerExtra = {
/**
* An abort signal used to communicate if the request was cancelled from the sender's side.
*/
signal: AbortSignal;
};
/**
* Information about a request's timeout state
*/
type TimeoutInfo = {
timeoutId: ReturnType<typeof setTimeout>;
startTime: number;
timeout: number;
maxTotalTimeout?: number;
onTimeout: () => void;
};
/**
* Implements MCP protocol framing on top of a pluggable transport, including
* features like request/response linking, notifications, and progress.
*/
export abstract class Protocol<
SendRequestT extends Request,
SendNotificationT extends Notification,
SendResultT extends Result,
> {
private _transport?: Transport;
private _requestMessageId = 0;
private _requestHandlers: Map<
string,
(
request: JSONRPCRequest,
extra: RequestHandlerExtra,
) => Promise<SendResultT>
> = new Map();
private _requestHandlerAbortControllers: Map<RequestId, AbortController> =
new Map();
private _notificationHandlers: Map<
string,
(notification: JSONRPCNotification) => Promise<void>
> = new Map();
private _responseHandlers: Map<
number,
(response: JSONRPCResponse | Error) => void
> = new Map();
private _progressHandlers: Map<number, ProgressCallback> = new Map();
private _timeoutInfo: Map<number, TimeoutInfo> = new Map();
/**
* Callback for when the connection is closed for any reason.
*
* This is invoked when close() is called as well.
*/
onclose?: () => void;
/**
* Callback for when an error occurs.
*
* Note that errors are not necessarily fatal; they are used for reporting any kind of exceptional condition out of band.
*/
onerror?: (error: Error) => void;
/**
* A handler to invoke for any request types that do not have their own handler installed.
*/
fallbackRequestHandler?: (request: Request) => Promise<SendResultT>;
/**
* A handler to invoke for any notification types that do not have their own handler installed.
*/
fallbackNotificationHandler?: (notification: Notification) => Promise<void>;
constructor(private _options?: ProtocolOptions) {
this.setNotificationHandler(CancelledNotificationSchema, (notification) => {
const controller = this._requestHandlerAbortControllers.get(
notification.params.requestId,
);
controller?.abort(notification.params.reason);
});
this.setNotificationHandler(ProgressNotificationSchema, (notification) => {
this._onprogress(notification as unknown as ProgressNotification);
});
this.setRequestHandler(
PingRequestSchema,
// Automatic pong by default.
(_request) => ({}) as SendResultT,
);
}
private _setupTimeout(
messageId: number,
timeout: number,
maxTotalTimeout: number | undefined,
onTimeout: () => void
) {
this._timeoutInfo.set(messageId, {
timeoutId: setTimeout(onTimeout, timeout),
startTime: Date.now(),
timeout,
maxTotalTimeout,
onTimeout
});
}
private _resetTimeout(messageId: number): boolean {
const info = this._timeoutInfo.get(messageId);
if (!info) return false;
const totalElapsed = Date.now() - info.startTime;
if (info.maxTotalTimeout && totalElapsed >= info.maxTotalTimeout) {
this._timeoutInfo.delete(messageId);
throw new McpError(
ErrorCode.RequestTimeout,
"Maximum total timeout exceeded",
{ maxTotalTimeout: info.maxTotalTimeout, totalElapsed }
);
}
clearTimeout(info.timeoutId);
info.timeoutId = setTimeout(info.onTimeout, info.timeout);
return true;
}
private _cleanupTimeout(messageId: number) {
const info = this._timeoutInfo.get(messageId);
if (info) {
clearTimeout(info.timeoutId);
this._timeoutInfo.delete(messageId);
}
}
/**
* Attaches to the given transport, starts it, and starts listening for messages.
*
* The Protocol object assumes ownership of the Transport, replacing any callbacks that have already been set, and expects that it is the only user of the Transport instance going forward.
*/
async connect(transport: Transport): Promise<void> {
this._transport = transport;
this._transport.onclose = () => {
this._onclose();
};
this._transport.onerror = (error: Error) => {
this._onerror(error);
};
this._transport.onmessage = (message) => {
if (!("method" in message)) {
this._onresponse(message);
} else if ("id" in message) {
this._onrequest(message);
} else {
this._onnotification(message);
}
};
await this._transport.start();
}
private _onclose(): void {
const responseHandlers = this._responseHandlers;
this._responseHandlers = new Map();
this._progressHandlers.clear();
this._transport = undefined;
this.onclose?.();
const error = new McpError(ErrorCode.ConnectionClosed, "Connection closed");
for (const handler of responseHandlers.values()) {
handler(error);
}
}
private _onerror(error: Error): void {
this.onerror?.(error);
}
private _onnotification(notification: JSONRPCNotification): void {
const handler =
this._notificationHandlers.get(notification.method) ??
this.fallbackNotificationHandler;
// Ignore notifications not being subscribed to.
if (handler === undefined) {
return;
}
// Starting with Promise.resolve() puts any synchronous errors into the monad as well.
Promise.resolve()
.then(() => handler(notification))
.catch((error) =>
this._onerror(
new Error(`Uncaught error in notification handler: ${error}`),
),
);
}
private _onrequest(request: JSONRPCRequest): void {
const handler =
this._requestHandlers.get(request.method) ?? this.fallbackRequestHandler;
if (handler === undefined) {
this._transport
?.send({
jsonrpc: "2.0",
id: request.id,
error: {
code: ErrorCode.MethodNotFound,
message: "Method not found",
},
})
.catch((error) =>
this._onerror(
new Error(`Failed to send an error response: ${error}`),
),
);
return;
}
const abortController = new AbortController();
this._requestHandlerAbortControllers.set(request.id, abortController);
// Starting with Promise.resolve() puts any synchronous errors into the monad as well.
Promise.resolve()
.then(() => handler(request, { signal: abortController.signal }))
.then(
(result) => {
if (abortController.signal.aborted) {
return;
}
return this._transport?.send({
result,
jsonrpc: "2.0",
id: request.id,
});
},
(error) => {
if (abortController.signal.aborted) {
return;
}
return this._transport?.send({
jsonrpc: "2.0",
id: request.id,
error: {
code: Number.isSafeInteger(error["code"])
? error["code"]
: ErrorCode.InternalError,
message: error.message ?? "Internal error",
},
});
},
)
.catch((error) =>
this._onerror(new Error(`Failed to send response: ${error}`)),
)
.finally(() => {
this._requestHandlerAbortControllers.delete(request.id);
});
}
private _onprogress(notification: ProgressNotification): void {
const { progressToken, ...params } = notification.params;
const messageId = Number(progressToken);
const handler = this._progressHandlers.get(messageId);
if (!handler) {
this._onerror(new Error(`Received a progress notification for an unknown token: ${JSON.stringify(notification)}`));
return;
}
const responseHandler = this._responseHandlers.get(messageId);
if (this._timeoutInfo.has(messageId) && responseHandler) {
try {
this._resetTimeout(messageId);
} catch (error) {
responseHandler(error as Error);
return;
}
}
handler(params);
}
private _onresponse(response: JSONRPCResponse | JSONRPCError): void {
const messageId = Number(response.id);
const handler = this._responseHandlers.get(messageId);
if (handler === undefined) {
this._onerror(
new Error(
`Received a response for an unknown message ID: ${JSON.stringify(response)}`,
),
);
return;
}
this._responseHandlers.delete(messageId);
this._progressHandlers.delete(messageId);
this._cleanupTimeout(messageId);
if ("result" in response) {
handler(response);
} else {
const error = new McpError(
response.error.code,
response.error.message,
response.error.data,
);
handler(error);
}
}
get transport(): Transport | undefined {
return this._transport;
}
/**
* Closes the connection.
*/
async close(): Promise<void> {
await this._transport?.close();
}
/**
* A method to check if a capability is supported by the remote side, for the given method to be called.
*
* This should be implemented by subclasses.
*/
protected abstract assertCapabilityForMethod(
method: SendRequestT["method"],
): void;
/**
* A method to check if a notification is supported by the local side, for the given method to be sent.
*
* This should be implemented by subclasses.
*/
protected abstract assertNotificationCapability(
method: SendNotificationT["method"],
): void;
/**
* A method to check if a request handler is supported by the local side, for the given method to be handled.
*
* This should be implemented by subclasses.
*/
protected abstract assertRequestHandlerCapability(method: string): void;
/**
* Sends a request and wait for a response.
*
* Do not use this method to emit notifications! Use notification() instead.
*/
request<T extends ZodType<object>>(
request: SendRequestT,
resultSchema: T,
options?: RequestOptions,
): Promise<z.infer<T>> {
return new Promise((resolve, reject) => {
if (!this._transport) {
reject(new Error("Not connected"));
return;
}
if (this._options?.enforceStrictCapabilities === true) {
this.assertCapabilityForMethod(request.method);
}
options?.signal?.throwIfAborted();
const messageId = this._requestMessageId++;
const jsonrpcRequest: JSONRPCRequest = {
...request,
jsonrpc: "2.0",
id: messageId,
};
if (options?.onprogress) {
this._progressHandlers.set(messageId, options.onprogress);
jsonrpcRequest.params = {
...request.params,
_meta: { progressToken: messageId },
};
}
const cancel = (reason: unknown) => {
this._responseHandlers.delete(messageId);
this._progressHandlers.delete(messageId);
this._cleanupTimeout(messageId);
this._transport
?.send({
jsonrpc: "2.0",
method: "notifications/cancelled",
params: {
requestId: messageId,
reason: String(reason),
},
})
.catch((error) =>
this._onerror(new Error(`Failed to send cancellation: ${error}`)),
);
reject(reason);
};
this._responseHandlers.set(messageId, (response) => {
if (options?.signal?.aborted) {
return;
}
if (response instanceof Error) {
return reject(response);
}
try {
const result = resultSchema.parse(response.result);
resolve(result);
} catch (error) {
reject(error);
}
});
options?.signal?.addEventListener("abort", () => {
cancel(options?.signal?.reason);
});
const timeout = options?.timeout ?? DEFAULT_REQUEST_TIMEOUT_MSEC;
const timeoutHandler = () => cancel(new McpError(
ErrorCode.RequestTimeout,
"Request timed out",
{ timeout }
));
this._setupTimeout(messageId, timeout, options?.maxTotalTimeout, timeoutHandler);
this._transport.send(jsonrpcRequest).catch((error) => {
this._cleanupTimeout(messageId);
reject(error);
});
});
}
/**
* Emits a notification, which is a one-way message that does not expect a response.
*/
async notification(notification: SendNotificationT): Promise<void> {
if (!this._transport) {
throw new Error("Not connected");
}
this.assertNotificationCapability(notification.method);
const jsonrpcNotification: JSONRPCNotification = {
...notification,
jsonrpc: "2.0",
};
await this._transport.send(jsonrpcNotification);
}
/**
* Registers a handler to invoke when this protocol object receives a request with the given method.
*
* Note that this will replace any previous request handler for the same method.
*/
setRequestHandler<
T extends ZodObject<{
method: ZodLiteral<string>;
}>,
>(
requestSchema: T,
handler: (
request: z.infer<T>,
extra: RequestHandlerExtra,
) => SendResultT | Promise<SendResultT>,
): void {
const method = requestSchema.shape.method.value;
this.assertRequestHandlerCapability(method);
this._requestHandlers.set(method, (request, extra) =>
Promise.resolve(handler(requestSchema.parse(request), extra)),
);
}
/**
* Removes the request handler for the given method.
*/
removeRequestHandler(method: string): void {
this._requestHandlers.delete(method);
}
/**
* Asserts that a request handler has not already been set for the given method, in preparation for a new one being automatically installed.
*/
assertCanSetRequestHandler(method: string): void {
if (this._requestHandlers.has(method)) {
throw new Error(
`A request handler for ${method} already exists, which would be overridden`,
);
}
}
/**
* Registers a handler to invoke when this protocol object receives a notification with the given method.
*
* Note that this will replace any previous notification handler for the same method.
*/
setNotificationHandler<
T extends ZodObject<{
method: ZodLiteral<string>;
}>,
>(
notificationSchema: T,
handler: (notification: z.infer<T>) => void | Promise<void>,
): void {
this._notificationHandlers.set(
notificationSchema.shape.method.value,
(notification) =>
Promise.resolve(handler(notificationSchema.parse(notification))),
);
}
/**
* Removes the notification handler for the given method.
*/
removeNotificationHandler(method: string): void {
this._notificationHandlers.delete(method);
}
}
export function mergeCapabilities<
T extends ServerCapabilities | ClientCapabilities,
>(base: T, additional: T): T {
return Object.entries(additional).reduce(
(acc, [key, value]) => {
if (value && typeof value === "object") {
acc[key] = acc[key] ? { ...acc[key], ...value } : value;
} else {
acc[key] = value;
}
return acc;
},
{ ...base },
);
}
================================================
File: src/shared/stdio.test.ts
================================================
import { JSONRPCMessage } from "../types.js";
import { ReadBuffer } from "./stdio.js";
const testMessage: JSONRPCMessage = {
jsonrpc: "2.0",
method: "foobar",
};
test("should have no messages after initialization", () => {
const readBuffer = new ReadBuffer();
expect(readBuffer.readMessage()).toBeNull();
});
test("should only yield a message after a newline", () => {
const readBuffer = new ReadBuffer();
readBuffer.append(Buffer.from(JSON.stringify(testMessage)));
expect(readBuffer.readMessage()).toBeNull();
readBuffer.append(Buffer.from("\n"));
expect(readBuffer.readMessage()).toEqual(testMessage);
expect(readBuffer.readMessage()).toBeNull();
});
test("should be reusable after clearing", () => {
const readBuffer = new ReadBuffer();
readBuffer.append(Buffer.from("foobar"));
readBuffer.clear();
expect(readBuffer.readMessage()).toBeNull();
readBuffer.append(Buffer.from(JSON.stringify(testMessage)));
readBuffer.append(Buffer.from("\n"));
expect(readBuffer.readMessage()).toEqual(testMessage);
});
================================================
File: src/shared/stdio.ts
================================================
import { JSONRPCMessage, JSONRPCMessageSchema } from "../types.js";
/**
* Buffers a continuous stdio stream into discrete JSON-RPC messages.
*/
export class ReadBuffer {
private _buffer?: Buffer;
append(chunk: Buffer): void {
this._buffer = this._buffer ? Buffer.concat([this._buffer, chunk]) : chunk;
}
readMessage(): JSONRPCMessage | null {
if (!this._buffer) {
return null;
}
const index = this._buffer.indexOf("\n");
if (index === -1) {
return null;
}
const line = this._buffer.toString("utf8", 0, index);
this._buffer = this._buffer.subarray(index + 1);
return deserializeMessage(line);
}
clear(): void {
this._buffer = undefined;
}
}
export function deserializeMessage(line: string): JSONRPCMessage {
return JSONRPCMessageSchema.parse(JSON.parse(line));
}
export function serializeMessage(message: JSONRPCMessage): string {
return JSON.stringify(message) + "\n";
}
================================================
File: src/shared/transport.ts
================================================
import { JSONRPCMessage } from "../types.js";
/**
* Describes the minimal contract for a MCP transport that a client or server can communicate over.
*/
export interface Transport {
/**
* Starts processing messages on the transport, including any connection steps that might need to be taken.
*
* This method should only be called after callbacks are installed, or else messages may be lost.
*
* NOTE: This method should not be called explicitly when using Client, Server, or Protocol classes, as they will implicitly call start().
*/
start(): Promise<void>;
/**
* Sends a JSON-RPC message (request or response).
*/
send(message: JSONRPCMessage): Promise<void>;
/**
* Closes the connection.
*/
close(): Promise<void>;
/**
* Callback for when the connection is closed for any reason.
*
* This should be invoked when close() is called as well.
*/
onclose?: () => void;
/**
* Callback for when an error occurs.
*
* Note that errors are not necessarily fatal; they are used for reporting any kind of exceptional condition out of band.
*/
onerror?: (error: Error) => void;
/**
* Callback for when a message (request or response) is received over the connection.
*/
onmessage?: (message: JSONRPCMessage) => void;
}
================================================
File: src/shared/uriTemplate.test.ts
================================================
import { UriTemplate } from "./uriTemplate.js";
describe("UriTemplate", () => {
describe("isTemplate", () => {
it("should return true for strings containing template expressions", () => {
expect(UriTemplate.isTemplate("{foo}")).toBe(true);
expect(UriTemplate.isTemplate("/users/{id}")).toBe(true);
expect(UriTemplate.isTemplate("http://example.com/{path}/{file}")).toBe(true);
expect(UriTemplate.isTemplate("/search{?q,limit}")).toBe(true);
});
it("should return false for strings without template expressions", () => {
expect(UriTemplate.isTemplate("")).toBe(false);
expect(UriTemplate.isTemplate("plain string")).toBe(false);
expect(UriTemplate.isTemplate("http://example.com/foo/bar")).toBe(false);
expect(UriTemplate.isTemplate("{}")).toBe(false); // Empty braces don't count
expect(UriTemplate.isTemplate("{ }")).toBe(false); // Just whitespace doesn't count
});
});
describe("simple string expansion", () => {
it("should expand simple string variables", () => {
const template = new UriTemplate("http://example.com/users/{username}");
expect(template.expand({ username: "fred" })).toBe(
"http://example.com/users/fred",
);
});
it("should handle multiple variables", () => {
const template = new UriTemplate("{x,y}");
expect(template.expand({ x: "1024", y: "768" })).toBe("1024,768");
});
it("should encode reserved characters", () => {
const template = new UriTemplate("{var}");
expect(template.expand({ var: "value with spaces" })).toBe(
"value%20with%20spaces",
);
});
});
describe("reserved expansion", () => {
it("should not encode reserved characters with + operator", () => {
const template = new UriTemplate("{+path}/here");
expect(template.expand({ path: "/foo/bar" })).toBe("/foo/bar/here");
});
});
describe("fragment expansion", () => {
it("should add # prefix and not encode reserved chars", () => {
const template = new UriTemplate("X{#var}");
expect(template.expand({ var: "/test" })).toBe("X#/test");
});
});
describe("label expansion", () => {
it("should add . prefix", () => {
const template = new UriTemplate("X{.var}");
expect(template.expand({ var: "test" })).toBe("X.test");
});
});
describe("path expansion", () => {
it("should add / prefix", () => {
const template = new UriTemplate("X{/var}");
expect(template.expand({ var: "test" })).toBe("X/test");
});
});
describe("query expansion", () => {
it("should add ? prefix and name=value format", () => {
const template = new UriTemplate("X{?var}");
expect(template.expand({ var: "test" })).toBe("X?var=test");
});
});
describe("form continuation expansion", () => {
it("should add & prefix and name=value format", () => {
const template = new UriTemplate("X{&var}");
expect(template.expand({ var: "test" })).toBe("X&var=test");
});
});
describe("matching", () => {
it("should match simple strings and extract variables", () => {
const template = new UriTemplate("http://example.com/users/{username}");
const match = template.match("http://example.com/users/fred");
expect(match).toEqual({ username: "fred" });
});
it("should match multiple variables", () => {
const template = new UriTemplate("/users/{username}/posts/{postId}");
const match = template.match("/users/fred/posts/123");
expect(match).toEqual({ username: "fred", postId: "123" });
});
it("should return null for non-matching URIs", () => {
const template = new UriTemplate("/users/{username}");
const match = template.match("/posts/123");
expect(match).toBeNull();
});
it("should handle exploded arrays", () => {
const template = new UriTemplate("{/list*}");
const match = template.match("/red,green,blue");
expect(match).toEqual({ list: ["red", "green", "blue"] });
});
});
describe("edge cases", () => {
it("should handle empty variables", () => {
const template = new UriTemplate("{empty}");
expect(template.expand({})).toBe("");
expect(template.expand({ empty: "" })).toBe("");
});
it("should handle undefined variables", () => {
const template = new UriTemplate("{a}{b}{c}");
expect(template.expand({ b: "2" })).toBe("2");
});
it("should handle special characters in variable names", () => {
const template = new UriTemplate("{$var_name}");
expect(template.expand({ "$var_name": "value" })).toBe("value");
});
});
describe("complex patterns", () => {
it("should handle nested path segments", () => {
const template = new UriTemplate("/api/{version}/{resource}/{id}");
expect(template.expand({
version: "v1",
resource: "users",
id: "123"
})).toBe("/api/v1/users/123");
});
it("should handle query parameters with arrays", () => {
const template = new UriTemplate("/search{?tags*}");
expect(template.expand({
tags: ["nodejs", "typescript", "testing"]
})).toBe("/search?tags=nodejs,typescript,testing");
});
it("should handle multiple query parameters", () => {
const template = new UriTemplate("/search{?q,page,limit}");
expect(template.expand({
q: "test",
page: "1",
limit: "10"
})).toBe("/search?q=test&page=1&limit=10");
});
});
describe("matching complex patterns", () => {
it("should match nested path segments", () => {
const template = new UriTemplate("/api/{version}/{resource}/{id}");
const match = template.match("/api/v1/users/123");
expect(match).toEqual({
version: "v1",
resource: "users",
id: "123"
});
});
it("should match query parameters", () => {
const template = new UriTemplate("/search{?q}");
const match = template.match("/search?q=test");
expect(match).toEqual({ q: "test" });
});
it("should match multiple query parameters", () => {
const template = new UriTemplate("/search{?q,page}");
const match = template.match("/search?q=test&page=1");
expect(match).toEqual({ q: "test", page: "1" });
});
it("should handle partial matches correctly", () => {
const template = new UriTemplate("/users/{id}");
expect(template.match("/users/123/extra")).toBeNull();
expect(template.match("/users")).toBeNull();
});
});
describe("security and edge cases", () => {
it("should handle extremely long input strings", () => {
const longString = "x".repeat(100000);
const template = new UriTemplate(`/api/{param}`);
expect(template.expand({ param: longString })).toBe(`/api/${longString}`);
expect(template.match(`/api/${longString}`)).toEqual({ param: longString });
});
it("should handle deeply nested template expressions", () => {
const template = new UriTemplate("{a}{b}{c}{d}{e}{f}{g}{h}{i}{j}".repeat(1000));
expect(() => template.expand({
a: "1", b: "2", c: "3", d: "4", e: "5",
f: "6", g: "7", h: "8", i: "9", j: "0"
})).not.toThrow();
});
it("should handle malformed template expressions", () => {
expect(() => new UriTemplate("{unclosed")).toThrow();
expect(() => new UriTemplate("{}")).not.toThrow();
expect(() => new UriTemplate("{,}")).not.toThrow();
expect(() => new UriTemplate("{a}{")).toThrow();
});
it("should handle pathological regex patterns", () => {
const template = new UriTemplate("/api/{param}");
// Create a string that could cause catastrophic backtracking
const input = "/api/" + "a".repeat(100000);
expect(() => template.match(input)).not.toThrow();
});
it("should handle invalid UTF-8 sequences", () => {
const template = new UriTemplate("/api/{param}");
const invalidUtf8 = "���";
expect(() => template.expand({ param: invalidUtf8 })).not.toThrow();
expect(() => template.match(`/api/${invalidUtf8}`)).not.toThrow();
});
it("should handle template/URI length mismatches", () => {
const template = new UriTemplate("/api/{param}");
expect(template.match("/api/")).toBeNull();
expect(template.match("/api")).toBeNull();
expect(template.match("/api/value/extra")).toBeNull();
});
it("should handle repeated operators", () => {
const template = new UriTemplate("{?a}{?b}{?c}");
expect(template.expand({ a: "1", b: "2", c: "3" })).toBe("?a=1&b=2&c=3");
});
it("should handle overlapping variable names", () => {
const template = new UriTemplate("{var}{vara}");
expect(template.expand({ var: "1", vara: "2" })).toBe("12");
});
it("should handle empty segments", () => {
const template = new UriTemplate("///{a}////{b}////");
expect(template.expand({ a: "1", b: "2" })).toBe("///1////2////");
expect(template.match("///1////2////")).toEqual({ a: "1", b: "2" });
});
it("should handle maximum template expression limit", () => {
// Create a template with many expressions
const expressions = Array(10000).fill("{param}").join("");
expect(() => new UriTemplate(expressions)).not.toThrow();
});
it("should handle maximum variable name length", () => {
const longName = "a".repeat(10000);
const template = new UriTemplate(`{${longName}}`);
const vars: Record<string, string> = {};
vars[longName] = "value";
expect(() => template.expand(vars)).not.toThrow();
});
});
});
================================================
File: src/shared/uriTemplate.ts
================================================
// Claude-authored implementation of RFC 6570 URI Templates
export type Variables = Record<string, string | string[]>;
const MAX_TEMPLATE_LENGTH = 1000000; // 1MB
const MAX_VARIABLE_LENGTH = 1000000; // 1MB
const MAX_TEMPLATE_EXPRESSIONS = 10000;
const MAX_REGEX_LENGTH = 1000000; // 1MB
export class UriTemplate {
/**
* Returns true if the given string contains any URI template expressions.
* A template expression is a sequence of characters enclosed in curly braces,
* like {foo} or {?bar}.
*/
static isTemplate(str: string): boolean {
// Look for any sequence of characters between curly braces
// that isn't just whitespace
return /\{[^}\s]+\}/.test(str);
}
private static validateLength(
str: string,
max: number,
context: string,
): void {
if (str.length > max) {
throw new Error(
`${context} exceeds maximum length of ${max} characters (got ${str.length})`,
);
}
}
private readonly template: string;
private readonly parts: Array<
| string
| { name: string; operator: string; names: string[]; exploded: boolean }
>;
constructor(template: string) {
UriTemplate.validateLength(template, MAX_TEMPLATE_LENGTH, "Template");
this.template = template;
this.parts = this.parse(template);
}
toString(): string {
return this.template;
}
private parse(
template: string,
): Array<
| string
| { name: string; operator: string; names: string[]; exploded: boolean }
> {
const parts: Array<
| string
| { name: string; operator: string; names: string[]; exploded: boolean }
> = [];
let currentText = "";
let i = 0;
let expressionCount = 0;
while (i < template.length) {
if (template[i] === "{") {
if (currentText) {
parts.push(currentText);
currentText = "";
}
const end = template.indexOf("}", i);
if (end === -1) throw new Error("Unclosed template expression");
expressionCount++;
if (expressionCount > MAX_TEMPLATE_EXPRESSIONS) {
throw new Error(
`Template contains too many expressions (max ${MAX_TEMPLATE_EXPRESSIONS})`,
);
}
const expr = template.slice(i + 1, end);
const operator = this.getOperator(expr);
const exploded = expr.includes("*");
const names = this.getNames(expr);
const name = names[0];
// Validate variable name length
for (const name of names) {
UriTemplate.validateLength(
name,
MAX_VARIABLE_LENGTH,
"Variable name",
);
}
parts.push({ name, operator, names, exploded });
i = end + 1;
} else {
currentText += template[i];
i++;
}
}
if (currentText) {
parts.push(currentText);
}
return parts;
}
private getOperator(expr: string): string {
const operators = ["+", "#", ".", "/", "?", "&"];
return operators.find((op) => expr.startsWith(op)) || "";
}
private getNames(expr: string): string[] {
const operator = this.getOperator(expr);
return expr
.slice(operator.length)
.split(",")
.map((name) => name.replace("*", "").trim())
.filter((name) => name.length > 0);
}
private encodeValue(value: string, operator: string): string {
UriTemplate.validateLength(value, MAX_VARIABLE_LENGTH, "Variable value");
if (operator === "+" || operator === "#") {
return encodeURI(value);
}
return encodeURIComponent(value);
}
private expandPart(
part: {
name: string;
operator: string;
names: string[];
exploded: boolean;
},
variables: Variables,
): string {
if (part.operator === "?" || part.operator === "&") {
const pairs = part.names
.map((name) => {
const value = variables[name];
if (value === undefined) return "";
const encoded = Array.isArray(value)
? value.map((v) => this.encodeValue(v, part.operator)).join(",")
: this.encodeValue(value.toString(), part.operator);
return `${name}=${encoded}`;
})
.filter((pair) => pair.length > 0);
if (pairs.length === 0) return "";
const separator = part.operator === "?" ? "?" : "&";
return separator + pairs.join("&");
}
if (part.names.length > 1) {
const values = part.names
.map((name) => variables[name])
.filter((v) => v !== undefined);
if (values.length === 0) return "";
return values.map((v) => (Array.isArray(v) ? v[0] : v)).join(",");
}
const value = variables[part.name];
if (value === undefined) return "";
const values = Array.isArray(value) ? value : [value];
const encoded = values.map((v) => this.encodeValue(v, part.operator));
switch (part.operator) {
case "":
return encoded.join(",");
case "+":
return encoded.join(",");
case "#":
return "#" + encoded.join(",");
case ".":
return "." + encoded.join(".");
case "/":
return "/" + encoded.join("/");
default:
return encoded.join(",");
}
}
expand(variables: Variables): string {
let result = "";
let hasQueryParam = false;
for (const part of this.parts) {
if (typeof part === "string") {
result += part;
continue;
}
const expanded = this.expandPart(part, variables);
if (!expanded) continue;
// Convert ? to & if we already have a query parameter
if ((part.operator === "?" || part.operator === "&") && hasQueryParam) {
result += expanded.replace("?", "&");
} else {
result += expanded;
}
if (part.operator === "?" || part.operator === "&") {
hasQueryParam = true;
}
}
return result;
}
private escapeRegExp(str: string): string {
return str.replace(/[.*+?^${}()|[\]\\]/g, "\\$&");
}
private partToRegExp(part: {
name: string;
operator: string;
names: string[];
exploded: boolean;
}): Array<{ pattern: string; name: string }> {
const patterns: Array<{ pattern: string; name: string }> = [];
// Validate variable name length for matching
for (const name of part.names) {
UriTemplate.validateLength(name, MAX_VARIABLE_LENGTH, "Variable name");
}
if (part.operator === "?" || part.operator === "&") {
for (let i = 0; i < part.names.length; i++) {
const name = part.names[i];
const prefix = i === 0 ? "\\" + part.operator : "&";
patterns.push({
pattern: prefix + this.escapeRegExp(name) + "=([^&]+)",
name,
});
}
return patterns;
}
let pattern: string;
const name = part.name;
switch (part.operator) {
case "":
pattern = part.exploded ? "([^/]+(?:,[^/]+)*)" : "([^/,]+)";
break;
case "+":
case "#":
pattern = "(.+)";
break;
case ".":
pattern = "\\.([^/,]+)";
break;
case "/":
pattern = "/" + (part.exploded ? "([^/]+(?:,[^/]+)*)" : "([^/,]+)");
break;
default:
pattern = "([^/]+)";
}
patterns.push({ pattern, name });
return patterns;
}
match(uri: string): Variables | null {
UriTemplate.validateLength(uri, MAX_TEMPLATE_LENGTH, "URI");
let pattern = "^";
const names: Array<{ name: string; exploded: boolean }> = [];
for (const part of this.parts) {
if (typeof part === "string") {
pattern += this.escapeRegExp(part);
} else {
const patterns = this.partToRegExp(part);
for (const { pattern: partPattern, name } of patterns) {
pattern += partPattern;
names.push({ name, exploded: part.exploded });
}
}
}
pattern += "$";
UriTemplate.validateLength(
pattern,
MAX_REGEX_LENGTH,
"Generated regex pattern",
);
const regex = new RegExp(pattern);
const match = uri.match(regex);
if (!match) return null;
const result: Variables = {};
for (let i = 0; i < names.length; i++) {
const { name, exploded } = names[i];
const value = match[i + 1];
const cleanName = name.replace("*", "");
if (exploded && value.includes(",")) {
result[cleanName] = value.split(",");
} else {
result[cleanName] = value;
}
}
return result;
}
}
================================================
File: .github/workflows/main.yml
================================================
on:
push:
branches:
- main
pull_request:
release:
types: [published]
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
with:
node-version: 18
cache: npm
- run: npm ci
- run: npm run build
- run: npm test
- run: npm run lint
publish:
runs-on: ubuntu-latest
if: github.event_name == 'release'
environment: release
needs: build
permissions:
contents: read
id-token: write
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
with:
node-version: 18
cache: npm
registry-url: 'https://registry.npmjs.org'
- run: npm ci
# TODO: Add --provenance once the repo is public
- run: npm publish --access public
env:
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment