diff --git a/README.md b/README.md index d8a2db2b6..1b1c35cdb 100644 --- a/README.md +++ b/README.md @@ -423,43 +423,41 @@ The `elicit()` method returns an `ElicitationResult` with: Authentication can be used by servers that want to expose tools accessing protected resources. -`mcp.server.auth` implements an OAuth 2.0 server interface, which servers can use by -providing an implementation of the `OAuthAuthorizationServerProvider` protocol. +`mcp.server.auth` implements OAuth 2.1 resource server functionality, where MCP servers act as Resource Servers (RS) that validate tokens issued by separate Authorization Servers (AS). This follows the [MCP authorization specification](https://modelcontextprotocol.io/specification/2025-06-18/basic/authorization) and implements RFC 9728 (Protected Resource Metadata) for AS discovery. + +MCP servers can use authentication by providing an implementation of the `TokenVerifier` protocol: ```python from mcp import FastMCP -from mcp.server.auth.provider import OAuthAuthorizationServerProvider -from mcp.server.auth.settings import ( - AuthSettings, - ClientRegistrationOptions, - RevocationOptions, -) +from mcp.server.auth.provider import TokenVerifier, TokenInfo +from mcp.server.auth.settings import AuthSettings -class MyOAuthServerProvider(OAuthAuthorizationServerProvider): - # See an example on how to implement at `examples/servers/simple-auth` - ... +class MyTokenVerifier(TokenVerifier): + # Implement token validation logic (typically via token introspection) + async def verify_token(self, token: str) -> TokenInfo: + # Verify with your authorization server + ... mcp = FastMCP( "My App", - auth_server_provider=MyOAuthServerProvider(), + token_verifier=MyTokenVerifier(), auth=AuthSettings( - issuer_url="https://myapp.com", - revocation_options=RevocationOptions( - enabled=True, - ), - client_registration_options=ClientRegistrationOptions( - enabled=True, - valid_scopes=["myscope", "myotherscope"], - default_scopes=["myscope"], - ), - required_scopes=["myscope"], + authorization_servers=["https://auth.example.com"], + required_scopes=["mcp:read", "mcp:write"], ), ) ``` -See [OAuthAuthorizationServerProvider](src/mcp/server/auth/provider.py) for more details. +For a complete example with separate Authorization Server and Resource Server implementations, see [`examples/servers/simple-auth/`](examples/servers/simple-auth/). + +**Architecture:** +- **Authorization Server (AS)**: Handles OAuth flows, user authentication, and token issuance +- **Resource Server (RS)**: Your MCP server that validates tokens and serves protected resources +- **Client**: Discovers AS through RFC 9728, obtains tokens, and uses them with the MCP server + +See [TokenVerifier](src/mcp/server/auth/provider.py) for more details on implementing token validation. ## Running Your Server diff --git a/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py b/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py index 577c392f3..6354f2026 100644 --- a/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py +++ b/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py @@ -160,8 +160,7 @@ async def connect(self): print(f"🔗 Attempting to connect to {self.server_url}...") try: - # Set up callback server - callback_server = CallbackServer(port=3000) + callback_server = CallbackServer(port=3030) callback_server.start() async def callback_handler() -> tuple[str, str | None]: @@ -175,7 +174,7 @@ async def callback_handler() -> tuple[str, str | None]: client_metadata_dict = { "client_name": "Simple Auth Client", - "redirect_uris": ["http://localhost:3000/callback"], + "redirect_uris": ["http://localhost:3030/callback"], "grant_types": ["authorization_code", "refresh_token"], "response_types": ["code"], "token_endpoint_auth_method": "client_secret_post", diff --git a/examples/servers/simple-auth/README.md b/examples/servers/simple-auth/README.md index 9906c4d36..3873cac70 100644 --- a/examples/servers/simple-auth/README.md +++ b/examples/servers/simple-auth/README.md @@ -1,91 +1,138 @@ -# Simple MCP Server with GitHub OAuth Authentication +# MCP OAuth Authentication Demo -This is a simple example of an MCP server with GitHub OAuth authentication. It demonstrates the essential components needed for OAuth integration with just a single tool. +This example demonstrates OAuth 2.0 authentication with the Model Context Protocol using **separate Authorization Server (AS) and Resource Server (RS)** to comply with the new RFC 9728 specification. -This is just an example of a server that uses auth, an official GitHub mcp server is [here](https://github.com/github/github-mcp-server) +--- -## Overview +## Setup Requirements -This simple demo to show to set up a server with: -- GitHub OAuth2 authorization flow -- Single tool: `get_user_profile` to retrieve GitHub user information +**Create a GitHub OAuth App:** +- Go to GitHub Settings > Developer settings > OAuth Apps > New OAuth App +- **Authorization callback URL:** `http://localhost:9000/github/callback` +- Note down your **Client ID** and **Client Secret** +**Set environment variables:** +```bash +export MCP_GITHUB_CLIENT_ID="your_client_id_here" +export MCP_GITHUB_CLIENT_SECRET="your_client_secret_here" +``` -## Prerequisites - -1. Create a GitHub OAuth App: - - Go to GitHub Settings > Developer settings > OAuth Apps > New OAuth App - - Application name: Any name (e.g., "Simple MCP Auth Demo") - - Homepage URL: `http://localhost:8000` - - Authorization callback URL: `http://localhost:8000/github/callback` - - Click "Register application" - - Note down your Client ID and Client Secret +--- -## Required Environment Variables +## Running the Servers -You MUST set these environment variables before running the server: +### Step 1: Start Authorization Server ```bash -export MCP_GITHUB_GITHUB_CLIENT_ID="your_client_id_here" -export MCP_GITHUB_GITHUB_CLIENT_SECRET="your_client_secret_here" +# Navigate to the simple-auth directory +cd examples/servers/simple-auth + +# Start Authorization Server on port 9000 +python -m mcp_simple_auth.auth_server --port=9000 ``` -The server will not start without these environment variables properly set. +**What it provides:** +- OAuth 2.0 flows (registration, authorization, token exchange) +- GitHub OAuth integration for user authentication +- Token introspection endpoint for Resource Servers (`/introspect`) +- User data proxy endpoint (`/github/user`) +--- -## Running the Server +### Step 2: Start Resource Server (MCP Server) ```bash -# Set environment variables first (see above) +# In another terminal, navigate to the simple-auth directory +cd examples/servers/simple-auth -# Run the server -uv run mcp-simple-auth +# Start Resource Server on port 8001, connected to Authorization Server +python -m mcp_simple_auth.server --port=8001 --auth-server=http://localhost:9000 --transport=streamable-http ``` -The server will start on `http://localhost:8000`. - -### Transport Options -This server supports multiple transport protocols that can run on the same port: +### Step 3: Test with Client -#### SSE (Server-Sent Events) - Default ```bash -uv run mcp-simple-auth -# or explicitly: -uv run mcp-simple-auth --transport sse +cd examples/clients/simple-auth-client +# Start client with streamable HTTP +MCP_SERVER_PORT=8001 MCP_TRANSPORT_TYPE=streamable_http python -m mcp_simple_auth_client.main ``` -SSE transport provides endpoint: -- `/sse` -#### Streamable HTTP +## How It Works + +### RFC 9728 Discovery + +**Client → Resource Server:** ```bash -uv run mcp-simple-auth --transport streamable-http +curl http://localhost:8001/.well-known/oauth-protected-resource +``` +```json +{ + "resource": "http://localhost:8001", + "authorization_servers": ["http://localhost:9000"] +} ``` -Streamable HTTP transport provides endpoint: -- `/mcp` +**Client → Authorization Server:** +```bash +curl http://localhost:9000/.well-known/oauth-authorization-server +``` +```json +{ + "issuer": "http://localhost:9000", + "authorization_endpoint": "http://localhost:9000/authorize", + "token_endpoint": "http://localhost:9000/token" +} +``` +## Legacy MCP Server as Authorization Server (Backwards Compatibility) -This ensures backward compatibility without needing multiple server instances. When using SSE transport (`--transport sse`), only the `/sse` endpoint is available. +For backwards compatibility with older MCP implementations, a legacy server is provided that acts as an Authorization Server (following the old spec where MCP servers could optionally provide OAuth): -## Available Tool +### Running the Legacy Server -### get_user_profile +```bash +# Start legacy authorization server on port 8002 +python -m mcp_simple_auth.legacy_as_server --port=8002 +``` + +**Differences from the new architecture:** +- **MCP server acts as AS:** The MCP server itself provides OAuth endpoints (old spec behavior) +- **No separate RS:** The server handles both authentication and MCP tools +- **Local token validation:** Tokens are validated internally without introspection +- **No RFC 9728 support:** Does not provide `/.well-known/oauth-protected-resource` +- **Direct OAuth discovery:** OAuth metadata is at the MCP server's URL -The only tool in this simple example. Returns the authenticated user's GitHub profile information. +### Testing with Legacy Server + +```bash +# Test with client (will automatically fall back to legacy discovery) +MCP_SERVER_PORT=8002 MCP_TRANSPORT_TYPE=streamable_http python -m mcp_simple_auth_client.main +``` -**Required scope**: `user` +The client will: +1. Try RFC 9728 discovery at `/.well-known/oauth-protected-resource` (404 on legacy server) +2. Fall back to direct OAuth discovery at `/.well-known/oauth-authorization-server` +3. Complete authentication with the MCP server acting as its own AS -**Returns**: GitHub user profile data including username, email, bio, etc. +This ensures existing MCP servers (which could optionally act as Authorization Servers under the old spec) continue to work while the ecosystem transitions to the new architecture where MCP servers are Resource Servers only. +## Manual Testing -## Troubleshooting +### Test Discovery +```bash +# Test Resource Server discovery endpoint (new architecture) +curl -v http://localhost:8001/.well-known/oauth-protected-resource -If the server fails to start, check: -1. Environment variables `MCP_GITHUB_GITHUB_CLIENT_ID` and `MCP_GITHUB_GITHUB_CLIENT_SECRET` are set -2. The GitHub OAuth app callback URL matches `http://localhost:8000/github/callback` -3. No other service is using port 8000 -4. The transport specified is valid (`sse` or `streamable-http`) +# Test Authorization Server metadata +curl -v http://localhost:9000/.well-known/oauth-authorization-server +``` -You can use [Inspector](https://github.com/modelcontextprotocol/inspector) to test Auth \ No newline at end of file +### Test Token Introspection +```bash +# After getting a token through OAuth flow: +curl -X POST http://localhost:9000/introspect \ + -H "Content-Type: application/x-www-form-urlencoded" \ + -d "token=your_access_token" +``` diff --git a/examples/servers/simple-auth/mcp_simple_auth/auth_server.py b/examples/servers/simple-auth/mcp_simple_auth/auth_server.py new file mode 100644 index 000000000..d7b7b93cd --- /dev/null +++ b/examples/servers/simple-auth/mcp_simple_auth/auth_server.py @@ -0,0 +1,233 @@ +""" +Authorization Server for MCP Split Demo. + +This server handles OAuth flows, client registration, and token issuance. +Can be replaced with enterprise authorization servers like Auth0, Entra ID, etc. + +NOTE: this is a simplified example for demonstration purposes. +This is not a production-ready implementation. + +Usage: + python -m mcp_simple_auth.auth_server --port=9000 +""" + +import asyncio +import logging +import time + +import click +from pydantic import AnyHttpUrl, BaseModel +from starlette.applications import Starlette +from starlette.exceptions import HTTPException +from starlette.requests import Request +from starlette.responses import JSONResponse, RedirectResponse, Response +from starlette.routing import Route +from uvicorn import Config, Server + +from mcp.server.auth.routes import cors_middleware, create_auth_routes +from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions + +from .github_oauth_provider import GitHubOAuthProvider, GitHubOAuthSettings + +logger = logging.getLogger(__name__) + + +class AuthServerSettings(BaseModel): + """Settings for the Authorization Server.""" + + # Server settings + host: str = "localhost" + port: int = 9000 + server_url: AnyHttpUrl = AnyHttpUrl("http://localhost:9000") + github_callback_path: str = "http://localhost:9000/github/callback" + + +class GitHubProxyAuthProvider(GitHubOAuthProvider): + """ + Authorization Server provider that proxies GitHub OAuth. + + This provider: + 1. Issues MCP tokens after GitHub authentication + 2. Stores token state for introspection by Resource Servers + 3. Maps MCP tokens to GitHub tokens for API access + """ + + def __init__(self, github_settings: GitHubOAuthSettings, github_callback_path: str): + super().__init__(github_settings, github_callback_path) + + +def create_authorization_server(server_settings: AuthServerSettings, github_settings: GitHubOAuthSettings) -> Starlette: + """Create the Authorization Server application.""" + oauth_provider = GitHubProxyAuthProvider(github_settings, server_settings.github_callback_path) + + auth_settings = AuthSettings( + issuer_url=server_settings.server_url, + client_registration_options=ClientRegistrationOptions( + enabled=True, + valid_scopes=[github_settings.mcp_scope], + default_scopes=[github_settings.mcp_scope], + ), + required_scopes=[github_settings.mcp_scope], + authorization_servers=None, + ) + + # Create OAuth routes + routes = create_auth_routes( + provider=oauth_provider, + issuer_url=auth_settings.issuer_url, + service_documentation_url=auth_settings.service_documentation_url, + client_registration_options=auth_settings.client_registration_options, + revocation_options=auth_settings.revocation_options, + ) + + # Add GitHub callback route + async def github_callback_handler(request: Request) -> Response: + """Handle GitHub OAuth callback.""" + code = request.query_params.get("code") + state = request.query_params.get("state") + + if not code or not state: + raise HTTPException(400, "Missing code or state parameter") + + redirect_uri = await oauth_provider.handle_github_callback(code, state) + return RedirectResponse(url=redirect_uri, status_code=302) + + routes.append(Route("/github/callback", endpoint=github_callback_handler, methods=["GET"])) + + # Add token introspection endpoint (RFC 7662) for Resource Servers + async def introspect_handler(request: Request) -> Response: + """ + Token introspection endpoint for Resource Servers. + + Resource Servers call this endpoint to validate tokens without + needing direct access to token storage. + """ + form = await request.form() + token = form.get("token") + if not token or not isinstance(token, str): + return JSONResponse({"active": False}, status_code=400) + + # Look up token in provider + access_token = await oauth_provider.load_access_token(token) + if not access_token: + return JSONResponse({"active": False}) + + # Return token info for Resource Server + return JSONResponse( + { + "active": True, + "client_id": access_token.client_id, + "scope": " ".join(access_token.scopes), + "exp": access_token.expires_at, + "iat": int(time.time()), + "token_type": "Bearer", + } + ) + + routes.append( + Route( + "/introspect", + endpoint=cors_middleware(introspect_handler, ["POST", "OPTIONS"]), + methods=["POST", "OPTIONS"], + ) + ) + + # Add GitHub user info endpoint (for Resource Server to fetch user data) + async def github_user_handler(request: Request) -> Response: + """ + Proxy endpoint to get GitHub user info using stored GitHub tokens. + + Resource Servers call this with MCP tokens to get GitHub user data + without exposing GitHub tokens to clients. + """ + # Extract Bearer token + auth_header = request.headers.get("authorization", "") + if not auth_header.startswith("Bearer "): + return JSONResponse({"error": "unauthorized"}, status_code=401) + + mcp_token = auth_header[7:] + + # Get GitHub user info using the provider method + user_info = await oauth_provider.get_github_user_info(mcp_token) + return JSONResponse(user_info) + + routes.append( + Route( + "/github/user", + endpoint=cors_middleware(github_user_handler, ["GET", "OPTIONS"]), + methods=["GET", "OPTIONS"], + ) + ) + + return Starlette(routes=routes) + + +async def run_server(server_settings: AuthServerSettings, github_settings: GitHubOAuthSettings): + """Run the Authorization Server.""" + auth_server = create_authorization_server(server_settings, github_settings) + + config = Config( + auth_server, + host=server_settings.host, + port=server_settings.port, + log_level="info", + ) + server = Server(config) + + logger.info("=" * 80) + logger.info("MCP AUTHORIZATION SERVER") + logger.info("=" * 80) + logger.info(f"Server URL: {server_settings.server_url}") + logger.info("Endpoints:") + logger.info(f" - OAuth Metadata: {server_settings.server_url}/.well-known/oauth-authorization-server") + logger.info(f" - Client Registration: {server_settings.server_url}/register") + logger.info(f" - Authorization: {server_settings.server_url}/authorize") + logger.info(f" - Token Exchange: {server_settings.server_url}/token") + logger.info(f" - Token Introspection: {server_settings.server_url}/introspect") + logger.info(f" - GitHub Callback: {server_settings.server_url}/github/callback") + logger.info(f" - GitHub User Proxy: {server_settings.server_url}/github/user") + logger.info("") + logger.info("Resource Servers should use /introspect to validate tokens") + logger.info("Configure GitHub App callback URL: " + server_settings.github_callback_path) + logger.info("=" * 80) + + await server.serve() + + +@click.command() +@click.option("--port", default=9000, help="Port to listen on") +def main(port: int) -> int: + """ + Run the MCP Authorization Server. + + This server handles OAuth flows and can be used by multiple Resource Servers. + + Environment variables needed: + - MCP_GITHUB_CLIENT_ID: GitHub OAuth Client ID + - MCP_GITHUB_CLIENT_SECRET: GitHub OAuth Client Secret + """ + logging.basicConfig(level=logging.INFO) + + # Load GitHub settings from environment variables + github_settings = GitHubOAuthSettings() + + # Validate required fields + if not github_settings.github_client_id or not github_settings.github_client_secret: + raise ValueError("GitHub credentials not provided") + + # Create server settings + host = "localhost" + server_url = f"http://{host}:{port}" + server_settings = AuthServerSettings( + host=host, + port=port, + server_url=AnyHttpUrl(server_url), + github_callback_path=f"{server_url}/github/callback", + ) + + asyncio.run(run_server(server_settings, github_settings)) + return 0 + + +if __name__ == "__main__": + main() # type: ignore[call-arg] diff --git a/examples/servers/simple-auth/mcp_simple_auth/github_oauth_provider.py b/examples/servers/simple-auth/mcp_simple_auth/github_oauth_provider.py new file mode 100644 index 000000000..bb45ae6c5 --- /dev/null +++ b/examples/servers/simple-auth/mcp_simple_auth/github_oauth_provider.py @@ -0,0 +1,257 @@ +""" +Shared GitHub OAuth provider for MCP servers. + +This module contains the common GitHub OAuth functionality used by both +the standalone authorization server and the legacy combined server. + +NOTE: this is a simplified example for demonstration purposes. +This is not a production-ready implementation. + +""" + +import logging +import secrets +import time +from typing import Any + +from pydantic import AnyHttpUrl +from pydantic_settings import BaseSettings, SettingsConfigDict +from starlette.exceptions import HTTPException + +from mcp.server.auth.provider import ( + AccessToken, + AuthorizationCode, + AuthorizationParams, + OAuthAuthorizationServerProvider, + RefreshToken, + construct_redirect_uri, +) +from mcp.shared._httpx_utils import create_mcp_http_client +from mcp.shared.auth import OAuthClientInformationFull, OAuthToken + +logger = logging.getLogger(__name__) + + +class GitHubOAuthSettings(BaseSettings): + """Common GitHub OAuth settings.""" + + model_config = SettingsConfigDict(env_prefix="MCP_") + + # GitHub OAuth settings - MUST be provided via environment variables + github_client_id: str | None = None + github_client_secret: str | None = None + + # GitHub OAuth URLs + github_auth_url: str = "https://github.com/login/oauth/authorize" + github_token_url: str = "https://github.com/login/oauth/access_token" + + mcp_scope: str = "user" + github_scope: str = "read:user" + + +class GitHubOAuthProvider(OAuthAuthorizationServerProvider): + """ + OAuth provider that uses GitHub as the identity provider. + + This provider handles the OAuth flow by: + 1. Redirecting users to GitHub for authentication + 2. Exchanging GitHub tokens for MCP tokens + 3. Maintaining token mappings for API access + """ + + def __init__(self, settings: GitHubOAuthSettings, github_callback_url: str): + self.settings = settings + self.github_callback_url = github_callback_url + self.clients: dict[str, OAuthClientInformationFull] = {} + self.auth_codes: dict[str, AuthorizationCode] = {} + self.tokens: dict[str, AccessToken] = {} + self.state_mapping: dict[str, dict[str, str]] = {} + # Maps MCP tokens to GitHub tokens + self.token_mapping: dict[str, str] = {} + + async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: + """Get OAuth client information.""" + return self.clients.get(client_id) + + async def register_client(self, client_info: OAuthClientInformationFull): + """Register a new OAuth client.""" + self.clients[client_info.client_id] = client_info + + async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str: + """Generate an authorization URL for GitHub OAuth flow.""" + state = params.state or secrets.token_hex(16) + + # Store state mapping for callback + self.state_mapping[state] = { + "redirect_uri": str(params.redirect_uri), + "code_challenge": params.code_challenge, + "redirect_uri_provided_explicitly": str(params.redirect_uri_provided_explicitly), + "client_id": client.client_id, + } + + # Build GitHub authorization URL + auth_url = ( + f"{self.settings.github_auth_url}" + f"?client_id={self.settings.github_client_id}" + f"&redirect_uri={self.github_callback_url}" + f"&scope={self.settings.github_scope}" + f"&state={state}" + ) + + return auth_url + + async def handle_github_callback(self, code: str, state: str) -> str: + """Handle GitHub OAuth callback and return redirect URI.""" + state_data = self.state_mapping.get(state) + if not state_data: + raise HTTPException(400, "Invalid state parameter") + + redirect_uri = state_data["redirect_uri"] + code_challenge = state_data["code_challenge"] + redirect_uri_provided_explicitly = state_data["redirect_uri_provided_explicitly"] == "True" + client_id = state_data["client_id"] + + # Exchange code for token with GitHub + async with create_mcp_http_client() as client: + response = await client.post( + self.settings.github_token_url, + data={ + "client_id": self.settings.github_client_id, + "client_secret": self.settings.github_client_secret, + "code": code, + "redirect_uri": self.github_callback_url, + }, + headers={"Accept": "application/json"}, + ) + + if response.status_code != 200: + raise HTTPException(400, "Failed to exchange code for token") + + data = response.json() + + if "error" in data: + raise HTTPException(400, data.get("error_description", data["error"])) + + github_token = data["access_token"] + + # Create MCP authorization code + new_code = f"mcp_{secrets.token_hex(16)}" + auth_code = AuthorizationCode( + code=new_code, + client_id=client_id, + redirect_uri=AnyHttpUrl(redirect_uri), + redirect_uri_provided_explicitly=redirect_uri_provided_explicitly, + expires_at=time.time() + 300, + scopes=[self.settings.mcp_scope], + code_challenge=code_challenge, + ) + self.auth_codes[new_code] = auth_code + + # Store GitHub token with MCP client_id + self.tokens[github_token] = AccessToken( + token=github_token, + client_id=client_id, + scopes=[self.settings.github_scope], + expires_at=None, + ) + + del self.state_mapping[state] + return construct_redirect_uri(redirect_uri, code=new_code, state=state) + + async def load_authorization_code( + self, client: OAuthClientInformationFull, authorization_code: str + ) -> AuthorizationCode | None: + """Load an authorization code.""" + return self.auth_codes.get(authorization_code) + + async def exchange_authorization_code( + self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode + ) -> OAuthToken: + """Exchange authorization code for tokens.""" + if authorization_code.code not in self.auth_codes: + raise ValueError("Invalid authorization code") + + # Generate MCP access token + mcp_token = f"mcp_{secrets.token_hex(32)}" + + # Store MCP token + self.tokens[mcp_token] = AccessToken( + token=mcp_token, + client_id=client.client_id, + scopes=authorization_code.scopes, + expires_at=int(time.time()) + 3600, + ) + + # Find GitHub token for this client + github_token = next( + ( + token + for token, data in self.tokens.items() + if (token.startswith("ghu_") or token.startswith("gho_")) and data.client_id == client.client_id + ), + None, + ) + + # Store mapping between MCP token and GitHub token + if github_token: + self.token_mapping[mcp_token] = github_token + + del self.auth_codes[authorization_code.code] + + return OAuthToken( + access_token=mcp_token, + token_type="Bearer", + expires_in=3600, + scope=" ".join(authorization_code.scopes), + ) + + async def load_access_token(self, token: str) -> AccessToken | None: + """Load and validate an access token.""" + access_token = self.tokens.get(token) + if not access_token: + return None + + # Check if expired + if access_token.expires_at and access_token.expires_at < time.time(): + del self.tokens[token] + return None + + return access_token + + async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshToken | None: + """Load a refresh token - not supported in this example.""" + return None + + async def exchange_refresh_token( + self, + client: OAuthClientInformationFull, + refresh_token: RefreshToken, + scopes: list[str], + ) -> OAuthToken: + """Exchange refresh token - not supported in this example.""" + raise NotImplementedError("Refresh tokens not supported") + + async def revoke_token(self, token: str, token_type_hint: str | None = None) -> None: + """Revoke a token.""" + if token in self.tokens: + del self.tokens[token] + + async def get_github_user_info(self, mcp_token: str) -> dict[str, Any]: + """Get GitHub user info using MCP token.""" + github_token = self.token_mapping.get(mcp_token) + if not github_token: + raise ValueError("No GitHub token found for MCP token") + + async with create_mcp_http_client() as client: + response = await client.get( + "https://api.github.com/user", + headers={ + "Authorization": f"Bearer {github_token}", + "Accept": "application/vnd.github.v3+json", + }, + ) + + if response.status_code != 200: + raise ValueError(f"GitHub API error: {response.status_code}") + + return response.json() diff --git a/examples/servers/simple-auth/mcp_simple_auth/legacy_as_server.py b/examples/servers/simple-auth/mcp_simple_auth/legacy_as_server.py new file mode 100644 index 000000000..08c344665 --- /dev/null +++ b/examples/servers/simple-auth/mcp_simple_auth/legacy_as_server.py @@ -0,0 +1,152 @@ +""" +Legacy Combined Authorization Server + Resource Server for MCP. + +This server implements the old spec where MCP servers could act as both AS and RS. +Used for backwards compatibility testing with the new split AS/RS architecture. + +NOTE: this is a simplified example for demonstration purposes. +This is not a production-ready implementation. + + +Usage: + python -m mcp_simple_auth.legacy_as_server --port=8002 +""" + +import logging +from typing import Any, Literal + +import click +from pydantic import AnyHttpUrl, BaseModel +from starlette.exceptions import HTTPException +from starlette.requests import Request +from starlette.responses import RedirectResponse, Response + +from mcp.server.auth.middleware.auth_context import get_access_token +from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions +from mcp.server.fastmcp.server import FastMCP + +from .github_oauth_provider import GitHubOAuthProvider, GitHubOAuthSettings + +logger = logging.getLogger(__name__) + + +class ServerSettings(BaseModel): + """Settings for the simple GitHub MCP server.""" + + # Server settings + host: str = "localhost" + port: int = 8000 + server_url: AnyHttpUrl = AnyHttpUrl("http://localhost:8000") + github_callback_path: str = "http://localhost:8000/github/callback" + + +class SimpleGitHubOAuthProvider(GitHubOAuthProvider): + """GitHub OAuth provider for legacy MCP server.""" + + def __init__(self, github_settings: GitHubOAuthSettings, github_callback_path: str): + super().__init__(github_settings, github_callback_path) + + +def create_simple_mcp_server(server_settings: ServerSettings, github_settings: GitHubOAuthSettings) -> FastMCP: + """Create a simple FastMCP server with GitHub OAuth.""" + oauth_provider = SimpleGitHubOAuthProvider(github_settings, server_settings.github_callback_path) + + auth_settings = AuthSettings( + issuer_url=server_settings.server_url, + client_registration_options=ClientRegistrationOptions( + enabled=True, + valid_scopes=[github_settings.mcp_scope], + default_scopes=[github_settings.mcp_scope], + ), + required_scopes=[github_settings.mcp_scope], + # No authorization_servers parameter in legacy mode + authorization_servers=None, + ) + + app = FastMCP( + name="Simple GitHub MCP Server", + instructions="A simple MCP server with GitHub OAuth authentication", + auth_server_provider=oauth_provider, + host=server_settings.host, + port=server_settings.port, + debug=True, + auth=auth_settings, + ) + + @app.custom_route("/github/callback", methods=["GET"]) + async def github_callback_handler(request: Request) -> Response: + """Handle GitHub OAuth callback.""" + code = request.query_params.get("code") + state = request.query_params.get("state") + + if not code or not state: + raise HTTPException(400, "Missing code or state parameter") + + redirect_uri = await oauth_provider.handle_github_callback(code, state) + return RedirectResponse(status_code=302, url=redirect_uri) + + def get_github_token() -> str: + """Get the GitHub token for the authenticated user.""" + access_token = get_access_token() + if not access_token: + raise ValueError("Not authenticated") + + # Get GitHub token from mapping + github_token = oauth_provider.token_mapping.get(access_token.token) + + if not github_token: + raise ValueError("No GitHub token found for user") + + return github_token + + @app.tool() + async def get_user_profile() -> dict[str, Any]: + """Get the authenticated user's GitHub profile information. + + This is the only tool in our simple example. It requires the 'user' scope. + """ + access_token = get_access_token() + if not access_token: + raise ValueError("Not authenticated") + + return await oauth_provider.get_github_user_info(access_token.token) + + return app + + +@click.command() +@click.option("--port", default=8000, help="Port to listen on") +@click.option( + "--transport", + default="streamable-http", + type=click.Choice(["sse", "streamable-http"]), + help="Transport protocol to use ('sse' or 'streamable-http')", +) +def main(port: int, transport: Literal["sse", "streamable-http"]) -> int: + """Run the simple GitHub MCP server.""" + logging.basicConfig(level=logging.INFO) + + # Load GitHub settings from environment variables + github_settings = GitHubOAuthSettings() + + # Validate required fields + if not github_settings.github_client_id or not github_settings.github_client_secret: + raise ValueError("GitHub credentials not provided") + # Create server settings + host = "localhost" + server_url = f"http://{host}:{port}" + server_settings = ServerSettings( + host=host, + port=port, + server_url=AnyHttpUrl(server_url), + github_callback_path=f"{server_url}/github/callback", + ) + + mcp_server = create_simple_mcp_server(server_settings, github_settings) + logger.info(f"Starting server with {transport} transport") + mcp_server.run(transport=transport) + return 0 + + +if __name__ == "__main__": + main() # type: ignore[call-arg] diff --git a/examples/servers/simple-auth/mcp_simple_auth/server.py b/examples/servers/simple-auth/mcp_simple_auth/server.py index 6e16f8b9d..6a6a5b306 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/server.py +++ b/examples/servers/simple-auth/mcp_simple_auth/server.py @@ -1,360 +1,213 @@ -"""Simple MCP Server with GitHub OAuth Authentication.""" +""" +MCP Resource Server with Token Introspection. + +This server validates tokens via Authorization Server introspection and serves MCP resources. +Demonstrates RFC 9728 Protected Resource Metadata for AS/RS separation. + +Usage: + python -m mcp_simple_auth.server --port=8001 --auth-server=http://localhost:9000 +""" import logging -import secrets -import time from typing import Any, Literal import click +import httpx from pydantic import AnyHttpUrl from pydantic_settings import BaseSettings, SettingsConfigDict -from starlette.exceptions import HTTPException -from starlette.requests import Request -from starlette.responses import JSONResponse, RedirectResponse, Response from mcp.server.auth.middleware.auth_context import get_access_token -from mcp.server.auth.provider import ( - AccessToken, - AuthorizationCode, - AuthorizationParams, - OAuthAuthorizationServerProvider, - RefreshToken, - construct_redirect_uri, -) -from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions +from mcp.server.auth.settings import AuthSettings from mcp.server.fastmcp.server import FastMCP -from mcp.shared._httpx_utils import create_mcp_http_client -from mcp.shared.auth import OAuthClientInformationFull, OAuthToken + +from .token_verifier import IntrospectionTokenVerifier logger = logging.getLogger(__name__) -class ServerSettings(BaseSettings): - """Settings for the simple GitHub MCP server.""" +class ResourceServerSettings(BaseSettings): + """Settings for the MCP Resource Server.""" - model_config = SettingsConfigDict(env_prefix="MCP_GITHUB_") + model_config = SettingsConfigDict(env_prefix="MCP_RESOURCE_") # Server settings host: str = "localhost" - port: int = 8000 - server_url: AnyHttpUrl = AnyHttpUrl("http://localhost:8000") + port: int = 8001 + server_url: AnyHttpUrl = AnyHttpUrl("http://localhost:8001") - # GitHub OAuth settings - MUST be provided via environment variables - github_client_id: str # Type: MCP_GITHUB_GITHUB_CLIENT_ID env var - github_client_secret: str # Type: MCP_GITHUB_GITHUB_CLIENT_SECRET env var - github_callback_path: str = "http://localhost:8000/github/callback" - - # GitHub OAuth URLs - github_auth_url: str = "https://github.com/login/oauth/authorize" - github_token_url: str = "https://github.com/login/oauth/access_token" + # Authorization Server settings + auth_server_url: AnyHttpUrl = AnyHttpUrl("http://localhost:9000") + auth_server_introspection_endpoint: str = "http://localhost:9000/introspect" + auth_server_github_user_endpoint: str = "http://localhost:9000/github/user" + # MCP settings mcp_scope: str = "user" - github_scope: str = "read:user" def __init__(self, **data): - """Initialize settings with values from environment variables. - - Note: github_client_id and github_client_secret are required but can be - loaded automatically from environment variables (MCP_GITHUB_GITHUB_CLIENT_ID - and MCP_GITHUB_GITHUB_CLIENT_SECRET) and don't need to be passed explicitly. - """ + """Initialize settings with values from environment variables.""" super().__init__(**data) -class SimpleGitHubOAuthProvider(OAuthAuthorizationServerProvider): - """Simple GitHub OAuth provider with essential functionality.""" - - def __init__(self, settings: ServerSettings): - self.settings = settings - self.clients: dict[str, OAuthClientInformationFull] = {} - self.auth_codes: dict[str, AuthorizationCode] = {} - self.tokens: dict[str, AccessToken] = {} - self.state_mapping: dict[str, dict[str, str]] = {} - # Store GitHub tokens with MCP tokens using the format: - # {"mcp_token": "github_token"} - self.token_mapping: dict[str, str] = {} - - async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: - """Get OAuth client information.""" - return self.clients.get(client_id) - - async def register_client(self, client_info: OAuthClientInformationFull): - """Register a new OAuth client.""" - self.clients[client_info.client_id] = client_info - - async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str: - """Generate an authorization URL for GitHub OAuth flow.""" - state = params.state or secrets.token_hex(16) - - # Store the state mapping - self.state_mapping[state] = { - "redirect_uri": str(params.redirect_uri), - "code_challenge": params.code_challenge, - "redirect_uri_provided_explicitly": str(params.redirect_uri_provided_explicitly), - "client_id": client.client_id, - } - - # Build GitHub authorization URL - auth_url = ( - f"{self.settings.github_auth_url}" - f"?client_id={self.settings.github_client_id}" - f"&redirect_uri={self.settings.github_callback_path}" - f"&scope={self.settings.github_scope}" - f"&state={state}" - ) - - return auth_url - - async def handle_github_callback(self, code: str, state: str) -> str: - """Handle GitHub OAuth callback.""" - state_data = self.state_mapping.get(state) - if not state_data: - raise HTTPException(400, "Invalid state parameter") - - redirect_uri = state_data["redirect_uri"] - code_challenge = state_data["code_challenge"] - redirect_uri_provided_explicitly = state_data["redirect_uri_provided_explicitly"] == "True" - client_id = state_data["client_id"] - - # Exchange code for token with GitHub - async with create_mcp_http_client() as client: - response = await client.post( - self.settings.github_token_url, - data={ - "client_id": self.settings.github_client_id, - "client_secret": self.settings.github_client_secret, - "code": code, - "redirect_uri": self.settings.github_callback_path, - }, - headers={"Accept": "application/json"}, - ) - - if response.status_code != 200: - raise HTTPException(400, "Failed to exchange code for token") - - data = response.json() - - if "error" in data: - raise HTTPException(400, data.get("error_description", data["error"])) - - github_token = data["access_token"] - - # Create MCP authorization code - new_code = f"mcp_{secrets.token_hex(16)}" - auth_code = AuthorizationCode( - code=new_code, - client_id=client_id, - redirect_uri=AnyHttpUrl(redirect_uri), - redirect_uri_provided_explicitly=redirect_uri_provided_explicitly, - expires_at=time.time() + 300, - scopes=[self.settings.mcp_scope], - code_challenge=code_challenge, - ) - self.auth_codes[new_code] = auth_code - - # Store GitHub token - we'll map the MCP token to this later - self.tokens[github_token] = AccessToken( - token=github_token, - client_id=client_id, - scopes=[self.settings.github_scope], - expires_at=None, - ) - - del self.state_mapping[state] - return construct_redirect_uri(redirect_uri, code=new_code, state=state) - - async def load_authorization_code( - self, client: OAuthClientInformationFull, authorization_code: str - ) -> AuthorizationCode | None: - """Load an authorization code.""" - return self.auth_codes.get(authorization_code) - - async def exchange_authorization_code( - self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode - ) -> OAuthToken: - """Exchange authorization code for tokens.""" - if authorization_code.code not in self.auth_codes: - raise ValueError("Invalid authorization code") - - # Generate MCP access token - mcp_token = f"mcp_{secrets.token_hex(32)}" - - # Store MCP token - self.tokens[mcp_token] = AccessToken( - token=mcp_token, - client_id=client.client_id, - scopes=authorization_code.scopes, - expires_at=int(time.time()) + 3600, - ) - - # Find GitHub token for this client - github_token = next( - ( - token - for token, data in self.tokens.items() - # see https://github.blog/engineering/platform-security/behind-githubs-new-authentication-token-formats/ - # which you get depends on your GH app setup. - if (token.startswith("ghu_") or token.startswith("gho_")) and data.client_id == client.client_id - ), - None, - ) - - # Store mapping between MCP token and GitHub token - if github_token: - self.token_mapping[mcp_token] = github_token - - del self.auth_codes[authorization_code.code] - - return OAuthToken( - access_token=mcp_token, - token_type="Bearer", - expires_in=3600, - scope=" ".join(authorization_code.scopes), - ) +def create_resource_server(settings: ResourceServerSettings) -> FastMCP: + """ + Create MCP Resource Server with token introspection. - async def load_access_token(self, token: str) -> AccessToken | None: - """Load and validate an access token.""" - access_token = self.tokens.get(token) - if not access_token: - return None - - # Check if expired - if access_token.expires_at and access_token.expires_at < time.time(): - del self.tokens[token] - return None - - return access_token - - async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshToken | None: - """Load a refresh token - not supported.""" - return None - - async def exchange_refresh_token( - self, - client: OAuthClientInformationFull, - refresh_token: RefreshToken, - scopes: list[str], - ) -> OAuthToken: - """Exchange refresh token""" - raise NotImplementedError("Not supported") - - async def revoke_token(self, token: str, token_type_hint: str | None = None) -> None: - """Revoke a token.""" - if token in self.tokens: - del self.tokens[token] - - -def create_simple_mcp_server(settings: ServerSettings) -> FastMCP: - """Create a simple FastMCP server with GitHub OAuth.""" - oauth_provider = SimpleGitHubOAuthProvider(settings) - - auth_settings = AuthSettings( - issuer_url=settings.server_url, - client_registration_options=ClientRegistrationOptions( - enabled=True, - valid_scopes=[settings.mcp_scope], - default_scopes=[settings.mcp_scope], - ), - required_scopes=[settings.mcp_scope], - ) + This server: + 1. Provides protected resource metadata (RFC 9728) + 2. Validates tokens via Authorization Server introspection + 3. Serves MCP tools and resources + """ + # Create token verifier for introspection + token_verifier = IntrospectionTokenVerifier(settings.auth_server_introspection_endpoint) + # Create FastMCP server as a Resource Server app = FastMCP( - name="Simple GitHub MCP Server", - instructions="A simple MCP server with GitHub OAuth authentication", - auth_server_provider=oauth_provider, + name="MCP Resource Server", + instructions="Resource Server that validates tokens via Authorization Server introspection", host=settings.host, port=settings.port, debug=True, - auth=auth_settings, + # Auth configuration for RS mode + token_verifier=token_verifier, + auth=AuthSettings( + issuer_url=settings.server_url, + required_scopes=[settings.mcp_scope], + authorization_servers=[settings.auth_server_url], + ), ) - @app.custom_route("/github/callback", methods=["GET"]) - async def github_callback_handler(request: Request) -> Response: - """Handle GitHub OAuth callback.""" - code = request.query_params.get("code") - state = request.query_params.get("state") - - if not code or not state: - raise HTTPException(400, "Missing code or state parameter") - - try: - redirect_uri = await oauth_provider.handle_github_callback(code, state) - return RedirectResponse(status_code=302, url=redirect_uri) - except HTTPException: - raise - except Exception as e: - logger.error("Unexpected error", exc_info=e) - return JSONResponse( - status_code=500, - content={ - "error": "server_error", - "error_description": "Unexpected error", - }, - ) + async def get_github_user_data() -> dict[str, Any]: + """ + Get GitHub user data via Authorization Server proxy endpoint. - def get_github_token() -> str: - """Get the GitHub token for the authenticated user.""" + This avoids exposing GitHub tokens to the Resource Server. + The Authorization Server handles the GitHub API call and returns the data. + """ access_token = get_access_token() if not access_token: raise ValueError("Not authenticated") - # Get GitHub token from mapping - github_token = oauth_provider.token_mapping.get(access_token.token) + # Call Authorization Server's GitHub proxy endpoint + async with httpx.AsyncClient() as client: + response = await client.get( + settings.auth_server_github_user_endpoint, + headers={ + "Authorization": f"Bearer {access_token.token}", + }, + ) - if not github_token: - raise ValueError("No GitHub token found for user") + if response.status_code != 200: + raise ValueError(f"GitHub user data fetch failed: {response.status_code} - {response.text}") - return github_token + return response.json() @app.tool() async def get_user_profile() -> dict[str, Any]: - """Get the authenticated user's GitHub profile information. + """ + Get the authenticated user's GitHub profile information. - This is the only tool in our simple example. It requires the 'user' scope. + This tool requires the 'user' scope and demonstrates how Resource Servers + can access user data without directly handling GitHub tokens. """ - github_token = get_github_token() + return await get_github_user_data() - async with create_mcp_http_client() as client: - response = await client.get( - "https://api.github.com/user", - headers={ - "Authorization": f"Bearer {github_token}", - "Accept": "application/vnd.github.v3+json", - }, - ) + @app.tool() + async def get_user_info() -> dict[str, Any]: + """ + Get information about the currently authenticated user. - if response.status_code != 200: - raise ValueError(f"GitHub API error: {response.status_code} - {response.text}") + Returns token and scope information from the Resource Server's perspective. + """ + access_token = get_access_token() + if not access_token: + raise ValueError("Not authenticated") - return response.json() + return { + "authenticated": True, + "client_id": access_token.client_id, + "scopes": access_token.scopes, + "token_expires_at": access_token.expires_at, + "token_type": "Bearer", + "resource_server": str(settings.server_url), + "authorization_server": str(settings.auth_server_url), + } return app @click.command() -@click.option("--port", default=8000, help="Port to listen on") -@click.option("--host", default="localhost", help="Host to bind to") +@click.option("--port", default=8001, help="Port to listen on") +@click.option("--auth-server", default="http://localhost:9000", help="Authorization Server URL") @click.option( "--transport", - default="sse", + default="streamable-http", type=click.Choice(["sse", "streamable-http"]), help="Transport protocol to use ('sse' or 'streamable-http')", ) -def main(port: int, host: str, transport: Literal["sse", "streamable-http"]) -> int: - """Run the simple GitHub MCP server.""" +def main(port: int, auth_server: str, transport: Literal["sse", "streamable-http"]) -> int: + """ + Run the MCP Resource Server. + + This server: + - Provides RFC 9728 Protected Resource Metadata + - Validates tokens via Authorization Server introspection + - Serves MCP tools requiring authentication + + Must be used with a running Authorization Server. + """ logging.basicConfig(level=logging.INFO) try: - # No hardcoded credentials - all from environment variables - settings = ServerSettings(host=host, port=port) + # Parse auth server URL + auth_server_url = AnyHttpUrl(auth_server) + + # Create settings + host = "localhost" + server_url = f"http://{host}:{port}" + settings = ResourceServerSettings( + host=host, + port=port, + server_url=AnyHttpUrl(server_url), + auth_server_url=auth_server_url, + auth_server_introspection_endpoint=f"{auth_server}/introspect", + auth_server_github_user_endpoint=f"{auth_server}/github/user", + ) except ValueError as e: - logger.error("Failed to load settings. Make sure environment variables are set:") - logger.error(" MCP_GITHUB_GITHUB_CLIENT_ID=") - logger.error(" MCP_GITHUB_GITHUB_CLIENT_SECRET=") - logger.error(f"Error: {e}") + logger.error(f"Configuration error: {e}") + logger.error("Make sure to provide a valid Authorization Server URL") return 1 - mcp_server = create_simple_mcp_server(settings) - logger.info(f"Starting server with {transport} transport") - mcp_server.run(transport=transport) - return 0 + try: + mcp_server = create_resource_server(settings) + + logger.info("=" * 80) + logger.info("📦 MCP RESOURCE SERVER") + logger.info("=" * 80) + logger.info(f"🌐 Server URL: {settings.server_url}") + logger.info(f"🔑 Authorization Server: {settings.auth_server_url}") + logger.info("📋 Endpoints:") + logger.info(f" ┌─ Protected Resource Metadata: {settings.server_url}/.well-known/oauth-protected-resource") + mcp_path = "sse" if transport == "sse" else "mcp" + logger.info(f" ├─ MCP Protocol: {settings.server_url}/{mcp_path}") + logger.info(f" └─ Token Introspection: {settings.auth_server_introspection_endpoint}") + logger.info("") + logger.info("🛠️ Available Tools:") + logger.info(" ├─ get_user_profile() - Get GitHub user profile") + logger.info(" └─ get_user_info() - Get authentication status") + logger.info("") + logger.info("🔍 Tokens validated via Authorization Server introspection") + logger.info("📱 Clients discover Authorization Server via Protected Resource Metadata") + logger.info("=" * 80) + + # Run the server - this should block and keep running + mcp_server.run(transport=transport) + logger.info("Server stopped") + return 0 + except Exception as e: + logger.error(f"Server error: {e}") + logger.exception("Exception details:") + return 1 + + +if __name__ == "__main__": + main() # type: ignore[call-arg] diff --git a/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py b/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py new file mode 100644 index 000000000..ba71322fa --- /dev/null +++ b/examples/servers/simple-auth/mcp_simple_auth/token_verifier.py @@ -0,0 +1,65 @@ +"""Example token verifier implementation using OAuth 2.0 Token Introspection (RFC 7662).""" + +import logging + +from mcp.server.auth.provider import AccessToken, TokenVerifier + +logger = logging.getLogger(__name__) + + +class IntrospectionTokenVerifier(TokenVerifier): + """Example token verifier that uses OAuth 2.0 Token Introspection (RFC 7662). + + This is a simple example implementation for demonstration purposes. + Production implementations should consider: + - Connection pooling and reuse + - More sophisticated error handling + - Rate limiting and retry logic + - Comprehensive configuration options + """ + + def __init__(self, introspection_endpoint: str): + self.introspection_endpoint = introspection_endpoint + + async def verify_token(self, token: str) -> AccessToken | None: + """Verify token via introspection endpoint.""" + import httpx + + # Validate URL to prevent SSRF attacks + if not self.introspection_endpoint.startswith(("https://", "http://localhost", "http://127.0.0.1")): + logger.warning(f"Rejecting introspection endpoint with unsafe scheme: {self.introspection_endpoint}") + return None + + # Configure secure HTTP client + timeout = httpx.Timeout(10.0, connect=5.0) + limits = httpx.Limits(max_connections=10, max_keepalive_connections=5) + + async with httpx.AsyncClient( + timeout=timeout, + limits=limits, + verify=True, # Enforce SSL verification + ) as client: + try: + response = await client.post( + self.introspection_endpoint, + data={"token": token}, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + + if response.status_code != 200: + logger.debug(f"Token introspection returned status {response.status_code}") + return None + + data = response.json() + if not data.get("active", False): + return None + + return AccessToken( + token=token, + client_id=data.get("client_id", "unknown"), + scopes=data.get("scope", "").split() if data.get("scope") else [], + expires_at=data.get("exp"), + ) + except Exception as e: + logger.warning(f"Token introspection failed: {e}") + return None diff --git a/src/mcp/client/auth.py b/src/mcp/client/auth.py index 4e777d600..50ce74aa4 100644 --- a/src/mcp/client/auth.py +++ b/src/mcp/client/auth.py @@ -11,11 +11,13 @@ import string import time from collections.abc import AsyncGenerator, Awaitable, Callable +from dataclasses import dataclass, field from typing import Protocol -from urllib.parse import urlencode, urljoin +from urllib.parse import urlencode, urljoin, urlparse import anyio import httpx +from pydantic import BaseModel, Field, ValidationError from mcp.client.streamable_http import MCP_PROTOCOL_VERSION from mcp.shared.auth import ( @@ -23,12 +25,40 @@ OAuthClientMetadata, OAuthMetadata, OAuthToken, + ProtectedResourceMetadata, ) from mcp.types import LATEST_PROTOCOL_VERSION logger = logging.getLogger(__name__) +class OAuthFlowError(Exception): + """Base exception for OAuth flow errors.""" + + +class OAuthTokenError(OAuthFlowError): + """Raised when token operations fail.""" + + +class OAuthRegistrationError(OAuthFlowError): + """Raised when client registration fails.""" + + +class PKCEParameters(BaseModel): + """PKCE (Proof Key for Code Exchange) parameters.""" + + code_verifier: str = Field(..., min_length=43, max_length=128) + code_challenge: str = Field(..., min_length=43, max_length=128) + + @classmethod + def generate(cls) -> "PKCEParameters": + """Generate new PKCE parameters.""" + code_verifier = "".join(secrets.choice(string.ascii_letters + string.digits + "-._~") for _ in range(128)) + digest = hashlib.sha256(code_verifier.encode()).digest() + code_challenge = base64.urlsafe_b64encode(digest).decode().rstrip("=") + return cls(code_verifier=code_verifier, code_challenge=code_challenge) + + class TokenStorage(Protocol): """Protocol for token storage implementations.""" @@ -49,12 +79,70 @@ async def set_client_info(self, client_info: OAuthClientInformationFull) -> None ... +@dataclass +class OAuthContext: + """OAuth flow context.""" + + server_url: str + client_metadata: OAuthClientMetadata + storage: TokenStorage + redirect_handler: Callable[[str], Awaitable[None]] + callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] + timeout: float = 300.0 + + # Discovered metadata + protected_resource_metadata: ProtectedResourceMetadata | None = None + oauth_metadata: OAuthMetadata | None = None + auth_server_url: str | None = None + + # Client registration + client_info: OAuthClientInformationFull | None = None + + # Token management + current_tokens: OAuthToken | None = None + token_expiry_time: float | None = None + + # State + lock: anyio.Lock = field(default_factory=anyio.Lock) + + def get_authorization_base_url(self, server_url: str) -> str: + """Extract base URL by removing path component.""" + parsed = urlparse(server_url) + return f"{parsed.scheme}://{parsed.netloc}" + + def update_token_expiry(self, token: OAuthToken) -> None: + """Update token expiry time.""" + if token.expires_in: + self.token_expiry_time = time.time() + token.expires_in + else: + self.token_expiry_time = None + + def is_token_valid(self) -> bool: + """Check if current token is valid.""" + return bool( + self.current_tokens + and self.current_tokens.access_token + and (not self.token_expiry_time or time.time() <= self.token_expiry_time) + ) + + def can_refresh_token(self) -> bool: + """Check if token can be refreshed.""" + return bool(self.current_tokens and self.current_tokens.refresh_token and self.client_info) + + def clear_tokens(self) -> None: + """Clear current tokens.""" + self.current_tokens = None + self.token_expiry_time = None + + class OAuthClientProvider(httpx.Auth): """ - Authentication for httpx using anyio. + OAuth2 authentication for httpx. Handles OAuth flow with automatic client registration and token storage. """ + requires_response_body = True + def __init__( self, server_url: str, @@ -64,407 +152,318 @@ def __init__( callback_handler: Callable[[], Awaitable[tuple[str, str | None]]], timeout: float = 300.0, ): - """ - Initialize OAuth2 authentication. - - Args: - server_url: Base URL of the OAuth server - client_metadata: OAuth client metadata - storage: Token storage implementation (defaults to in-memory) - redirect_handler: Function to handle authorization URL like opening browser - callback_handler: Function to wait for callback - and return (auth_code, state) - timeout: Timeout for OAuth flow in seconds - """ - self.server_url = server_url - self.client_metadata = client_metadata - self.storage = storage - self.redirect_handler = redirect_handler - self.callback_handler = callback_handler - self.timeout = timeout - - # Cached authentication state - self._current_tokens: OAuthToken | None = None - self._metadata: OAuthMetadata | None = None - self._client_info: OAuthClientInformationFull | None = None - self._token_expiry_time: float | None = None - - # PKCE flow parameters - self._code_verifier: str | None = None - self._code_challenge: str | None = None - - # State parameter for CSRF protection - self._auth_state: str | None = None - - # Thread safety lock - self._token_lock = anyio.Lock() - - def _generate_code_verifier(self) -> str: - """Generate a cryptographically random code verifier for PKCE.""" - return "".join(secrets.choice(string.ascii_letters + string.digits + "-._~") for _ in range(128)) - - def _generate_code_challenge(self, code_verifier: str) -> str: - """Generate a code challenge from a code verifier using SHA256.""" - digest = hashlib.sha256(code_verifier.encode()).digest() - return base64.urlsafe_b64encode(digest).decode().rstrip("=") - - def _get_authorization_base_url(self, server_url: str) -> str: - """ - Extract base URL by removing path component. - - Per MCP spec 2.3.2: https://api.example.com/v1/mcp -> https://api.example.com - """ - from urllib.parse import urlparse, urlunparse - - parsed = urlparse(server_url) - # Remove path component - return urlunparse((parsed.scheme, parsed.netloc, "", "", "", "")) - - async def _discover_oauth_metadata(self, server_url: str) -> OAuthMetadata | None: - """ - Discover OAuth metadata from server's well-known endpoint. - """ - # Extract base URL per MCP spec - auth_base_url = self._get_authorization_base_url(server_url) - url = urljoin(auth_base_url, "/.well-known/oauth-authorization-server") - headers = {MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION} - - async with httpx.AsyncClient() as client: + """Initialize OAuth2 authentication.""" + self.context = OAuthContext( + server_url=server_url, + client_metadata=client_metadata, + storage=storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + timeout=timeout, + ) + self._initialized = False + + async def _discover_protected_resource(self) -> httpx.Request: + """Build discovery request for protected resource metadata.""" + auth_base_url = self.context.get_authorization_base_url(self.context.server_url) + url = urljoin(auth_base_url, "/.well-known/oauth-protected-resource") + return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}) + + async def _handle_protected_resource_response(self, response: httpx.Response) -> None: + """Handle discovery response.""" + if response.status_code == 200: try: - response = await client.get(url, headers=headers) - if response.status_code == 404: - return None - response.raise_for_status() - metadata_json = response.json() - logger.debug(f"OAuth metadata discovered: {metadata_json}") - return OAuthMetadata.model_validate(metadata_json) - except Exception: - # Retry without MCP header for CORS compatibility - try: - response = await client.get(url) - if response.status_code == 404: - return None - response.raise_for_status() - metadata_json = response.json() - logger.debug(f"OAuth metadata discovered (no MCP header): {metadata_json}") - return OAuthMetadata.model_validate(metadata_json) - except Exception: - logger.exception("Failed to discover OAuth metadata") - return None - - async def _register_oauth_client( - self, - server_url: str, - client_metadata: OAuthClientMetadata, - metadata: OAuthMetadata | None = None, - ) -> OAuthClientInformationFull: - """ - Register OAuth client with server. - """ - if not metadata: - metadata = await self._discover_oauth_metadata(server_url) - - if metadata and metadata.registration_endpoint: - registration_url = str(metadata.registration_endpoint) + content = await response.aread() + metadata = ProtectedResourceMetadata.model_validate_json(content) + self.context.protected_resource_metadata = metadata + if metadata.authorization_servers: + self.context.auth_server_url = str(metadata.authorization_servers[0]) + except ValidationError: + pass + + async def _discover_oauth_metadata(self) -> httpx.Request: + """Build OAuth metadata discovery request.""" + if self.context.auth_server_url: + base_url = self.context.get_authorization_base_url(self.context.auth_server_url) else: - # Use fallback registration endpoint - auth_base_url = self._get_authorization_base_url(server_url) - registration_url = urljoin(auth_base_url, "/register") + base_url = self.context.get_authorization_base_url(self.context.server_url) - # Handle default scope - if client_metadata.scope is None and metadata and metadata.scopes_supported is not None: - client_metadata.scope = " ".join(metadata.scopes_supported) + url = urljoin(base_url, "/.well-known/oauth-authorization-server") + return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}) - # Serialize client metadata - registration_data = client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True) - - async with httpx.AsyncClient() as client: + async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None: + """Handle OAuth metadata response.""" + if response.status_code == 200: try: - response = await client.post( - registration_url, - json=registration_data, - headers={"Content-Type": "application/json"}, - ) - - if response.status_code not in (200, 201): - raise httpx.HTTPStatusError( - f"Registration failed: {response.status_code}", - request=response.request, - response=response, - ) - - response_data = response.json() - logger.debug(f"Registration successful: {response_data}") - return OAuthClientInformationFull.model_validate(response_data) - - except httpx.HTTPStatusError: - raise - except Exception: - logger.exception("Registration error") - raise - - async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]: - """ - HTTPX auth flow integration. - """ - - if not self._has_valid_token(): - await self.initialize() - await self.ensure_token() - # Add Bearer token if available - if self._current_tokens and self._current_tokens.access_token: - request.headers["Authorization"] = f"Bearer {self._current_tokens.access_token}" - - response = yield request - - # Clear token on 401 to trigger re-auth - if response.status_code == 401: - self._current_tokens = None - - def _has_valid_token(self) -> bool: - """Check if current token is valid.""" - if not self._current_tokens or not self._current_tokens.access_token: - return False - - # Check expiry time - if self._token_expiry_time and time.time() > self._token_expiry_time: - return False - - return True - - async def _validate_token_scopes(self, token_response: OAuthToken) -> None: - """ - Validate returned scopes against requested scopes. - - Per OAuth 2.1 Section 3.3: server may grant subset, not superset. - """ - if not token_response.scope: - # No scope returned = validation passes - return + content = await response.aread() + metadata = OAuthMetadata.model_validate_json(content) + self.context.oauth_metadata = metadata + # Apply default scope if none specified + if self.context.client_metadata.scope is None and metadata.scopes_supported is not None: + self.context.client_metadata.scope = " ".join(metadata.scopes_supported) + except ValidationError: + pass + + async def _register_client(self) -> httpx.Request | None: + """Build registration request or skip if already registered.""" + if self.context.client_info: + return None + + if self.context.oauth_metadata and self.context.oauth_metadata.registration_endpoint: + registration_url = str(self.context.oauth_metadata.registration_endpoint) + else: + auth_base_url = self.context.get_authorization_base_url(self.context.server_url) + registration_url = urljoin(auth_base_url, "/register") - # Check explicitly requested scopes only - requested_scopes: set[str] = set() + registration_data = self.context.client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True) - if self.client_metadata.scope: - # Validate against explicit scope request - requested_scopes = set(self.client_metadata.scope.split()) + return httpx.Request( + "POST", registration_url, json=registration_data, headers={"Content-Type": "application/json"} + ) - # Check for unauthorized scopes - returned_scopes = set(token_response.scope.split()) - unauthorized_scopes = returned_scopes - requested_scopes + async def _handle_registration_response(self, response: httpx.Response) -> None: + """Handle registration response.""" + if response.status_code not in (200, 201): + raise OAuthRegistrationError(f"Registration failed: {response.status_code} {response.text}") - if unauthorized_scopes: - raise Exception( - f"Server granted unauthorized scopes: {unauthorized_scopes}. " - f"Requested: {requested_scopes}, Returned: {returned_scopes}" - ) + try: + content = await response.aread() + client_info = OAuthClientInformationFull.model_validate_json(content) + self.context.client_info = client_info + await self.context.storage.set_client_info(client_info) + except ValidationError as e: + raise OAuthRegistrationError(f"Invalid registration response: {e}") + + async def _perform_authorization(self) -> tuple[str, str]: + """Perform the authorization redirect and get auth code.""" + if self.context.oauth_metadata and self.context.oauth_metadata.authorization_endpoint: + auth_endpoint = str(self.context.oauth_metadata.authorization_endpoint) else: - # No explicit scopes requested - accept server defaults - logger.debug( - f"No explicit scopes requested, accepting server-granted " - f"scopes: {set(token_response.scope.split())}" - ) + auth_base_url = self.context.get_authorization_base_url(self.context.server_url) + auth_endpoint = urljoin(auth_base_url, "/authorize") - async def initialize(self) -> None: - """Load stored tokens and client info.""" - self._current_tokens = await self.storage.get_tokens() - self._client_info = await self.storage.get_client_info() + if not self.context.client_info: + raise OAuthFlowError("No client info available for authorization") - async def _get_or_register_client(self) -> OAuthClientInformationFull: - """Get or register client with server.""" - if not self._client_info: - try: - self._client_info = await self._register_oauth_client( - self.server_url, self.client_metadata, self._metadata - ) - await self.storage.set_client_info(self._client_info) - except Exception: - logger.exception("Client registration failed") - raise - return self._client_info - - async def ensure_token(self) -> None: - """Ensure valid access token, refreshing or re-authenticating as needed.""" - async with self._token_lock: - # Return early if token is valid - if self._has_valid_token(): - return - - # Try refreshing existing token - if self._current_tokens and self._current_tokens.refresh_token and await self._refresh_access_token(): - return - - # Fall back to full OAuth flow - await self._perform_oauth_flow() - - async def _perform_oauth_flow(self) -> None: - """Execute OAuth2 authorization code flow with PKCE.""" - logger.debug("Starting authentication flow.") - - # Discover OAuth metadata - if not self._metadata: - self._metadata = await self._discover_oauth_metadata(self.server_url) - - # Ensure client registration - client_info = await self._get_or_register_client() - - # Generate PKCE challenge - self._code_verifier = self._generate_code_verifier() - self._code_challenge = self._generate_code_challenge(self._code_verifier) - - # Get authorization endpoint - if self._metadata and self._metadata.authorization_endpoint: - auth_url_base = str(self._metadata.authorization_endpoint) - else: - # Use fallback authorization endpoint - auth_base_url = self._get_authorization_base_url(self.server_url) - auth_url_base = urljoin(auth_base_url, "/authorize") + # Generate PKCE parameters + pkce_params = PKCEParameters.generate() + state = secrets.token_urlsafe(32) - # Build authorization URL - self._auth_state = secrets.token_urlsafe(32) auth_params = { "response_type": "code", - "client_id": client_info.client_id, - "redirect_uri": str(self.client_metadata.redirect_uris[0]), - "state": self._auth_state, - "code_challenge": self._code_challenge, + "client_id": self.context.client_info.client_id, + "redirect_uri": str(self.context.client_metadata.redirect_uris[0]), + "state": state, + "code_challenge": pkce_params.code_challenge, "code_challenge_method": "S256", } - # Include explicit scopes only - if self.client_metadata.scope: - auth_params["scope"] = self.client_metadata.scope - - auth_url = f"{auth_url_base}?{urlencode(auth_params)}" + if self.context.client_metadata.scope: + auth_params["scope"] = self.context.client_metadata.scope - # Redirect user for authorization - await self.redirect_handler(auth_url) + authorization_url = f"{auth_endpoint}?{urlencode(auth_params)}" + await self.context.redirect_handler(authorization_url) - auth_code, returned_state = await self.callback_handler() + # Wait for callback + auth_code, returned_state = await self.context.callback_handler() - # Validate state parameter for CSRF protection - if returned_state is None or not secrets.compare_digest(returned_state, self._auth_state): - raise Exception(f"State parameter mismatch: {returned_state} != {self._auth_state}") - - # Clear state after validation - self._auth_state = None + if returned_state is None or not secrets.compare_digest(returned_state, state): + raise OAuthFlowError(f"State parameter mismatch: {returned_state} != {state}") if not auth_code: - raise Exception("No authorization code received") + raise OAuthFlowError("No authorization code received") + + # Return auth code and code verifier for token exchange + return auth_code, pkce_params.code_verifier - # Exchange authorization code for tokens - await self._exchange_code_for_token(auth_code, client_info) + async def _exchange_token(self, auth_code: str, code_verifier: str) -> httpx.Request: + """Build token exchange request.""" + if not self.context.client_info: + raise OAuthFlowError("Missing client info") - async def _exchange_code_for_token(self, auth_code: str, client_info: OAuthClientInformationFull) -> None: - """Exchange authorization code for access token.""" - # Get token endpoint - if self._metadata and self._metadata.token_endpoint: - token_url = str(self._metadata.token_endpoint) + if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint: + token_url = str(self.context.oauth_metadata.token_endpoint) else: - # Use fallback token endpoint - auth_base_url = self._get_authorization_base_url(self.server_url) + auth_base_url = self.context.get_authorization_base_url(self.context.server_url) token_url = urljoin(auth_base_url, "/token") token_data = { "grant_type": "authorization_code", "code": auth_code, - "redirect_uri": str(self.client_metadata.redirect_uris[0]), - "client_id": client_info.client_id, - "code_verifier": self._code_verifier, + "redirect_uri": str(self.context.client_metadata.redirect_uris[0]), + "client_id": self.context.client_info.client_id, + "code_verifier": code_verifier, } - if client_info.client_secret: - token_data["client_secret"] = client_info.client_secret - - async with httpx.AsyncClient() as client: - response = await client.post( - token_url, - data=token_data, - headers={"Content-Type": "application/x-www-form-urlencoded"}, - timeout=30.0, - ) + if self.context.client_info.client_secret: + token_data["client_secret"] = self.context.client_info.client_secret - if response.status_code != 200: - # Parse OAuth error response - try: - error_data = response.json() - error_msg = error_data.get("error_description", error_data.get("error", "Unknown error")) - raise Exception(f"Token exchange failed: {error_msg} " f"(HTTP {response.status_code})") - except Exception: - raise Exception(f"Token exchange failed: {response.status_code} {response.text}") - - # Parse token response - token_response = OAuthToken.model_validate(response.json()) - - # Validate token scopes - await self._validate_token_scopes(token_response) - - # Calculate token expiry - if token_response.expires_in: - self._token_expiry_time = time.time() + token_response.expires_in - else: - self._token_expiry_time = None - - # Store tokens - await self.storage.set_tokens(token_response) - self._current_tokens = token_response - - async def _refresh_access_token(self) -> bool: - """Refresh access token using refresh token.""" - if not self._current_tokens or not self._current_tokens.refresh_token: - return False + return httpx.Request( + "POST", token_url, data=token_data, headers={"Content-Type": "application/x-www-form-urlencoded"} + ) - # Get client credentials - client_info = await self._get_or_register_client() + async def _handle_token_response(self, response: httpx.Response) -> None: + """Handle token exchange response.""" + if response.status_code != 200: + raise OAuthTokenError(f"Token exchange failed: {response.status_code}") - # Get token endpoint - if self._metadata and self._metadata.token_endpoint: - token_url = str(self._metadata.token_endpoint) + try: + content = await response.aread() + token_response = OAuthToken.model_validate_json(content) + + # Validate scopes + if token_response.scope and self.context.client_metadata.scope: + requested_scopes = set(self.context.client_metadata.scope.split()) + returned_scopes = set(token_response.scope.split()) + unauthorized_scopes = returned_scopes - requested_scopes + if unauthorized_scopes: + raise OAuthTokenError(f"Server granted unauthorized scopes: {unauthorized_scopes}") + + self.context.current_tokens = token_response + self.context.update_token_expiry(token_response) + await self.context.storage.set_tokens(token_response) + except ValidationError as e: + raise OAuthTokenError(f"Invalid token response: {e}") + + async def _refresh_token(self) -> httpx.Request: + """Build token refresh request.""" + if not self.context.current_tokens or not self.context.current_tokens.refresh_token: + raise OAuthTokenError("No refresh token available") + + if not self.context.client_info: + raise OAuthTokenError("No client info available") + + if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint: + token_url = str(self.context.oauth_metadata.token_endpoint) else: - # Use fallback token endpoint - auth_base_url = self._get_authorization_base_url(self.server_url) + auth_base_url = self.context.get_authorization_base_url(self.context.server_url) token_url = urljoin(auth_base_url, "/token") refresh_data = { "grant_type": "refresh_token", - "refresh_token": self._current_tokens.refresh_token, - "client_id": client_info.client_id, + "refresh_token": self.context.current_tokens.refresh_token, + "client_id": self.context.client_info.client_id, } - if client_info.client_secret: - refresh_data["client_secret"] = client_info.client_secret + if self.context.client_info.client_secret: + refresh_data["client_secret"] = self.context.client_info.client_secret - try: - async with httpx.AsyncClient() as client: - response = await client.post( - token_url, - data=refresh_data, - headers={"Content-Type": "application/x-www-form-urlencoded"}, - timeout=30.0, - ) - - if response.status_code != 200: - logger.error(f"Token refresh failed: {response.status_code}") - return False - - # Parse refreshed tokens - token_response = OAuthToken.model_validate(response.json()) - - # Validate token scopes - await self._validate_token_scopes(token_response) - - # Calculate token expiry - if token_response.expires_in: - self._token_expiry_time = time.time() + token_response.expires_in - else: - self._token_expiry_time = None + return httpx.Request( + "POST", token_url, data=refresh_data, headers={"Content-Type": "application/x-www-form-urlencoded"} + ) - # Store refreshed tokens - await self.storage.set_tokens(token_response) - self._current_tokens = token_response + async def _handle_refresh_response(self, response: httpx.Response) -> bool: + """Handle token refresh response. Returns True if successful.""" + if response.status_code != 200: + logger.warning(f"Token refresh failed: {response.status_code}") + self.context.clear_tokens() + return False - return True + try: + content = await response.aread() + token_response = OAuthToken.model_validate_json(content) - except Exception: - logger.exception("Token refresh failed") + self.context.current_tokens = token_response + self.context.update_token_expiry(token_response) + await self.context.storage.set_tokens(token_response) + + return True + except ValidationError as e: + logger.error(f"Invalid refresh response: {e}") + self.context.clear_tokens() return False + + async def _initialize(self) -> None: + """Load stored tokens and client info.""" + self.context.current_tokens = await self.context.storage.get_tokens() + self.context.client_info = await self.context.storage.get_client_info() + self._initialized = True + + def _add_auth_header(self, request: httpx.Request) -> None: + """Add authorization header to request if we have valid tokens.""" + if self.context.current_tokens and self.context.current_tokens.access_token: + request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}" + + async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]: + """HTTPX auth flow integration.""" + async with self.context.lock: + if not self._initialized: + await self._initialize() + + # Perform OAuth flow if not authenticated + if not self.context.is_token_valid(): + try: + # OAuth flow must be inline due to generator constraints + # Step 1: Discover protected resource metadata (spec revision 2025-06-18) + discovery_request = await self._discover_protected_resource() + discovery_response = yield discovery_request + await self._handle_protected_resource_response(discovery_response) + + # Step 2: Discover OAuth metadata + oauth_request = await self._discover_oauth_metadata() + oauth_response = yield oauth_request + await self._handle_oauth_metadata_response(oauth_response) + + # Step 3: Register client if needed + registration_request = await self._register_client() + if registration_request: + registration_response = yield registration_request + await self._handle_registration_response(registration_response) + + # Step 4: Perform authorization + auth_code, code_verifier = await self._perform_authorization() + + # Step 5: Exchange authorization code for tokens + token_request = await self._exchange_token(auth_code, code_verifier) + token_response = yield token_request + await self._handle_token_response(token_response) + except Exception as e: + logger.error(f"OAuth flow error: {e}") + raise + + # Add authorization header and make request + self._add_auth_header(request) + response = yield request + + # Handle 401 responses + if response.status_code == 401 and self.context.can_refresh_token(): + # Try to refresh token + refresh_request = await self._refresh_token() + refresh_response = yield refresh_request + + if await self._handle_refresh_response(refresh_response): + # Retry original request with new token + self._add_auth_header(request) + yield request + else: + # Refresh failed, need full re-authentication + self._initialized = False + + # OAuth flow must be inline due to generator constraints + # Step 1: Discover protected resource metadata (spec revision 2025-06-18) + discovery_request = await self._discover_protected_resource() + discovery_response = yield discovery_request + await self._handle_protected_resource_response(discovery_response) + + # Step 2: Discover OAuth metadata + oauth_request = await self._discover_oauth_metadata() + oauth_response = yield oauth_request + await self._handle_oauth_metadata_response(oauth_response) + + # Step 3: Register client if needed + registration_request = await self._register_client() + if registration_request: + registration_response = yield registration_request + await self._handle_registration_response(registration_response) + + # Step 4: Perform authorization + auth_code, code_verifier = await self._perform_authorization() + + # Step 5: Exchange authorization code for tokens + token_request = await self._exchange_token(auth_code, code_verifier) + token_response = yield token_request + await self._handle_token_response(token_response) + + # Retry with new tokens + self._add_auth_header(request) + yield request diff --git a/src/mcp/server/auth/handlers/metadata.py b/src/mcp/server/auth/handlers/metadata.py index e37e5d311..f12644215 100644 --- a/src/mcp/server/auth/handlers/metadata.py +++ b/src/mcp/server/auth/handlers/metadata.py @@ -4,7 +4,7 @@ from starlette.responses import Response from mcp.server.auth.json_response import PydanticJSONResponse -from mcp.shared.auth import OAuthMetadata +from mcp.shared.auth import OAuthMetadata, ProtectedResourceMetadata @dataclass @@ -16,3 +16,14 @@ async def handle(self, request: Request) -> Response: content=self.metadata, headers={"Cache-Control": "public, max-age=3600"}, # Cache for 1 hour ) + + +@dataclass +class ProtectedResourceMetadataHandler: + metadata: ProtectedResourceMetadata + + async def handle(self, request: Request) -> Response: + return PydanticJSONResponse( + content=self.metadata, + headers={"Cache-Control": "public, max-age=3600"}, # Cache for 1 hour + ) diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index d73455200..450ee406c 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -151,7 +151,14 @@ async def handle(self, request: Request): authorize_request_redirect_uri = auth_code.redirect_uri else: authorize_request_redirect_uri = None - if token_request.redirect_uri != authorize_request_redirect_uri: + + # Convert both sides to strings for comparison to handle AnyUrl vs string issues + token_redirect_str = str(token_request.redirect_uri) if token_request.redirect_uri is not None else None + auth_redirect_str = ( + str(authorize_request_redirect_uri) if authorize_request_redirect_uri is not None else None + ) + + if token_redirect_str != auth_redirect_str: return self.response( TokenErrorResponse( error="invalid_request", diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py index 2fe1342b7..6251e5ad5 100644 --- a/src/mcp/server/auth/middleware/bearer_auth.py +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -1,12 +1,13 @@ +import json import time from typing import Any +from pydantic import AnyHttpUrl from starlette.authentication import AuthCredentials, AuthenticationBackend, SimpleUser -from starlette.exceptions import HTTPException from starlette.requests import HTTPConnection from starlette.types import Receive, Scope, Send -from mcp.server.auth.provider import AccessToken, OAuthAuthorizationServerProvider +from mcp.server.auth.provider import AccessToken, TokenVerifier class AuthenticatedUser(SimpleUser): @@ -20,14 +21,11 @@ def __init__(self, auth_info: AccessToken): class BearerAuthBackend(AuthenticationBackend): """ - Authentication backend that validates Bearer tokens. + Authentication backend that validates Bearer tokens using a TokenVerifier. """ - def __init__( - self, - provider: OAuthAuthorizationServerProvider[Any, Any, Any], - ): - self.provider = provider + def __init__(self, token_verifier: TokenVerifier): + self.token_verifier = token_verifier async def authenticate(self, conn: HTTPConnection): auth_header = next( @@ -39,8 +37,8 @@ async def authenticate(self, conn: HTTPConnection): token = auth_header[7:] # Remove "Bearer " prefix - # Validate the token with the provider - auth_info = await self.provider.load_access_token(token) + # Validate the token with the verifier + auth_info = await self.token_verifier.verify_token(token) if not auth_info: return None @@ -59,27 +57,72 @@ class RequireAuthMiddleware: auth info in the request state. """ - def __init__(self, app: Any, required_scopes: list[str]): + def __init__( + self, + app: Any, + required_scopes: list[str], + resource_metadata_url: AnyHttpUrl | None = None, + ): """ Initialize the middleware. Args: app: ASGI application - provider: Authentication provider to validate tokens - required_scopes: Optional list of scopes that the token must have + required_scopes: List of scopes that the token must have + resource_metadata_url: Optional protected resource metadata URL for WWW-Authenticate header """ self.app = app self.required_scopes = required_scopes + self.resource_metadata_url = resource_metadata_url async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: auth_user = scope.get("user") if not isinstance(auth_user, AuthenticatedUser): - raise HTTPException(status_code=401, detail="Unauthorized") + await self._send_auth_error( + send, status_code=401, error="invalid_token", description="Authentication required" + ) + return + auth_credentials = scope.get("auth") for required_scope in self.required_scopes: # auth_credentials should always be provided; this is just paranoia if auth_credentials is None or required_scope not in auth_credentials.scopes: - raise HTTPException(status_code=403, detail="Insufficient scope") + await self._send_auth_error( + send, status_code=403, error="insufficient_scope", description=f"Required scope: {required_scope}" + ) + return await self.app(scope, receive, send) + + async def _send_auth_error(self, send: Send, status_code: int, error: str, description: str) -> None: + """Send an authentication error response with WWW-Authenticate header.""" + # Build WWW-Authenticate header value + www_auth_parts = [f'error="{error}"', f'error_description="{description}"'] + if self.resource_metadata_url: + www_auth_parts.append(f'resource_metadata="{self.resource_metadata_url}"') + + www_authenticate = f"Bearer {', '.join(www_auth_parts)}" + + # Send response + body = {"error": error, "error_description": description} + body_bytes = json.dumps(body).encode() + + await send( + { + "type": "http.response.start", + "status": status_code, + "headers": [ + (b"content-type", b"application/json"), + (b"content-length", str(len(body_bytes)).encode()), + (b"www-authenticate", www_authenticate.encode()), + ], + } + ) + + await send( + { + "type": "http.response.body", + "body": body_bytes, + } + ) diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index da18d7a71..acdd55bc2 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -86,6 +86,13 @@ class TokenError(Exception): error_description: str | None = None +class TokenVerifier(Protocol): + """Protocol for verifying bearer tokens.""" + + async def verify_token(self, token: str) -> AccessToken | None: + """Verify a bearer token and return access info if valid.""" + + # NOTE: FastMCP doesn't render any of these types in the user response, so it's # OK to add fields to subclasses which should not be exposed externally. AuthorizationCodeT = TypeVar("AuthorizationCodeT", bound=AuthorizationCode) @@ -278,3 +285,19 @@ def construct_redirect_uri(redirect_uri_base: str, **params: str | None) -> str: redirect_uri = urlunparse(parsed_uri._replace(query=urlencode(query_params))) return redirect_uri + + +class ProviderTokenVerifier(TokenVerifier): + """Token verifier that uses an OAuthAuthorizationServerProvider. + + This is provided for backwards compatibility with existing auth_server_provider + configurations. For new implementations using AS/RS separation, consider using + the TokenVerifier protocol with a dedicated implementation like IntrospectionTokenVerifier. + """ + + def __init__(self, provider: "OAuthAuthorizationServerProvider[AuthorizationCode, RefreshToken, AccessToken]"): + self.provider = provider + + async def verify_token(self, token: str) -> AccessToken | None: + """Verify token using the provider's load_access_token method.""" + return await self.provider.load_access_token(token) diff --git a/src/mcp/server/auth/routes.py b/src/mcp/server/auth/routes.py index 8647334e0..305440242 100644 --- a/src/mcp/server/auth/routes.py +++ b/src/mcp/server/auth/routes.py @@ -180,3 +180,40 @@ def build_metadata( metadata.revocation_endpoint_auth_methods_supported = ["client_secret_post"] return metadata + + +def create_protected_resource_routes( + resource_url: AnyHttpUrl, + authorization_servers: list[AnyHttpUrl], + scopes_supported: list[str] | None = None, +) -> list[Route]: + """ + Create routes for OAuth 2.0 Protected Resource Metadata (RFC 9728). + + Args: + resource_url: The URL of this resource server + authorization_servers: List of authorization servers that can issue tokens + scopes_supported: Optional list of scopes supported by this resource + + Returns: + List of Starlette routes for protected resource metadata + """ + from mcp.server.auth.handlers.metadata import ProtectedResourceMetadataHandler + from mcp.shared.auth import ProtectedResourceMetadata + + metadata = ProtectedResourceMetadata( + resource=resource_url, + authorization_servers=authorization_servers, + scopes_supported=scopes_supported, + # bearer_methods_supported defaults to ["header"] in the model + ) + + handler = ProtectedResourceMetadataHandler(metadata) + + return [ + Route( + "/.well-known/oauth-protected-resource", + endpoint=cors_middleware(handler.handle, ["GET", "OPTIONS"]), + methods=["GET", "OPTIONS"], + ) + ] diff --git a/src/mcp/server/auth/settings.py b/src/mcp/server/auth/settings.py index 7306d91af..0269c31b6 100644 --- a/src/mcp/server/auth/settings.py +++ b/src/mcp/server/auth/settings.py @@ -15,9 +15,15 @@ class RevocationOptions(BaseModel): class AuthSettings(BaseModel): issuer_url: AnyHttpUrl = Field( ..., - description="URL advertised as OAuth issuer; this should be the URL the server " "is reachable at", + description="Base URL where this server is reachable. For AS: OAuth issuer URL. For RS: Resource server URL.", ) service_documentation_url: AnyHttpUrl | None = None client_registration_options: ClientRegistrationOptions | None = None revocation_options: RevocationOptions | None = None required_scopes: list[str] | None = None + + # Resource Server settings (when operating as RS only) + authorization_servers: list[AnyHttpUrl] | None = Field( + None, + description="Authorization servers that can issue tokens for this resource (RS mode)", + ) diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 1b761e917..c74114127 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -30,10 +30,8 @@ BearerAuthBackend, RequireAuthMiddleware, ) -from mcp.server.auth.provider import OAuthAuthorizationServerProvider -from mcp.server.auth.settings import ( - AuthSettings, -) +from mcp.server.auth.provider import OAuthAuthorizationServerProvider, ProviderTokenVerifier, TokenVerifier +from mcp.server.auth.settings import AuthSettings from mcp.server.elicitation import ElicitationResult, ElicitSchemaModelT, elicit_with_validation from mcp.server.fastmcp.exceptions import ResourceError from mcp.server.fastmcp.prompts import Prompt, PromptManager @@ -141,6 +139,7 @@ def __init__( name: str | None = None, instructions: str | None = None, auth_server_provider: OAuthAuthorizationServerProvider[Any, Any, Any] | None = None, + token_verifier: TokenVerifier | None = None, event_store: EventStore | None = None, *, tools: list[Tool] | None = None, @@ -156,14 +155,22 @@ def __init__( self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools) self._resource_manager = ResourceManager(warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources) self._prompt_manager = PromptManager(warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts) - if (self.settings.auth is not None) != (auth_server_provider is not None): - # TODO: after we support separate authorization servers (see - # https://github.com/modelcontextprotocol/modelcontextprotocol/pull/284) - # we should validate that if auth is enabled, we have either an - # auth_server_provider to host our own authorization server, - # OR the URL of a 3rd party authorization server. - raise ValueError("settings.auth must be specified if and only if auth_server_provider " "is specified") + # Validate auth configuration + if self.settings.auth is not None: + if auth_server_provider and token_verifier: + raise ValueError("Cannot specify both auth_server_provider and token_verifier") + if not auth_server_provider and not token_verifier: + raise ValueError("Must specify either auth_server_provider or token_verifier when auth is enabled") + else: + if auth_server_provider or token_verifier: + raise ValueError("Cannot specify auth_server_provider or token_verifier without auth settings") + self._auth_server_provider = auth_server_provider + self._token_verifier = token_verifier + + # Create token verifier from provider if needed (backwards compatibility) + if auth_server_provider and not token_verifier: + self._token_verifier = ProviderTokenVerifier(auth_server_provider) self._event_store = event_store self._custom_starlette_routes: list[Route] = [] self.dependencies = self.settings.dependencies @@ -701,49 +708,60 @@ async def handle_sse(scope: Scope, receive: Receive, send: Send): middleware: list[Middleware] = [] required_scopes = [] - # Add auth endpoints if auth provider is configured - if self._auth_server_provider: - assert self.settings.auth - from mcp.server.auth.routes import create_auth_routes - + # Set up auth if configured + if self.settings.auth: required_scopes = self.settings.auth.required_scopes or [] - middleware = [ - # extract auth info from request (but do not require it) - Middleware( - AuthenticationMiddleware, - backend=BearerAuthBackend( - provider=self._auth_server_provider, + # Add auth middleware if token verifier is available + if self._token_verifier: + middleware = [ + # extract auth info from request (but do not require it) + Middleware( + AuthenticationMiddleware, + backend=BearerAuthBackend(self._token_verifier), ), - ), - # Add the auth context middleware to store - # authenticated user in a contextvar - Middleware(AuthContextMiddleware), - ] - routes.extend( - create_auth_routes( - provider=self._auth_server_provider, - issuer_url=self.settings.auth.issuer_url, - service_documentation_url=self.settings.auth.service_documentation_url, - client_registration_options=self.settings.auth.client_registration_options, - revocation_options=self.settings.auth.revocation_options, + # Add the auth context middleware to store + # authenticated user in a contextvar + Middleware(AuthContextMiddleware), + ] + + # Add auth endpoints if auth server provider is configured + if self._auth_server_provider: + from mcp.server.auth.routes import create_auth_routes + + routes.extend( + create_auth_routes( + provider=self._auth_server_provider, + issuer_url=self.settings.auth.issuer_url, + service_documentation_url=self.settings.auth.service_documentation_url, + client_registration_options=self.settings.auth.client_registration_options, + revocation_options=self.settings.auth.revocation_options, + ) + ) + + # When auth is configured, require authentication + if self._token_verifier: + # Determine resource metadata URL + resource_metadata_url = None + if self.settings.auth and self.settings.auth.authorization_servers: + from pydantic import AnyHttpUrl + + resource_metadata_url = AnyHttpUrl( + str(self.settings.auth.issuer_url).rstrip("/") + "/.well-known/oauth-protected-resource" ) - ) - # When auth is not configured, we shouldn't require auth - if self._auth_server_provider: # Auth is enabled, wrap the endpoints with RequireAuthMiddleware routes.append( Route( self.settings.sse_path, - endpoint=RequireAuthMiddleware(handle_sse, required_scopes), + endpoint=RequireAuthMiddleware(handle_sse, required_scopes, resource_metadata_url), methods=["GET"], ) ) routes.append( Mount( self.settings.message_path, - app=RequireAuthMiddleware(sse.handle_post_message, required_scopes), + app=RequireAuthMiddleware(sse.handle_post_message, required_scopes, resource_metadata_url), ) ) else: @@ -766,6 +784,18 @@ async def sse_endpoint(request: Request) -> Response: app=sse.handle_post_message, ) ) + # Add protected resource metadata endpoint if configured as RS + if self.settings.auth and self.settings.auth.authorization_servers: + from mcp.server.auth.routes import create_protected_resource_routes + + routes.extend( + create_protected_resource_routes( + resource_url=self.settings.auth.issuer_url, + authorization_servers=self.settings.auth.authorization_servers, + scopes_supported=self.settings.auth.required_scopes, + ) + ) + # mount these routes last, so they have the lowest route matching precedence routes.extend(self._custom_starlette_routes) @@ -796,35 +826,49 @@ async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> middleware: list[Middleware] = [] required_scopes = [] - # Add auth endpoints if auth provider is configured - if self._auth_server_provider: - assert self.settings.auth - from mcp.server.auth.routes import create_auth_routes - + # Set up auth if configured + if self.settings.auth: required_scopes = self.settings.auth.required_scopes or [] - middleware = [ - Middleware( - AuthenticationMiddleware, - backend=BearerAuthBackend( - provider=self._auth_server_provider, + # Add auth middleware if token verifier is available + if self._token_verifier: + middleware = [ + Middleware( + AuthenticationMiddleware, + backend=BearerAuthBackend(self._token_verifier), ), - ), - Middleware(AuthContextMiddleware), - ] - routes.extend( - create_auth_routes( - provider=self._auth_server_provider, - issuer_url=self.settings.auth.issuer_url, - service_documentation_url=self.settings.auth.service_documentation_url, - client_registration_options=self.settings.auth.client_registration_options, - revocation_options=self.settings.auth.revocation_options, + Middleware(AuthContextMiddleware), + ] + + # Add auth endpoints if auth server provider is configured + if self._auth_server_provider: + from mcp.server.auth.routes import create_auth_routes + + routes.extend( + create_auth_routes( + provider=self._auth_server_provider, + issuer_url=self.settings.auth.issuer_url, + service_documentation_url=self.settings.auth.service_documentation_url, + client_registration_options=self.settings.auth.client_registration_options, + revocation_options=self.settings.auth.revocation_options, + ) ) - ) + + # Set up routes with or without auth + if self._token_verifier: + # Determine resource metadata URL + resource_metadata_url = None + if self.settings.auth and self.settings.auth.authorization_servers: + from pydantic import AnyHttpUrl + + resource_metadata_url = AnyHttpUrl( + str(self.settings.auth.issuer_url).rstrip("/") + "/.well-known/oauth-protected-resource" + ) + routes.append( Mount( self.settings.streamable_http_path, - app=RequireAuthMiddleware(handle_streamable_http, required_scopes), + app=RequireAuthMiddleware(handle_streamable_http, required_scopes, resource_metadata_url), ) ) else: @@ -836,6 +880,28 @@ async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> ) ) + # Add protected resource metadata endpoint if configured as RS + if self.settings.auth and self.settings.auth.authorization_servers: + from mcp.server.auth.handlers.metadata import ProtectedResourceMetadataHandler + from mcp.server.auth.routes import cors_middleware + from mcp.shared.auth import ProtectedResourceMetadata + + protected_resource_metadata = ProtectedResourceMetadata( + resource=self.settings.auth.issuer_url, + authorization_servers=self.settings.auth.authorization_servers, + scopes_supported=self.settings.auth.required_scopes, + ) + routes.append( + Route( + "/.well-known/oauth-protected-resource", + endpoint=cors_middleware( + ProtectedResourceMetadataHandler(protected_resource_metadata).handle, + ["GET", "OPTIONS"], + ), + methods=["GET", "OPTIONS"], + ) + ) + routes.extend(self._custom_starlette_routes) return Starlette( diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index 4d2d57221..1f2d1659a 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -129,3 +129,16 @@ class OAuthMetadata(BaseModel): introspection_endpoint_auth_methods_supported: list[str] | None = None introspection_endpoint_auth_signing_alg_values_supported: None = None code_challenge_methods_supported: list[str] | None = None + + +class ProtectedResourceMetadata(BaseModel): + """ + RFC 9728 OAuth 2.0 Protected Resource Metadata. + See https://datatracker.ietf.org/doc/html/rfc9728#section-2 + """ + + resource: AnyHttpUrl + authorization_servers: list[AnyHttpUrl] = Field(..., min_length=1) + scopes_supported: list[str] | None = None + bearer_methods_supported: list[str] | None = Field(default=["header"]) # MCP only supports header method + resource_documentation: AnyHttpUrl | None = None diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index de4eb70af..8dee687a9 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -1,25 +1,17 @@ """ -Tests for OAuth client authentication implementation. +Tests for refactored OAuth client authentication implementation. """ -import base64 -import hashlib import time -from unittest.mock import AsyncMock, Mock, patch -from urllib.parse import parse_qs, urlparse import httpx import pytest -from inline_snapshot import snapshot from pydantic import AnyHttpUrl, AnyUrl -from mcp.client.auth import OAuthClientProvider -from mcp.server.auth.routes import build_metadata -from mcp.server.auth.settings import ClientRegistrationOptions, RevocationOptions +from mcp.client.auth import OAuthClientProvider, PKCEParameters from mcp.shared.auth import ( OAuthClientInformationFull, OAuthClientMetadata, - OAuthMetadata, OAuthToken, ) @@ -52,43 +44,15 @@ def mock_storage(): @pytest.fixture def client_metadata(): return OAuthClientMetadata( - redirect_uris=[AnyUrl("http://localhost:3000/callback")], client_name="Test Client", - grant_types=["authorization_code", "refresh_token"], - response_types=["code"], + client_uri=AnyHttpUrl("https://example.com"), + redirect_uris=[AnyUrl("http://localhost:3030/callback")], scope="read write", ) @pytest.fixture -def oauth_metadata(): - return OAuthMetadata( - issuer=AnyHttpUrl("https://auth.example.com"), - authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), - token_endpoint=AnyHttpUrl("https://auth.example.com/token"), - registration_endpoint=AnyHttpUrl("https://auth.example.com/register"), - scopes_supported=["read", "write", "admin"], - response_types_supported=["code"], - grant_types_supported=["authorization_code", "refresh_token"], - code_challenge_methods_supported=["S256"], - ) - - -@pytest.fixture -def oauth_client_info(): - return OAuthClientInformationFull( - client_id="test_client_id", - client_secret="test_client_secret", - redirect_uris=[AnyUrl("http://localhost:3000/callback")], - client_name="Test Client", - grant_types=["authorization_code", "refresh_token"], - response_types=["code"], - scope="read write", - ) - - -@pytest.fixture -def oauth_token(): +def valid_tokens(): return OAuthToken( access_token="test_access_token", token_type="Bearer", @@ -99,799 +63,253 @@ def oauth_token(): @pytest.fixture -async def oauth_provider(client_metadata, mock_storage): - async def mock_redirect_handler(url: str) -> None: +def oauth_provider(client_metadata, mock_storage): + async def redirect_handler(url: str) -> None: + """Mock redirect handler.""" pass - async def mock_callback_handler() -> tuple[str, str | None]: + async def callback_handler() -> tuple[str, str | None]: + """Mock callback handler.""" return "test_auth_code", "test_state" return OAuthClientProvider( server_url="https://api.example.com/v1/mcp", client_metadata=client_metadata, storage=mock_storage, - redirect_handler=mock_redirect_handler, - callback_handler=mock_callback_handler, + redirect_handler=redirect_handler, + callback_handler=callback_handler, ) -class TestOAuthClientProvider: - """Test OAuth client provider functionality.""" - - @pytest.mark.anyio - async def test_init(self, oauth_provider, client_metadata, mock_storage): - """Test OAuth provider initialization.""" - assert oauth_provider.server_url == "https://api.example.com/v1/mcp" - assert oauth_provider.client_metadata == client_metadata - assert oauth_provider.storage == mock_storage - assert oauth_provider.timeout == 300.0 +class TestPKCEParameters: + """Test PKCE parameter generation.""" - def test_generate_code_verifier(self, oauth_provider): - """Test PKCE code verifier generation.""" - verifier = oauth_provider._generate_code_verifier() + def test_pkce_generation(self): + """Test PKCE parameter generation creates valid values.""" + pkce = PKCEParameters.generate() - # Check length (128 characters) - assert len(verifier) == 128 + # Verify lengths + assert len(pkce.code_verifier) == 128 + assert 43 <= len(pkce.code_challenge) <= 128 - # Check charset (RFC 7636: A-Z, a-z, 0-9, "-", ".", "_", "~") + # Verify characters used in verifier allowed_chars = set("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~") - assert set(verifier) <= allowed_chars + assert all(c in allowed_chars for c in pkce.code_verifier) - # Check uniqueness (generate multiple and ensure they're different) - verifiers = {oauth_provider._generate_code_verifier() for _ in range(10)} - assert len(verifiers) == 10 + # Verify base64url encoding in challenge (no padding) + assert "=" not in pkce.code_challenge - @pytest.mark.anyio - async def test_generate_code_challenge(self, oauth_provider): - """Test PKCE code challenge generation.""" - verifier = "test_code_verifier_123" - challenge = oauth_provider._generate_code_challenge(verifier) + def test_pkce_uniqueness(self): + """Test PKCE generates unique values each time.""" + pkce1 = PKCEParameters.generate() + pkce2 = PKCEParameters.generate() - # Manually calculate expected challenge - expected_digest = hashlib.sha256(verifier.encode()).digest() - expected_challenge = base64.urlsafe_b64encode(expected_digest).decode().rstrip("=") + assert pkce1.code_verifier != pkce2.code_verifier + assert pkce1.code_challenge != pkce2.code_challenge - assert challenge == expected_challenge - # Verify it's base64url without padding - assert "=" not in challenge - assert "+" not in challenge - assert "/" not in challenge +class TestOAuthContext: + """Test OAuth context functionality.""" @pytest.mark.anyio - async def test_get_authorization_base_url(self, oauth_provider): - """Test authorization base URL extraction.""" + async def test_oauth_provider_initialization(self, oauth_provider, client_metadata, mock_storage): + """Test OAuthClientProvider basic setup.""" + assert oauth_provider.context.server_url == "https://api.example.com/v1/mcp" + assert oauth_provider.context.client_metadata == client_metadata + assert oauth_provider.context.storage == mock_storage + assert oauth_provider.context.timeout == 300.0 + assert oauth_provider.context is not None + + def test_context_url_parsing(self, oauth_provider): + """Test get_authorization_base_url() extracts base URLs correctly.""" + context = oauth_provider.context + # Test with path - assert oauth_provider._get_authorization_base_url("https://api.example.com/v1/mcp") == "https://api.example.com" + assert context.get_authorization_base_url("https://api.example.com/v1/mcp") == "https://api.example.com" # Test with no path - assert oauth_provider._get_authorization_base_url("https://api.example.com") == "https://api.example.com" + assert context.get_authorization_base_url("https://api.example.com") == "https://api.example.com" # Test with port assert ( - oauth_provider._get_authorization_base_url("https://api.example.com:8080/path/to/mcp") + context.get_authorization_base_url("https://api.example.com:8080/path/to/mcp") == "https://api.example.com:8080" ) - @pytest.mark.anyio - async def test_discover_oauth_metadata_success(self, oauth_provider, oauth_metadata): - """Test successful OAuth metadata discovery.""" - metadata_response = oauth_metadata.model_dump(by_alias=True, mode="json") - - with patch("httpx.AsyncClient") as mock_client_class: - mock_client = AsyncMock() - mock_client_class.return_value.__aenter__.return_value = mock_client - - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = metadata_response - mock_client.get.return_value = mock_response - - result = await oauth_provider._discover_oauth_metadata("https://api.example.com/v1/mcp") - - assert result is not None - assert result.authorization_endpoint == oauth_metadata.authorization_endpoint - assert result.token_endpoint == oauth_metadata.token_endpoint - - # Verify correct URL was called - mock_client.get.assert_called_once() - call_args = mock_client.get.call_args[0] - assert call_args[0] == "https://api.example.com/.well-known/oauth-authorization-server" - - @pytest.mark.anyio - async def test_discover_oauth_metadata_not_found(self, oauth_provider): - """Test OAuth metadata discovery when not found.""" - with patch("httpx.AsyncClient") as mock_client_class: - mock_client = AsyncMock() - mock_client_class.return_value.__aenter__.return_value = mock_client - - mock_response = Mock() - mock_response.status_code = 404 - mock_client.get.return_value = mock_response - - result = await oauth_provider._discover_oauth_metadata("https://api.example.com/v1/mcp") - - assert result is None - - @pytest.mark.anyio - async def test_discover_oauth_metadata_cors_fallback(self, oauth_provider, oauth_metadata): - """Test OAuth metadata discovery with CORS fallback.""" - metadata_response = oauth_metadata.model_dump(by_alias=True, mode="json") - - with patch("httpx.AsyncClient") as mock_client_class: - mock_client = AsyncMock() - mock_client_class.return_value.__aenter__.return_value = mock_client - - # First call fails (CORS), second succeeds - mock_response_success = Mock() - mock_response_success.status_code = 200 - mock_response_success.json.return_value = metadata_response - - mock_client.get.side_effect = [ - TypeError("CORS error"), # First call fails - mock_response_success, # Second call succeeds - ] - - result = await oauth_provider._discover_oauth_metadata("https://api.example.com/v1/mcp") - - assert result is not None - assert mock_client.get.call_count == 2 - - @pytest.mark.anyio - async def test_register_oauth_client_success(self, oauth_provider, oauth_metadata, oauth_client_info): - """Test successful OAuth client registration.""" - registration_response = oauth_client_info.model_dump(by_alias=True, mode="json") - - with patch("httpx.AsyncClient") as mock_client_class: - mock_client = AsyncMock() - mock_client_class.return_value.__aenter__.return_value = mock_client - - mock_response = Mock() - mock_response.status_code = 201 - mock_response.json.return_value = registration_response - mock_client.post.return_value = mock_response - - result = await oauth_provider._register_oauth_client( - "https://api.example.com/v1/mcp", - oauth_provider.client_metadata, - oauth_metadata, - ) - - assert result.client_id == oauth_client_info.client_id - assert result.client_secret == oauth_client_info.client_secret - - # Verify correct registration endpoint was used - mock_client.post.assert_called_once() - call_args = mock_client.post.call_args - assert call_args[0][0] == str(oauth_metadata.registration_endpoint) - - @pytest.mark.anyio - async def test_register_oauth_client_fallback_endpoint(self, oauth_provider, oauth_client_info): - """Test OAuth client registration with fallback endpoint.""" - registration_response = oauth_client_info.model_dump(by_alias=True, mode="json") - - with patch("httpx.AsyncClient") as mock_client_class: - mock_client = AsyncMock() - mock_client_class.return_value.__aenter__.return_value = mock_client - - mock_response = Mock() - mock_response.status_code = 201 - mock_response.json.return_value = registration_response - mock_client.post.return_value = mock_response - - # Mock metadata discovery to return None (fallback) - with patch.object(oauth_provider, "_discover_oauth_metadata", return_value=None): - result = await oauth_provider._register_oauth_client( - "https://api.example.com/v1/mcp", - oauth_provider.client_metadata, - None, - ) - - assert result.client_id == oauth_client_info.client_id - - # Verify fallback endpoint was used - mock_client.post.assert_called_once() - call_args = mock_client.post.call_args - assert call_args[0][0] == "https://api.example.com/register" - - @pytest.mark.anyio - async def test_register_oauth_client_failure(self, oauth_provider): - """Test OAuth client registration failure.""" - with patch("httpx.AsyncClient") as mock_client_class: - mock_client = AsyncMock() - mock_client_class.return_value.__aenter__.return_value = mock_client - - mock_response = Mock() - mock_response.status_code = 400 - mock_response.text = "Bad Request" - mock_client.post.return_value = mock_response - - # Mock metadata discovery to return None (fallback) - with patch.object(oauth_provider, "_discover_oauth_metadata", return_value=None): - with pytest.raises(httpx.HTTPStatusError): - await oauth_provider._register_oauth_client( - "https://api.example.com/v1/mcp", - oauth_provider.client_metadata, - None, - ) - - @pytest.mark.anyio - async def test_has_valid_token_no_token(self, oauth_provider): - """Test token validation with no token.""" - assert not oauth_provider._has_valid_token() - - @pytest.mark.anyio - async def test_has_valid_token_valid(self, oauth_provider, oauth_token): - """Test token validation with valid token.""" - oauth_provider._current_tokens = oauth_token - oauth_provider._token_expiry_time = time.time() + 3600 # Future expiry - - assert oauth_provider._has_valid_token() - - @pytest.mark.anyio - async def test_has_valid_token_expired(self, oauth_provider, oauth_token): - """Test token validation with expired token.""" - oauth_provider._current_tokens = oauth_token - oauth_provider._token_expiry_time = time.time() - 3600 # Past expiry - - assert not oauth_provider._has_valid_token() - - @pytest.mark.anyio - async def test_validate_token_scopes_no_scope(self, oauth_provider): - """Test scope validation with no scope returned.""" - token = OAuthToken(access_token="test", token_type="Bearer") - - # Should not raise exception - await oauth_provider._validate_token_scopes(token) - - @pytest.mark.anyio - async def test_validate_token_scopes_valid(self, oauth_provider, client_metadata): - """Test scope validation with valid scopes.""" - oauth_provider.client_metadata = client_metadata - token = OAuthToken( - access_token="test", - token_type="Bearer", - scope="read write", + # Test with query params + assert ( + context.get_authorization_base_url("https://api.example.com/path?param=value") == "https://api.example.com" ) - # Should not raise exception - await oauth_provider._validate_token_scopes(token) - @pytest.mark.anyio - async def test_validate_token_scopes_subset(self, oauth_provider, client_metadata): - """Test scope validation with subset of requested scopes.""" - oauth_provider.client_metadata = client_metadata - token = OAuthToken( - access_token="test", - token_type="Bearer", - scope="read", - ) + async def test_token_validity_checking(self, oauth_provider, mock_storage, valid_tokens): + """Test is_token_valid() and can_refresh_token() logic.""" + context = oauth_provider.context - # Should not raise exception (servers can grant subset) - await oauth_provider._validate_token_scopes(token) + # No tokens - should be invalid + assert not context.is_token_valid() + assert not context.can_refresh_token() - @pytest.mark.anyio - async def test_validate_token_scopes_unauthorized(self, oauth_provider, client_metadata): - """Test scope validation with unauthorized scopes.""" - oauth_provider.client_metadata = client_metadata - token = OAuthToken( - access_token="test", - token_type="Bearer", - scope="read write admin", # Includes unauthorized "admin" + # Set valid tokens and client info + context.current_tokens = valid_tokens + context.token_expiry_time = time.time() + 1800 # 30 minutes from now + context.client_info = OAuthClientInformationFull( + client_id="test_client_id", + client_secret="test_client_secret", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], ) - with pytest.raises(Exception, match="Server granted unauthorized scopes"): - await oauth_provider._validate_token_scopes(token) + # Should be valid + assert context.is_token_valid() + assert context.can_refresh_token() # Has refresh token and client info - @pytest.mark.anyio - async def test_validate_token_scopes_no_requested(self, oauth_provider): - """Test scope validation with no requested scopes accepts any server scopes.""" - # No scope in client metadata - oauth_provider.client_metadata.scope = None - token = OAuthToken( - access_token="test", - token_type="Bearer", - scope="admin super", - ) + # Expire the token + context.token_expiry_time = time.time() - 100 # Expired 100 seconds ago + assert not context.is_token_valid() + assert context.can_refresh_token() # Can still refresh - # Should not raise exception when no scopes were explicitly requested - # (accepts server defaults) - await oauth_provider._validate_token_scopes(token) + # Remove refresh token + context.current_tokens.refresh_token = None + assert not context.can_refresh_token() - @pytest.mark.anyio - async def test_initialize(self, oauth_provider, mock_storage, oauth_token, oauth_client_info): - """Test initialization loading from storage.""" - mock_storage._tokens = oauth_token - mock_storage._client_info = oauth_client_info + # Remove client info + context.current_tokens.refresh_token = "test_refresh_token" + context.client_info = None + assert not context.can_refresh_token() - await oauth_provider.initialize() + def test_clear_tokens(self, oauth_provider, valid_tokens): + """Test clear_tokens() removes token data.""" + context = oauth_provider.context + context.current_tokens = valid_tokens + context.token_expiry_time = time.time() + 1800 - assert oauth_provider._current_tokens == oauth_token - assert oauth_provider._client_info == oauth_client_info + # Clear tokens + context.clear_tokens() - @pytest.mark.anyio - async def test_get_or_register_client_existing(self, oauth_provider, oauth_client_info): - """Test getting existing client info.""" - oauth_provider._client_info = oauth_client_info + # Verify cleared + assert context.current_tokens is None + assert context.token_expiry_time is None - result = await oauth_provider._get_or_register_client() - assert result == oauth_client_info +class TestOAuthFlow: + """Test OAuth flow methods.""" @pytest.mark.anyio - async def test_get_or_register_client_register_new(self, oauth_provider, oauth_client_info): - """Test registering new client.""" - with patch.object(oauth_provider, "_register_oauth_client", return_value=oauth_client_info) as mock_register: - result = await oauth_provider._get_or_register_client() + async def test_discover_protected_resource_request(self, oauth_provider): + """Test protected resource discovery request building.""" + request = await oauth_provider._discover_protected_resource() - assert result == oauth_client_info - assert oauth_provider._client_info == oauth_client_info - mock_register.assert_called_once() + assert request.method == "GET" + assert str(request.url) == "https://api.example.com/.well-known/oauth-protected-resource" + assert "mcp-protocol-version" in request.headers @pytest.mark.anyio - async def test_exchange_code_for_token_success(self, oauth_provider, oauth_client_info, oauth_token): - """Test successful code exchange for token.""" - oauth_provider._code_verifier = "test_verifier" - token_response = oauth_token.model_dump(by_alias=True, mode="json") - - with patch("httpx.AsyncClient") as mock_client_class: - mock_client = AsyncMock() - mock_client_class.return_value.__aenter__.return_value = mock_client - - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = token_response - mock_client.post.return_value = mock_response - - with patch.object(oauth_provider, "_validate_token_scopes") as mock_validate: - await oauth_provider._exchange_code_for_token("test_auth_code", oauth_client_info) + async def test_discover_oauth_metadata_request(self, oauth_provider): + """Test OAuth metadata discovery request building.""" + request = await oauth_provider._discover_oauth_metadata() - assert oauth_provider._current_tokens.access_token == oauth_token.access_token - mock_validate.assert_called_once() + assert request.method == "GET" + assert str(request.url) == "https://api.example.com/.well-known/oauth-authorization-server" + assert "mcp-protocol-version" in request.headers @pytest.mark.anyio - async def test_exchange_code_for_token_failure(self, oauth_provider, oauth_client_info): - """Test failed code exchange for token.""" - oauth_provider._code_verifier = "test_verifier" + async def test_register_client_request(self, oauth_provider): + """Test client registration request building.""" + request = await oauth_provider._register_client() - with patch("httpx.AsyncClient") as mock_client_class: - mock_client = AsyncMock() - mock_client_class.return_value.__aenter__.return_value = mock_client - - mock_response = Mock() - mock_response.status_code = 400 - mock_response.text = "Invalid grant" - mock_client.post.return_value = mock_response - - with pytest.raises(Exception, match="Token exchange failed"): - await oauth_provider._exchange_code_for_token("invalid_auth_code", oauth_client_info) + assert request is not None + assert request.method == "POST" + assert str(request.url) == "https://api.example.com/register" + assert request.headers["Content-Type"] == "application/json" @pytest.mark.anyio - async def test_refresh_access_token_success(self, oauth_provider, oauth_client_info, oauth_token): - """Test successful token refresh.""" - oauth_provider._current_tokens = oauth_token - oauth_provider._client_info = oauth_client_info - - new_token = OAuthToken( - access_token="new_access_token", - token_type="Bearer", - expires_in=3600, - refresh_token="new_refresh_token", - scope="read write", + async def test_register_client_skip_if_registered(self, oauth_provider, mock_storage): + """Test client registration is skipped if already registered.""" + # Set existing client info + client_info = OAuthClientInformationFull( + client_id="existing_client", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], ) - token_response = new_token.model_dump(by_alias=True, mode="json") - - with patch("httpx.AsyncClient") as mock_client_class: - mock_client = AsyncMock() - mock_client_class.return_value.__aenter__.return_value = mock_client - - mock_response = Mock() - mock_response.status_code = 200 - mock_response.json.return_value = token_response - mock_client.post.return_value = mock_response + oauth_provider.context.client_info = client_info - with patch.object(oauth_provider, "_validate_token_scopes") as mock_validate: - result = await oauth_provider._refresh_access_token() - - assert result is True - assert oauth_provider._current_tokens.access_token == new_token.access_token - mock_validate.assert_called_once() + # Should return None (skip registration) + request = await oauth_provider._register_client() + assert request is None @pytest.mark.anyio - async def test_refresh_access_token_no_refresh_token(self, oauth_provider): - """Test token refresh with no refresh token.""" - oauth_provider._current_tokens = OAuthToken( - access_token="test", - token_type="Bearer", - # No refresh_token + async def test_token_exchange_request(self, oauth_provider): + """Test token exchange request building.""" + # Set up required context + oauth_provider.context.client_info = OAuthClientInformationFull( + client_id="test_client", + client_secret="test_secret", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], ) - result = await oauth_provider._refresh_access_token() - assert result is False - - @pytest.mark.anyio - async def test_refresh_access_token_failure(self, oauth_provider, oauth_client_info, oauth_token): - """Test failed token refresh.""" - oauth_provider._current_tokens = oauth_token - oauth_provider._client_info = oauth_client_info - - with patch("httpx.AsyncClient") as mock_client_class: - mock_client = AsyncMock() - mock_client_class.return_value.__aenter__.return_value = mock_client - - mock_response = Mock() - mock_response.status_code = 400 - mock_client.post.return_value = mock_response - - result = await oauth_provider._refresh_access_token() - assert result is False - - @pytest.mark.anyio - async def test_perform_oauth_flow_success(self, oauth_provider, oauth_metadata, oauth_client_info): - """Test successful OAuth flow.""" - oauth_provider._metadata = oauth_metadata - oauth_provider._client_info = oauth_client_info - - # Mock the redirect handler to capture the auth URL - auth_url_captured = None - - async def mock_redirect_handler(url: str) -> None: - nonlocal auth_url_captured - auth_url_captured = url - - oauth_provider.redirect_handler = mock_redirect_handler - - # Mock callback handler with matching state - async def mock_callback_handler() -> tuple[str, str | None]: - # Extract state from auth URL to return matching value - if auth_url_captured: - parsed_url = urlparse(auth_url_captured) - query_params = parse_qs(parsed_url.query) - state = query_params.get("state", [None])[0] - return "test_auth_code", state - return "test_auth_code", "test_state" - - oauth_provider.callback_handler = mock_callback_handler - - with patch.object(oauth_provider, "_exchange_code_for_token") as mock_exchange: - await oauth_provider._perform_oauth_flow() - - # Verify auth URL was generated correctly - assert auth_url_captured is not None - parsed_url = urlparse(auth_url_captured) - query_params = parse_qs(parsed_url.query) - - assert query_params["response_type"][0] == "code" - assert query_params["client_id"][0] == oauth_client_info.client_id - assert query_params["code_challenge_method"][0] == "S256" - assert "code_challenge" in query_params - assert "state" in query_params - - # Verify code exchange was called - mock_exchange.assert_called_once_with("test_auth_code", oauth_client_info) - - @pytest.mark.anyio - async def test_perform_oauth_flow_state_mismatch(self, oauth_provider, oauth_metadata, oauth_client_info): - """Test OAuth flow with state parameter mismatch.""" - oauth_provider._metadata = oauth_metadata - oauth_provider._client_info = oauth_client_info - - # Mock callback handler to return mismatched state - async def mock_callback_handler() -> tuple[str, str | None]: - return "test_auth_code", "wrong_state" - - oauth_provider.callback_handler = mock_callback_handler - - async def mock_redirect_handler(url: str) -> None: - pass - - oauth_provider.redirect_handler = mock_redirect_handler - - with pytest.raises(Exception, match="State parameter mismatch"): - await oauth_provider._perform_oauth_flow() - - @pytest.mark.anyio - async def test_ensure_token_existing_valid(self, oauth_provider, oauth_token): - """Test ensure_token with existing valid token.""" - oauth_provider._current_tokens = oauth_token - oauth_provider._token_expiry_time = time.time() + 3600 - - await oauth_provider.ensure_token() - - # Should not trigger new auth flow - assert oauth_provider._current_tokens == oauth_token - - @pytest.mark.anyio - async def test_ensure_token_refresh(self, oauth_provider, oauth_token): - """Test ensure_token with expired token that can be refreshed.""" - oauth_provider._current_tokens = oauth_token - oauth_provider._token_expiry_time = time.time() - 3600 # Expired - - with patch.object(oauth_provider, "_refresh_access_token", return_value=True) as mock_refresh: - await oauth_provider.ensure_token() - mock_refresh.assert_called_once() - - @pytest.mark.anyio - async def test_ensure_token_full_flow(self, oauth_provider): - """Test ensure_token triggering full OAuth flow.""" - # No existing token - with patch.object(oauth_provider, "_perform_oauth_flow") as mock_flow: - await oauth_provider.ensure_token() - mock_flow.assert_called_once() - - @pytest.mark.anyio - async def test_async_auth_flow_add_token(self, oauth_provider, oauth_token): - """Test async auth flow adding Bearer token to request.""" - oauth_provider._current_tokens = oauth_token - oauth_provider._token_expiry_time = time.time() + 3600 + request = await oauth_provider._exchange_token("test_auth_code", "test_verifier") + + assert request.method == "POST" + assert str(request.url) == "https://api.example.com/token" + assert request.headers["Content-Type"] == "application/x-www-form-urlencoded" + + # Check form data + content = request.content.decode() + assert "grant_type=authorization_code" in content + assert "code=test_auth_code" in content + assert "code_verifier=test_verifier" in content + assert "client_id=test_client" in content + assert "client_secret=test_secret" in content + + @pytest.mark.anyio + async def test_refresh_token_request(self, oauth_provider, valid_tokens): + """Test refresh token request building.""" + # Set up required context + oauth_provider.context.current_tokens = valid_tokens + oauth_provider.context.client_info = OAuthClientInformationFull( + client_id="test_client", + client_secret="test_secret", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) - request = httpx.Request("GET", "https://api.example.com/data") + request = await oauth_provider._refresh_token() - # Mock response - mock_response = Mock() - mock_response.status_code = 200 + assert request.method == "POST" + assert str(request.url) == "https://api.example.com/token" + assert request.headers["Content-Type"] == "application/x-www-form-urlencoded" - auth_flow = oauth_provider.async_auth_flow(request) - updated_request = await auth_flow.__anext__() + # Check form data + content = request.content.decode() + assert "grant_type=refresh_token" in content + assert "refresh_token=test_refresh_token" in content + assert "client_id=test_client" in content + assert "client_secret=test_secret" in content - assert updated_request.headers["Authorization"] == f"Bearer {oauth_token.access_token}" - # Send mock response - try: - await auth_flow.asend(mock_response) - except StopAsyncIteration: - pass +class TestAuthFlow: + """Test the auth flow in httpx.""" @pytest.mark.anyio - async def test_async_auth_flow_401_response(self, oauth_provider, oauth_token): - """Test async auth flow handling 401 response.""" - oauth_provider._current_tokens = oauth_token - oauth_provider._token_expiry_time = time.time() + 3600 + async def test_auth_flow_with_valid_tokens(self, oauth_provider, mock_storage, valid_tokens): + """Test auth flow when tokens are already valid.""" + # Pre-store valid tokens + await mock_storage.set_tokens(valid_tokens) + oauth_provider.context.current_tokens = valid_tokens + oauth_provider.context.token_expiry_time = time.time() + 1800 + oauth_provider._initialized = True - request = httpx.Request("GET", "https://api.example.com/data") + # Create a test request + test_request = httpx.Request("GET", "https://api.example.com/test") - # Mock 401 response - mock_response = Mock() - mock_response.status_code = 401 + # Mock the auth flow + auth_flow = oauth_provider.async_auth_flow(test_request) - auth_flow = oauth_provider.async_auth_flow(request) - await auth_flow.__anext__() + # Should get the request with auth header added + request = await auth_flow.__anext__() + assert request.headers["Authorization"] == "Bearer test_access_token" - # Send 401 response + # Send a successful response + response = httpx.Response(200) try: - await auth_flow.asend(mock_response) + await auth_flow.asend(response) except StopAsyncIteration: - pass - - # Should clear current tokens - assert oauth_provider._current_tokens is None - - @pytest.mark.anyio - async def test_async_auth_flow_no_token(self, oauth_provider): - """Test async auth flow with no token triggers auth flow.""" - request = httpx.Request("GET", "https://api.example.com/data") - - with ( - patch.object(oauth_provider, "initialize") as mock_init, - patch.object(oauth_provider, "ensure_token") as mock_ensure, - ): - auth_flow = oauth_provider.async_auth_flow(request) - updated_request = await auth_flow.__anext__() - - mock_init.assert_called_once() - mock_ensure.assert_called_once() - - # No Authorization header should be added if no token - assert "Authorization" not in updated_request.headers - - @pytest.mark.anyio - async def test_scope_priority_client_metadata_first(self, oauth_provider, oauth_client_info): - """Test that client metadata scope takes priority.""" - oauth_provider.client_metadata.scope = "read write" - oauth_provider._client_info = oauth_client_info - oauth_provider._client_info.scope = "admin" - - # Build auth params to test scope logic - auth_params = { - "response_type": "code", - "client_id": "test_client", - "redirect_uri": "http://localhost:3000/callback", - "state": "test_state", - "code_challenge": "test_challenge", - "code_challenge_method": "S256", - } - - # Apply scope logic from _perform_oauth_flow - if oauth_provider.client_metadata.scope: - auth_params["scope"] = oauth_provider.client_metadata.scope - elif hasattr(oauth_provider._client_info, "scope") and oauth_provider._client_info.scope: - auth_params["scope"] = oauth_provider._client_info.scope - - assert auth_params["scope"] == "read write" - - @pytest.mark.anyio - async def test_scope_priority_no_client_metadata_scope(self, oauth_provider, oauth_client_info): - """Test that no scope parameter is set when client metadata has no scope.""" - oauth_provider.client_metadata.scope = None - oauth_provider._client_info = oauth_client_info - oauth_provider._client_info.scope = "admin" - - # Build auth params to test scope logic - auth_params = { - "response_type": "code", - "client_id": "test_client", - "redirect_uri": "http://localhost:3000/callback", - "state": "test_state", - "code_challenge": "test_challenge", - "code_challenge_method": "S256", - } - - # Apply simplified scope logic from _perform_oauth_flow - if oauth_provider.client_metadata.scope: - auth_params["scope"] = oauth_provider.client_metadata.scope - # No fallback to client_info scope in simplified logic - - # No scope should be set since client metadata doesn't have explicit scope - assert "scope" not in auth_params - - @pytest.mark.anyio - async def test_scope_priority_no_scope(self, oauth_provider, oauth_client_info): - """Test that no scope parameter is set when no scopes specified.""" - oauth_provider.client_metadata.scope = None - oauth_provider._client_info = oauth_client_info - oauth_provider._client_info.scope = None - - # Build auth params to test scope logic - auth_params = { - "response_type": "code", - "client_id": "test_client", - "redirect_uri": "http://localhost:3000/callback", - "state": "test_state", - "code_challenge": "test_challenge", - "code_challenge_method": "S256", - } - - # Apply scope logic from _perform_oauth_flow - if oauth_provider.client_metadata.scope: - auth_params["scope"] = oauth_provider.client_metadata.scope - elif hasattr(oauth_provider._client_info, "scope") and oauth_provider._client_info.scope: - auth_params["scope"] = oauth_provider._client_info.scope - - # No scope should be set - assert "scope" not in auth_params - - @pytest.mark.anyio - async def test_state_parameter_validation_uses_constant_time( - self, oauth_provider, oauth_metadata, oauth_client_info - ): - """Test that state parameter validation uses constant-time comparison.""" - oauth_provider._metadata = oauth_metadata - oauth_provider._client_info = oauth_client_info - - # Mock callback handler to return mismatched state - async def mock_callback_handler() -> tuple[str, str | None]: - return "test_auth_code", "wrong_state" - - oauth_provider.callback_handler = mock_callback_handler - - async def mock_redirect_handler(url: str) -> None: - pass - - oauth_provider.redirect_handler = mock_redirect_handler - - # Patch secrets.compare_digest to verify it's being called - with patch("mcp.client.auth.secrets.compare_digest", return_value=False) as mock_compare: - with pytest.raises(Exception, match="State parameter mismatch"): - await oauth_provider._perform_oauth_flow() - - # Verify constant-time comparison was used - mock_compare.assert_called_once() - - @pytest.mark.anyio - async def test_state_parameter_validation_none_state(self, oauth_provider, oauth_metadata, oauth_client_info): - """Test that None state is handled correctly.""" - oauth_provider._metadata = oauth_metadata - oauth_provider._client_info = oauth_client_info - - # Mock callback handler to return None state - async def mock_callback_handler() -> tuple[str, str | None]: - return "test_auth_code", None - - oauth_provider.callback_handler = mock_callback_handler - - async def mock_redirect_handler(url: str) -> None: - pass - - oauth_provider.redirect_handler = mock_redirect_handler - - with pytest.raises(Exception, match="State parameter mismatch"): - await oauth_provider._perform_oauth_flow() - - @pytest.mark.anyio - async def test_token_exchange_error_basic(self, oauth_provider, oauth_client_info): - """Test token exchange error handling (basic).""" - oauth_provider._code_verifier = "test_verifier" - - with patch("httpx.AsyncClient") as mock_client_class: - mock_client = AsyncMock() - mock_client_class.return_value.__aenter__.return_value = mock_client - - # Mock error response - mock_response = Mock() - mock_response.status_code = 400 - mock_response.text = "Bad Request" - mock_client.post.return_value = mock_response - - with pytest.raises(Exception, match="Token exchange failed"): - await oauth_provider._exchange_code_for_token("invalid_auth_code", oauth_client_info) - - -@pytest.mark.parametrize( - ( - "issuer_url", - "service_documentation_url", - "authorization_endpoint", - "token_endpoint", - "registration_endpoint", - "revocation_endpoint", - ), - ( - pytest.param( - "https://auth.example.com", - "https://auth.example.com/docs", - "https://auth.example.com/authorize", - "https://auth.example.com/token", - "https://auth.example.com/register", - "https://auth.example.com/revoke", - id="simple-url", - ), - pytest.param( - "https://auth.example.com/", - "https://auth.example.com/docs", - "https://auth.example.com/authorize", - "https://auth.example.com/token", - "https://auth.example.com/register", - "https://auth.example.com/revoke", - id="with-trailing-slash", - ), - pytest.param( - "https://auth.example.com/v1/mcp", - "https://auth.example.com/v1/mcp/docs", - "https://auth.example.com/v1/mcp/authorize", - "https://auth.example.com/v1/mcp/token", - "https://auth.example.com/v1/mcp/register", - "https://auth.example.com/v1/mcp/revoke", - id="with-path-param", - ), - ), -) -def test_build_metadata( - issuer_url: str, - service_documentation_url: str, - authorization_endpoint: str, - token_endpoint: str, - registration_endpoint: str, - revocation_endpoint: str, -): - metadata = build_metadata( - issuer_url=AnyHttpUrl(issuer_url), - service_documentation_url=AnyHttpUrl(service_documentation_url), - client_registration_options=ClientRegistrationOptions(enabled=True, valid_scopes=["read", "write", "admin"]), - revocation_options=RevocationOptions(enabled=True), - ) - - assert metadata == snapshot( - OAuthMetadata( - issuer=AnyHttpUrl(issuer_url), - authorization_endpoint=AnyHttpUrl(authorization_endpoint), - token_endpoint=AnyHttpUrl(token_endpoint), - registration_endpoint=AnyHttpUrl(registration_endpoint), - scopes_supported=["read", "write", "admin"], - grant_types_supported=["authorization_code", "refresh_token"], - token_endpoint_auth_methods_supported=["client_secret_post"], - service_documentation=AnyHttpUrl(service_documentation_url), - revocation_endpoint=AnyHttpUrl(revocation_endpoint), - revocation_endpoint_auth_methods_supported=["client_secret_post"], - code_challenge_methods_supported=["S256"], - ) - ) + pass # Expected diff --git a/tests/server/auth/middleware/test_bearer_auth.py b/tests/server/auth/middleware/test_bearer_auth.py index 79b813096..5bb0f969e 100644 --- a/tests/server/auth/middleware/test_bearer_auth.py +++ b/tests/server/auth/middleware/test_bearer_auth.py @@ -8,7 +8,6 @@ import pytest from starlette.authentication import AuthCredentials from starlette.datastructures import Headers -from starlette.exceptions import HTTPException from starlette.requests import Request from starlette.types import Message, Receive, Scope, Send @@ -20,6 +19,7 @@ from mcp.server.auth.provider import ( AccessToken, OAuthAuthorizationServerProvider, + ProviderTokenVerifier, ) @@ -118,14 +118,14 @@ class TestBearerAuthBackend: async def test_no_auth_header(self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]): """Test authentication with no Authorization header.""" - backend = BearerAuthBackend(provider=mock_oauth_provider) + backend = BearerAuthBackend(token_verifier=ProviderTokenVerifier(mock_oauth_provider)) request = Request({"type": "http", "headers": []}) result = await backend.authenticate(request) assert result is None async def test_non_bearer_auth_header(self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]): """Test authentication with non-Bearer Authorization header.""" - backend = BearerAuthBackend(provider=mock_oauth_provider) + backend = BearerAuthBackend(token_verifier=ProviderTokenVerifier(mock_oauth_provider)) request = Request( { "type": "http", @@ -137,7 +137,7 @@ async def test_non_bearer_auth_header(self, mock_oauth_provider: OAuthAuthorizat async def test_invalid_token(self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]): """Test authentication with invalid token.""" - backend = BearerAuthBackend(provider=mock_oauth_provider) + backend = BearerAuthBackend(token_verifier=ProviderTokenVerifier(mock_oauth_provider)) request = Request( { "type": "http", @@ -153,7 +153,7 @@ async def test_expired_token( expired_access_token: AccessToken, ): """Test authentication with expired token.""" - backend = BearerAuthBackend(provider=mock_oauth_provider) + backend = BearerAuthBackend(token_verifier=ProviderTokenVerifier(mock_oauth_provider)) add_token_to_provider(mock_oauth_provider, "expired_token", expired_access_token) request = Request( { @@ -170,7 +170,7 @@ async def test_valid_token( valid_access_token: AccessToken, ): """Test authentication with valid token.""" - backend = BearerAuthBackend(provider=mock_oauth_provider) + backend = BearerAuthBackend(token_verifier=ProviderTokenVerifier(mock_oauth_provider)) add_token_to_provider(mock_oauth_provider, "valid_token", valid_access_token) request = Request( { @@ -194,7 +194,7 @@ async def test_token_without_expiry( no_expiry_access_token: AccessToken, ): """Test authentication with token that has no expiry.""" - backend = BearerAuthBackend(provider=mock_oauth_provider) + backend = BearerAuthBackend(token_verifier=ProviderTokenVerifier(mock_oauth_provider)) add_token_to_provider(mock_oauth_provider, "no_expiry_token", no_expiry_access_token) request = Request( { @@ -218,7 +218,7 @@ async def test_lowercase_bearer_prefix( valid_access_token: AccessToken, ): """Test with lowercase 'bearer' prefix in Authorization header""" - backend = BearerAuthBackend(provider=mock_oauth_provider) + backend = BearerAuthBackend(token_verifier=ProviderTokenVerifier(mock_oauth_provider)) add_token_to_provider(mock_oauth_provider, "valid_token", valid_access_token) headers = Headers({"Authorization": "bearer valid_token"}) scope = {"type": "http", "headers": headers.raw} @@ -238,7 +238,7 @@ async def test_mixed_case_bearer_prefix( valid_access_token: AccessToken, ): """Test with mixed 'BeArEr' prefix in Authorization header""" - backend = BearerAuthBackend(provider=mock_oauth_provider) + backend = BearerAuthBackend(token_verifier=ProviderTokenVerifier(mock_oauth_provider)) add_token_to_provider(mock_oauth_provider, "valid_token", valid_access_token) headers = Headers({"authorization": "BeArEr valid_token"}) scope = {"type": "http", "headers": headers.raw} @@ -258,7 +258,7 @@ async def test_mixed_case_authorization_header( valid_access_token: AccessToken, ): """Test authentication with mixed 'Authorization' header.""" - backend = BearerAuthBackend(provider=mock_oauth_provider) + backend = BearerAuthBackend(token_verifier=ProviderTokenVerifier(mock_oauth_provider)) add_token_to_provider(mock_oauth_provider, "valid_token", valid_access_token) headers = Headers({"AuThOrIzAtIoN": "BeArEr valid_token"}) scope = {"type": "http", "headers": headers.raw} @@ -287,14 +287,18 @@ async def test_no_user(self): async def receive() -> Message: return {"type": "http.request"} + sent_messages = [] + async def send(message: Message) -> None: - pass + sent_messages.append(message) - with pytest.raises(HTTPException) as excinfo: - await middleware(scope, receive, send) + await middleware(scope, receive, send) - assert excinfo.value.status_code == 401 - assert excinfo.value.detail == "Unauthorized" + # Check that a 401 response was sent + assert len(sent_messages) == 2 + assert sent_messages[0]["type"] == "http.response.start" + assert sent_messages[0]["status"] == 401 + assert any(h[0] == b"www-authenticate" for h in sent_messages[0]["headers"]) assert not app.called async def test_non_authenticated_user(self): @@ -307,14 +311,18 @@ async def test_non_authenticated_user(self): async def receive() -> Message: return {"type": "http.request"} + sent_messages = [] + async def send(message: Message) -> None: - pass + sent_messages.append(message) - with pytest.raises(HTTPException) as excinfo: - await middleware(scope, receive, send) + await middleware(scope, receive, send) - assert excinfo.value.status_code == 401 - assert excinfo.value.detail == "Unauthorized" + # Check that a 401 response was sent + assert len(sent_messages) == 2 + assert sent_messages[0]["type"] == "http.response.start" + assert sent_messages[0]["status"] == 401 + assert any(h[0] == b"www-authenticate" for h in sent_messages[0]["headers"]) assert not app.called async def test_missing_required_scope(self, valid_access_token: AccessToken): @@ -332,14 +340,18 @@ async def test_missing_required_scope(self, valid_access_token: AccessToken): async def receive() -> Message: return {"type": "http.request"} + sent_messages = [] + async def send(message: Message) -> None: - pass + sent_messages.append(message) - with pytest.raises(HTTPException) as excinfo: - await middleware(scope, receive, send) + await middleware(scope, receive, send) - assert excinfo.value.status_code == 403 - assert excinfo.value.detail == "Insufficient scope" + # Check that a 403 response was sent + assert len(sent_messages) == 2 + assert sent_messages[0]["type"] == "http.response.start" + assert sent_messages[0]["status"] == 403 + assert any(h[0] == b"www-authenticate" for h in sent_messages[0]["headers"]) assert not app.called async def test_no_auth_credentials(self, valid_access_token: AccessToken): @@ -356,14 +368,18 @@ async def test_no_auth_credentials(self, valid_access_token: AccessToken): async def receive() -> Message: return {"type": "http.request"} + sent_messages = [] + async def send(message: Message) -> None: - pass + sent_messages.append(message) - with pytest.raises(HTTPException) as excinfo: - await middleware(scope, receive, send) + await middleware(scope, receive, send) - assert excinfo.value.status_code == 403 - assert excinfo.value.detail == "Insufficient scope" + # Check that a 403 response was sent + assert len(sent_messages) == 2 + assert sent_messages[0]["type"] == "http.response.start" + assert sent_messages[0]["status"] == 403 + assert any(h[0] == b"www-authenticate" for h in sent_messages[0]["headers"]) assert not app.called async def test_has_required_scopes(self, valid_access_token: AccessToken):