diff --git a/src/client/auth.test.ts b/src/client/auth.test.ts index 1b9fb071..91422de0 100644 --- a/src/client/auth.test.ts +++ b/src/client/auth.test.ts @@ -324,6 +324,7 @@ describe("OAuth Authorization", () => { metadata: undefined, clientInformation: validClientInfo, redirectUrl: "http://localhost:3000/callback", + resource: new URL("https://api.example.com/mcp-server"), } ); @@ -338,6 +339,7 @@ describe("OAuth Authorization", () => { expect(authorizationUrl.searchParams.get("redirect_uri")).toBe( "http://localhost:3000/callback" ); + expect(authorizationUrl.searchParams.get("resource")).toBe("https://api.example.com/mcp-server"); expect(codeVerifier).toBe("test_verifier"); }); @@ -465,6 +467,7 @@ describe("OAuth Authorization", () => { authorizationCode: "code123", codeVerifier: "verifier123", redirectUri: "http://localhost:3000/callback", + resource: new URL("https://api.example.com/mcp-server"), }); expect(tokens).toEqual(validTokens); @@ -487,6 +490,7 @@ describe("OAuth Authorization", () => { expect(body.get("client_id")).toBe("client123"); expect(body.get("client_secret")).toBe("secret123"); expect(body.get("redirect_uri")).toBe("http://localhost:3000/callback"); + expect(body.get("resource")).toBe("https://api.example.com/mcp-server"); }); it("validates token response schema", async () => { @@ -554,6 +558,7 @@ describe("OAuth Authorization", () => { const tokens = await refreshAuthorization("https://auth.example.com", { clientInformation: validClientInfo, refreshToken: "refresh123", + resource: new URL("https://api.example.com/mcp-server"), }); expect(tokens).toEqual(validTokensWithNewRefreshToken); @@ -574,6 +579,7 @@ describe("OAuth Authorization", () => { expect(body.get("refresh_token")).toBe("refresh123"); expect(body.get("client_id")).toBe("client123"); expect(body.get("client_secret")).toBe("secret123"); + expect(body.get("resource")).toBe("https://api.example.com/mcp-server"); }); it("exchanges refresh token for new tokens and keep existing refresh token if none is returned", async () => { @@ -807,5 +813,236 @@ describe("OAuth Authorization", () => { "https://resource.example.com/.well-known/oauth-authorization-server" ); }); + + it("passes resource parameter through authorization flow", async () => { + // Mock successful metadata discovery + mockFetch.mockImplementation((url) => { + const urlString = url.toString(); + if (urlString.includes("/.well-known/oauth-authorization-server")) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + issuer: "https://auth.example.com", + authorization_endpoint: "https://auth.example.com/authorize", + token_endpoint: "https://auth.example.com/token", + response_types_supported: ["code"], + code_challenge_methods_supported: ["S256"], + }), + }); + } + return Promise.resolve({ ok: false, status: 404 }); + }); + + // Mock provider methods for authorization flow + (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ + client_id: "test-client", + client_secret: "test-secret", + }); + (mockProvider.tokens as jest.Mock).mockResolvedValue(undefined); + (mockProvider.saveCodeVerifier as jest.Mock).mockResolvedValue(undefined); + (mockProvider.redirectToAuthorization as jest.Mock).mockResolvedValue(undefined); + + // Call auth without authorization code (should trigger redirect) + const result = await auth(mockProvider, { + serverUrl: "https://api.example.com/mcp-server", + }); + + expect(result).toBe("REDIRECT"); + + // Verify the authorization URL includes the resource parameter + expect(mockProvider.redirectToAuthorization).toHaveBeenCalledWith( + expect.objectContaining({ + searchParams: expect.any(URLSearchParams), + }) + ); + + const redirectCall = (mockProvider.redirectToAuthorization as jest.Mock).mock.calls[0]; + const authUrl: URL = redirectCall[0]; + expect(authUrl.searchParams.get("resource")).toBe("https://api.example.com/mcp-server"); + }); + + it("includes resource in token exchange when authorization code is provided", async () => { + // Mock successful metadata discovery and token exchange + mockFetch.mockImplementation((url) => { + const urlString = url.toString(); + + if (urlString.includes("/.well-known/oauth-authorization-server")) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + issuer: "https://auth.example.com", + authorization_endpoint: "https://auth.example.com/authorize", + token_endpoint: "https://auth.example.com/token", + response_types_supported: ["code"], + code_challenge_methods_supported: ["S256"], + }), + }); + } else if (urlString.includes("/token")) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + access_token: "access123", + token_type: "Bearer", + expires_in: 3600, + refresh_token: "refresh123", + }), + }); + } + + return Promise.resolve({ ok: false, status: 404 }); + }); + + // Mock provider methods for token exchange + (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ + client_id: "test-client", + client_secret: "test-secret", + }); + (mockProvider.codeVerifier as jest.Mock).mockResolvedValue("test-verifier"); + (mockProvider.saveTokens as jest.Mock).mockResolvedValue(undefined); + + // Call auth with authorization code + const result = await auth(mockProvider, { + serverUrl: "https://api.example.com/mcp-server", + authorizationCode: "auth-code-123", + }); + + expect(result).toBe("AUTHORIZED"); + + // Find the token exchange call + const tokenCall = mockFetch.mock.calls.find(call => + call[0].toString().includes("/token") + ); + expect(tokenCall).toBeDefined(); + + const body = tokenCall![1].body as URLSearchParams; + expect(body.get("resource")).toBe("https://api.example.com/mcp-server"); + expect(body.get("code")).toBe("auth-code-123"); + }); + + it("includes resource in token refresh", async () => { + // Mock successful metadata discovery and token refresh + mockFetch.mockImplementation((url) => { + const urlString = url.toString(); + + if (urlString.includes("/.well-known/oauth-authorization-server")) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + issuer: "https://auth.example.com", + authorization_endpoint: "https://auth.example.com/authorize", + token_endpoint: "https://auth.example.com/token", + response_types_supported: ["code"], + code_challenge_methods_supported: ["S256"], + }), + }); + } else if (urlString.includes("/token")) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + access_token: "new-access123", + token_type: "Bearer", + expires_in: 3600, + }), + }); + } + + return Promise.resolve({ ok: false, status: 404 }); + }); + + // Mock provider methods for token refresh + (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ + client_id: "test-client", + client_secret: "test-secret", + }); + (mockProvider.tokens as jest.Mock).mockResolvedValue({ + access_token: "old-access", + refresh_token: "refresh123", + }); + (mockProvider.saveTokens as jest.Mock).mockResolvedValue(undefined); + + // Call auth with existing tokens (should trigger refresh) + const result = await auth(mockProvider, { + serverUrl: "https://api.example.com/mcp-server", + }); + + expect(result).toBe("AUTHORIZED"); + + // Find the token refresh call + const tokenCall = mockFetch.mock.calls.find(call => + call[0].toString().includes("/token") + ); + expect(tokenCall).toBeDefined(); + + const body = tokenCall![1].body as URLSearchParams; + expect(body.get("resource")).toBe("https://api.example.com/mcp-server"); + expect(body.get("grant_type")).toBe("refresh_token"); + expect(body.get("refresh_token")).toBe("refresh123"); + }); + + it("skips default PRM resource validation when custom validateResourceURL is provided", async () => { + const mockValidateResourceURL = jest.fn().mockResolvedValue(undefined); + const providerWithCustomValidation = { + ...mockProvider, + validateResourceURL: mockValidateResourceURL, + }; + + // Mock protected resource metadata with mismatched resource URL + // This would normally throw an error in default validation, but should be skipped + mockFetch.mockImplementation((url) => { + const urlString = url.toString(); + + if (urlString.includes("/.well-known/oauth-protected-resource")) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + resource: "https://different-resource.example.com/mcp-server", // Mismatched resource + authorization_servers: ["https://auth.example.com"], + }), + }); + } else if (urlString.includes("/.well-known/oauth-authorization-server")) { + return Promise.resolve({ + ok: true, + status: 200, + json: async () => ({ + issuer: "https://auth.example.com", + authorization_endpoint: "https://auth.example.com/authorize", + token_endpoint: "https://auth.example.com/token", + response_types_supported: ["code"], + code_challenge_methods_supported: ["S256"], + }), + }); + } + + return Promise.resolve({ ok: false, status: 404 }); + }); + + // Mock provider methods + (providerWithCustomValidation.clientInformation as jest.Mock).mockResolvedValue({ + client_id: "test-client", + client_secret: "test-secret", + }); + (providerWithCustomValidation.tokens as jest.Mock).mockResolvedValue(undefined); + (providerWithCustomValidation.saveCodeVerifier as jest.Mock).mockResolvedValue(undefined); + (providerWithCustomValidation.redirectToAuthorization as jest.Mock).mockResolvedValue(undefined); + + // Call auth - should succeed despite resource mismatch because custom validation overrides default + const result = await auth(providerWithCustomValidation, { + serverUrl: "https://api.example.com/mcp-server", + }); + + expect(result).toBe("REDIRECT"); + + // Verify custom validation method was called + expect(mockValidateResourceURL).toHaveBeenCalledWith( + "https://api.example.com/mcp-server", + "https://different-resource.example.com/mcp-server" + ); + }); }); }); diff --git a/src/client/auth.ts b/src/client/auth.ts index 7a91eb25..28d9d833 100644 --- a/src/client/auth.ts +++ b/src/client/auth.ts @@ -2,6 +2,7 @@ import pkceChallenge from "pkce-challenge"; import { LATEST_PROTOCOL_VERSION } from "../types.js"; import type { OAuthClientMetadata, OAuthClientInformation, OAuthTokens, OAuthMetadata, OAuthClientInformationFull, OAuthProtectedResourceMetadata } from "../shared/auth.js"; import { OAuthClientInformationFullSchema, OAuthMetadataSchema, OAuthProtectedResourceMetadataSchema, OAuthTokensSchema } from "../shared/auth.js"; +import { resourceUrlFromServerUrl } from "../shared/auth-utils.js"; /** * Implements an end-to-end OAuth client to be used with one MCP server. @@ -71,6 +72,15 @@ export interface OAuthClientProvider { * the authorization result. */ codeVerifier(): string | Promise; + + /** + * If defined, overrides the selection and validation of the + * RFC 8707 Resource Indicator. If left undefined, default + * validation behavior will be used. + * + * Implementations must verify the returned resource matches the MCP server. + */ + validateResourceURL?(serverUrl: string | URL, resource?: string): Promise; } export type AuthResult = "AUTHORIZED" | "REDIRECT"; @@ -99,11 +109,10 @@ export async function auth( scope?: string; resourceMetadataUrl?: URL }): Promise { + let resourceMetadata: OAuthProtectedResourceMetadata | undefined; let authorizationServerUrl = serverUrl; try { - const resourceMetadata = await discoverOAuthProtectedResourceMetadata( - resourceMetadataUrl || serverUrl); - + resourceMetadata = await discoverOAuthProtectedResourceMetadata(serverUrl, {resourceMetadataUrl}); if (resourceMetadata.authorization_servers && resourceMetadata.authorization_servers.length > 0) { authorizationServerUrl = resourceMetadata.authorization_servers[0]; } @@ -111,6 +120,8 @@ export async function auth( console.warn("Could not load OAuth Protected Resource metadata, falling back to /.well-known/oauth-authorization-server", error) } + const resource: URL | undefined = await selectResourceURL(serverUrl, provider, resourceMetadata); + const metadata = await discoverOAuthMetadata(authorizationServerUrl); // Handle client registration if needed @@ -142,6 +153,7 @@ export async function auth( authorizationCode, codeVerifier, redirectUri: provider.redirectUrl, + resource, }); await provider.saveTokens(tokens); @@ -158,6 +170,7 @@ export async function auth( metadata, clientInformation, refreshToken: tokens.refresh_token, + resource, }); await provider.saveTokens(newTokens); @@ -176,6 +189,7 @@ export async function auth( state, redirectUrl: provider.redirectUrl, scope: scope || provider.clientMetadata.scope, + resource, }); await provider.saveCodeVerifier(codeVerifier); @@ -183,6 +197,19 @@ export async function auth( return "REDIRECT"; } +async function selectResourceURL(serverUrl: string| URL, provider: OAuthClientProvider, resourceMetadata?: OAuthProtectedResourceMetadata): Promise { + if (provider.validateResourceURL) { + return await provider.validateResourceURL(serverUrl, resourceMetadata?.resource); + } + + const resource = resourceUrlFromServerUrl(typeof serverUrl === "string" ? new URL(serverUrl) : serverUrl); + if (resourceMetadata && resourceMetadata.resource !== resource.href) { + throw new Error(`Protected resource ${resourceMetadata.resource} does not match expected ${resource}`); + } + + return resource; +} + /** * Extract resource_metadata from response header. */ @@ -310,12 +337,14 @@ export async function startAuthorization( redirectUrl, scope, state, + resource, }: { metadata?: OAuthMetadata; clientInformation: OAuthClientInformation; redirectUrl: string | URL; scope?: string; state?: string; + resource?: URL; }, ): Promise<{ authorizationUrl: URL; codeVerifier: string }> { const responseType = "code"; @@ -365,6 +394,10 @@ export async function startAuthorization( authorizationUrl.searchParams.set("scope", scope); } + if (resource) { + authorizationUrl.searchParams.set("resource", resource.href); + } + return { authorizationUrl, codeVerifier }; } @@ -379,12 +412,14 @@ export async function exchangeAuthorization( authorizationCode, codeVerifier, redirectUri, + resource, }: { metadata?: OAuthMetadata; clientInformation: OAuthClientInformation; authorizationCode: string; codeVerifier: string; redirectUri: string | URL; + resource?: URL; }, ): Promise { const grantType = "authorization_code"; @@ -418,6 +453,10 @@ export async function exchangeAuthorization( params.set("client_secret", clientInformation.client_secret); } + if (resource) { + params.set("resource", resource.href); + } + const response = await fetch(tokenUrl, { method: "POST", headers: { @@ -442,10 +481,12 @@ export async function refreshAuthorization( metadata, clientInformation, refreshToken, + resource, }: { metadata?: OAuthMetadata; clientInformation: OAuthClientInformation; refreshToken: string; + resource?: URL; }, ): Promise { const grantType = "refresh_token"; @@ -477,6 +518,10 @@ export async function refreshAuthorization( params.set("client_secret", clientInformation.client_secret); } + if (resource) { + params.set("resource", resource.href); + } + const response = await fetch(tokenUrl, { method: "POST", headers: { diff --git a/src/client/sse.test.ts b/src/client/sse.test.ts index 714e1fdd..3cb4e8a3 100644 --- a/src/client/sse.test.ts +++ b/src/client/sse.test.ts @@ -398,7 +398,7 @@ describe("SSEClientTransport", () => { 'Content-Type': 'application/json', }) .end(JSON.stringify({ - resource: "https://resource.example.com", + resource: resourceBaseUrl.href, authorization_servers: [`${authBaseUrl}`], })); return; @@ -450,7 +450,7 @@ describe("SSEClientTransport", () => { 'Content-Type': 'application/json', }) .end(JSON.stringify({ - resource: "https://resource.example.com", + resource: resourceBaseUrl.href, authorization_servers: [`${authBaseUrl}`], })); return; @@ -601,7 +601,7 @@ describe("SSEClientTransport", () => { 'Content-Type': 'application/json', }) .end(JSON.stringify({ - resource: "https://resource.example.com", + resource: resourceBaseUrl.href, authorization_servers: [`${authBaseUrl}`], })); return; @@ -723,7 +723,7 @@ describe("SSEClientTransport", () => { 'Content-Type': 'application/json', }) .end(JSON.stringify({ - resource: "https://resource.example.com", + resource: resourceBaseUrl.href, authorization_servers: [`${authBaseUrl}`], })); return; @@ -851,7 +851,7 @@ describe("SSEClientTransport", () => { 'Content-Type': 'application/json', }) .end(JSON.stringify({ - resource: "https://resource.example.com", + resource: resourceBaseUrl.href, authorization_servers: [`${authBaseUrl}`], })); return; diff --git a/src/examples/README.md b/src/examples/README.md index 68e1ece2..ac92e8de 100644 --- a/src/examples/README.md +++ b/src/examples/README.md @@ -76,6 +76,9 @@ npx tsx src/examples/server/simpleStreamableHttp.ts # To add a demo of authentication to this example, use: npx tsx src/examples/server/simpleStreamableHttp.ts --oauth + +# To mitigate impersonation risks, enable strict Resource Identifier verification: +npx tsx src/examples/server/simpleStreamableHttp.ts --oauth --oauth-strict ``` ##### JSON Response Mode Server diff --git a/src/examples/server/demoInMemoryOAuthProvider.ts b/src/examples/server/demoInMemoryOAuthProvider.ts index 024208d6..fe8d3f9c 100644 --- a/src/examples/server/demoInMemoryOAuthProvider.ts +++ b/src/examples/server/demoInMemoryOAuthProvider.ts @@ -1,10 +1,11 @@ import { randomUUID } from 'node:crypto'; import { AuthorizationParams, OAuthServerProvider } from '../../server/auth/provider.js'; import { OAuthRegisteredClientsStore } from '../../server/auth/clients.js'; -import { OAuthClientInformationFull, OAuthMetadata, OAuthTokens } from 'src/shared/auth.js'; +import { OAuthClientInformationFull, OAuthMetadata, OAuthTokens } from '../../shared/auth.js'; import express, { Request, Response } from "express"; -import { AuthInfo } from 'src/server/auth/types.js'; -import { createOAuthMetadata, mcpAuthRouter } from 'src/server/auth/router.js'; +import { AuthInfo } from '../../server/auth/types.js'; +import { createOAuthMetadata, mcpAuthRouter } from '../../server/auth/router.js'; +import { resourceUrlFromServerUrl } from '../../shared/auth-utils.js'; export class DemoInMemoryClientsStore implements OAuthRegisteredClientsStore { @@ -34,6 +35,17 @@ export class DemoInMemoryAuthProvider implements OAuthServerProvider { params: AuthorizationParams, client: OAuthClientInformationFull}>(); private tokens = new Map(); + private validateResource?: (resource?: URL) => boolean; + + constructor({mcpServerUrl}: {mcpServerUrl?: URL} = {}) { + if (mcpServerUrl) { + const expectedResource = resourceUrlFromServerUrl(mcpServerUrl); + this.validateResource = (resource?: URL) => { + if (!resource) return false; + return resource.toString() === expectedResource.toString(); + }; + } + } async authorize( client: OAuthClientInformationFull, @@ -89,6 +101,10 @@ export class DemoInMemoryAuthProvider implements OAuthServerProvider { throw new Error(`Authorization code was not issued to this client, ${codeData.client.client_id} != ${client.client_id}`); } + if (this.validateResource && !this.validateResource(codeData.params.resource)) { + throw new Error(`Invalid resource: ${codeData.params.resource}`); + } + this.codes.delete(authorizationCode); const token = randomUUID(); @@ -97,7 +113,8 @@ export class DemoInMemoryAuthProvider implements OAuthServerProvider { clientId: client.client_id, scopes: codeData.params.scopes || [], expiresAt: Date.now() + 3600000, // 1 hour - type: 'access' + resource: codeData.params.resource, + type: 'access', }; this.tokens.set(token, tokenData); @@ -113,7 +130,8 @@ export class DemoInMemoryAuthProvider implements OAuthServerProvider { async exchangeRefreshToken( _client: OAuthClientInformationFull, _refreshToken: string, - _scopes?: string[] + _scopes?: string[], + _resource?: URL ): Promise { throw new Error('Not implemented for example demo'); } @@ -129,18 +147,19 @@ export class DemoInMemoryAuthProvider implements OAuthServerProvider { clientId: tokenData.clientId, scopes: tokenData.scopes, expiresAt: Math.floor(tokenData.expiresAt / 1000), + resource: tokenData.resource, }; } } -export const setupAuthServer = (authServerUrl: URL): OAuthMetadata => { +export const setupAuthServer = (authServerUrl: URL, mcpServerUrl: URL): OAuthMetadata => { // Create separate auth server app // NOTE: This is a separate app on a separate port to illustrate // how to separate an OAuth Authorization Server from a Resource // server in the SDK. The SDK is not intended to be provide a standalone // authorization server. - const provider = new DemoInMemoryAuthProvider(); + const provider = new DemoInMemoryAuthProvider({mcpServerUrl}); const authApp = express(); authApp.use(express.json()); // For introspection requests @@ -168,7 +187,8 @@ export const setupAuthServer = (authServerUrl: URL): OAuthMetadata => { active: true, client_id: tokenInfo.clientId, scope: tokenInfo.scopes.join(' '), - exp: tokenInfo.expiresAt + exp: tokenInfo.expiresAt, + aud: tokenInfo.resource, }); return } catch (error) { diff --git a/src/examples/server/simpleStreamableHttp.ts b/src/examples/server/simpleStreamableHttp.ts index ebe31920..9eb87d92 100644 --- a/src/examples/server/simpleStreamableHttp.ts +++ b/src/examples/server/simpleStreamableHttp.ts @@ -12,6 +12,7 @@ import { OAuthMetadata } from 'src/shared/auth.js'; // Check for OAuth flag const useOAuth = process.argv.includes('--oauth'); +const strictOAuth = process.argv.includes('--oauth-strict'); // Create an MCP server with implementation details const getServer = () => { @@ -279,10 +280,10 @@ app.use(express.json()); let authMiddleware = null; if (useOAuth) { // Create auth middleware for MCP endpoints - const mcpServerUrl = new URL(`http://localhost:${MCP_PORT}`); + const mcpServerUrl = new URL(`http://localhost:${MCP_PORT}/mcp`); const authServerUrl = new URL(`http://localhost:${AUTH_PORT}`); - const oauthMetadata: OAuthMetadata = setupAuthServer(authServerUrl); + const oauthMetadata: OAuthMetadata = setupAuthServer(authServerUrl, mcpServerUrl); const tokenVerifier = { verifyAccessToken: async (token: string) => { @@ -308,6 +309,15 @@ if (useOAuth) { } const data = await response.json(); + + if (strictOAuth) { + if (!data.aud) { + throw new Error(`Resource Indicator (RFC8707) missing`); + } + if (data.aud !== mcpServerUrl.href) { + throw new Error(`Expected resource indicator ${mcpServerUrl}, got: ${data.aud}`); + } + } // Convert the response to AuthInfo format return { diff --git a/src/server/auth/handlers/authorize.test.ts b/src/server/auth/handlers/authorize.test.ts index e921d5ea..438db6a6 100644 --- a/src/server/auth/handlers/authorize.test.ts +++ b/src/server/auth/handlers/authorize.test.ts @@ -276,6 +276,34 @@ describe('Authorization Handler', () => { }); }); + describe('Resource parameter validation', () => { + it('propagates resource parameter', async () => { + const mockProviderWithResource = jest.spyOn(mockProvider, 'authorize'); + + const response = await supertest(app) + .get('/authorize') + .query({ + client_id: 'valid-client', + redirect_uri: 'https://example.com/callback', + response_type: 'code', + code_challenge: 'challenge123', + code_challenge_method: 'S256', + resource: 'https://api.example.com/resource' + }); + + expect(response.status).toBe(302); + expect(mockProviderWithResource).toHaveBeenCalledWith( + validClient, + expect.objectContaining({ + resource: new URL('https://api.example.com/resource'), + redirectUri: 'https://example.com/callback', + codeChallenge: 'challenge123' + }), + expect.any(Object) + ); + }); + }); + describe('Successful authorization', () => { it('handles successful authorization with all parameters', async () => { const response = await supertest(app) diff --git a/src/server/auth/handlers/authorize.ts b/src/server/auth/handlers/authorize.ts index 3e9a336b..0a6283a8 100644 --- a/src/server/auth/handlers/authorize.ts +++ b/src/server/auth/handlers/authorize.ts @@ -35,6 +35,7 @@ const RequestAuthorizationParamsSchema = z.object({ code_challenge_method: z.literal("S256"), scope: z.string().optional(), state: z.string().optional(), + resource: z.string().url().optional(), }); export function authorizationHandler({ provider, rateLimit: rateLimitConfig }: AuthorizationHandlerOptions): RequestHandler { @@ -115,7 +116,7 @@ export function authorizationHandler({ provider, rateLimit: rateLimitConfig }: A throw new InvalidRequestError(parseResult.error.message); } - const { scope, code_challenge } = parseResult.data; + const { scope, code_challenge, resource } = parseResult.data; state = parseResult.data.state; // Validate scopes @@ -138,6 +139,7 @@ export function authorizationHandler({ provider, rateLimit: rateLimitConfig }: A scopes: requestedScopes, redirectUri: redirect_uri, codeChallenge: code_challenge, + resource: resource ? new URL(resource) : undefined, }, res); } catch (error) { // Post-redirect errors - redirect with error parameters diff --git a/src/server/auth/handlers/token.test.ts b/src/server/auth/handlers/token.test.ts index c165fe7f..4b7fae02 100644 --- a/src/server/auth/handlers/token.test.ts +++ b/src/server/auth/handlers/token.test.ts @@ -264,12 +264,14 @@ describe('Token Handler', () => { }); it('returns tokens for valid code exchange', async () => { + const mockExchangeCode = jest.spyOn(mockProvider, 'exchangeAuthorizationCode'); const response = await supertest(app) .post('/token') .type('form') .send({ client_id: 'valid-client', client_secret: 'valid-secret', + resource: 'https://api.example.com/resource', grant_type: 'authorization_code', code: 'valid_code', code_verifier: 'valid_verifier' @@ -280,6 +282,13 @@ describe('Token Handler', () => { expect(response.body.token_type).toBe('bearer'); expect(response.body.expires_in).toBe(3600); expect(response.body.refresh_token).toBe('mock_refresh_token'); + expect(mockExchangeCode).toHaveBeenCalledWith( + validClient, + 'valid_code', + undefined, // code_verifier is undefined after PKCE validation + undefined, // redirect_uri + new URL('https://api.example.com/resource') // resource parameter + ); }); it('passes through code verifier when using proxy provider', async () => { @@ -440,12 +449,14 @@ describe('Token Handler', () => { }); it('returns new tokens for valid refresh token', async () => { + const mockExchangeRefresh = jest.spyOn(mockProvider, 'exchangeRefreshToken'); const response = await supertest(app) .post('/token') .type('form') .send({ client_id: 'valid-client', client_secret: 'valid-secret', + resource: 'https://api.example.com/resource', grant_type: 'refresh_token', refresh_token: 'valid_refresh_token' }); @@ -455,6 +466,12 @@ describe('Token Handler', () => { expect(response.body.token_type).toBe('bearer'); expect(response.body.expires_in).toBe(3600); expect(response.body.refresh_token).toBe('new_mock_refresh_token'); + expect(mockExchangeRefresh).toHaveBeenCalledWith( + validClient, + 'valid_refresh_token', + undefined, // scopes + new URL('https://api.example.com/resource') // resource parameter + ); }); it('respects requested scopes on refresh', async () => { diff --git a/src/server/auth/handlers/token.ts b/src/server/auth/handlers/token.ts index eadbd751..1d97805b 100644 --- a/src/server/auth/handlers/token.ts +++ b/src/server/auth/handlers/token.ts @@ -32,11 +32,13 @@ const AuthorizationCodeGrantSchema = z.object({ code: z.string(), code_verifier: z.string(), redirect_uri: z.string().optional(), + resource: z.string().url().optional(), }); const RefreshTokenGrantSchema = z.object({ refresh_token: z.string(), scope: z.string().optional(), + resource: z.string().url().optional(), }); export function tokenHandler({ provider, rateLimit: rateLimitConfig }: TokenHandlerOptions): RequestHandler { @@ -89,7 +91,7 @@ export function tokenHandler({ provider, rateLimit: rateLimitConfig }: TokenHand throw new InvalidRequestError(parseResult.error.message); } - const { code, code_verifier, redirect_uri } = parseResult.data; + const { code, code_verifier, redirect_uri, resource } = parseResult.data; const skipLocalPkceValidation = provider.skipLocalPkceValidation; @@ -107,7 +109,8 @@ export function tokenHandler({ provider, rateLimit: rateLimitConfig }: TokenHand client, code, skipLocalPkceValidation ? code_verifier : undefined, - redirect_uri + redirect_uri, + resource ? new URL(resource) : undefined ); res.status(200).json(tokens); break; @@ -119,10 +122,10 @@ export function tokenHandler({ provider, rateLimit: rateLimitConfig }: TokenHand throw new InvalidRequestError(parseResult.error.message); } - const { refresh_token, scope } = parseResult.data; + const { refresh_token, scope, resource } = parseResult.data; const scopes = scope?.split(" "); - const tokens = await provider.exchangeRefreshToken(client, refresh_token, scopes); + const tokens = await provider.exchangeRefreshToken(client, refresh_token, scopes, resource ? new URL(resource) : undefined); res.status(200).json(tokens); break; } diff --git a/src/server/auth/provider.ts b/src/server/auth/provider.ts index 7815b713..18beb216 100644 --- a/src/server/auth/provider.ts +++ b/src/server/auth/provider.ts @@ -8,6 +8,7 @@ export type AuthorizationParams = { scopes?: string[]; codeChallenge: string; redirectUri: string; + resource?: URL; }; /** @@ -40,13 +41,14 @@ export interface OAuthServerProvider { client: OAuthClientInformationFull, authorizationCode: string, codeVerifier?: string, - redirectUri?: string + redirectUri?: string, + resource?: URL ): Promise; /** * Exchanges a refresh token for an access token. */ - exchangeRefreshToken(client: OAuthClientInformationFull, refreshToken: string, scopes?: string[]): Promise; + exchangeRefreshToken(client: OAuthClientInformationFull, refreshToken: string, scopes?: string[], resource?: URL): Promise; /** * Verifies an access token and returns information about it. diff --git a/src/server/auth/providers/proxyProvider.test.ts b/src/server/auth/providers/proxyProvider.test.ts index 69039c3e..4e98d0dc 100644 --- a/src/server/auth/providers/proxyProvider.test.ts +++ b/src/server/auth/providers/proxyProvider.test.ts @@ -88,6 +88,7 @@ describe("Proxy OAuth Server Provider", () => { codeChallenge: "test-challenge", state: "test-state", scopes: ["read", "write"], + resource: new URL('https://api.example.com/resource'), }, mockResponse ); @@ -100,6 +101,7 @@ describe("Proxy OAuth Server Provider", () => { expectedUrl.searchParams.set("code_challenge_method", "S256"); expectedUrl.searchParams.set("state", "test-state"); expectedUrl.searchParams.set("scope", "read write"); + expectedUrl.searchParams.set('resource', 'https://api.example.com/resource'); expect(mockResponse.redirect).toHaveBeenCalledWith(expectedUrl.toString()); }); @@ -164,6 +166,41 @@ describe("Proxy OAuth Server Provider", () => { expect(tokens).toEqual(mockTokenResponse); }); + it('includes resource parameter in authorization code exchange', async () => { + const tokens = await provider.exchangeAuthorizationCode( + validClient, + 'test-code', + 'test-verifier', + 'https://example.com/callback', + new URL('https://api.example.com/resource') + ); + + expect(global.fetch).toHaveBeenCalledWith( + 'https://auth.example.com/token', + expect.objectContaining({ + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded', + }, + body: expect.stringContaining('resource=' + encodeURIComponent('https://api.example.com/resource')) + }) + ); + expect(tokens).toEqual(mockTokenResponse); + }); + + it('handles authorization code exchange without resource parameter', async () => { + const tokens = await provider.exchangeAuthorizationCode( + validClient, + 'test-code', + 'test-verifier' + ); + + const fetchCall = (global.fetch as jest.Mock).mock.calls[0]; + const body = fetchCall[1].body as string; + expect(body).not.toContain('resource='); + expect(tokens).toEqual(mockTokenResponse); + }); + it("exchanges refresh token for new tokens", async () => { const tokens = await provider.exchangeRefreshToken( validClient, @@ -184,6 +221,26 @@ describe("Proxy OAuth Server Provider", () => { expect(tokens).toEqual(mockTokenResponse); }); + it('includes resource parameter in refresh token exchange', async () => { + const tokens = await provider.exchangeRefreshToken( + validClient, + 'test-refresh-token', + ['read', 'write'], + new URL('https://api.example.com/resource') + ); + + expect(global.fetch).toHaveBeenCalledWith( + 'https://auth.example.com/token', + expect.objectContaining({ + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded', + }, + body: expect.stringContaining('resource=' + encodeURIComponent('https://api.example.com/resource')) + }) + ); + expect(tokens).toEqual(mockTokenResponse); + }); }); describe("client registration", () => { diff --git a/src/server/auth/providers/proxyProvider.ts b/src/server/auth/providers/proxyProvider.ts index db7460e5..de74862b 100644 --- a/src/server/auth/providers/proxyProvider.ts +++ b/src/server/auth/providers/proxyProvider.ts @@ -134,6 +134,7 @@ export class ProxyOAuthServerProvider implements OAuthServerProvider { // Add optional standard OAuth parameters if (params.state) searchParams.set("state", params.state); if (params.scopes?.length) searchParams.set("scope", params.scopes.join(" ")); + if (params.resource) searchParams.set("resource", params.resource.href); targetUrl.search = searchParams.toString(); res.redirect(targetUrl.toString()); @@ -152,7 +153,8 @@ export class ProxyOAuthServerProvider implements OAuthServerProvider { client: OAuthClientInformationFull, authorizationCode: string, codeVerifier?: string, - redirectUri?: string + redirectUri?: string, + resource?: URL ): Promise { const params = new URLSearchParams({ grant_type: "authorization_code", @@ -172,6 +174,10 @@ export class ProxyOAuthServerProvider implements OAuthServerProvider { params.append("redirect_uri", redirectUri); } + if (resource) { + params.append("resource", resource.href); + } + const response = await fetch(this._endpoints.tokenUrl, { method: "POST", headers: { @@ -192,7 +198,8 @@ export class ProxyOAuthServerProvider implements OAuthServerProvider { async exchangeRefreshToken( client: OAuthClientInformationFull, refreshToken: string, - scopes?: string[] + scopes?: string[], + resource?: URL ): Promise { const params = new URLSearchParams({ @@ -209,6 +216,10 @@ export class ProxyOAuthServerProvider implements OAuthServerProvider { params.set("scope", scopes.join(" ")); } + if (resource) { + params.set("resource", resource.href); + } + const response = await fetch(this._endpoints.tokenUrl, { method: "POST", headers: { diff --git a/src/server/auth/types.ts b/src/server/auth/types.ts index c25c2b60..0189e9ed 100644 --- a/src/server/auth/types.ts +++ b/src/server/auth/types.ts @@ -22,6 +22,12 @@ export interface AuthInfo { */ expiresAt?: number; + /** + * The RFC 8707 resource server identifier for which this token is valid. + * If set, this MUST match the MCP server's resource identifier (minus hash fragment). + */ + resource?: URL; + /** * Additional data associated with the token. * This field should be used for any additional data that needs to be attached to the auth info. diff --git a/src/shared/auth-utils.test.ts b/src/shared/auth-utils.test.ts new file mode 100644 index 00000000..c35bb122 --- /dev/null +++ b/src/shared/auth-utils.test.ts @@ -0,0 +1,30 @@ +import { resourceUrlFromServerUrl } from './auth-utils.js'; + +describe('auth-utils', () => { + describe('resourceUrlFromServerUrl', () => { + it('should remove fragments', () => { + expect(resourceUrlFromServerUrl(new URL('https://example.com/path#fragment')).href).toBe('https://example.com/path'); + expect(resourceUrlFromServerUrl(new URL('https://example.com#fragment')).href).toBe('https://example.com/'); + expect(resourceUrlFromServerUrl(new URL('https://example.com/path?query=1#fragment')).href).toBe('https://example.com/path?query=1'); + }); + + it('should return URL unchanged if no fragment', () => { + expect(resourceUrlFromServerUrl(new URL('https://example.com')).href).toBe('https://example.com/'); + expect(resourceUrlFromServerUrl(new URL('https://example.com/path')).href).toBe('https://example.com/path'); + expect(resourceUrlFromServerUrl(new URL('https://example.com/path?query=1')).href).toBe('https://example.com/path?query=1'); + }); + + it('should keep everything else unchanged', () => { + // Case sensitivity preserved + expect(resourceUrlFromServerUrl(new URL('https://EXAMPLE.COM/PATH')).href).toBe('https://example.com/PATH'); + // Ports preserved + expect(resourceUrlFromServerUrl(new URL('https://example.com:443/path')).href).toBe('https://example.com/path'); + expect(resourceUrlFromServerUrl(new URL('https://example.com:8080/path')).href).toBe('https://example.com:8080/path'); + // Query parameters preserved + expect(resourceUrlFromServerUrl(new URL('https://example.com?foo=bar&baz=qux')).href).toBe('https://example.com/?foo=bar&baz=qux'); + // Trailing slashes preserved + expect(resourceUrlFromServerUrl(new URL('https://example.com/')).href).toBe('https://example.com/'); + expect(resourceUrlFromServerUrl(new URL('https://example.com/path/')).href).toBe('https://example.com/path/'); + }); + }); +}); \ No newline at end of file diff --git a/src/shared/auth-utils.ts b/src/shared/auth-utils.ts new file mode 100644 index 00000000..086d812f --- /dev/null +++ b/src/shared/auth-utils.ts @@ -0,0 +1,14 @@ +/** + * Utilities for handling OAuth resource URIs. + */ + +/** + * Converts a server URL to a resource URL by removing the fragment. + * RFC 8707 section 2 states that resource URIs "MUST NOT include a fragment component". + * Keeps everything else unchanged (scheme, domain, port, path, query). + */ +export function resourceUrlFromServerUrl(url: URL): URL { + const resourceURL = new URL(url.href); + resourceURL.hash = ''; // Remove fragment + return resourceURL; +}