diff --git a/README.md b/README.md index 846ba5232..5cc1f4b29 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,7 @@ - [Tools](#tools) - [Prompts](#prompts) - [Completions](#completions) + - [Sampling](#sampling) - [Running Your Server](#running-your-server) - [stdio](#stdio) - [Streamable HTTP](#streamable-http) @@ -44,6 +45,8 @@ The Model Context Protocol allows applications to provide context for LLMs in a npm install @modelcontextprotocol/sdk ``` +> ⚠️ MCP requires Node v18.x up to work fine. + ## Quick Start Let's create a simple MCP server that exposes a calculator tool and some data: @@ -382,6 +385,68 @@ import { getDisplayName } from "@modelcontextprotocol/sdk/shared/metadataUtils.j const displayName = getDisplayName(tool); ``` +### Sampling + +MCP servers can request LLM completions from connected clients that support sampling. + +```typescript +import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; +import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; +import { z } from "zod"; + +const mcpServer = new McpServer({ + name: "tools-with-sample-server", + version: "1.0.0", +}); + +// Tool that uses LLM sampling to summarize any text +mcpServer.registerTool( + "summarize", + { + description: "Summarize any text using an LLM", + inputSchema: { + text: z.string().describe("Text to summarize"), + }, + }, + async ({ text }) => { + // Call the LLM through MCP sampling + const response = await mcpServer.server.createMessage({ + messages: [ + { + role: "user", + content: { + type: "text", + text: `Please summarize the following text concisely:\n\n${text}`, + }, + }, + ], + maxTokens: 500, + }); + + return { + content: [ + { + type: "text", + text: response.content.type === "text" ? response.content.text : "Unable to generate summary", + }, + ], + }; + } +); + +async function main() { + const transport = new StdioServerTransport(); + await mcpServer.connect(transport); + console.log("MCP server is running..."); +} + +main().catch((error) => { + console.error("Server error:", error); + process.exit(1); +}); +``` + + ## Running Your Server MCP servers in TypeScript need to be connected to a transport to communicate with clients. How you start the server depends on the choice of transport: @@ -444,7 +509,11 @@ app.post('/mcp', async (req, res) => { onsessioninitialized: (sessionId) => { // Store the transport by session ID transports[sessionId] = transport; - } + }, + // DNS rebinding protection is disabled by default for backwards compatibility. If you are running this server + // locally, make sure to set: + // enableDnsRebindingProtection: true, + // allowedHosts: ['127.0.0.1'], }); // Clean up transport when closed @@ -584,8 +653,17 @@ app.delete('/mcp', async (req: Request, res: Response) => { // Start the server const PORT = 3000; -app.listen(PORT, () => { - console.log(`MCP Stateless Streamable HTTP Server listening on port ${PORT}`); +setupServer().then(() => { + app.listen(PORT, (error) => { + if (error) { + console.error('Failed to start server:', error); + process.exit(1); + } + console.log(`MCP Stateless Streamable HTTP Server listening on port ${PORT}`); + }); +}).catch(error => { + console.error('Failed to set up the server:', error); + process.exit(1); }); ``` @@ -596,6 +674,22 @@ This stateless approach is useful for: - RESTful scenarios where each request is independent - Horizontally scaled deployments without shared session state +#### DNS Rebinding Protection + +The Streamable HTTP transport includes DNS rebinding protection to prevent security vulnerabilities. By default, this protection is **disabled** for backwards compatibility. + +**Important**: If you are running this server locally, enable DNS rebinding protection: + +```typescript +const transport = new StreamableHTTPServerTransport({ + sessionIdGenerator: () => randomUUID(), + enableDnsRebindingProtection: true, + + allowedHosts: ['127.0.0.1', ...], + allowedOrigins: ['https://yourdomain.com', 'https://www.yourdomain.com'] +}); +``` + ### Testing and Debugging To test your server, you can use the [MCP Inspector](https://github.com/modelcontextprotocol/inspector). See its README for more information. diff --git a/package-lock.json b/package-lock.json index 9f1d43a33..16b90a3b7 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@modelcontextprotocol/sdk", - "version": "1.13.2", + "version": "1.13.3", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@modelcontextprotocol/sdk", - "version": "1.13.2", + "version": "1.13.3", "license": "MIT", "dependencies": { "ajv": "^6.12.6", @@ -14,6 +14,7 @@ "cors": "^2.8.5", "cross-spawn": "^7.0.5", "eventsource": "^3.0.2", + "eventsource-parser": "^3.0.0", "express": "^5.0.1", "express-rate-limit": "^7.5.0", "pkce-challenge": "^5.0.0", diff --git a/package.json b/package.json index 8feb10aff..e50619668 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@modelcontextprotocol/sdk", - "version": "1.13.2", + "version": "1.13.3", "description": "Model Context Protocol implementation for TypeScript", "license": "MIT", "author": "Anthropic, PBC (https://anthropic.com)", @@ -54,6 +54,7 @@ "cors": "^2.8.5", "cross-spawn": "^7.0.5", "eventsource": "^3.0.2", + "eventsource-parser": "^3.0.0", "express": "^5.0.1", "express-rate-limit": "^7.5.0", "pkce-challenge": "^5.0.0", diff --git a/src/client/auth.test.ts b/src/client/auth.test.ts index 8e77c0a5b..8155e1342 100644 --- a/src/client/auth.test.ts +++ b/src/client/auth.test.ts @@ -403,6 +403,19 @@ describe("OAuth Authorization", () => { expect(mockFetch).toHaveBeenCalledTimes(2); }); + it("returns undefined when both CORS requests fail in fetchWithCorsRetry", async () => { + // fetchWithCorsRetry tries with headers (fails with CORS), then retries without headers (also fails with CORS) + // simulating a 404 w/o headers set. We want this to return undefined, not throw TypeError + mockFetch.mockImplementation(() => { + // Both the initial request with headers and retry without headers fail with CORS TypeError + return Promise.reject(new TypeError("Failed to fetch")); + }); + + // This should return undefined (the desired behavior after the fix) + const metadata = await discoverOAuthMetadata("https://auth.example.com/path"); + expect(metadata).toBeUndefined(); + }); + it("returns undefined when discovery endpoint returns 404", async () => { mockFetch.mockResolvedValueOnce({ ok: false, diff --git a/src/client/auth.ts b/src/client/auth.ts index 376905743..71101a428 100644 --- a/src/client/auth.ts +++ b/src/client/auth.ts @@ -292,25 +292,24 @@ export async function discoverOAuthProtectedResourceMetadata( return OAuthProtectedResourceMetadataSchema.parse(await response.json()); } -/** - * Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata. - * - * If the server returns a 404 for the well-known endpoint, this function will - * return `undefined`. Any other errors will be thrown as exceptions. - */ /** * Helper function to handle fetch with CORS retry logic */ async function fetchWithCorsRetry( url: URL, - headers: Record, -): Promise { + headers?: Record, +): Promise { try { return await fetch(url, { headers }); } catch (error) { - // CORS errors come back as TypeError, retry without headers if (error instanceof TypeError) { - return await fetch(url); + if (headers) { + // CORS errors come back as TypeError, retry without headers + return fetchWithCorsRetry(url) + } else { + // We're getting CORS errors on retry too, return undefined + return undefined + } } throw error; } @@ -334,7 +333,7 @@ function buildWellKnownPath(pathname: string): string { async function tryMetadataDiscovery( url: URL, protocolVersion: string, -): Promise { +): Promise { const headers = { "MCP-Protocol-Version": protocolVersion }; @@ -344,10 +343,16 @@ async function tryMetadataDiscovery( /** * Determines if fallback to root discovery should be attempted */ -function shouldAttemptFallback(response: Response, pathname: string): boolean { - return response.status === 404 && pathname !== '/'; +function shouldAttemptFallback(response: Response | undefined, pathname: string): boolean { + return !response || response.status === 404 && pathname !== '/'; } +/** + * Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata. + * + * If the server returns a 404 for the well-known endpoint, this function will + * return `undefined`. Any other errors will be thrown as exceptions. + */ export async function discoverOAuthMetadata( authorizationServerUrl: string | URL, opts?: { protocolVersion?: string }, @@ -362,18 +367,10 @@ export async function discoverOAuthMetadata( // If path-aware discovery fails with 404, try fallback to root discovery if (shouldAttemptFallback(response, issuer.pathname)) { - try { - const rootUrl = new URL("/.well-known/oauth-authorization-server", issuer); - response = await tryMetadataDiscovery(rootUrl, protocolVersion); - - if (response.status === 404) { - return undefined; - } - } catch { - // If fallback fails, return undefined - return undefined; - } - } else if (response.status === 404) { + const rootUrl = new URL("/.well-known/oauth-authorization-server", issuer); + response = await tryMetadataDiscovery(rootUrl, protocolVersion); + } + if (!response || response.status === 404) { return undefined; } diff --git a/src/client/stdio.test.ts b/src/client/stdio.test.ts index 646f9ea5d..b21324469 100644 --- a/src/client/stdio.test.ts +++ b/src/client/stdio.test.ts @@ -59,3 +59,12 @@ test("should read messages", async () => { await client.close(); }); + +test("should return child process pid", async () => { + const client = new StdioClientTransport(serverParameters); + + await client.start(); + expect(client.pid).not.toBeNull(); + await client.close(); + expect(client.pid).toBeNull(); +}); diff --git a/src/client/stdio.ts b/src/client/stdio.ts index 9e35293d3..e9c9fa8f0 100644 --- a/src/client/stdio.ts +++ b/src/client/stdio.ts @@ -56,6 +56,7 @@ export const DEFAULT_INHERITED_ENV_VARS = "TEMP", "USERNAME", "USERPROFILE", + "PROGRAMFILES", ] : /* list inspired by the default env inheritance of sudo */ ["HOME", "LOGNAME", "PATH", "SHELL", "TERM", "USER"]; @@ -184,6 +185,15 @@ export class StdioClientTransport implements Transport { return this._process?.stderr ?? null; } + /** + * The child process pid spawned by this transport. + * + * This is only available after the transport has been started. + */ + get pid(): number | null { + return this._process?.pid ?? null; + } + private processReadBuffer() { while (true) { try { diff --git a/src/client/streamableHttp.test.ts b/src/client/streamableHttp.test.ts index f748a2be3..11dfe7d41 100644 --- a/src/client/streamableHttp.test.ts +++ b/src/client/streamableHttp.test.ts @@ -476,6 +476,37 @@ describe("StreamableHTTPClientTransport", () => { expect(global.fetch).toHaveBeenCalledTimes(2); }); + it("should always send specified custom headers (Headers class)", async () => { + const requestInit = { + headers: new Headers({ + "X-Custom-Header": "CustomValue" + }) + }; + transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { + requestInit: requestInit + }); + + let actualReqInit: RequestInit = {}; + + ((global.fetch as jest.Mock)).mockImplementation( + async (_url, reqInit) => { + actualReqInit = reqInit; + return new Response(null, { status: 200, headers: { "content-type": "text/event-stream" } }); + } + ); + + await transport.start(); + + await transport["_startOrAuthSse"]({}); + expect((actualReqInit.headers as Headers).get("x-custom-header")).toBe("CustomValue"); + + (requestInit.headers as Headers).set("X-Custom-Header","SecondCustomValue"); + + await transport.send({ jsonrpc: "2.0", method: "test", params: {} } as JSONRPCMessage); + expect((actualReqInit.headers as Headers).get("x-custom-header")).toBe("SecondCustomValue"); + + expect(global.fetch).toHaveBeenCalledTimes(2); + }); it("should have exponential backoff with configurable maxRetries", () => { // This test verifies the maxRetries and backoff calculation directly diff --git a/src/client/streamableHttp.ts b/src/client/streamableHttp.ts index 4117bb1b4..730784422 100644 --- a/src/client/streamableHttp.ts +++ b/src/client/streamableHttp.ts @@ -178,9 +178,12 @@ export class StreamableHTTPClientTransport implements Transport { headers["mcp-protocol-version"] = this._protocolVersion; } - return new Headers( - { ...headers, ...this._requestInit?.headers } - ); + const extraHeaders = this._normalizeHeaders(this._requestInit?.headers); + + return new Headers({ + ...headers, + ...extraHeaders, + }); } @@ -246,6 +249,20 @@ export class StreamableHTTPClientTransport implements Transport { } + private _normalizeHeaders(headers: HeadersInit | undefined): Record { + if (!headers) return {}; + + if (headers instanceof Headers) { + return Object.fromEntries(headers.entries()); + } + + if (Array.isArray(headers)) { + return Object.fromEntries(headers); + } + + return { ...headers as Record }; + } + /** * Schedule a reconnection attempt with exponential backoff * diff --git a/src/examples/server/mcpServerOutputSchema.ts b/src/examples/server/mcpServerOutputSchema.ts index 75bfe6900..de3b363ed 100644 --- a/src/examples/server/mcpServerOutputSchema.ts +++ b/src/examples/server/mcpServerOutputSchema.ts @@ -43,7 +43,14 @@ server.registerTool( void country; // Simulate weather API call const temp_c = Math.round((Math.random() * 35 - 5) * 10) / 10; - const conditions = ["sunny", "cloudy", "rainy", "stormy", "snowy"][Math.floor(Math.random() * 5)]; + const conditionCandidates = [ + "sunny", + "cloudy", + "rainy", + "stormy", + "snowy", + ] as const; + const conditions = conditionCandidates[Math.floor(Math.random() * conditionCandidates.length)]; const structuredContent = { temperature: { @@ -77,4 +84,4 @@ async function main() { main().catch((error) => { console.error("Server error:", error); process.exit(1); -}); \ No newline at end of file +}); diff --git a/src/examples/server/toolWithSampleServer.ts b/src/examples/server/toolWithSampleServer.ts new file mode 100644 index 000000000..44e5cecbb --- /dev/null +++ b/src/examples/server/toolWithSampleServer.ts @@ -0,0 +1,57 @@ + +// Run with: npx tsx src/examples/server/toolWithSampleServer.ts + +import { McpServer } from "../../server/mcp.js"; +import { StdioServerTransport } from "../../server/stdio.js"; +import { z } from "zod"; + +const mcpServer = new McpServer({ + name: "tools-with-sample-server", + version: "1.0.0", +}); + +// Tool that uses LLM sampling to summarize any text +mcpServer.registerTool( + "summarize", + { + description: "Summarize any text using an LLM", + inputSchema: { + text: z.string().describe("Text to summarize"), + }, + }, + async ({ text }) => { + // Call the LLM through MCP sampling + const response = await mcpServer.server.createMessage({ + messages: [ + { + role: "user", + content: { + type: "text", + text: `Please summarize the following text concisely:\n\n${text}`, + }, + }, + ], + maxTokens: 500, + }); + + return { + content: [ + { + type: "text", + text: response.content.type === "text" ? response.content.text : "Unable to generate summary", + }, + ], + }; + } +); + +async function main() { + const transport = new StdioServerTransport(); + await mcpServer.connect(transport); + console.log("MCP server is running..."); +} + +main().catch((error) => { + console.error("Server error:", error); + process.exit(1); +}); \ No newline at end of file diff --git a/src/server/mcp.test.ts b/src/server/mcp.test.ts index 0764ffe88..dc96a1b0f 100644 --- a/src/server/mcp.test.ts +++ b/src/server/mcp.test.ts @@ -267,6 +267,7 @@ describe("tool()", () => { expect(result.tools[0].name).toBe("test"); expect(result.tools[0].inputSchema).toEqual({ type: "object", + properties: {}, }); // Adding the tool before the connection was established means no notification was sent @@ -1311,7 +1312,7 @@ describe("tool()", () => { resultType: "structured", // Missing required 'timestamp' field someExtraField: "unexpected" // Extra field not in schema - }, + } as unknown as { processedInput: string; resultType: string; timestamp: string }, // Type assertion to bypass TypeScript validation for testing purposes }) ); diff --git a/src/server/mcp.ts b/src/server/mcp.ts index 67da78ffb..a5624e153 100644 --- a/src/server/mcp.ts +++ b/src/server/mcp.ts @@ -169,7 +169,7 @@ export class McpServer { } const args = parseResult.data; - const cb = tool.callback as ToolCallback; + const cb = tool.callback as ToolCallback; try { result = await Promise.resolve(cb(args, extra)); } catch (error) { @@ -184,7 +184,7 @@ export class McpServer { }; } } else { - const cb = tool.callback as ToolCallback; + const cb = tool.callback as ToolCallback; try { result = await Promise.resolve(cb(extra)); } catch (error) { @@ -613,6 +613,18 @@ export class McpServer { * Registers a resource with a config object and callback. * For static resources, use a URI string. For dynamic resources, use a ResourceTemplate. */ + registerResource( + name: string, + uriOrTemplate: string, + config: ResourceMetadata, + readCallback: ReadResourceCallback + ): RegisteredResource; + registerResource( + name: string, + uriOrTemplate: ResourceTemplate, + config: ResourceMetadata, + readCallback: ReadResourceTemplateCallback + ): RegisteredResourceTemplate; registerResource( name: string, uriOrTemplate: string | ResourceTemplate, @@ -760,7 +772,7 @@ export class McpServer { inputSchema: ZodRawShape | undefined, outputSchema: ZodRawShape | undefined, annotations: ToolAnnotations | undefined, - callback: ToolCallback + callback: ToolCallback ): RegisteredTool { const registeredTool: RegisteredTool = { title, @@ -917,7 +929,7 @@ export class McpServer { outputSchema?: OutputArgs; annotations?: ToolAnnotations; }, - cb: ToolCallback + cb: ToolCallback ): RegisteredTool { if (this._registeredTools[name]) { throw new Error(`Tool ${name} is already registered`); @@ -932,7 +944,7 @@ export class McpServer { inputSchema, outputSchema, annotations, - cb as ToolCallback + cb as ToolCallback ); } @@ -1126,6 +1138,16 @@ export class ResourceTemplate { } } +/** + * Type helper to create a strongly-typed CallToolResult with structuredContent + */ +type TypedCallToolResult = + OutputArgs extends ZodRawShape + ? CallToolResult & { + structuredContent?: z.objectOutputType; + } + : CallToolResult; + /** * Callback for a tool handler registered with Server.tool(). * @@ -1136,13 +1158,21 @@ export class ResourceTemplate { * - `content` if the tool does not have an outputSchema * - Both fields are optional but typically one should be provided */ -export type ToolCallback = - Args extends ZodRawShape +export type ToolCallback< + InputArgs extends undefined | ZodRawShape = undefined, + OutputArgs extends undefined | ZodRawShape = undefined +> = InputArgs extends ZodRawShape ? ( - args: z.objectOutputType, - extra: RequestHandlerExtra, - ) => CallToolResult | Promise - : (extra: RequestHandlerExtra) => CallToolResult | Promise; + args: z.objectOutputType, + extra: RequestHandlerExtra + ) => + | TypedCallToolResult + | Promise> + : ( + extra: RequestHandlerExtra + ) => + | TypedCallToolResult + | Promise>; export type RegisteredTool = { title?: string; @@ -1150,26 +1180,29 @@ export type RegisteredTool = { inputSchema?: AnyZodObject; outputSchema?: AnyZodObject; annotations?: ToolAnnotations; - callback: ToolCallback; + callback: ToolCallback; enabled: boolean; enable(): void; disable(): void; - update( - updates: { - name?: string | null, - title?: string, - description?: string, - paramsSchema?: InputArgs, - outputSchema?: OutputArgs, - annotations?: ToolAnnotations, - callback?: ToolCallback, - enabled?: boolean - }): void - remove(): void + update< + InputArgs extends ZodRawShape, + OutputArgs extends ZodRawShape + >(updates: { + name?: string | null; + title?: string; + description?: string; + paramsSchema?: InputArgs; + outputSchema?: OutputArgs; + annotations?: ToolAnnotations; + callback?: ToolCallback + enabled?: boolean + }): void; + remove(): void; }; const EMPTY_OBJECT_JSON_SCHEMA = { type: "object" as const, + properties: {}, }; // Helper to check if an object is a Zod schema (ZodRawShape) diff --git a/src/server/sse.test.ts b/src/server/sse.test.ts index 32c894f07..a7f180961 100644 --- a/src/server/sse.test.ts +++ b/src/server/sse.test.ts @@ -453,4 +453,264 @@ describe('SSEServerTransport', () => { expect.stringContaining(`data: /messages?sessionId=${transport.sessionId}`)); }); }); -}); \ No newline at end of file + + describe('DNS rebinding protection', () => { + beforeEach(() => { + jest.clearAllMocks(); + }); + + describe('Host header validation', () => { + it('should accept requests with allowed host headers', async () => { + const mockRes = createMockResponse(); + const transport = new SSEServerTransport('/messages', mockRes, { + allowedHosts: ['localhost:3000', 'example.com'], + enableDnsRebindingProtection: true, + }); + await transport.start(); + + const mockReq = createMockRequest({ + headers: { + host: 'localhost:3000', + 'content-type': 'application/json', + } + }); + const mockHandleRes = createMockResponse(); + + await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes.writeHead).toHaveBeenCalledWith(202); + expect(mockHandleRes.end).toHaveBeenCalledWith('Accepted'); + }); + + it('should reject requests with disallowed host headers', async () => { + const mockRes = createMockResponse(); + const transport = new SSEServerTransport('/messages', mockRes, { + allowedHosts: ['localhost:3000'], + enableDnsRebindingProtection: true, + }); + await transport.start(); + + const mockReq = createMockRequest({ + headers: { + host: 'evil.com', + 'content-type': 'application/json', + } + }); + const mockHandleRes = createMockResponse(); + + await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes.writeHead).toHaveBeenCalledWith(403); + expect(mockHandleRes.end).toHaveBeenCalledWith('Invalid Host header: evil.com'); + }); + + it('should reject requests without host header when allowedHosts is configured', async () => { + const mockRes = createMockResponse(); + const transport = new SSEServerTransport('/messages', mockRes, { + allowedHosts: ['localhost:3000'], + enableDnsRebindingProtection: true, + }); + await transport.start(); + + const mockReq = createMockRequest({ + headers: { + 'content-type': 'application/json', + } + }); + const mockHandleRes = createMockResponse(); + + await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes.writeHead).toHaveBeenCalledWith(403); + expect(mockHandleRes.end).toHaveBeenCalledWith('Invalid Host header: undefined'); + }); + }); + + describe('Origin header validation', () => { + it('should accept requests with allowed origin headers', async () => { + const mockRes = createMockResponse(); + const transport = new SSEServerTransport('/messages', mockRes, { + allowedOrigins: ['http://localhost:3000', 'https://example.com'], + enableDnsRebindingProtection: true, + }); + await transport.start(); + + const mockReq = createMockRequest({ + headers: { + origin: 'http://localhost:3000', + 'content-type': 'application/json', + } + }); + const mockHandleRes = createMockResponse(); + + await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes.writeHead).toHaveBeenCalledWith(202); + expect(mockHandleRes.end).toHaveBeenCalledWith('Accepted'); + }); + + it('should reject requests with disallowed origin headers', async () => { + const mockRes = createMockResponse(); + const transport = new SSEServerTransport('/messages', mockRes, { + allowedOrigins: ['http://localhost:3000'], + enableDnsRebindingProtection: true, + }); + await transport.start(); + + const mockReq = createMockRequest({ + headers: { + origin: 'http://evil.com', + 'content-type': 'application/json', + } + }); + const mockHandleRes = createMockResponse(); + + await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes.writeHead).toHaveBeenCalledWith(403); + expect(mockHandleRes.end).toHaveBeenCalledWith('Invalid Origin header: http://evil.com'); + }); + }); + + describe('Content-Type validation', () => { + it('should accept requests with application/json content-type', async () => { + const mockRes = createMockResponse(); + const transport = new SSEServerTransport('/messages', mockRes); + await transport.start(); + + const mockReq = createMockRequest({ + headers: { + 'content-type': 'application/json', + } + }); + const mockHandleRes = createMockResponse(); + + await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes.writeHead).toHaveBeenCalledWith(202); + expect(mockHandleRes.end).toHaveBeenCalledWith('Accepted'); + }); + + it('should accept requests with application/json with charset', async () => { + const mockRes = createMockResponse(); + const transport = new SSEServerTransport('/messages', mockRes); + await transport.start(); + + const mockReq = createMockRequest({ + headers: { + 'content-type': 'application/json; charset=utf-8', + } + }); + const mockHandleRes = createMockResponse(); + + await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes.writeHead).toHaveBeenCalledWith(202); + expect(mockHandleRes.end).toHaveBeenCalledWith('Accepted'); + }); + + it('should reject requests with non-application/json content-type when protection is enabled', async () => { + const mockRes = createMockResponse(); + const transport = new SSEServerTransport('/messages', mockRes); + await transport.start(); + + const mockReq = createMockRequest({ + headers: { + 'content-type': 'text/plain', + } + }); + const mockHandleRes = createMockResponse(); + + await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes.writeHead).toHaveBeenCalledWith(400); + expect(mockHandleRes.end).toHaveBeenCalledWith('Error: Unsupported content-type: text/plain'); + }); + }); + + describe('enableDnsRebindingProtection option', () => { + it('should skip all validations when enableDnsRebindingProtection is false', async () => { + const mockRes = createMockResponse(); + const transport = new SSEServerTransport('/messages', mockRes, { + allowedHosts: ['localhost:3000'], + allowedOrigins: ['http://localhost:3000'], + enableDnsRebindingProtection: false, + }); + await transport.start(); + + const mockReq = createMockRequest({ + headers: { + host: 'evil.com', + origin: 'http://evil.com', + 'content-type': 'text/plain', + } + }); + const mockHandleRes = createMockResponse(); + + await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); + + // Should pass even with invalid headers because protection is disabled + expect(mockHandleRes.writeHead).toHaveBeenCalledWith(400); + // The error should be from content-type parsing, not DNS rebinding protection + expect(mockHandleRes.end).toHaveBeenCalledWith('Error: Unsupported content-type: text/plain'); + }); + }); + + describe('Combined validations', () => { + it('should validate both host and origin when both are configured', async () => { + const mockRes = createMockResponse(); + const transport = new SSEServerTransport('/messages', mockRes, { + allowedHosts: ['localhost:3000'], + allowedOrigins: ['http://localhost:3000'], + enableDnsRebindingProtection: true, + }); + await transport.start(); + + // Valid host, invalid origin + const mockReq1 = createMockRequest({ + headers: { + host: 'localhost:3000', + origin: 'http://evil.com', + 'content-type': 'application/json', + } + }); + const mockHandleRes1 = createMockResponse(); + + await transport.handlePostMessage(mockReq1, mockHandleRes1, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes1.writeHead).toHaveBeenCalledWith(403); + expect(mockHandleRes1.end).toHaveBeenCalledWith('Invalid Origin header: http://evil.com'); + + // Invalid host, valid origin + const mockReq2 = createMockRequest({ + headers: { + host: 'evil.com', + origin: 'http://localhost:3000', + 'content-type': 'application/json', + } + }); + const mockHandleRes2 = createMockResponse(); + + await transport.handlePostMessage(mockReq2, mockHandleRes2, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes2.writeHead).toHaveBeenCalledWith(403); + expect(mockHandleRes2.end).toHaveBeenCalledWith('Invalid Host header: evil.com'); + + // Both valid + const mockReq3 = createMockRequest({ + headers: { + host: 'localhost:3000', + origin: 'http://localhost:3000', + 'content-type': 'application/json', + } + }); + const mockHandleRes3 = createMockResponse(); + + await transport.handlePostMessage(mockReq3, mockHandleRes3, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes3.writeHead).toHaveBeenCalledWith(202); + expect(mockHandleRes3.end).toHaveBeenCalledWith('Accepted'); + }); + }); + }); +}); diff --git a/src/server/sse.ts b/src/server/sse.ts index 978ce29fa..e07256867 100644 --- a/src/server/sse.ts +++ b/src/server/sse.ts @@ -9,6 +9,29 @@ import { URL } from 'url'; const MAXIMUM_MESSAGE_SIZE = "4mb"; +/** + * Configuration options for SSEServerTransport. + */ +export interface SSEServerTransportOptions { + /** + * List of allowed host header values for DNS rebinding protection. + * If not specified, host validation is disabled. + */ + allowedHosts?: string[]; + + /** + * List of allowed origin header values for DNS rebinding protection. + * If not specified, origin validation is disabled. + */ + allowedOrigins?: string[]; + + /** + * Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured). + * Default is false for backwards compatibility. + */ + enableDnsRebindingProtection?: boolean; +} + /** * Server transport for SSE: this will send messages over an SSE connection and receive messages from HTTP POST requests. * @@ -17,6 +40,7 @@ const MAXIMUM_MESSAGE_SIZE = "4mb"; export class SSEServerTransport implements Transport { private _sseResponse?: ServerResponse; private _sessionId: string; + private _options: SSEServerTransportOptions; onclose?: () => void; onerror?: (error: Error) => void; onmessage?: (message: JSONRPCMessage, extra?: MessageExtraInfo) => void; @@ -27,8 +51,39 @@ export class SSEServerTransport implements Transport { constructor( private _endpoint: string, private res: ServerResponse, + options?: SSEServerTransportOptions, ) { this._sessionId = randomUUID(); + this._options = options || {enableDnsRebindingProtection: false}; + } + + /** + * Validates request headers for DNS rebinding protection. + * @returns Error message if validation fails, undefined if validation passes. + */ + private validateRequestHeaders(req: IncomingMessage): string | undefined { + // Skip validation if protection is not enabled + if (!this._options.enableDnsRebindingProtection) { + return undefined; + } + + // Validate Host header if allowedHosts is configured + if (this._options.allowedHosts && this._options.allowedHosts.length > 0) { + const hostHeader = req.headers.host; + if (!hostHeader || !this._options.allowedHosts.includes(hostHeader)) { + return `Invalid Host header: ${hostHeader}`; + } + } + + // Validate Origin header if allowedOrigins is configured + if (this._options.allowedOrigins && this._options.allowedOrigins.length > 0) { + const originHeader = req.headers.origin; + if (!originHeader || !this._options.allowedOrigins.includes(originHeader)) { + return `Invalid Origin header: ${originHeader}`; + } + } + + return undefined; } /** @@ -85,6 +140,15 @@ export class SSEServerTransport implements Transport { res.writeHead(500).end(message); throw new Error(message); } + + // Validate request headers for DNS rebinding protection + const validationError = this.validateRequestHeaders(req); + if (validationError) { + res.writeHead(403).end(validationError); + this.onerror?.(new Error(validationError)); + return; + } + const authInfo: AuthInfo | undefined = req.auth; const requestInfo: RequestInfo = { headers: req.headers }; diff --git a/src/server/streamableHttp.test.ts b/src/server/streamableHttp.test.ts index ce5c7446a..502435ead 100644 --- a/src/server/streamableHttp.test.ts +++ b/src/server/streamableHttp.test.ts @@ -1,5 +1,5 @@ import { createServer, type Server, IncomingMessage, ServerResponse } from "node:http"; -import { AddressInfo } from "node:net"; +import { createServer as netCreateServer, AddressInfo } from "node:net"; import { randomUUID } from "node:crypto"; import { EventStore, StreamableHTTPServerTransport, EventId, StreamId } from "./streamableHttp.js"; import { McpServer } from "./mcp.js"; @@ -7,6 +7,20 @@ import { CallToolResult, JSONRPCMessage } from "../types.js"; import { z } from "zod"; import { AuthInfo } from "./auth/types.js"; +async function getFreePort() { + return new Promise(res => { + const srv = netCreateServer(); + srv.listen(0, () => { + const address = srv.address()! + if (typeof address === "string") { + throw new Error("Unexpected address type: " + typeof address); + } + const port = (address as AddressInfo).port; + srv.close((_err) => res(port)) + }); + }) +} + /** * Test server configuration for StreamableHTTPServerTransport tests */ @@ -363,7 +377,7 @@ describe("StreamableHTTPServerTransport", () => { return { content: [{ type: "text", text: `Hello, ${name}!` }, { type: "text", text: `${JSON.stringify(requestInfo)}` }] }; } ); - + const toolCallMessage: JSONRPCMessage = { jsonrpc: "2.0", method: "tools/call", @@ -814,7 +828,7 @@ describe("StreamableHTTPServerTransport", () => { // Send request with matching protocol version const response = await sendPostRequest(baseUrl, TEST_MESSAGES.toolsList, sessionId); - + expect(response.status).toBe(200); }); @@ -832,7 +846,7 @@ describe("StreamableHTTPServerTransport", () => { }, body: JSON.stringify(TEST_MESSAGES.toolsList), }); - + expect(response.status).toBe(200); }); @@ -850,7 +864,7 @@ describe("StreamableHTTPServerTransport", () => { }, body: JSON.stringify(TEST_MESSAGES.toolsList), }); - + expect(response.status).toBe(400); const errorData = await response.json(); expectErrorResponse(errorData, -32000, /Bad Request: Unsupported protocol version \(supported versions: .+\)/); @@ -858,13 +872,13 @@ describe("StreamableHTTPServerTransport", () => { it("should accept when protocol version differs from negotiated version", async () => { sessionId = await initializeServer(); - + // Spy on console.warn to verify warning is logged const warnSpy = jest.spyOn(console, 'warn').mockImplementation(); // Send request with different but supported protocol version const response = await fetch(baseUrl, { - method: "POST", + method: "POST", headers: { "Content-Type": "application/json", Accept: "application/json, text/event-stream", @@ -873,10 +887,10 @@ describe("StreamableHTTPServerTransport", () => { }, body: JSON.stringify(TEST_MESSAGES.toolsList), }); - + // Request should still succeed expect(response.status).toBe(200); - + warnSpy.mockRestore(); }); @@ -892,7 +906,7 @@ describe("StreamableHTTPServerTransport", () => { "mcp-protocol-version": "invalid-version", }, }); - + expect(response.status).toBe(400); const errorData = await response.json(); expectErrorResponse(errorData, -32000, /Bad Request: Unsupported protocol version \(supported versions: .+\)/); @@ -909,7 +923,7 @@ describe("StreamableHTTPServerTransport", () => { "mcp-protocol-version": "invalid-version", }, }); - + expect(response.status).toBe(400); const errorData = await response.json(); expectErrorResponse(errorData, -32000, /Bad Request: Unsupported protocol version \(supported versions: .+\)/); @@ -951,12 +965,12 @@ describe("StreamableHTTPServerTransport with AuthInfo", () => { method: "tools/call", params: { name: "profile", - arguments: {active: true}, + arguments: { active: true }, }, id: "call-1", }; - const response = await sendPostRequest(baseUrl, toolCallMessage, sessionId, {'authorization': 'Bearer test-token'}); + const response = await sendPostRequest(baseUrl, toolCallMessage, sessionId, { 'authorization': 'Bearer test-token' }); expect(response.status).toBe(200); const text = await readSSEEvent(response); @@ -978,7 +992,7 @@ describe("StreamableHTTPServerTransport with AuthInfo", () => { id: "call-1", }); }); - + it("should calls tool without authInfo when it is optional", async () => { sessionId = await initializeServer(); @@ -987,7 +1001,7 @@ describe("StreamableHTTPServerTransport with AuthInfo", () => { method: "tools/call", params: { name: "profile", - arguments: {active: false}, + arguments: { active: false }, }, id: "call-1", }; @@ -1471,7 +1485,7 @@ describe("StreamableHTTPServerTransport in stateless mode", () => { // Open first SSE stream const stream1 = await fetch(baseUrl, { method: "GET", - headers: { + headers: { Accept: "text/event-stream", "mcp-protocol-version": "2025-03-26" }, @@ -1481,11 +1495,282 @@ describe("StreamableHTTPServerTransport in stateless mode", () => { // Open second SSE stream - should still be rejected, stateless mode still only allows one const stream2 = await fetch(baseUrl, { method: "GET", - headers: { + headers: { Accept: "text/event-stream", "mcp-protocol-version": "2025-03-26" }, }); expect(stream2.status).toBe(409); // Conflict - only one stream allowed }); -}); \ No newline at end of file +}); + +// Test DNS rebinding protection +describe("StreamableHTTPServerTransport DNS rebinding protection", () => { + let server: Server; + let transport: StreamableHTTPServerTransport; + let baseUrl: URL; + + afterEach(async () => { + if (server && transport) { + await stopTestServer({ server, transport }); + } + }); + + describe("Host header validation", () => { + it("should accept requests with allowed host headers", async () => { + const result = await createTestServerWithDnsProtection({ + sessionIdGenerator: undefined, + allowedHosts: ['localhost'], + enableDnsRebindingProtection: true, + }); + server = result.server; + transport = result.transport; + baseUrl = result.baseUrl; + + // Note: fetch() automatically sets Host header to match the URL + // Since we're connecting to localhost:3001 and that's in allowedHosts, this should work + const response = await fetch(baseUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "application/json, text/event-stream", + }, + body: JSON.stringify(TEST_MESSAGES.initialize), + }); + + expect(response.status).toBe(200); + }); + + it("should reject requests with disallowed host headers", async () => { + // Test DNS rebinding protection by creating a server that only allows example.com + // but we're connecting via localhost, so it should be rejected + const result = await createTestServerWithDnsProtection({ + sessionIdGenerator: undefined, + allowedHosts: ['example.com:3001'], + enableDnsRebindingProtection: true, + }); + server = result.server; + transport = result.transport; + baseUrl = result.baseUrl; + + const response = await fetch(baseUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "application/json, text/event-stream", + }, + body: JSON.stringify(TEST_MESSAGES.initialize), + }); + + expect(response.status).toBe(403); + const body = await response.json(); + expect(body.error.message).toContain("Invalid Host header:"); + }); + + it("should reject GET requests with disallowed host headers", async () => { + const result = await createTestServerWithDnsProtection({ + sessionIdGenerator: undefined, + allowedHosts: ['example.com:3001'], + enableDnsRebindingProtection: true, + }); + server = result.server; + transport = result.transport; + baseUrl = result.baseUrl; + + const response = await fetch(baseUrl, { + method: "GET", + headers: { + Accept: "text/event-stream", + }, + }); + + expect(response.status).toBe(403); + }); + }); + + describe("Origin header validation", () => { + it("should accept requests with allowed origin headers", async () => { + const result = await createTestServerWithDnsProtection({ + sessionIdGenerator: undefined, + allowedOrigins: ['http://localhost:3000', 'https://example.com'], + enableDnsRebindingProtection: true, + }); + server = result.server; + transport = result.transport; + baseUrl = result.baseUrl; + + const response = await fetch(baseUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "application/json, text/event-stream", + Origin: "http://localhost:3000", + }, + body: JSON.stringify(TEST_MESSAGES.initialize), + }); + + expect(response.status).toBe(200); + }); + + it("should reject requests with disallowed origin headers", async () => { + const result = await createTestServerWithDnsProtection({ + sessionIdGenerator: undefined, + allowedOrigins: ['http://localhost:3000'], + enableDnsRebindingProtection: true, + }); + server = result.server; + transport = result.transport; + baseUrl = result.baseUrl; + + const response = await fetch(baseUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "application/json, text/event-stream", + Origin: "http://evil.com", + }, + body: JSON.stringify(TEST_MESSAGES.initialize), + }); + + expect(response.status).toBe(403); + const body = await response.json(); + expect(body.error.message).toBe("Invalid Origin header: http://evil.com"); + }); + }); + + describe("enableDnsRebindingProtection option", () => { + it("should skip all validations when enableDnsRebindingProtection is false", async () => { + const result = await createTestServerWithDnsProtection({ + sessionIdGenerator: undefined, + allowedHosts: ['localhost'], + allowedOrigins: ['http://localhost:3000'], + enableDnsRebindingProtection: false, + }); + server = result.server; + transport = result.transport; + baseUrl = result.baseUrl; + + const response = await fetch(baseUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "application/json, text/event-stream", + Host: "evil.com", + Origin: "http://evil.com", + }, + body: JSON.stringify(TEST_MESSAGES.initialize), + }); + + // Should pass even with invalid headers because protection is disabled + expect(response.status).toBe(200); + }); + }); + + describe("Combined validations", () => { + it("should validate both host and origin when both are configured", async () => { + const result = await createTestServerWithDnsProtection({ + sessionIdGenerator: undefined, + allowedHosts: ['localhost'], + allowedOrigins: ['http://localhost:3001'], + enableDnsRebindingProtection: true, + }); + server = result.server; + transport = result.transport; + baseUrl = result.baseUrl; + + // Test with invalid origin (host will be automatically correct via fetch) + const response1 = await fetch(baseUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "application/json, text/event-stream", + Origin: "http://evil.com", + }, + body: JSON.stringify(TEST_MESSAGES.initialize), + }); + + expect(response1.status).toBe(403); + const body1 = await response1.json(); + expect(body1.error.message).toBe("Invalid Origin header: http://evil.com"); + + // Test with valid origin + const response2 = await fetch(baseUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "application/json, text/event-stream", + Origin: "http://localhost:3001", + }, + body: JSON.stringify(TEST_MESSAGES.initialize), + }); + + expect(response2.status).toBe(200); + }); + }); +}); + +/** + * Helper to create test server with DNS rebinding protection options + */ +async function createTestServerWithDnsProtection(config: { + sessionIdGenerator: (() => string) | undefined; + allowedHosts?: string[]; + allowedOrigins?: string[]; + enableDnsRebindingProtection?: boolean; +}): Promise<{ + server: Server; + transport: StreamableHTTPServerTransport; + mcpServer: McpServer; + baseUrl: URL; +}> { + const mcpServer = new McpServer( + { name: "test-server", version: "1.0.0" }, + { capabilities: { logging: {} } } + ); + + const port = await getFreePort(); + + if (config.allowedHosts) { + config.allowedHosts = config.allowedHosts.map(host => { + if (host.includes(':')) { + return host; + } + return `localhost:${port}`; + }); + } + + const transport = new StreamableHTTPServerTransport({ + sessionIdGenerator: config.sessionIdGenerator, + allowedHosts: config.allowedHosts, + allowedOrigins: config.allowedOrigins, + enableDnsRebindingProtection: config.enableDnsRebindingProtection, + }); + + await mcpServer.connect(transport); + + const httpServer = createServer(async (req, res) => { + if (req.method === "POST") { + let body = ""; + req.on("data", (chunk) => (body += chunk)); + req.on("end", async () => { + const parsedBody = JSON.parse(body); + await transport.handleRequest(req as IncomingMessage & { auth?: AuthInfo }, res, parsedBody); + }); + } else { + await transport.handleRequest(req as IncomingMessage & { auth?: AuthInfo }, res); + } + }); + + await new Promise((resolve) => { + httpServer.listen(port, () => resolve()); + }); + + const serverUrl = new URL(`http://localhost:${port}/`); + + return { + server: httpServer, + transport, + mcpServer, + baseUrl: serverUrl, + }; +} \ No newline at end of file diff --git a/src/server/streamableHttp.ts b/src/server/streamableHttp.ts index 677da45ea..022d1a474 100644 --- a/src/server/streamableHttp.ts +++ b/src/server/streamableHttp.ts @@ -61,6 +61,24 @@ export interface StreamableHTTPServerTransportOptions { * If provided, resumability will be enabled, allowing clients to reconnect and resume messages */ eventStore?: EventStore; + + /** + * List of allowed host header values for DNS rebinding protection. + * If not specified, host validation is disabled. + */ + allowedHosts?: string[]; + + /** + * List of allowed origin header values for DNS rebinding protection. + * If not specified, origin validation is disabled. + */ + allowedOrigins?: string[]; + + /** + * Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured). + * Default is false for backwards compatibility. + */ + enableDnsRebindingProtection?: boolean; } /** @@ -109,6 +127,9 @@ export class StreamableHTTPServerTransport implements Transport { private _standaloneSseStreamId: string = '_GET_stream'; private _eventStore?: EventStore; private _onsessioninitialized?: (sessionId: string) => void; + private _allowedHosts?: string[]; + private _allowedOrigins?: string[]; + private _enableDnsRebindingProtection: boolean; sessionId?: string; onclose?: () => void; @@ -120,6 +141,9 @@ export class StreamableHTTPServerTransport implements Transport { this._enableJsonResponse = options.enableJsonResponse ?? false; this._eventStore = options.eventStore; this._onsessioninitialized = options.onsessioninitialized; + this._allowedHosts = options.allowedHosts; + this._allowedOrigins = options.allowedOrigins; + this._enableDnsRebindingProtection = options.enableDnsRebindingProtection ?? false; } /** @@ -133,10 +157,54 @@ export class StreamableHTTPServerTransport implements Transport { this._started = true; } + /** + * Validates request headers for DNS rebinding protection. + * @returns Error message if validation fails, undefined if validation passes. + */ + private validateRequestHeaders(req: IncomingMessage): string | undefined { + // Skip validation if protection is not enabled + if (!this._enableDnsRebindingProtection) { + return undefined; + } + + // Validate Host header if allowedHosts is configured + if (this._allowedHosts && this._allowedHosts.length > 0) { + const hostHeader = req.headers.host; + if (!hostHeader || !this._allowedHosts.includes(hostHeader)) { + return `Invalid Host header: ${hostHeader}`; + } + } + + // Validate Origin header if allowedOrigins is configured + if (this._allowedOrigins && this._allowedOrigins.length > 0) { + const originHeader = req.headers.origin; + if (!originHeader || !this._allowedOrigins.includes(originHeader)) { + return `Invalid Origin header: ${originHeader}`; + } + } + + return undefined; + } + /** * Handles an incoming HTTP request, whether GET or POST */ async handleRequest(req: IncomingMessage & { auth?: AuthInfo }, res: ServerResponse, parsedBody?: unknown): Promise { + // Validate request headers for DNS rebinding protection + const validationError = this.validateRequestHeaders(req); + if (validationError) { + res.writeHead(403).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32000, + message: validationError + }, + id: null + })); + this.onerror?.(new Error(validationError)); + return; + } + if (req.method === "POST") { await this.handlePostRequest(req, res, parsedBody); } else if (req.method === "GET") {