diff --git a/README.md b/README.md index 3889dc40b..a63cb4056 100644 --- a/README.md +++ b/README.md @@ -66,7 +66,7 @@ The Model Context Protocol allows applications to provide context for LLMs in a - Build MCP clients that can connect to any MCP server - Create MCP servers that expose resources, prompts and tools -- Use standard transports like stdio and SSE +- Use standard transports like stdio, SSE, and Streamable HTTP - Handle all MCP protocol messages and lifecycle events ## Installation @@ -334,7 +334,7 @@ mcp = FastMCP("My App", ) ``` -See [OAuthServerProvider](mcp/server/auth/provider.py) for more details. +See [OAuthServerProvider](src/mcp/server/auth/provider.py) for more details. ## Running Your Server @@ -387,8 +387,81 @@ python server.py mcp run server.py ``` +### Streamable HTTP Transport + +> **Note**: Streamable HTTP transport is superseding SSE transport for production deployments. + +```python +from mcp.server.fastmcp import FastMCP + +# Stateful server (maintains session state) +mcp = FastMCP("StatefulServer") + +# Stateless server (no session persistence) +mcp = FastMCP("StatelessServer", stateless_http=True) + +# Run server with streamable_http transport +mcp.run(transport="streamable-http") +``` + +You can mount multiple FastMCP servers in a FastAPI application: + +```python +# echo.py +from mcp.server.fastmcp import FastMCP + +mcp = FastMCP(name="EchoServer", stateless_http=True) + + +@mcp.tool(description="A simple echo tool") +def echo(message: str) -> str: + return f"Echo: {message}" +``` + +```python +# math.py +from mcp.server.fastmcp import FastMCP + +mcp = FastMCP(name="MathServer", stateless_http=True) + + +@mcp.tool(description="A simple add tool") +def add_two(n: int) -> str: + return n + 2 +``` + +```python +# main.py +from fastapi import FastAPI +from mcp.echo import echo +from mcp.math import math + + +app = FastAPI() + +# Use the session manager's lifespan +app = FastAPI(lifespan=lambda app: echo.mcp.session_manager.run()) +app.mount("/echo", echo.mcp.streamable_http_app()) +app.mount("/math", math.mcp.streamable_http_app()) +``` + +For low level server with Streamable HTTP implementations, see: +- Stateful server: [`examples/servers/simple-streamablehttp/`](examples/servers/simple-streamablehttp/) +- Stateless server: [`examples/servers/simple-streamablehttp-stateless/`](examples/servers/simple-streamablehttp-stateless/) + + + +The streamable HTTP transport supports: +- Stateful and stateless operation modes +- Resumability with event stores +- JSON or SSE response formats +- Better scalability for multi-node deployments + + ### Mounting to an Existing ASGI Server +> **Note**: SSE transport is being superseded by [Streamable HTTP transport](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http). + You can mount the SSE server to an existing ASGI server using the `sse_app` method. This allows you to integrate the SSE server with other ASGI applications. ```python @@ -410,6 +483,43 @@ app = Starlette( app.router.routes.append(Host('mcp.acme.corp', app=mcp.sse_app())) ``` +When mounting multiple MCP servers under different paths, you can configure the mount path in several ways: + +```python +from starlette.applications import Starlette +from starlette.routing import Mount +from mcp.server.fastmcp import FastMCP + +# Create multiple MCP servers +github_mcp = FastMCP("GitHub API") +browser_mcp = FastMCP("Browser") +curl_mcp = FastMCP("Curl") +search_mcp = FastMCP("Search") + +# Method 1: Configure mount paths via settings (recommended for persistent configuration) +github_mcp.settings.mount_path = "/github" +browser_mcp.settings.mount_path = "/browser" + +# Method 2: Pass mount path directly to sse_app (preferred for ad-hoc mounting) +# This approach doesn't modify the server's settings permanently + +# Create Starlette app with multiple mounted servers +app = Starlette( + routes=[ + # Using settings-based configuration + Mount("/github", app=github_mcp.sse_app()), + Mount("/browser", app=browser_mcp.sse_app()), + # Using direct mount path parameter + Mount("/curl", app=curl_mcp.sse_app("/curl")), + Mount("/search", app=search_mcp.sse_app("/search")), + ] +) + +# Method 3: For direct execution, you can also pass the mount path to run() +if __name__ == "__main__": + search_mcp.run(transport="sse", mount_path="/search") +``` + For more information on mounting applications in Starlette, see the [Starlette documentation](https://www.starlette.io/routing/#submounting-routes). ## Examples @@ -584,7 +694,7 @@ if __name__ == "__main__": ### Writing MCP Clients -The SDK provides a high-level client interface for connecting to MCP servers: +The SDK provides a high-level client interface for connecting to MCP servers using various [transports](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports): ```python from mcp import ClientSession, StdioServerParameters, types @@ -648,6 +758,28 @@ if __name__ == "__main__": asyncio.run(run()) ``` +Clients can also connect using [Streamable HTTP transport](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http): + +```python +from mcp.client.streamable_http import streamablehttp_client +from mcp import ClientSession + + +async def main(): + # Connect to a streamable HTTP server + async with streamablehttp_client("example/mcp") as ( + read_stream, + write_stream, + _, + ): + # Create a session using the client streams + async with ClientSession(read_stream, write_stream) as session: + # Initialize the connection + await session.initialize() + # Call a tool + tool_result = await session.call_tool("echo", {"message": "hello"}) +``` + ### MCP Primitives The MCP protocol defines three core primitives that servers can implement: diff --git a/examples/servers/simple-auth/README.md b/examples/servers/simple-auth/README.md new file mode 100644 index 000000000..1d0979d97 --- /dev/null +++ b/examples/servers/simple-auth/README.md @@ -0,0 +1,65 @@ +# Simple MCP Server with GitHub OAuth Authentication + +This is a simple example of an MCP server with GitHub OAuth authentication. It demonstrates the essential components needed for OAuth integration with just a single tool. + +This is just an example of a server that uses auth, an official GitHub mcp server is [here](https://github.com/github/github-mcp-server) + +## Overview + +This simple demo to show to set up a server with: +- GitHub OAuth2 authorization flow +- Single tool: `get_user_profile` to retrieve GitHub user information + + +## Prerequisites + +1. Create a GitHub OAuth App: + - Go to GitHub Settings > Developer settings > OAuth Apps > New OAuth App + - Application name: Any name (e.g., "Simple MCP Auth Demo") + - Homepage URL: `http://localhost:8000` + - Authorization callback URL: `http://localhost:8000/github/callback` + - Click "Register application" + - Note down your Client ID and Client Secret + +## Required Environment Variables + +You MUST set these environment variables before running the server: + +```bash +export MCP_GITHUB_GITHUB_CLIENT_ID="your_client_id_here" +export MCP_GITHUB_GITHUB_CLIENT_SECRET="your_client_secret_here" +``` + +The server will not start without these environment variables properly set. + + +## Running the Server + +```bash +# Set environment variables first (see above) + +# Run the server +uv run mcp-simple-auth +``` + +The server will start on `http://localhost:8000`. + +## Available Tool + +### get_user_profile + +The only tool in this simple example. Returns the authenticated user's GitHub profile information. + +**Required scope**: `user` + +**Returns**: GitHub user profile data including username, email, bio, etc. + + +## Troubleshooting + +If the server fails to start, check: +1. Environment variables `MCP_GITHUB_GITHUB_CLIENT_ID` and `MCP_GITHUB_GITHUB_CLIENT_SECRET` are set +2. The GitHub OAuth app callback URL matches `http://localhost:8000/github/callback` +3. No other service is using port 8000 + +You can use [Inspector](https://github.com/modelcontextprotocol/inspector) to test Auth \ No newline at end of file diff --git a/examples/servers/simple-auth/mcp_simple_auth/__init__.py b/examples/servers/simple-auth/mcp_simple_auth/__init__.py new file mode 100644 index 000000000..3e12b3183 --- /dev/null +++ b/examples/servers/simple-auth/mcp_simple_auth/__init__.py @@ -0,0 +1 @@ +"""Simple MCP server with GitHub OAuth authentication.""" diff --git a/examples/servers/simple-auth/mcp_simple_auth/__main__.py b/examples/servers/simple-auth/mcp_simple_auth/__main__.py new file mode 100644 index 000000000..a8840780b --- /dev/null +++ b/examples/servers/simple-auth/mcp_simple_auth/__main__.py @@ -0,0 +1,7 @@ +"""Main entry point for simple MCP server with GitHub OAuth authentication.""" + +import sys + +from mcp_simple_auth.server import main + +sys.exit(main()) diff --git a/examples/servers/simple-auth/mcp_simple_auth/server.py b/examples/servers/simple-auth/mcp_simple_auth/server.py new file mode 100644 index 000000000..2f1e4086f --- /dev/null +++ b/examples/servers/simple-auth/mcp_simple_auth/server.py @@ -0,0 +1,368 @@ +"""Simple MCP Server with GitHub OAuth Authentication.""" + +import logging +import secrets +import time +from typing import Any + +import click +from pydantic import AnyHttpUrl +from pydantic_settings import BaseSettings, SettingsConfigDict +from starlette.exceptions import HTTPException +from starlette.requests import Request +from starlette.responses import JSONResponse, RedirectResponse, Response + +from mcp.server.auth.middleware.auth_context import get_access_token +from mcp.server.auth.provider import ( + AccessToken, + AuthorizationCode, + AuthorizationParams, + OAuthAuthorizationServerProvider, + RefreshToken, + construct_redirect_uri, +) +from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions +from mcp.server.fastmcp.server import FastMCP +from mcp.shared._httpx_utils import create_mcp_http_client +from mcp.shared.auth import OAuthClientInformationFull, OAuthToken + +logger = logging.getLogger(__name__) + + +class ServerSettings(BaseSettings): + """Settings for the simple GitHub MCP server.""" + + model_config = SettingsConfigDict(env_prefix="MCP_GITHUB_") + + # Server settings + host: str = "localhost" + port: int = 8000 + server_url: AnyHttpUrl = AnyHttpUrl("http://localhost:8000") + + # GitHub OAuth settings - MUST be provided via environment variables + github_client_id: str # Type: MCP_GITHUB_GITHUB_CLIENT_ID env var + github_client_secret: str # Type: MCP_GITHUB_GITHUB_CLIENT_SECRET env var + github_callback_path: str = "http://localhost:8000/github/callback" + + # GitHub OAuth URLs + github_auth_url: str = "https://github.com/login/oauth/authorize" + github_token_url: str = "https://github.com/login/oauth/access_token" + + mcp_scope: str = "user" + github_scope: str = "read:user" + + def __init__(self, **data): + """Initialize settings with values from environment variables. + + Note: github_client_id and github_client_secret are required but can be + loaded automatically from environment variables (MCP_GITHUB_GITHUB_CLIENT_ID + and MCP_GITHUB_GITHUB_CLIENT_SECRET) and don't need to be passed explicitly. + """ + super().__init__(**data) + + +class SimpleGitHubOAuthProvider(OAuthAuthorizationServerProvider): + """Simple GitHub OAuth provider with essential functionality.""" + + def __init__(self, settings: ServerSettings): + self.settings = settings + self.clients: dict[str, OAuthClientInformationFull] = {} + self.auth_codes: dict[str, AuthorizationCode] = {} + self.tokens: dict[str, AccessToken] = {} + self.state_mapping: dict[str, dict[str, str]] = {} + # Store GitHub tokens with MCP tokens using the format: + # {"mcp_token": "github_token"} + self.token_mapping: dict[str, str] = {} + + async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: + """Get OAuth client information.""" + return self.clients.get(client_id) + + async def register_client(self, client_info: OAuthClientInformationFull): + """Register a new OAuth client.""" + self.clients[client_info.client_id] = client_info + + async def authorize( + self, client: OAuthClientInformationFull, params: AuthorizationParams + ) -> str: + """Generate an authorization URL for GitHub OAuth flow.""" + state = params.state or secrets.token_hex(16) + + # Store the state mapping + self.state_mapping[state] = { + "redirect_uri": str(params.redirect_uri), + "code_challenge": params.code_challenge, + "redirect_uri_provided_explicitly": str( + params.redirect_uri_provided_explicitly + ), + "client_id": client.client_id, + } + + # Build GitHub authorization URL + auth_url = ( + f"{self.settings.github_auth_url}" + f"?client_id={self.settings.github_client_id}" + f"&redirect_uri={self.settings.github_callback_path}" + f"&scope={self.settings.github_scope}" + f"&state={state}" + ) + + return auth_url + + async def handle_github_callback(self, code: str, state: str) -> str: + """Handle GitHub OAuth callback.""" + state_data = self.state_mapping.get(state) + if not state_data: + raise HTTPException(400, "Invalid state parameter") + + redirect_uri = state_data["redirect_uri"] + code_challenge = state_data["code_challenge"] + redirect_uri_provided_explicitly = ( + state_data["redirect_uri_provided_explicitly"] == "True" + ) + client_id = state_data["client_id"] + + # Exchange code for token with GitHub + async with create_mcp_http_client() as client: + response = await client.post( + self.settings.github_token_url, + data={ + "client_id": self.settings.github_client_id, + "client_secret": self.settings.github_client_secret, + "code": code, + "redirect_uri": self.settings.github_callback_path, + }, + headers={"Accept": "application/json"}, + ) + + if response.status_code != 200: + raise HTTPException(400, "Failed to exchange code for token") + + data = response.json() + + if "error" in data: + raise HTTPException(400, data.get("error_description", data["error"])) + + github_token = data["access_token"] + + # Create MCP authorization code + new_code = f"mcp_{secrets.token_hex(16)}" + auth_code = AuthorizationCode( + code=new_code, + client_id=client_id, + redirect_uri=AnyHttpUrl(redirect_uri), + redirect_uri_provided_explicitly=redirect_uri_provided_explicitly, + expires_at=time.time() + 300, + scopes=[self.settings.mcp_scope], + code_challenge=code_challenge, + ) + self.auth_codes[new_code] = auth_code + + # Store GitHub token - we'll map the MCP token to this later + self.tokens[github_token] = AccessToken( + token=github_token, + client_id=client_id, + scopes=[self.settings.github_scope], + expires_at=None, + ) + + del self.state_mapping[state] + return construct_redirect_uri(redirect_uri, code=new_code, state=state) + + async def load_authorization_code( + self, client: OAuthClientInformationFull, authorization_code: str + ) -> AuthorizationCode | None: + """Load an authorization code.""" + return self.auth_codes.get(authorization_code) + + async def exchange_authorization_code( + self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode + ) -> OAuthToken: + """Exchange authorization code for tokens.""" + if authorization_code.code not in self.auth_codes: + raise ValueError("Invalid authorization code") + + # Generate MCP access token + mcp_token = f"mcp_{secrets.token_hex(32)}" + + # Store MCP token + self.tokens[mcp_token] = AccessToken( + token=mcp_token, + client_id=client.client_id, + scopes=authorization_code.scopes, + expires_at=int(time.time()) + 3600, + ) + + # Find GitHub token for this client + github_token = next( + ( + token + for token, data in self.tokens.items() + # see https://github.blog/engineering/platform-security/behind-githubs-new-authentication-token-formats/ + # which you get depends on your GH app setup. + if (token.startswith("ghu_") or token.startswith("gho_")) + and data.client_id == client.client_id + ), + None, + ) + + # Store mapping between MCP token and GitHub token + if github_token: + self.token_mapping[mcp_token] = github_token + + del self.auth_codes[authorization_code.code] + + return OAuthToken( + access_token=mcp_token, + token_type="bearer", + expires_in=3600, + scope=" ".join(authorization_code.scopes), + ) + + async def load_access_token(self, token: str) -> AccessToken | None: + """Load and validate an access token.""" + access_token = self.tokens.get(token) + if not access_token: + return None + + # Check if expired + if access_token.expires_at and access_token.expires_at < time.time(): + del self.tokens[token] + return None + + return access_token + + async def load_refresh_token( + self, client: OAuthClientInformationFull, refresh_token: str + ) -> RefreshToken | None: + """Load a refresh token - not supported.""" + return None + + async def exchange_refresh_token( + self, + client: OAuthClientInformationFull, + refresh_token: RefreshToken, + scopes: list[str], + ) -> OAuthToken: + """Exchange refresh token""" + raise NotImplementedError("Not supported") + + async def revoke_token( + self, token: str, token_type_hint: str | None = None + ) -> None: + """Revoke a token.""" + if token in self.tokens: + del self.tokens[token] + + +def create_simple_mcp_server(settings: ServerSettings) -> FastMCP: + """Create a simple FastMCP server with GitHub OAuth.""" + oauth_provider = SimpleGitHubOAuthProvider(settings) + + auth_settings = AuthSettings( + issuer_url=settings.server_url, + client_registration_options=ClientRegistrationOptions( + enabled=True, + valid_scopes=[settings.mcp_scope], + default_scopes=[settings.mcp_scope], + ), + required_scopes=[settings.mcp_scope], + ) + + app = FastMCP( + name="Simple GitHub MCP Server", + instructions="A simple MCP server with GitHub OAuth authentication", + auth_server_provider=oauth_provider, + host=settings.host, + port=settings.port, + debug=True, + auth=auth_settings, + ) + + @app.custom_route("/github/callback", methods=["GET"]) + async def github_callback_handler(request: Request) -> Response: + """Handle GitHub OAuth callback.""" + code = request.query_params.get("code") + state = request.query_params.get("state") + + if not code or not state: + raise HTTPException(400, "Missing code or state parameter") + + try: + redirect_uri = await oauth_provider.handle_github_callback(code, state) + return RedirectResponse(status_code=302, url=redirect_uri) + except HTTPException: + raise + except Exception as e: + logger.error("Unexpected error", exc_info=e) + return JSONResponse( + status_code=500, + content={ + "error": "server_error", + "error_description": "Unexpected error", + }, + ) + + def get_github_token() -> str: + """Get the GitHub token for the authenticated user.""" + access_token = get_access_token() + if not access_token: + raise ValueError("Not authenticated") + + # Get GitHub token from mapping + github_token = oauth_provider.token_mapping.get(access_token.token) + + if not github_token: + raise ValueError("No GitHub token found for user") + + return github_token + + @app.tool() + async def get_user_profile() -> dict[str, Any]: + """Get the authenticated user's GitHub profile information. + + This is the only tool in our simple example. It requires the 'user' scope. + """ + github_token = get_github_token() + + async with create_mcp_http_client() as client: + response = await client.get( + "https://api.github.com/user", + headers={ + "Authorization": f"Bearer {github_token}", + "Accept": "application/vnd.github.v3+json", + }, + ) + + if response.status_code != 200: + raise ValueError( + f"GitHub API error: {response.status_code} - {response.text}" + ) + + return response.json() + + return app + + +@click.command() +@click.option("--port", default=8000, help="Port to listen on") +@click.option("--host", default="localhost", help="Host to bind to") +def main(port: int, host: str) -> int: + """Run the simple GitHub MCP server.""" + logging.basicConfig(level=logging.INFO) + + try: + # No hardcoded credentials - all from environment variables + settings = ServerSettings(host=host, port=port) + except ValueError as e: + logger.error( + "Failed to load settings. Make sure environment variables are set:" + ) + logger.error(" MCP_GITHUB_GITHUB_CLIENT_ID=") + logger.error(" MCP_GITHUB_GITHUB_CLIENT_SECRET=") + logger.error(f"Error: {e}") + return 1 + + mcp_server = create_simple_mcp_server(settings) + mcp_server.run(transport="sse") + return 0 diff --git a/examples/servers/simple-auth/pyproject.toml b/examples/servers/simple-auth/pyproject.toml new file mode 100644 index 000000000..40ae278a4 --- /dev/null +++ b/examples/servers/simple-auth/pyproject.toml @@ -0,0 +1,31 @@ +[project] +name = "mcp-simple-auth" +version = "0.1.0" +description = "A simple MCP server demonstrating OAuth authentication" +readme = "README.md" +requires-python = ">=3.10" +authors = [{ name = "Anthropic, PBC." }] +license = { text = "MIT" } +dependencies = [ + "anyio>=4.5", + "click>=8.1.0", + "httpx>=0.27", + "mcp", + "pydantic>=2.0", + "pydantic-settings>=2.5.2", + "sse-starlette>=1.6.1", + "uvicorn>=0.23.1; sys_platform != 'emscripten'", +] + +[project.scripts] +mcp-simple-auth = "mcp_simple_auth.server:main" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["mcp_simple_auth"] + +[tool.uv] +dev-dependencies = ["pyright>=1.1.391", "pytest>=8.3.4", "ruff>=0.8.5"] \ No newline at end of file diff --git a/examples/servers/simple-prompt/mcp_simple_prompt/server.py b/examples/servers/simple-prompt/mcp_simple_prompt/server.py index 0552f2770..bc14b7cd0 100644 --- a/examples/servers/simple-prompt/mcp_simple_prompt/server.py +++ b/examples/servers/simple-prompt/mcp_simple_prompt/server.py @@ -90,6 +90,7 @@ async def get_prompt( if transport == "sse": from mcp.server.sse import SseServerTransport from starlette.applications import Starlette + from starlette.responses import Response from starlette.routing import Mount, Route sse = SseServerTransport("/messages/") @@ -101,6 +102,7 @@ async def handle_sse(request): await app.run( streams[0], streams[1], app.create_initialization_options() ) + return Response() starlette_app = Starlette( debug=True, diff --git a/examples/servers/simple-resource/mcp_simple_resource/server.py b/examples/servers/simple-resource/mcp_simple_resource/server.py index 0ec1d926a..06f567fbe 100644 --- a/examples/servers/simple-resource/mcp_simple_resource/server.py +++ b/examples/servers/simple-resource/mcp_simple_resource/server.py @@ -46,6 +46,7 @@ async def read_resource(uri: FileUrl) -> str | bytes: if transport == "sse": from mcp.server.sse import SseServerTransport from starlette.applications import Starlette + from starlette.responses import Response from starlette.routing import Mount, Route sse = SseServerTransport("/messages/") @@ -57,11 +58,12 @@ async def handle_sse(request): await app.run( streams[0], streams[1], app.create_initialization_options() ) + return Response() starlette_app = Starlette( debug=True, routes=[ - Route("/sse", endpoint=handle_sse), + Route("/sse", endpoint=handle_sse, methods=["GET"]), Mount("/messages/", app=sse.handle_post_message), ], ) diff --git a/examples/servers/simple-streamablehttp-stateless/README.md b/examples/servers/simple-streamablehttp-stateless/README.md new file mode 100644 index 000000000..2abb60614 --- /dev/null +++ b/examples/servers/simple-streamablehttp-stateless/README.md @@ -0,0 +1,41 @@ +# MCP Simple StreamableHttp Stateless Server Example + +A stateless MCP server example demonstrating the StreamableHttp transport without maintaining session state. This example is ideal for understanding how to deploy MCP servers in multi-node environments where requests can be routed to any instance. + +## Features + +- Uses the StreamableHTTP transport in stateless mode (mcp_session_id=None) +- Each request creates a new ephemeral connection +- No session state maintained between requests +- Task lifecycle scoped to individual requests +- Suitable for deployment in multi-node environments + + +## Usage + +Start the server: + +```bash +# Using default port 3000 +uv run mcp-simple-streamablehttp-stateless + +# Using custom port +uv run mcp-simple-streamablehttp-stateless --port 3000 + +# Custom logging level +uv run mcp-simple-streamablehttp-stateless --log-level DEBUG + +# Enable JSON responses instead of SSE streams +uv run mcp-simple-streamablehttp-stateless --json-response +``` + +The server exposes a tool named "start-notification-stream" that accepts three arguments: + +- `interval`: Time between notifications in seconds (e.g., 1.0) +- `count`: Number of notifications to send (e.g., 5) +- `caller`: Identifier string for the caller + + +## Client + +You can connect to this server using an HTTP client. For now, only the TypeScript SDK has streamable HTTP client examples, or you can use [Inspector](https://github.com/modelcontextprotocol/inspector) for testing. \ No newline at end of file diff --git a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/__init__.py b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/__main__.py b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/__main__.py new file mode 100644 index 000000000..f5f6e402d --- /dev/null +++ b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/__main__.py @@ -0,0 +1,4 @@ +from .server import main + +if __name__ == "__main__": + main() diff --git a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py new file mode 100644 index 000000000..f718df801 --- /dev/null +++ b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py @@ -0,0 +1,141 @@ +import contextlib +import logging +from collections.abc import AsyncIterator + +import anyio +import click +import mcp.types as types +from mcp.server.lowlevel import Server +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from starlette.applications import Starlette +from starlette.routing import Mount +from starlette.types import Receive, Scope, Send + +logger = logging.getLogger(__name__) + + +@click.command() +@click.option("--port", default=3000, help="Port to listen on for HTTP") +@click.option( + "--log-level", + default="INFO", + help="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)", +) +@click.option( + "--json-response", + is_flag=True, + default=False, + help="Enable JSON responses instead of SSE streams", +) +def main( + port: int, + log_level: str, + json_response: bool, +) -> int: + # Configure logging + logging.basicConfig( + level=getattr(logging, log_level.upper()), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + + app = Server("mcp-streamable-http-stateless-demo") + + @app.call_tool() + async def call_tool( + name: str, arguments: dict + ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: + ctx = app.request_context + interval = arguments.get("interval", 1.0) + count = arguments.get("count", 5) + caller = arguments.get("caller", "unknown") + + # Send the specified number of notifications with the given interval + for i in range(count): + await ctx.session.send_log_message( + level="info", + data=f"Notification {i+1}/{count} from caller: {caller}", + logger="notification_stream", + related_request_id=ctx.request_id, + ) + if i < count - 1: # Don't wait after the last notification + await anyio.sleep(interval) + + return [ + types.TextContent( + type="text", + text=( + f"Sent {count} notifications with {interval}s interval" + f" for caller: {caller}" + ), + ) + ] + + @app.list_tools() + async def list_tools() -> list[types.Tool]: + return [ + types.Tool( + name="start-notification-stream", + description=( + "Sends a stream of notifications with configurable count" + " and interval" + ), + inputSchema={ + "type": "object", + "required": ["interval", "count", "caller"], + "properties": { + "interval": { + "type": "number", + "description": "Interval between notifications in seconds", + }, + "count": { + "type": "number", + "description": "Number of notifications to send", + }, + "caller": { + "type": "string", + "description": ( + "Identifier of the caller to include in notifications" + ), + }, + }, + }, + ) + ] + + # Create the session manager with true stateless mode + session_manager = StreamableHTTPSessionManager( + app=app, + event_store=None, + json_response=json_response, + stateless=True, + ) + + async def handle_streamable_http( + scope: Scope, receive: Receive, send: Send + ) -> None: + await session_manager.handle_request(scope, receive, send) + + @contextlib.asynccontextmanager + async def lifespan(app: Starlette) -> AsyncIterator[None]: + """Context manager for session manager.""" + async with session_manager.run(): + logger.info("Application started with StreamableHTTP session manager!") + try: + yield + finally: + logger.info("Application shutting down...") + + # Create an ASGI application using the transport + starlette_app = Starlette( + debug=True, + routes=[ + Mount("/mcp", app=handle_streamable_http), + ], + lifespan=lifespan, + ) + + import uvicorn + + uvicorn.run(starlette_app, host="0.0.0.0", port=port) + + return 0 diff --git a/examples/servers/simple-streamablehttp-stateless/pyproject.toml b/examples/servers/simple-streamablehttp-stateless/pyproject.toml new file mode 100644 index 000000000..d2b089451 --- /dev/null +++ b/examples/servers/simple-streamablehttp-stateless/pyproject.toml @@ -0,0 +1,36 @@ +[project] +name = "mcp-simple-streamablehttp-stateless" +version = "0.1.0" +description = "A simple MCP server exposing a StreamableHttp transport in stateless mode" +readme = "README.md" +requires-python = ">=3.10" +authors = [{ name = "Anthropic, PBC." }] +keywords = ["mcp", "llm", "automation", "web", "fetch", "http", "streamable", "stateless"] +license = { text = "MIT" } +dependencies = ["anyio>=4.5", "click>=8.1.0", "httpx>=0.27", "mcp", "starlette", "uvicorn"] + +[project.scripts] +mcp-simple-streamablehttp-stateless = "mcp_simple_streamablehttp_stateless.server:main" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["mcp_simple_streamablehttp_stateless"] + +[tool.pyright] +include = ["mcp_simple_streamablehttp_stateless"] +venvPath = "." +venv = ".venv" + +[tool.ruff.lint] +select = ["E", "F", "I"] +ignore = [] + +[tool.ruff] +line-length = 88 +target-version = "py310" + +[tool.uv] +dev-dependencies = ["pyright>=1.1.378", "pytest>=8.3.3", "ruff>=0.6.9"] \ No newline at end of file diff --git a/examples/servers/simple-streamablehttp/README.md b/examples/servers/simple-streamablehttp/README.md new file mode 100644 index 000000000..f850b7286 --- /dev/null +++ b/examples/servers/simple-streamablehttp/README.md @@ -0,0 +1,55 @@ +# MCP Simple StreamableHttp Server Example + +A simple MCP server example demonstrating the StreamableHttp transport, which enables HTTP-based communication with MCP servers using streaming. + +## Features + +- Uses the StreamableHTTP transport for server-client communication +- Supports REST API operations (POST, GET, DELETE) for `/mcp` endpoint +- Task management with anyio task groups +- Ability to send multiple notifications over time to the client +- Proper resource cleanup and lifespan management +- Resumability support via InMemoryEventStore + +## Usage + +Start the server on the default or custom port: + +```bash + +# Using custom port +uv run mcp-simple-streamablehttp --port 3000 + +# Custom logging level +uv run mcp-simple-streamablehttp --log-level DEBUG + +# Enable JSON responses instead of SSE streams +uv run mcp-simple-streamablehttp --json-response +``` + +The server exposes a tool named "start-notification-stream" that accepts three arguments: + +- `interval`: Time between notifications in seconds (e.g., 1.0) +- `count`: Number of notifications to send (e.g., 5) +- `caller`: Identifier string for the caller + +## Resumability Support + +This server includes resumability support through the InMemoryEventStore. This enables clients to: + +- Reconnect to the server after a disconnection +- Resume event streaming from where they left off using the Last-Event-ID header + + +The server will: +- Generate unique event IDs for each SSE message +- Store events in memory for later replay +- Replay missed events when a client reconnects with a Last-Event-ID header + +Note: The InMemoryEventStore is designed for demonstration purposes only. For production use, consider implementing a persistent storage solution. + + + +## Client + +You can connect to this server using an HTTP client, for now only Typescript SDK has streamable HTTP client examples or you can use [Inspector](https://github.com/modelcontextprotocol/inspector) \ No newline at end of file diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__init__.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__main__.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__main__.py new file mode 100644 index 000000000..f5f6e402d --- /dev/null +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__main__.py @@ -0,0 +1,4 @@ +from .server import main + +if __name__ == "__main__": + main() diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/event_store.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/event_store.py new file mode 100644 index 000000000..28c58149f --- /dev/null +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/event_store.py @@ -0,0 +1,105 @@ +""" +In-memory event store for demonstrating resumability functionality. + +This is a simple implementation intended for examples and testing, +not for production use where a persistent storage solution would be more appropriate. +""" + +import logging +from collections import deque +from dataclasses import dataclass +from uuid import uuid4 + +from mcp.server.streamable_http import ( + EventCallback, + EventId, + EventMessage, + EventStore, + StreamId, +) +from mcp.types import JSONRPCMessage + +logger = logging.getLogger(__name__) + + +@dataclass +class EventEntry: + """ + Represents an event entry in the event store. + """ + + event_id: EventId + stream_id: StreamId + message: JSONRPCMessage + + +class InMemoryEventStore(EventStore): + """ + Simple in-memory implementation of the EventStore interface for resumability. + This is primarily intended for examples and testing, not for production use + where a persistent storage solution would be more appropriate. + + This implementation keeps only the last N events per stream for memory efficiency. + """ + + def __init__(self, max_events_per_stream: int = 100): + """Initialize the event store. + + Args: + max_events_per_stream: Maximum number of events to keep per stream + """ + self.max_events_per_stream = max_events_per_stream + # for maintaining last N events per stream + self.streams: dict[StreamId, deque[EventEntry]] = {} + # event_id -> EventEntry for quick lookup + self.event_index: dict[EventId, EventEntry] = {} + + async def store_event( + self, stream_id: StreamId, message: JSONRPCMessage + ) -> EventId: + """Stores an event with a generated event ID.""" + event_id = str(uuid4()) + event_entry = EventEntry( + event_id=event_id, stream_id=stream_id, message=message + ) + + # Get or create deque for this stream + if stream_id not in self.streams: + self.streams[stream_id] = deque(maxlen=self.max_events_per_stream) + + # If deque is full, the oldest event will be automatically removed + # We need to remove it from the event_index as well + if len(self.streams[stream_id]) == self.max_events_per_stream: + oldest_event = self.streams[stream_id][0] + self.event_index.pop(oldest_event.event_id, None) + + # Add new event + self.streams[stream_id].append(event_entry) + self.event_index[event_id] = event_entry + + return event_id + + async def replay_events_after( + self, + last_event_id: EventId, + send_callback: EventCallback, + ) -> StreamId | None: + """Replays events that occurred after the specified event ID.""" + if last_event_id not in self.event_index: + logger.warning(f"Event ID {last_event_id} not found in store") + return None + + # Get the stream and find events after the last one + last_event = self.event_index[last_event_id] + stream_id = last_event.stream_id + stream_events = self.streams.get(last_event.stream_id, deque()) + + # Events in deque are already in chronological order + found_last = False + for event in stream_events: + if found_last: + await send_callback(EventMessage(event.message, event.event_id)) + elif event.event_id == last_event_id: + found_last = True + + return stream_id diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py new file mode 100644 index 000000000..1a76097b5 --- /dev/null +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -0,0 +1,169 @@ +import contextlib +import logging +from collections.abc import AsyncIterator + +import anyio +import click +import mcp.types as types +from mcp.server.lowlevel import Server +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from pydantic import AnyUrl +from starlette.applications import Starlette +from starlette.routing import Mount +from starlette.types import Receive, Scope, Send + +from .event_store import InMemoryEventStore + +# Configure logging +logger = logging.getLogger(__name__) + + +@click.command() +@click.option("--port", default=3000, help="Port to listen on for HTTP") +@click.option( + "--log-level", + default="INFO", + help="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)", +) +@click.option( + "--json-response", + is_flag=True, + default=False, + help="Enable JSON responses instead of SSE streams", +) +def main( + port: int, + log_level: str, + json_response: bool, +) -> int: + # Configure logging + logging.basicConfig( + level=getattr(logging, log_level.upper()), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + + app = Server("mcp-streamable-http-demo") + + @app.call_tool() + async def call_tool( + name: str, arguments: dict + ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: + ctx = app.request_context + interval = arguments.get("interval", 1.0) + count = arguments.get("count", 5) + caller = arguments.get("caller", "unknown") + + # Send the specified number of notifications with the given interval + for i in range(count): + # Include more detailed message for resumability demonstration + notification_msg = ( + f"[{i+1}/{count}] Event from '{caller}' - " + f"Use Last-Event-ID to resume if disconnected" + ) + await ctx.session.send_log_message( + level="info", + data=notification_msg, + logger="notification_stream", + # Associates this notification with the original request + # Ensures notifications are sent to the correct response stream + # Without this, notifications will either go to: + # - a standalone SSE stream (if GET request is supported) + # - nowhere (if GET request isn't supported) + related_request_id=ctx.request_id, + ) + logger.debug(f"Sent notification {i+1}/{count} for caller: {caller}") + if i < count - 1: # Don't wait after the last notification + await anyio.sleep(interval) + + # This will send a resource notificaiton though standalone SSE + # established by GET request + await ctx.session.send_resource_updated(uri=AnyUrl("http:///test_resource")) + return [ + types.TextContent( + type="text", + text=( + f"Sent {count} notifications with {interval}s interval" + f" for caller: {caller}" + ), + ) + ] + + @app.list_tools() + async def list_tools() -> list[types.Tool]: + return [ + types.Tool( + name="start-notification-stream", + description=( + "Sends a stream of notifications with configurable count" + " and interval" + ), + inputSchema={ + "type": "object", + "required": ["interval", "count", "caller"], + "properties": { + "interval": { + "type": "number", + "description": "Interval between notifications in seconds", + }, + "count": { + "type": "number", + "description": "Number of notifications to send", + }, + "caller": { + "type": "string", + "description": ( + "Identifier of the caller to include in notifications" + ), + }, + }, + }, + ) + ] + + # Create event store for resumability + # The InMemoryEventStore enables resumability support for StreamableHTTP transport. + # It stores SSE events with unique IDs, allowing clients to: + # 1. Receive event IDs for each SSE message + # 2. Resume streams by sending Last-Event-ID in GET requests + # 3. Replay missed events after reconnection + # Note: This in-memory implementation is for demonstration ONLY. + # For production, use a persistent storage solution. + event_store = InMemoryEventStore() + + # Create the session manager with our app and event store + session_manager = StreamableHTTPSessionManager( + app=app, + event_store=event_store, # Enable resumability + json_response=json_response, + ) + + # ASGI handler for streamable HTTP connections + async def handle_streamable_http( + scope: Scope, receive: Receive, send: Send + ) -> None: + await session_manager.handle_request(scope, receive, send) + + @contextlib.asynccontextmanager + async def lifespan(app: Starlette) -> AsyncIterator[None]: + """Context manager for managing session manager lifecycle.""" + async with session_manager.run(): + logger.info("Application started with StreamableHTTP session manager!") + try: + yield + finally: + logger.info("Application shutting down...") + + # Create an ASGI application using the transport + starlette_app = Starlette( + debug=True, + routes=[ + Mount("/mcp", app=handle_streamable_http), + ], + lifespan=lifespan, + ) + + import uvicorn + + uvicorn.run(starlette_app, host="0.0.0.0", port=port) + + return 0 diff --git a/examples/servers/simple-streamablehttp/pyproject.toml b/examples/servers/simple-streamablehttp/pyproject.toml new file mode 100644 index 000000000..c35887d1f --- /dev/null +++ b/examples/servers/simple-streamablehttp/pyproject.toml @@ -0,0 +1,36 @@ +[project] +name = "mcp-simple-streamablehttp" +version = "0.1.0" +description = "A simple MCP server exposing a StreamableHttp transport for testing" +readme = "README.md" +requires-python = ">=3.10" +authors = [{ name = "Anthropic, PBC." }] +keywords = ["mcp", "llm", "automation", "web", "fetch", "http", "streamable"] +license = { text = "MIT" } +dependencies = ["anyio>=4.5", "click>=8.1.0", "httpx>=0.27", "mcp", "starlette", "uvicorn"] + +[project.scripts] +mcp-simple-streamablehttp = "mcp_simple_streamablehttp.server:main" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["mcp_simple_streamablehttp"] + +[tool.pyright] +include = ["mcp_simple_streamablehttp"] +venvPath = "." +venv = ".venv" + +[tool.ruff.lint] +select = ["E", "F", "I"] +ignore = [] + +[tool.ruff] +line-length = 88 +target-version = "py310" + +[tool.uv] +dev-dependencies = ["pyright>=1.1.378", "pytest>=8.3.3", "ruff>=0.6.9"] \ No newline at end of file diff --git a/examples/servers/simple-tool/mcp_simple_tool/server.py b/examples/servers/simple-tool/mcp_simple_tool/server.py index 3eace52ea..5f4e28bb7 100644 --- a/examples/servers/simple-tool/mcp_simple_tool/server.py +++ b/examples/servers/simple-tool/mcp_simple_tool/server.py @@ -1,8 +1,8 @@ import anyio import click -import httpx import mcp.types as types from mcp.server.lowlevel import Server +from mcp.shared._httpx_utils import create_mcp_http_client async def fetch_website( @@ -11,7 +11,7 @@ async def fetch_website( headers = { "User-Agent": "MCP Test Server (github.com/modelcontextprotocol/python-sdk)" } - async with httpx.AsyncClient(follow_redirects=True, headers=headers) as client: + async with create_mcp_http_client(headers=headers) as client: response = await client.get(url) response.raise_for_status() return [types.TextContent(type="text", text=response.text)] @@ -60,6 +60,7 @@ async def list_tools() -> list[types.Tool]: if transport == "sse": from mcp.server.sse import SseServerTransport from starlette.applications import Starlette + from starlette.responses import Response from starlette.routing import Mount, Route sse = SseServerTransport("/messages/") @@ -71,11 +72,12 @@ async def handle_sse(request): await app.run( streams[0], streams[1], app.create_initialization_options() ) + return Response() starlette_app = Starlette( debug=True, routes=[ - Route("/sse", endpoint=handle_sse), + Route("/sse", endpoint=handle_sse, methods=["GET"]), Mount("/messages/", app=sse.handle_post_message), ], ) diff --git a/src/mcp/cli/claude.py b/src/mcp/cli/claude.py index 5a0ce0ab4..17c957df2 100644 --- a/src/mcp/cli/claude.py +++ b/src/mcp/cli/claude.py @@ -2,6 +2,7 @@ import json import os +import shutil import sys from pathlib import Path from typing import Any @@ -30,6 +31,16 @@ def get_claude_config_path() -> Path | None: return path return None +def get_uv_path() -> str: + """Get the full path to the uv executable.""" + uv_path = shutil.which("uv") + if not uv_path: + logger.error( + "uv executable not found in PATH, falling back to 'uv'. " + "Please ensure uv is installed and in your PATH" + ) + return "uv" # Fall back to just "uv" if not found + return uv_path def update_claude_config( file_spec: str, @@ -54,6 +65,7 @@ def update_claude_config( Claude Desktop may not be installed or properly set up. """ config_dir = get_claude_config_path() + uv_path = get_uv_path() if not config_dir: raise RuntimeError( "Claude Desktop config directory not found. Please ensure Claude Desktop" @@ -117,7 +129,7 @@ def update_claude_config( # Add fastmcp run command args.extend(["mcp", "run", file_spec]) - server_config: dict[str, Any] = {"command": "uv", "args": args} + server_config: dict[str, Any] = {"command": uv_path, "args": args} # Add environment variables if specified if env_vars: diff --git a/src/mcp/client/__main__.py b/src/mcp/client/__main__.py index 84e15bd56..2ec68e56c 100644 --- a/src/mcp/client/__main__.py +++ b/src/mcp/client/__main__.py @@ -11,8 +11,8 @@ from mcp.client.session import ClientSession from mcp.client.sse import sse_client from mcp.client.stdio import StdioServerParameters, stdio_client +from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder -from mcp.types import JSONRPCMessage if not sys.warnoptions: import warnings @@ -36,8 +36,8 @@ async def message_handler( async def run_session( - read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], - write_stream: MemoryObjectSendStream[JSONRPCMessage], + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], + write_stream: MemoryObjectSendStream[SessionMessage], client_info: types.Implementation | None = None, ): async with ClientSession( diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index fc86f0110..7bb8821f7 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -7,6 +7,7 @@ import mcp.types as types from mcp.shared.context import RequestContext +from mcp.shared.message import SessionMessage from mcp.shared.session import BaseSession, RequestResponder from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS @@ -92,8 +93,8 @@ class ClientSession( ): def __init__( self, - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], - write_stream: MemoryObjectSendStream[types.JSONRPCMessage], + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], + write_stream: MemoryObjectSendStream[SessionMessage], read_timeout_seconds: timedelta | None = None, sampling_callback: SamplingFnT | None = None, list_roots_callback: ListRootsFnT | None = None, @@ -260,6 +261,7 @@ async def call_tool( read_timeout_seconds: timedelta | None = None, ) -> types.CallToolResult: """Send a tools/call request.""" + return await self.send_request( types.ClientRequest( types.CallToolRequest( diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 4f6241a72..29195cbd9 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -10,6 +10,8 @@ from httpx_sse import aconnect_sse import mcp.types as types +from mcp.shared._httpx_utils import create_mcp_http_client +from mcp.shared.message import SessionMessage logger = logging.getLogger(__name__) @@ -31,11 +33,11 @@ async def sse_client( `sse_read_timeout` determines how long (in seconds) the client will wait for a new event before disconnecting. All other HTTP operations are controlled by `timeout`. """ - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] + read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] - write_stream: MemoryObjectSendStream[types.JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] + write_stream: MemoryObjectSendStream[SessionMessage] + write_stream_reader: MemoryObjectReceiveStream[SessionMessage] read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) @@ -43,7 +45,7 @@ async def sse_client( async with anyio.create_task_group() as tg: try: logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}") - async with httpx.AsyncClient(headers=headers) as client: + async with create_mcp_http_client(headers=headers) as client: async with aconnect_sse( client, "GET", @@ -97,7 +99,8 @@ async def sse_reader( await read_stream_writer.send(exc) continue - await read_stream_writer.send(message) + session_message = SessionMessage(message) + await read_stream_writer.send(session_message) case _: logger.warning( f"Unknown SSE event: {sse.event}" @@ -111,11 +114,13 @@ async def sse_reader( async def post_writer(endpoint_url: str): try: async with write_stream_reader: - async for message in write_stream_reader: - logger.debug(f"Sending client message: {message}") + async for session_message in write_stream_reader: + logger.debug( + f"Sending client message: {session_message}" + ) response = await client.post( endpoint_url, - json=message.model_dump( + json=session_message.message.model_dump( by_alias=True, mode="json", exclude_none=True, diff --git a/src/mcp/client/stdio/__init__.py b/src/mcp/client/stdio/__init__.py index 83de57a2b..e8be5aff5 100644 --- a/src/mcp/client/stdio/__init__.py +++ b/src/mcp/client/stdio/__init__.py @@ -11,6 +11,7 @@ from pydantic import BaseModel, Field import mcp.types as types +from mcp.shared.message import SessionMessage from .win32 import ( create_windows_process, @@ -98,11 +99,11 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder Client transport for stdio: this will connect to a server by spawning a process and communicating with it over stdin/stdout. """ - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] + read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] - write_stream: MemoryObjectSendStream[types.JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] + write_stream: MemoryObjectSendStream[SessionMessage] + write_stream_reader: MemoryObjectReceiveStream[SessionMessage] read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) @@ -143,7 +144,8 @@ async def stdout_reader(): await read_stream_writer.send(exc) continue - await read_stream_writer.send(message) + session_message = SessionMessage(message) + await read_stream_writer.send(session_message) except anyio.ClosedResourceError: await anyio.lowlevel.checkpoint() @@ -152,8 +154,10 @@ async def stdin_writer(): try: async with write_stream_reader: - async for message in write_stream_reader: - json = message.model_dump_json(by_alias=True, exclude_none=True) + async for session_message in write_stream_reader: + json = session_message.message.model_dump_json( + by_alias=True, exclude_none=True + ) await process.stdin.send( (json + "\n").encode( encoding=server.encoding, diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py new file mode 100644 index 000000000..183653b9a --- /dev/null +++ b/src/mcp/client/streamable_http.py @@ -0,0 +1,483 @@ +""" +StreamableHTTP Client Transport Module + +This module implements the StreamableHTTP transport for MCP clients, +providing support for HTTP POST requests with optional SSE streaming responses +and session management. +""" + +import logging +from collections.abc import AsyncGenerator, Awaitable, Callable +from contextlib import asynccontextmanager +from dataclasses import dataclass +from datetime import timedelta +from typing import Any + +import anyio +import httpx +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from httpx_sse import EventSource, ServerSentEvent, aconnect_sse + +from mcp.shared._httpx_utils import create_mcp_http_client +from mcp.shared.message import ClientMessageMetadata, SessionMessage +from mcp.types import ( + ErrorData, + JSONRPCError, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + RequestId, +) + +logger = logging.getLogger(__name__) + + +SessionMessageOrError = SessionMessage | Exception +StreamWriter = MemoryObjectSendStream[SessionMessageOrError] +StreamReader = MemoryObjectReceiveStream[SessionMessage] +GetSessionIdCallback = Callable[[], str | None] + +MCP_SESSION_ID = "mcp-session-id" +LAST_EVENT_ID = "last-event-id" +CONTENT_TYPE = "content-type" +ACCEPT = "Accept" + + +JSON = "application/json" +SSE = "text/event-stream" + + +class StreamableHTTPError(Exception): + """Base exception for StreamableHTTP transport errors.""" + + pass + + +class ResumptionError(StreamableHTTPError): + """Raised when resumption request is invalid.""" + + pass + + +@dataclass +class RequestContext: + """Context for a request operation.""" + + client: httpx.AsyncClient + headers: dict[str, str] + session_id: str | None + session_message: SessionMessage + metadata: ClientMessageMetadata | None + read_stream_writer: StreamWriter + sse_read_timeout: timedelta + + +class StreamableHTTPTransport: + """StreamableHTTP client transport implementation.""" + + def __init__( + self, + url: str, + headers: dict[str, Any] | None = None, + timeout: timedelta = timedelta(seconds=30), + sse_read_timeout: timedelta = timedelta(seconds=60 * 5), + ) -> None: + """Initialize the StreamableHTTP transport. + + Args: + url: The endpoint URL. + headers: Optional headers to include in requests. + timeout: HTTP timeout for regular operations. + sse_read_timeout: Timeout for SSE read operations. + """ + self.url = url + self.headers = headers or {} + self.timeout = timeout + self.sse_read_timeout = sse_read_timeout + self.session_id: str | None = None + self.request_headers = { + ACCEPT: f"{JSON}, {SSE}", + CONTENT_TYPE: JSON, + **self.headers, + } + + def _update_headers_with_session( + self, base_headers: dict[str, str] + ) -> dict[str, str]: + """Update headers with session ID if available.""" + headers = base_headers.copy() + if self.session_id: + headers[MCP_SESSION_ID] = self.session_id + return headers + + def _is_initialization_request(self, message: JSONRPCMessage) -> bool: + """Check if the message is an initialization request.""" + return ( + isinstance(message.root, JSONRPCRequest) + and message.root.method == "initialize" + ) + + def _is_initialized_notification(self, message: JSONRPCMessage) -> bool: + """Check if the message is an initialized notification.""" + return ( + isinstance(message.root, JSONRPCNotification) + and message.root.method == "notifications/initialized" + ) + + def _maybe_extract_session_id_from_response( + self, + response: httpx.Response, + ) -> None: + """Extract and store session ID from response headers.""" + new_session_id = response.headers.get(MCP_SESSION_ID) + if new_session_id: + self.session_id = new_session_id + logger.info(f"Received session ID: {self.session_id}") + + async def _handle_sse_event( + self, + sse: ServerSentEvent, + read_stream_writer: StreamWriter, + original_request_id: RequestId | None = None, + resumption_callback: Callable[[str], Awaitable[None]] | None = None, + ) -> bool: + """Handle an SSE event, returning True if the response is complete.""" + if sse.event == "message": + try: + message = JSONRPCMessage.model_validate_json(sse.data) + logger.debug(f"SSE message: {message}") + + # If this is a response and we have original_request_id, replace it + if original_request_id is not None and isinstance( + message.root, JSONRPCResponse | JSONRPCError + ): + message.root.id = original_request_id + + session_message = SessionMessage(message) + await read_stream_writer.send(session_message) + + # Call resumption token callback if we have an ID + if sse.id and resumption_callback: + await resumption_callback(sse.id) + + # If this is a response or error return True indicating completion + # Otherwise, return False to continue listening + return isinstance(message.root, JSONRPCResponse | JSONRPCError) + + except Exception as exc: + logger.error(f"Error parsing SSE message: {exc}") + await read_stream_writer.send(exc) + return False + else: + logger.warning(f"Unknown SSE event: {sse.event}") + return False + + async def handle_get_stream( + self, + client: httpx.AsyncClient, + read_stream_writer: StreamWriter, + ) -> None: + """Handle GET stream for server-initiated messages.""" + try: + if not self.session_id: + return + + headers = self._update_headers_with_session(self.request_headers) + + async with aconnect_sse( + client, + "GET", + self.url, + headers=headers, + timeout=httpx.Timeout( + self.timeout.seconds, read=self.sse_read_timeout.seconds + ), + ) as event_source: + event_source.response.raise_for_status() + logger.debug("GET SSE connection established") + + async for sse in event_source.aiter_sse(): + await self._handle_sse_event(sse, read_stream_writer) + + except Exception as exc: + logger.debug(f"GET stream error (non-fatal): {exc}") + + async def _handle_resumption_request(self, ctx: RequestContext) -> None: + """Handle a resumption request using GET with SSE.""" + headers = self._update_headers_with_session(ctx.headers) + if ctx.metadata and ctx.metadata.resumption_token: + headers[LAST_EVENT_ID] = ctx.metadata.resumption_token + else: + raise ResumptionError("Resumption request requires a resumption token") + + # Extract original request ID to map responses + original_request_id = None + if isinstance(ctx.session_message.message.root, JSONRPCRequest): + original_request_id = ctx.session_message.message.root.id + + async with aconnect_sse( + ctx.client, + "GET", + self.url, + headers=headers, + timeout=httpx.Timeout( + self.timeout.seconds, read=ctx.sse_read_timeout.seconds + ), + ) as event_source: + event_source.response.raise_for_status() + logger.debug("Resumption GET SSE connection established") + + async for sse in event_source.aiter_sse(): + is_complete = await self._handle_sse_event( + sse, + ctx.read_stream_writer, + original_request_id, + ctx.metadata.on_resumption_token_update if ctx.metadata else None, + ) + if is_complete: + break + + async def _handle_post_request(self, ctx: RequestContext) -> None: + """Handle a POST request with response processing.""" + headers = self._update_headers_with_session(ctx.headers) + message = ctx.session_message.message + is_initialization = self._is_initialization_request(message) + + async with ctx.client.stream( + "POST", + self.url, + json=message.model_dump(by_alias=True, mode="json", exclude_none=True), + headers=headers, + ) as response: + if response.status_code == 202: + logger.debug("Received 202 Accepted") + return + + if response.status_code == 404: + if isinstance(message.root, JSONRPCRequest): + await self._send_session_terminated_error( + ctx.read_stream_writer, + message.root.id, + ) + return + + response.raise_for_status() + if is_initialization: + self._maybe_extract_session_id_from_response(response) + + content_type = response.headers.get(CONTENT_TYPE, "").lower() + + if content_type.startswith(JSON): + await self._handle_json_response(response, ctx.read_stream_writer) + elif content_type.startswith(SSE): + await self._handle_sse_response(response, ctx) + else: + await self._handle_unexpected_content_type( + content_type, + ctx.read_stream_writer, + ) + + async def _handle_json_response( + self, + response: httpx.Response, + read_stream_writer: StreamWriter, + ) -> None: + """Handle JSON response from the server.""" + try: + content = await response.aread() + message = JSONRPCMessage.model_validate_json(content) + session_message = SessionMessage(message) + await read_stream_writer.send(session_message) + except Exception as exc: + logger.error(f"Error parsing JSON response: {exc}") + await read_stream_writer.send(exc) + + async def _handle_sse_response( + self, response: httpx.Response, ctx: RequestContext + ) -> None: + """Handle SSE response from the server.""" + try: + event_source = EventSource(response) + async for sse in event_source.aiter_sse(): + await self._handle_sse_event( + sse, + ctx.read_stream_writer, + resumption_callback=( + ctx.metadata.on_resumption_token_update + if ctx.metadata + else None + ), + ) + except Exception as e: + logger.exception("Error reading SSE stream:") + await ctx.read_stream_writer.send(e) + + async def _handle_unexpected_content_type( + self, + content_type: str, + read_stream_writer: StreamWriter, + ) -> None: + """Handle unexpected content type in response.""" + error_msg = f"Unexpected content type: {content_type}" + logger.error(error_msg) + await read_stream_writer.send(ValueError(error_msg)) + + async def _send_session_terminated_error( + self, + read_stream_writer: StreamWriter, + request_id: RequestId, + ) -> None: + """Send a session terminated error response.""" + jsonrpc_error = JSONRPCError( + jsonrpc="2.0", + id=request_id, + error=ErrorData(code=32600, message="Session terminated"), + ) + session_message = SessionMessage(JSONRPCMessage(jsonrpc_error)) + await read_stream_writer.send(session_message) + + async def post_writer( + self, + client: httpx.AsyncClient, + write_stream_reader: StreamReader, + read_stream_writer: StreamWriter, + write_stream: MemoryObjectSendStream[SessionMessage], + start_get_stream: Callable[[], None], + ) -> None: + """Handle writing requests to the server.""" + try: + async with write_stream_reader: + async for session_message in write_stream_reader: + message = session_message.message + metadata = ( + session_message.metadata + if isinstance(session_message.metadata, ClientMessageMetadata) + else None + ) + + # Check if this is a resumption request + is_resumption = bool(metadata and metadata.resumption_token) + + logger.debug(f"Sending client message: {message}") + + # Handle initialized notification + if self._is_initialized_notification(message): + start_get_stream() + + ctx = RequestContext( + client=client, + headers=self.request_headers, + session_id=self.session_id, + session_message=session_message, + metadata=metadata, + read_stream_writer=read_stream_writer, + sse_read_timeout=self.sse_read_timeout, + ) + + if is_resumption: + await self._handle_resumption_request(ctx) + else: + await self._handle_post_request(ctx) + + except Exception as exc: + logger.error(f"Error in post_writer: {exc}") + finally: + await read_stream_writer.aclose() + await write_stream.aclose() + + async def terminate_session(self, client: httpx.AsyncClient) -> None: + """Terminate the session by sending a DELETE request.""" + if not self.session_id: + return + + try: + headers = self._update_headers_with_session(self.request_headers) + response = await client.delete(self.url, headers=headers) + + if response.status_code == 405: + logger.debug("Server does not allow session termination") + elif response.status_code != 200: + logger.warning(f"Session termination failed: {response.status_code}") + except Exception as exc: + logger.warning(f"Session termination failed: {exc}") + + def get_session_id(self) -> str | None: + """Get the current session ID.""" + return self.session_id + + +@asynccontextmanager +async def streamablehttp_client( + url: str, + headers: dict[str, Any] | None = None, + timeout: timedelta = timedelta(seconds=30), + sse_read_timeout: timedelta = timedelta(seconds=60 * 5), + terminate_on_close: bool = True, +) -> AsyncGenerator[ + tuple[ + MemoryObjectReceiveStream[SessionMessage | Exception], + MemoryObjectSendStream[SessionMessage], + GetSessionIdCallback, + ], + None, +]: + """ + Client transport for StreamableHTTP. + + `sse_read_timeout` determines how long (in seconds) the client will wait for a new + event before disconnecting. All other HTTP operations are controlled by `timeout`. + + Yields: + Tuple containing: + - read_stream: Stream for reading messages from the server + - write_stream: Stream for sending messages to the server + - get_session_id_callback: Function to retrieve the current session ID + """ + transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout) + + read_stream_writer, read_stream = anyio.create_memory_object_stream[ + SessionMessage | Exception + ](0) + write_stream, write_stream_reader = anyio.create_memory_object_stream[ + SessionMessage + ](0) + + async with anyio.create_task_group() as tg: + try: + logger.info(f"Connecting to StreamableHTTP endpoint: {url}") + + async with create_mcp_http_client( + headers=transport.request_headers, + timeout=httpx.Timeout( + transport.timeout.seconds, read=transport.sse_read_timeout.seconds + ), + ) as client: + # Define callbacks that need access to tg + def start_get_stream() -> None: + tg.start_soon( + transport.handle_get_stream, client, read_stream_writer + ) + + tg.start_soon( + transport.post_writer, + client, + write_stream_reader, + read_stream_writer, + write_stream, + start_get_stream, + ) + + try: + yield ( + read_stream, + write_stream, + transport.get_session_id, + ) + finally: + if transport.session_id and terminate_on_close: + await transport.terminate_session(client) + tg.cancel_scope.cancel() + finally: + await read_stream_writer.aclose() + await write_stream.aclose() diff --git a/src/mcp/client/websocket.py b/src/mcp/client/websocket.py index 2c2ed38b9..ac542fb3f 100644 --- a/src/mcp/client/websocket.py +++ b/src/mcp/client/websocket.py @@ -10,6 +10,7 @@ from websockets.typing import Subprotocol import mcp.types as types +from mcp.shared.message import SessionMessage logger = logging.getLogger(__name__) @@ -19,8 +20,8 @@ async def websocket_client( url: str, ) -> AsyncGenerator[ tuple[ - MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], - MemoryObjectSendStream[types.JSONRPCMessage], + MemoryObjectReceiveStream[SessionMessage | Exception], + MemoryObjectSendStream[SessionMessage], ], None, ]: @@ -39,10 +40,10 @@ async def websocket_client( # Create two in-memory streams: # - One for incoming messages (read_stream, written by ws_reader) # - One for outgoing messages (write_stream, read by ws_writer) - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] - write_stream: MemoryObjectSendStream[types.JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] + read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] + write_stream: MemoryObjectSendStream[SessionMessage] + write_stream_reader: MemoryObjectReceiveStream[SessionMessage] read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) @@ -59,7 +60,8 @@ async def ws_reader(): async for raw_text in ws: try: message = types.JSONRPCMessage.model_validate_json(raw_text) - await read_stream_writer.send(message) + session_message = SessionMessage(message) + await read_stream_writer.send(session_message) except ValidationError as exc: # If JSON parse or model validation fails, send the exception await read_stream_writer.send(exc) @@ -70,9 +72,9 @@ async def ws_writer(): sends them to the server. """ async with write_stream_reader: - async for message in write_stream_reader: + async for session_message in write_stream_reader: # Convert to a dict, then to JSON - msg_dict = message.model_dump( + msg_dict = session_message.message.model_dump( by_alias=True, mode="json", exclude_none=True ) await ws.send(json.dumps(msg_dict)) diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py index 295605af7..30b5e2ba6 100644 --- a/src/mcp/server/auth/middleware/bearer_auth.py +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -34,8 +34,15 @@ def __init__( self.provider = provider async def authenticate(self, conn: HTTPConnection): - auth_header = conn.headers.get("Authorization") - if not auth_header or not auth_header.startswith("Bearer "): + auth_header = next( + ( + conn.headers.get(key) + for key in conn.headers + if key.lower() == "authorization" + ), + None, + ) + if not auth_header or not auth_header.lower().startswith("bearer "): return None token = auth_header[7:] # Remove "Bearer " prefix diff --git a/src/mcp/server/auth/routes.py b/src/mcp/server/auth/routes.py index 29dd6a43a..4c56ca247 100644 --- a/src/mcp/server/auth/routes.py +++ b/src/mcp/server/auth/routes.py @@ -177,7 +177,7 @@ def build_metadata( issuer=issuer_url, authorization_endpoint=authorization_url, token_endpoint=token_url, - scopes_supported=None, + scopes_supported=client_registration_options.valid_scopes, response_types_supported=["code"], response_modes_supported=None, grant_types_supported=["authorization_code", "refresh_token"], diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 65d342e1a..c31f29d4c 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -47,6 +47,8 @@ from mcp.server.session import ServerSession, ServerSessionT from mcp.server.sse import SseServerTransport from mcp.server.stdio import stdio_server +from mcp.server.streamable_http import EventStore +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.shared.context import LifespanContextT, RequestContext from mcp.types import ( AnyFunction, @@ -87,8 +89,16 @@ class Settings(BaseSettings, Generic[LifespanResultT]): # HTTP settings host: str = "0.0.0.0" port: int = 8000 + mount_path: str = "/" # Mount path (e.g. "/github", defaults to root path) sse_path: str = "/sse" message_path: str = "/messages/" + streamable_http_path: str = "/mcp" + + # StreamableHTTP settings + json_response: bool = False + stateless_http: bool = ( + False # If True, uses true stateless mode (new transport per request) + ) # resource settings warn_on_duplicate_resources: bool = True @@ -130,6 +140,7 @@ def __init__( instructions: str | None = None, auth_server_provider: OAuthAuthorizationServerProvider[Any, Any, Any] | None = None, + event_store: EventStore | None = None, **settings: Any, ): self.settings = Settings(**settings) @@ -161,8 +172,10 @@ def __init__( "is specified" ) self._auth_server_provider = auth_server_provider + self._event_store = event_store self._custom_starlette_routes: list[Route] = [] self.dependencies = self.settings.dependencies + self._session_manager: StreamableHTTPSessionManager | None = None # Set up MCP protocol handlers self._setup_handlers() @@ -178,20 +191,47 @@ def name(self) -> str: def instructions(self) -> str | None: return self._mcp_server.instructions - def run(self, transport: Literal["stdio", "sse"] = "stdio") -> None: + @property + def session_manager(self) -> StreamableHTTPSessionManager: + """Get the StreamableHTTP session manager. + + This is exposed to enable advanced use cases like mounting multiple + FastMCP servers in a single FastAPI application. + + Raises: + RuntimeError: If called before streamable_http_app() has been called. + """ + if self._session_manager is None: + raise RuntimeError( + "Session manager can only be accessed after" + "calling streamable_http_app()." + "The session manager is created lazily" + "to avoid unnecessary initialization." + ) + return self._session_manager + + def run( + self, + transport: Literal["stdio", "sse", "streamable-http"] = "stdio", + mount_path: str | None = None, + ) -> None: """Run the FastMCP server. Note this is a synchronous function. Args: - transport: Transport protocol to use ("stdio" or "sse") + transport: Transport protocol to use ("stdio", "sse", or "streamable-http") + mount_path: Optional mount path for SSE transport """ - TRANSPORTS = Literal["stdio", "sse"] + TRANSPORTS = Literal["stdio", "sse", "streamable-http"] if transport not in TRANSPORTS.__args__: # type: ignore raise ValueError(f"Unknown transport: {transport}") - if transport == "stdio": - anyio.run(self.run_stdio_async) - else: # transport == "sse" - anyio.run(self.run_sse_async) + match transport: + case "stdio": + anyio.run(self.run_stdio_async) + case "sse": + anyio.run(lambda: self.run_sse_async(mount_path)) + case "streamable-http": + anyio.run(self.run_streamable_http_async) def _setup_handlers(self) -> None: """Set up core MCP protocol handlers.""" @@ -552,11 +592,11 @@ async def run_stdio_async(self) -> None: self._mcp_server.create_initialization_options(), ) - async def run_sse_async(self) -> None: + async def run_sse_async(self, mount_path: str | None = None) -> None: """Run the server using SSE transport.""" import uvicorn - starlette_app = self.sse_app() + starlette_app = self.sse_app(mount_path) config = uvicorn.Config( starlette_app, @@ -567,14 +607,66 @@ async def run_sse_async(self) -> None: server = uvicorn.Server(config) await server.serve() - def sse_app(self) -> Starlette: + async def run_streamable_http_async(self) -> None: + """Run the server using StreamableHTTP transport.""" + import uvicorn + + starlette_app = self.streamable_http_app() + + config = uvicorn.Config( + starlette_app, + host=self.settings.host, + port=self.settings.port, + log_level=self.settings.log_level.lower(), + ) + server = uvicorn.Server(config) + await server.serve() + + def _normalize_path(self, mount_path: str, endpoint: str) -> str: + """ + Combine mount path and endpoint to return a normalized path. + + Args: + mount_path: The mount path (e.g. "/github" or "/") + endpoint: The endpoint path (e.g. "/messages/") + + Returns: + Normalized path (e.g. "/github/messages/") + """ + # Special case: root path + if mount_path == "/": + return endpoint + + # Remove trailing slash from mount path + if mount_path.endswith("/"): + mount_path = mount_path[:-1] + + # Ensure endpoint starts with slash + if not endpoint.startswith("/"): + endpoint = "/" + endpoint + + # Combine paths + return mount_path + endpoint + + def sse_app(self, mount_path: str | None = None) -> Starlette: """Return an instance of the SSE server app.""" from starlette.middleware import Middleware from starlette.routing import Mount, Route + # Update mount_path in settings if provided + if mount_path is not None: + self.settings.mount_path = mount_path + + # Create normalized endpoint considering the mount path + normalized_message_endpoint = self._normalize_path( + self.settings.mount_path, self.settings.message_path + ) + # Set up auth context and dependencies - sse = SseServerTransport(self.settings.message_path) + sse = SseServerTransport( + normalized_message_endpoint, + ) async def handle_sse(scope: Scope, receive: Receive, send: Send): # Add client ID from auth context into request context if available @@ -589,6 +681,7 @@ async def handle_sse(scope: Scope, receive: Receive, send: Send): streams[1], self._mcp_server.create_initialization_options(), ) + return Response() # Create routes routes: list[Route | Mount] = [] @@ -624,19 +717,42 @@ async def handle_sse(scope: Scope, receive: Receive, send: Send): ) ) - routes.append( - Route( - self.settings.sse_path, - endpoint=RequireAuthMiddleware(handle_sse, required_scopes), - methods=["GET"], + # When auth is not configured, we shouldn't require auth + if self._auth_server_provider: + # Auth is enabled, wrap the endpoints with RequireAuthMiddleware + routes.append( + Route( + self.settings.sse_path, + endpoint=RequireAuthMiddleware(handle_sse, required_scopes), + methods=["GET"], + ) ) - ) - routes.append( - Mount( - self.settings.message_path, - app=RequireAuthMiddleware(sse.handle_post_message, required_scopes), + routes.append( + Mount( + self.settings.message_path, + app=RequireAuthMiddleware(sse.handle_post_message, required_scopes), + ) + ) + else: + # Auth is disabled, no need for RequireAuthMiddleware + # Since handle_sse is an ASGI app, we need to create a compatible endpoint + async def sse_endpoint(request: Request) -> Response: + # Convert the Starlette request to ASGI parameters + return await handle_sse(request.scope, request.receive, request._send) # type: ignore[reportPrivateUsage] + + routes.append( + Route( + self.settings.sse_path, + endpoint=sse_endpoint, + methods=["GET"], + ) + ) + routes.append( + Mount( + self.settings.message_path, + app=sse.handle_post_message, + ) ) - ) # mount these routes last, so they have the lowest route matching precedence routes.extend(self._custom_starlette_routes) @@ -645,6 +761,80 @@ async def handle_sse(scope: Scope, receive: Receive, send: Send): debug=self.settings.debug, routes=routes, middleware=middleware ) + def streamable_http_app(self) -> Starlette: + """Return an instance of the StreamableHTTP server app.""" + from starlette.middleware import Middleware + from starlette.routing import Mount + + # Create session manager on first call (lazy initialization) + if self._session_manager is None: + self._session_manager = StreamableHTTPSessionManager( + app=self._mcp_server, + event_store=self._event_store, + json_response=self.settings.json_response, + stateless=self.settings.stateless_http, # Use the stateless setting + ) + + # Create the ASGI handler + async def handle_streamable_http( + scope: Scope, receive: Receive, send: Send + ) -> None: + await self.session_manager.handle_request(scope, receive, send) + + # Create routes + routes: list[Route | Mount] = [] + middleware: list[Middleware] = [] + required_scopes = [] + + # Add auth endpoints if auth provider is configured + if self._auth_server_provider: + assert self.settings.auth + from mcp.server.auth.routes import create_auth_routes + + required_scopes = self.settings.auth.required_scopes or [] + + middleware = [ + Middleware( + AuthenticationMiddleware, + backend=BearerAuthBackend( + provider=self._auth_server_provider, + ), + ), + Middleware(AuthContextMiddleware), + ] + routes.extend( + create_auth_routes( + provider=self._auth_server_provider, + issuer_url=self.settings.auth.issuer_url, + service_documentation_url=self.settings.auth.service_documentation_url, + client_registration_options=self.settings.auth.client_registration_options, + revocation_options=self.settings.auth.revocation_options, + ) + ) + routes.append( + Mount( + self.settings.streamable_http_path, + app=RequireAuthMiddleware(handle_streamable_http, required_scopes), + ) + ) + else: + # Auth is disabled, no wrapper needed + routes.append( + Mount( + self.settings.streamable_http_path, + app=handle_streamable_http, + ) + ) + + routes.extend(self._custom_starlette_routes) + + return Starlette( + debug=self.settings.debug, + routes=routes, + middleware=middleware, + lifespan=lambda app: self.session_manager.run(), + ) + async def list_prompts(self) -> list[MCPPrompt]: """List all available prompts.""" prompts = self._prompt_manager.list_prompts() @@ -814,7 +1004,10 @@ async def log( **extra: Additional structured data to include """ await self.request_context.session.send_log_message( - level=level, data=message, logger=logger_name + level=level, + data=message, + logger=logger_name, + related_request_id=self.request_id, ) @property diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index b4f6330b5..4b97b33da 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -84,6 +84,7 @@ async def main(): from mcp.server.stdio import stdio_server as stdio_server from mcp.shared.context import RequestContext from mcp.shared.exceptions import McpError +from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder logger = logging.getLogger(__name__) @@ -471,19 +472,29 @@ async def handler(req: types.CompleteRequest): async def run( self, - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], - write_stream: MemoryObjectSendStream[types.JSONRPCMessage], + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], + write_stream: MemoryObjectSendStream[SessionMessage], initialization_options: InitializationOptions, # When False, exceptions are returned as messages to the client. # When True, exceptions are raised, which will cause the server to shut down # but also make tracing exceptions much easier during testing and when using # in-process servers. raise_exceptions: bool = False, + # When True, the server is stateless and + # clients can perform initialization with any node. The client must still follow + # the initialization lifecycle, but can do so with any available node + # rather than requiring initialization for each connection. + stateless: bool = False, ): async with AsyncExitStack() as stack: lifespan_context = await stack.enter_async_context(self.lifespan(self)) session = await stack.enter_async_context( - ServerSession(read_stream, write_stream, initialization_options) + ServerSession( + read_stream, + write_stream, + initialization_options, + stateless=stateless, + ) ) async with anyio.create_task_group() as tg: diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 568ecd4b9..c769d1aa3 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -47,6 +47,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: import mcp.types as types from mcp.server.models import InitializationOptions +from mcp.shared.message import SessionMessage from mcp.shared.session import ( BaseSession, RequestResponder, @@ -82,14 +83,20 @@ class ServerSession( def __init__( self, - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], - write_stream: MemoryObjectSendStream[types.JSONRPCMessage], + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], + write_stream: MemoryObjectSendStream[SessionMessage], init_options: InitializationOptions, + stateless: bool = False, ) -> None: super().__init__( read_stream, write_stream, types.ClientRequest, types.ClientNotification ) - self._initialization_state = InitializationState.NotInitialized + self._initialization_state = ( + InitializationState.Initialized + if stateless + else InitializationState.NotInitialized + ) + self._init_options = init_options self._incoming_message_stream_writer, self._incoming_message_stream_reader = ( anyio.create_memory_object_stream[ServerRequestResponder](0) @@ -97,9 +104,6 @@ def __init__( self._exit_stack.push_async_callback( lambda: self._incoming_message_stream_reader.aclose() ) - self._exit_stack.push_async_callback( - lambda: self._incoming_message_stream_writer.aclose() - ) @property def client_params(self) -> types.InitializeRequestParams | None: @@ -137,6 +141,10 @@ def check_client_capability(self, capability: types.ClientCapabilities) -> bool: return True + async def _receive_loop(self) -> None: + async with self._incoming_message_stream_writer: + await super()._receive_loop() + async def _received_request( self, responder: RequestResponder[types.ClientRequest, types.ServerResult] ): @@ -179,7 +187,11 @@ async def _received_notification( ) async def send_log_message( - self, level: types.LoggingLevel, data: Any, logger: str | None = None + self, + level: types.LoggingLevel, + data: Any, + logger: str | None = None, + related_request_id: types.RequestId | None = None, ) -> None: """Send a log message notification.""" await self.send_notification( @@ -192,7 +204,8 @@ async def send_log_message( logger=logger, ), ) - ) + ), + related_request_id, ) async def send_resource_updated(self, uri: AnyUrl) -> None: @@ -261,7 +274,11 @@ async def send_ping(self) -> types.EmptyResult: ) async def send_progress_notification( - self, progress_token: str | int, progress: float, total: float | None = None + self, + progress_token: str | int, + progress: float, + total: float | None = None, + related_request_id: str | None = None, ) -> None: """Send a progress notification.""" await self.send_notification( @@ -274,7 +291,8 @@ async def send_progress_notification( total=total, ), ) - ) + ), + related_request_id, ) async def send_resource_list_changed(self) -> None: diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index d051c25bf..cc41a80d6 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -10,7 +10,7 @@ # Create Starlette routes for SSE and message handling routes = [ - Route("/sse", endpoint=handle_sse), + Route("/sse", endpoint=handle_sse, methods=["GET"]), Mount("/messages/", app=sse.handle_post_message), ] @@ -22,12 +22,18 @@ async def handle_sse(request): await app.run( streams[0], streams[1], app.create_initialization_options() ) + # Return empty response to avoid NoneType error + return Response() # Create and run Starlette app starlette_app = Starlette(routes=routes) uvicorn.run(starlette_app, host="0.0.0.0", port=port) ``` +Note: The handle_sse function must return a Response to avoid a "TypeError: 'NoneType' +object is not callable" error when client disconnects. The example above returns +an empty Response() after the SSE connection ends to fix this. + See SseServerTransport class documentation for more details. """ @@ -46,6 +52,7 @@ async def handle_sse(request): from starlette.types import Receive, Scope, Send import mcp.types as types +from mcp.shared.message import SessionMessage logger = logging.getLogger(__name__) @@ -63,9 +70,7 @@ class SseServerTransport: """ _endpoint: str - _read_stream_writers: dict[ - UUID, MemoryObjectSendStream[types.JSONRPCMessage | Exception] - ] + _read_stream_writers: dict[UUID, MemoryObjectSendStream[SessionMessage | Exception]] def __init__(self, endpoint: str) -> None: """ @@ -85,11 +90,11 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): raise ValueError("connect_sse can only handle HTTP requests") logger.debug("Setting up SSE connection") - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] + read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] - write_stream: MemoryObjectSendStream[types.JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] + write_stream: MemoryObjectSendStream[SessionMessage] + write_stream_reader: MemoryObjectReceiveStream[SessionMessage] read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) @@ -109,23 +114,34 @@ async def sse_writer(): await sse_stream_writer.send({"event": "endpoint", "data": session_uri}) logger.debug(f"Sent endpoint event: {session_uri}") - async for message in write_stream_reader: - logger.debug(f"Sending message via SSE: {message}") + async for session_message in write_stream_reader: + logger.debug(f"Sending message via SSE: {session_message}") await sse_stream_writer.send( { "event": "message", - "data": message.model_dump_json( + "data": session_message.message.model_dump_json( by_alias=True, exclude_none=True ), } ) async with anyio.create_task_group() as tg: - response = EventSourceResponse( - content=sse_stream_reader, data_sender_callable=sse_writer - ) + + async def response_wrapper(scope: Scope, receive: Receive, send: Send): + """ + The EventSourceResponse returning signals a client close / disconnect. + In this case we close our side of the streams to signal the client that + the connection has been closed. + """ + await EventSourceResponse( + content=sse_stream_reader, data_sender_callable=sse_writer + )(scope, receive, send) + await read_stream_writer.aclose() + await write_stream_reader.aclose() + logging.debug(f"Client session disconnected {session_id}") + logger.debug("Starting SSE response task") - tg.start_soon(response, scope, receive, send) + tg.start_soon(response_wrapper, scope, receive, send) logger.debug("Yielding read and write streams") yield (read_stream, write_stream) @@ -169,7 +185,8 @@ async def handle_post_message( await writer.send(err) return - logger.debug(f"Sending message to writer: {message}") + session_message = SessionMessage(message) + logger.debug(f"Sending session message to writer: {session_message}") response = Response("Accepted", status_code=202) await response(scope, receive, send) - await writer.send(message) + await writer.send(session_message) diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index 0e0e49129..f0bbe5a31 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -27,6 +27,7 @@ async def run_server(): from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream import mcp.types as types +from mcp.shared.message import SessionMessage @asynccontextmanager @@ -47,11 +48,11 @@ async def stdio_server( if not stdout: stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8")) - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] + read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] - write_stream: MemoryObjectSendStream[types.JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] + write_stream: MemoryObjectSendStream[SessionMessage] + write_stream_reader: MemoryObjectReceiveStream[SessionMessage] read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) @@ -66,15 +67,18 @@ async def stdin_reader(): await read_stream_writer.send(exc) continue - await read_stream_writer.send(message) + session_message = SessionMessage(message) + await read_stream_writer.send(session_message) except anyio.ClosedResourceError: await anyio.lowlevel.checkpoint() async def stdout_writer(): try: async with write_stream_reader: - async for message in write_stream_reader: - json = message.model_dump_json(by_alias=True, exclude_none=True) + async for session_message in write_stream_reader: + json = session_message.message.model_dump_json( + by_alias=True, exclude_none=True + ) await stdout.write(json + "\n") await stdout.flush() except anyio.ClosedResourceError: diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py new file mode 100644 index 000000000..ace74b33b --- /dev/null +++ b/src/mcp/server/streamable_http.py @@ -0,0 +1,926 @@ +""" +StreamableHTTP Server Transport Module + +This module implements an HTTP transport layer with Streamable HTTP. + +The transport handles bidirectional communication using HTTP requests and +responses, with streaming support for long-running operations. +""" + +import json +import logging +import re +from abc import ABC, abstractmethod +from collections.abc import AsyncGenerator, Awaitable, Callable +from contextlib import asynccontextmanager +from dataclasses import dataclass +from http import HTTPStatus + +import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from pydantic import ValidationError +from sse_starlette import EventSourceResponse +from starlette.requests import Request +from starlette.responses import Response +from starlette.types import Receive, Scope, Send + +from mcp.shared.message import ServerMessageMetadata, SessionMessage +from mcp.types import ( + INTERNAL_ERROR, + INVALID_PARAMS, + INVALID_REQUEST, + PARSE_ERROR, + ErrorData, + JSONRPCError, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + RequestId, +) + +logger = logging.getLogger(__name__) + +# Maximum size for incoming messages +MAXIMUM_MESSAGE_SIZE = 4 * 1024 * 1024 # 4MB + +# Header names +MCP_SESSION_ID_HEADER = "mcp-session-id" +LAST_EVENT_ID_HEADER = "last-event-id" + +# Content types +CONTENT_TYPE_JSON = "application/json" +CONTENT_TYPE_SSE = "text/event-stream" + +# Special key for the standalone GET stream +GET_STREAM_KEY = "_GET_stream" + +# Session ID validation pattern (visible ASCII characters ranging from 0x21 to 0x7E) +# Pattern ensures entire string contains only valid characters by using ^ and $ anchors +SESSION_ID_PATTERN = re.compile(r"^[\x21-\x7E]+$") + +# Type aliases +StreamId = str +EventId = str + + +@dataclass +class EventMessage: + """ + A JSONRPCMessage with an optional event ID for stream resumability. + """ + + message: JSONRPCMessage + event_id: str | None = None + + +EventCallback = Callable[[EventMessage], Awaitable[None]] + + +class EventStore(ABC): + """ + Interface for resumability support via event storage. + """ + + @abstractmethod + async def store_event( + self, stream_id: StreamId, message: JSONRPCMessage + ) -> EventId: + """ + Stores an event for later retrieval. + + Args: + stream_id: ID of the stream the event belongs to + message: The JSON-RPC message to store + + Returns: + The generated event ID for the stored event + """ + pass + + @abstractmethod + async def replay_events_after( + self, + last_event_id: EventId, + send_callback: EventCallback, + ) -> StreamId | None: + """ + Replays events that occurred after the specified event ID. + + Args: + last_event_id: The ID of the last event the client received + send_callback: A callback function to send events to the client + + Returns: + The stream ID of the replayed events + """ + pass + + +class StreamableHTTPServerTransport: + """ + HTTP server transport with event streaming support for MCP. + + Handles JSON-RPC messages in HTTP POST requests with SSE streaming. + Supports optional JSON responses and session management. + """ + + # Server notification streams for POST requests as well as standalone SSE stream + _read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] | None = ( + None + ) + _read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] | None = None + _write_stream: MemoryObjectSendStream[SessionMessage] | None = None + _write_stream_reader: MemoryObjectReceiveStream[SessionMessage] | None = None + + def __init__( + self, + mcp_session_id: str | None, + is_json_response_enabled: bool = False, + event_store: EventStore | None = None, + ) -> None: + """ + Initialize a new StreamableHTTP server transport. + + Args: + mcp_session_id: Optional session identifier for this connection. + Must contain only visible ASCII characters (0x21-0x7E). + is_json_response_enabled: If True, return JSON responses for requests + instead of SSE streams. Default is False. + event_store: Event store for resumability support. If provided, + resumability will be enabled, allowing clients to + reconnect and resume messages. + + Raises: + ValueError: If the session ID contains invalid characters. + """ + if mcp_session_id is not None and not SESSION_ID_PATTERN.fullmatch( + mcp_session_id + ): + raise ValueError( + "Session ID must only contain visible ASCII characters (0x21-0x7E)" + ) + + self.mcp_session_id = mcp_session_id + self.is_json_response_enabled = is_json_response_enabled + self._event_store = event_store + self._request_streams: dict[ + RequestId, + tuple[ + MemoryObjectSendStream[EventMessage], + MemoryObjectReceiveStream[EventMessage], + ], + ] = {} + self._terminated = False + + def _create_error_response( + self, + error_message: str, + status_code: HTTPStatus, + error_code: int = INVALID_REQUEST, + headers: dict[str, str] | None = None, + ) -> Response: + """Create an error response with a simple string message.""" + response_headers = {"Content-Type": CONTENT_TYPE_JSON} + if headers: + response_headers.update(headers) + + if self.mcp_session_id: + response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id + + # Return a properly formatted JSON error response + error_response = JSONRPCError( + jsonrpc="2.0", + id="server-error", # We don't have a request ID for general errors + error=ErrorData( + code=error_code, + message=error_message, + ), + ) + + return Response( + error_response.model_dump_json(by_alias=True, exclude_none=True), + status_code=status_code, + headers=response_headers, + ) + + def _create_json_response( + self, + response_message: JSONRPCMessage | None, + status_code: HTTPStatus = HTTPStatus.OK, + headers: dict[str, str] | None = None, + ) -> Response: + """Create a JSON response from a JSONRPCMessage""" + response_headers = {"Content-Type": CONTENT_TYPE_JSON} + if headers: + response_headers.update(headers) + + if self.mcp_session_id: + response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id + + return Response( + response_message.model_dump_json(by_alias=True, exclude_none=True) + if response_message + else None, + status_code=status_code, + headers=response_headers, + ) + + def _get_session_id(self, request: Request) -> str | None: + """Extract the session ID from request headers.""" + return request.headers.get(MCP_SESSION_ID_HEADER) + + def _create_event_data(self, event_message: EventMessage) -> dict[str, str]: + """Create event data dictionary from an EventMessage.""" + event_data = { + "event": "message", + "data": event_message.message.model_dump_json( + by_alias=True, exclude_none=True + ), + } + + # If an event ID was provided, include it + if event_message.event_id: + event_data["id"] = event_message.event_id + + return event_data + + async def _clean_up_memory_streams(self, request_id: RequestId) -> None: + """Clean up memory streams for a given request ID.""" + if request_id in self._request_streams: + try: + # Close the request stream + await self._request_streams[request_id][0].aclose() + await self._request_streams[request_id][1].aclose() + except Exception as e: + logger.debug(f"Error closing memory streams: {e}") + finally: + # Remove the request stream from the mapping + self._request_streams.pop(request_id, None) + + async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None: + """Application entry point that handles all HTTP requests""" + request = Request(scope, receive) + if self._terminated: + # If the session has been terminated, return 404 Not Found + response = self._create_error_response( + "Not Found: Session has been terminated", + HTTPStatus.NOT_FOUND, + ) + await response(scope, receive, send) + return + + if request.method == "POST": + await self._handle_post_request(scope, request, receive, send) + elif request.method == "GET": + await self._handle_get_request(request, send) + elif request.method == "DELETE": + await self._handle_delete_request(request, send) + else: + await self._handle_unsupported_request(request, send) + + def _check_accept_headers(self, request: Request) -> tuple[bool, bool]: + """Check if the request accepts the required media types.""" + accept_header = request.headers.get("accept", "") + accept_types = [media_type.strip() for media_type in accept_header.split(",")] + + has_json = any( + media_type.startswith(CONTENT_TYPE_JSON) for media_type in accept_types + ) + has_sse = any( + media_type.startswith(CONTENT_TYPE_SSE) for media_type in accept_types + ) + + return has_json, has_sse + + def _check_content_type(self, request: Request) -> bool: + """Check if the request has the correct Content-Type.""" + content_type = request.headers.get("content-type", "") + content_type_parts = [ + part.strip() for part in content_type.split(";")[0].split(",") + ] + + return any(part == CONTENT_TYPE_JSON for part in content_type_parts) + + async def _handle_post_request( + self, scope: Scope, request: Request, receive: Receive, send: Send + ) -> None: + """Handle POST requests containing JSON-RPC messages.""" + writer = self._read_stream_writer + if writer is None: + raise ValueError( + "No read stream writer available. Ensure connect() is called first." + ) + try: + # Check Accept headers + has_json, has_sse = self._check_accept_headers(request) + if not (has_json and has_sse): + response = self._create_error_response( + ( + "Not Acceptable: Client must accept both application/json and " + "text/event-stream" + ), + HTTPStatus.NOT_ACCEPTABLE, + ) + await response(scope, receive, send) + return + + # Validate Content-Type + if not self._check_content_type(request): + response = self._create_error_response( + "Unsupported Media Type: Content-Type must be application/json", + HTTPStatus.UNSUPPORTED_MEDIA_TYPE, + ) + await response(scope, receive, send) + return + + # Parse the body - only read it once + body = await request.body() + if len(body) > MAXIMUM_MESSAGE_SIZE: + response = self._create_error_response( + "Payload Too Large: Message exceeds maximum size", + HTTPStatus.REQUEST_ENTITY_TOO_LARGE, + ) + await response(scope, receive, send) + return + + try: + raw_message = json.loads(body) + except json.JSONDecodeError as e: + response = self._create_error_response( + f"Parse error: {str(e)}", HTTPStatus.BAD_REQUEST, PARSE_ERROR + ) + await response(scope, receive, send) + return + + try: + message = JSONRPCMessage.model_validate(raw_message) + except ValidationError as e: + response = self._create_error_response( + f"Validation error: {str(e)}", + HTTPStatus.BAD_REQUEST, + INVALID_PARAMS, + ) + await response(scope, receive, send) + return + + # Check if this is an initialization request + is_initialization_request = ( + isinstance(message.root, JSONRPCRequest) + and message.root.method == "initialize" + ) + + if is_initialization_request: + # Check if the server already has an established session + if self.mcp_session_id: + # Check if request has a session ID + request_session_id = self._get_session_id(request) + + # If request has a session ID but doesn't match, return 404 + if request_session_id and request_session_id != self.mcp_session_id: + response = self._create_error_response( + "Not Found: Invalid or expired session ID", + HTTPStatus.NOT_FOUND, + ) + await response(scope, receive, send) + return + # For non-initialization requests, validate the session + elif not await self._validate_session(request, send): + return + + # For notifications and responses only, return 202 Accepted + if not isinstance(message.root, JSONRPCRequest): + # Create response object and send it + response = self._create_json_response( + None, + HTTPStatus.ACCEPTED, + ) + await response(scope, receive, send) + + # Process the message after sending the response + session_message = SessionMessage(message) + await writer.send(session_message) + + return + + # Extract the request ID outside the try block for proper scope + request_id = str(message.root.id) + # Register this stream for the request ID + self._request_streams[request_id] = anyio.create_memory_object_stream[ + EventMessage + ](0) + request_stream_reader = self._request_streams[request_id][1] + + if self.is_json_response_enabled: + # Process the message + session_message = SessionMessage(message) + await writer.send(session_message) + try: + # Process messages from the request-specific stream + # We need to collect all messages until we get a response + response_message = None + + # Use similar approach to SSE writer for consistency + async for event_message in request_stream_reader: + # If it's a response, this is what we're waiting for + if isinstance( + event_message.message.root, JSONRPCResponse | JSONRPCError + ): + response_message = event_message.message + break + # For notifications and request, keep waiting + else: + logger.debug( + f"received: {event_message.message.root.method}" + ) + + # At this point we should have a response + if response_message: + # Create JSON response + response = self._create_json_response(response_message) + await response(scope, receive, send) + else: + # This shouldn't happen in normal operation + logger.error( + "No response message received before stream closed" + ) + response = self._create_error_response( + "Error processing request: No response received", + HTTPStatus.INTERNAL_SERVER_ERROR, + ) + await response(scope, receive, send) + except Exception as e: + logger.exception(f"Error processing JSON response: {e}") + response = self._create_error_response( + f"Error processing request: {str(e)}", + HTTPStatus.INTERNAL_SERVER_ERROR, + INTERNAL_ERROR, + ) + await response(scope, receive, send) + finally: + await self._clean_up_memory_streams(request_id) + else: + # Create SSE stream + sse_stream_writer, sse_stream_reader = ( + anyio.create_memory_object_stream[dict[str, str]](0) + ) + + async def sse_writer(): + # Get the request ID from the incoming request message + try: + async with sse_stream_writer, request_stream_reader: + # Process messages from the request-specific stream + async for event_message in request_stream_reader: + # Build the event data + event_data = self._create_event_data(event_message) + await sse_stream_writer.send(event_data) + + # If response, remove from pending streams and close + if isinstance( + event_message.message.root, + JSONRPCResponse | JSONRPCError, + ): + break + except Exception as e: + logger.exception(f"Error in SSE writer: {e}") + finally: + logger.debug("Closing SSE writer") + await self._clean_up_memory_streams(request_id) + + # Create and start EventSourceResponse + # SSE stream mode (original behavior) + # Set up headers + headers = { + "Cache-Control": "no-cache, no-transform", + "Connection": "keep-alive", + "Content-Type": CONTENT_TYPE_SSE, + **( + {MCP_SESSION_ID_HEADER: self.mcp_session_id} + if self.mcp_session_id + else {} + ), + } + response = EventSourceResponse( + content=sse_stream_reader, + data_sender_callable=sse_writer, + headers=headers, + ) + + # Start the SSE response (this will send headers immediately) + try: + # First send the response to establish the SSE connection + async with anyio.create_task_group() as tg: + tg.start_soon(response, scope, receive, send) + # Then send the message to be processed by the server + session_message = SessionMessage(message) + await writer.send(session_message) + except Exception: + logger.exception("SSE response error") + await sse_stream_writer.aclose() + await sse_stream_reader.aclose() + await self._clean_up_memory_streams(request_id) + + except Exception as err: + logger.exception("Error handling POST request") + response = self._create_error_response( + f"Error handling POST request: {err}", + HTTPStatus.INTERNAL_SERVER_ERROR, + INTERNAL_ERROR, + ) + await response(scope, receive, send) + if writer: + await writer.send(Exception(err)) + return + + async def _handle_get_request(self, request: Request, send: Send) -> None: + """ + Handle GET request to establish SSE. + + This allows the server to communicate to the client without the client + first sending data via HTTP POST. The server can send JSON-RPC requests + and notifications on this stream. + """ + writer = self._read_stream_writer + if writer is None: + raise ValueError( + "No read stream writer available. Ensure connect() is called first." + ) + + # Validate Accept header - must include text/event-stream + _, has_sse = self._check_accept_headers(request) + + if not has_sse: + response = self._create_error_response( + "Not Acceptable: Client must accept text/event-stream", + HTTPStatus.NOT_ACCEPTABLE, + ) + await response(request.scope, request.receive, send) + return + + if not await self._validate_session(request, send): + return + # Handle resumability: check for Last-Event-ID header + if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER): + await self._replay_events(last_event_id, request, send) + return + + headers = { + "Cache-Control": "no-cache, no-transform", + "Connection": "keep-alive", + "Content-Type": CONTENT_TYPE_SSE, + } + + if self.mcp_session_id: + headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id + + # Check if we already have an active GET stream + if GET_STREAM_KEY in self._request_streams: + response = self._create_error_response( + "Conflict: Only one SSE stream is allowed per session", + HTTPStatus.CONFLICT, + ) + await response(request.scope, request.receive, send) + return + + # Create SSE stream + sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[ + dict[str, str] + ](0) + + async def standalone_sse_writer(): + try: + # Create a standalone message stream for server-initiated messages + + self._request_streams[GET_STREAM_KEY] = ( + anyio.create_memory_object_stream[EventMessage](0) + ) + standalone_stream_reader = self._request_streams[GET_STREAM_KEY][1] + + async with sse_stream_writer, standalone_stream_reader: + # Process messages from the standalone stream + async for event_message in standalone_stream_reader: + # For the standalone stream, we handle: + # - JSONRPCNotification (server sends notifications to client) + # - JSONRPCRequest (server sends requests to client) + # We should NOT receive JSONRPCResponse + + # Send the message via SSE + event_data = self._create_event_data(event_message) + await sse_stream_writer.send(event_data) + except Exception as e: + logger.exception(f"Error in standalone SSE writer: {e}") + finally: + logger.debug("Closing standalone SSE writer") + await self._clean_up_memory_streams(GET_STREAM_KEY) + + # Create and start EventSourceResponse + response = EventSourceResponse( + content=sse_stream_reader, + data_sender_callable=standalone_sse_writer, + headers=headers, + ) + + try: + # This will send headers immediately and establish the SSE connection + await response(request.scope, request.receive, send) + except Exception as e: + logger.exception(f"Error in standalone SSE response: {e}") + await sse_stream_writer.aclose() + await sse_stream_reader.aclose() + await self._clean_up_memory_streams(GET_STREAM_KEY) + + async def _handle_delete_request(self, request: Request, send: Send) -> None: + """Handle DELETE requests for explicit session termination.""" + # Validate session ID + if not self.mcp_session_id: + # If no session ID set, return Method Not Allowed + response = self._create_error_response( + "Method Not Allowed: Session termination not supported", + HTTPStatus.METHOD_NOT_ALLOWED, + ) + await response(request.scope, request.receive, send) + return + + if not await self._validate_session(request, send): + return + + await self._terminate_session() + + response = self._create_json_response( + None, + HTTPStatus.OK, + ) + await response(request.scope, request.receive, send) + + async def _terminate_session(self) -> None: + """Terminate the current session, closing all streams. + + Once terminated, all requests with this session ID will receive 404 Not Found. + """ + + self._terminated = True + logger.info(f"Terminating session: {self.mcp_session_id}") + + # We need a copy of the keys to avoid modification during iteration + request_stream_keys = list(self._request_streams.keys()) + + # Close all request streams asynchronously + for key in request_stream_keys: + try: + await self._clean_up_memory_streams(key) + except Exception as e: + logger.debug(f"Error closing stream {key} during termination: {e}") + + # Clear the request streams dictionary immediately + self._request_streams.clear() + try: + if self._read_stream_writer is not None: + await self._read_stream_writer.aclose() + if self._read_stream is not None: + await self._read_stream.aclose() + if self._write_stream_reader is not None: + await self._write_stream_reader.aclose() + if self._write_stream is not None: + await self._write_stream.aclose() + except Exception as e: + logger.debug(f"Error closing streams: {e}") + + async def _handle_unsupported_request(self, request: Request, send: Send) -> None: + """Handle unsupported HTTP methods.""" + headers = { + "Content-Type": CONTENT_TYPE_JSON, + "Allow": "GET, POST, DELETE", + } + if self.mcp_session_id: + headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id + + response = self._create_error_response( + "Method Not Allowed", + HTTPStatus.METHOD_NOT_ALLOWED, + headers=headers, + ) + await response(request.scope, request.receive, send) + + async def _validate_session(self, request: Request, send: Send) -> bool: + """Validate the session ID in the request.""" + if not self.mcp_session_id: + # If we're not using session IDs, return True + return True + + # Get the session ID from the request headers + request_session_id = self._get_session_id(request) + + # If no session ID provided but required, return error + if not request_session_id: + response = self._create_error_response( + "Bad Request: Missing session ID", + HTTPStatus.BAD_REQUEST, + ) + await response(request.scope, request.receive, send) + return False + + # If session ID doesn't match, return error + if request_session_id != self.mcp_session_id: + response = self._create_error_response( + "Not Found: Invalid or expired session ID", + HTTPStatus.NOT_FOUND, + ) + await response(request.scope, request.receive, send) + return False + + return True + + async def _replay_events( + self, last_event_id: str, request: Request, send: Send + ) -> None: + """ + Replays events that would have been sent after the specified event ID. + Only used when resumability is enabled. + """ + event_store = self._event_store + if not event_store: + return + + try: + headers = { + "Cache-Control": "no-cache, no-transform", + "Connection": "keep-alive", + "Content-Type": CONTENT_TYPE_SSE, + } + + if self.mcp_session_id: + headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id + + # Create SSE stream for replay + sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[ + dict[str, str] + ](0) + + async def replay_sender(): + try: + async with sse_stream_writer: + # Define an async callback for sending events + async def send_event(event_message: EventMessage) -> None: + event_data = self._create_event_data(event_message) + await sse_stream_writer.send(event_data) + + # Replay past events and get the stream ID + stream_id = await event_store.replay_events_after( + last_event_id, send_event + ) + + # If stream ID not in mapping, create it + if stream_id and stream_id not in self._request_streams: + self._request_streams[stream_id] = ( + anyio.create_memory_object_stream[EventMessage](0) + ) + msg_reader = self._request_streams[stream_id][1] + + # Forward messages to SSE + async with msg_reader: + async for event_message in msg_reader: + event_data = self._create_event_data(event_message) + + await sse_stream_writer.send(event_data) + except Exception as e: + logger.exception(f"Error in replay sender: {e}") + + # Create and start EventSourceResponse + response = EventSourceResponse( + content=sse_stream_reader, + data_sender_callable=replay_sender, + headers=headers, + ) + + try: + await response(request.scope, request.receive, send) + except Exception as e: + logger.exception(f"Error in replay response: {e}") + finally: + await sse_stream_writer.aclose() + await sse_stream_reader.aclose() + + except Exception as e: + logger.exception(f"Error replaying events: {e}") + response = self._create_error_response( + f"Error replaying events: {str(e)}", + HTTPStatus.INTERNAL_SERVER_ERROR, + INTERNAL_ERROR, + ) + await response(request.scope, request.receive, send) + + @asynccontextmanager + async def connect( + self, + ) -> AsyncGenerator[ + tuple[ + MemoryObjectReceiveStream[SessionMessage | Exception], + MemoryObjectSendStream[SessionMessage], + ], + None, + ]: + """Context manager that provides read and write streams for a connection. + + Yields: + Tuple of (read_stream, write_stream) for bidirectional communication + """ + + # Create the memory streams for this connection + + read_stream_writer, read_stream = anyio.create_memory_object_stream[ + SessionMessage | Exception + ](0) + write_stream, write_stream_reader = anyio.create_memory_object_stream[ + SessionMessage + ](0) + + # Store the streams + self._read_stream_writer = read_stream_writer + self._read_stream = read_stream + self._write_stream_reader = write_stream_reader + self._write_stream = write_stream + + # Start a task group for message routing + async with anyio.create_task_group() as tg: + # Create a message router that distributes messages to request streams + async def message_router(): + try: + async for session_message in write_stream_reader: + # Determine which request stream(s) should receive this message + message = session_message.message + target_request_id = None + if isinstance( + message.root, JSONRPCNotification | JSONRPCRequest + ): + # Extract related_request_id from meta if it exists + if ( + session_message.metadata is not None + and isinstance( + session_message.metadata, + ServerMessageMetadata, + ) + and session_message.metadata.related_request_id + is not None + ): + target_request_id = str( + session_message.metadata.related_request_id + ) + else: + target_request_id = str(message.root.id) + + request_stream_id = target_request_id or GET_STREAM_KEY + + # Store the event if we have an event store, + # regardless of whether a client is connected + # messages will be replayed on the re-connect + event_id = None + if self._event_store: + event_id = await self._event_store.store_event( + request_stream_id, message + ) + logger.debug(f"Stored {event_id} from {request_stream_id}") + + if request_stream_id in self._request_streams: + try: + # Send both the message and the event ID + await self._request_streams[request_stream_id][0].send( + EventMessage(message, event_id) + ) + except ( + anyio.BrokenResourceError, + anyio.ClosedResourceError, + ): + # Stream might be closed, remove from registry + self._request_streams.pop(request_stream_id, None) + else: + logging.debug( + f"""Request stream {request_stream_id} not found + for message. Still processing message as the client + might reconnect and replay.""" + ) + except Exception as e: + logger.exception(f"Error in message router: {e}") + + # Start the message router + tg.start_soon(message_router) + + try: + # Yield the streams for the caller to use + yield read_stream, write_stream + finally: + for stream_id in list(self._request_streams.keys()): + try: + await self._clean_up_memory_streams(stream_id) + except Exception as e: + logger.debug(f"Error closing request stream: {e}") + pass + self._request_streams.clear() + + # Clean up the read and write streams + try: + await read_stream_writer.aclose() + await read_stream.aclose() + await write_stream_reader.aclose() + await write_stream.aclose() + except Exception as e: + logger.debug(f"Error closing streams: {e}") diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py new file mode 100644 index 000000000..e5ef8b4aa --- /dev/null +++ b/src/mcp/server/streamable_http_manager.py @@ -0,0 +1,258 @@ +"""StreamableHTTP Session Manager for MCP servers.""" + +from __future__ import annotations + +import contextlib +import logging +import threading +from collections.abc import AsyncIterator +from http import HTTPStatus +from typing import Any +from uuid import uuid4 + +import anyio +from anyio.abc import TaskStatus +from starlette.requests import Request +from starlette.responses import Response +from starlette.types import Receive, Scope, Send + +from mcp.server.lowlevel.server import Server as MCPServer +from mcp.server.streamable_http import ( + MCP_SESSION_ID_HEADER, + EventStore, + StreamableHTTPServerTransport, +) + +logger = logging.getLogger(__name__) + + +class StreamableHTTPSessionManager: + """ + Manages StreamableHTTP sessions with optional resumability via event store. + + This class abstracts away the complexity of session management, event storage, + and request handling for StreamableHTTP transports. It handles: + + 1. Session tracking for clients + 2. Resumability via an optional event store + 3. Connection management and lifecycle + 4. Request handling and transport setup + + Important: Only one StreamableHTTPSessionManager instance should be created + per application. The instance cannot be reused after its run() context has + completed. If you need to restart the manager, create a new instance. + + Args: + app: The MCP server instance + event_store: Optional event store for resumability support. + If provided, enables resumable connections where clients + can reconnect and receive missed events. + If None, sessions are still tracked but not resumable. + json_response: Whether to use JSON responses instead of SSE streams + stateless: If True, creates a completely fresh transport for each request + with no session tracking or state persistence between requests. + + """ + + def __init__( + self, + app: MCPServer[Any], + event_store: EventStore | None = None, + json_response: bool = False, + stateless: bool = False, + ): + self.app = app + self.event_store = event_store + self.json_response = json_response + self.stateless = stateless + + # Session tracking (only used if not stateless) + self._session_creation_lock = anyio.Lock() + self._server_instances: dict[str, StreamableHTTPServerTransport] = {} + + # The task group will be set during lifespan + self._task_group = None + # Thread-safe tracking of run() calls + self._run_lock = threading.Lock() + self._has_started = False + + @contextlib.asynccontextmanager + async def run(self) -> AsyncIterator[None]: + """ + Run the session manager with proper lifecycle management. + + This creates and manages the task group for all session operations. + + Important: This method can only be called once per instance. The same + StreamableHTTPSessionManager instance cannot be reused after this + context manager exits. Create a new instance if you need to restart. + + Use this in the lifespan context manager of your Starlette app: + + @contextlib.asynccontextmanager + async def lifespan(app: Starlette) -> AsyncIterator[None]: + async with session_manager.run(): + yield + """ + # Thread-safe check to ensure run() is only called once + with self._run_lock: + if self._has_started: + raise RuntimeError( + "StreamableHTTPSessionManager .run() can only be called " + "once per instance. Create a new instance if you need to run again." + ) + self._has_started = True + + async with anyio.create_task_group() as tg: + # Store the task group for later use + self._task_group = tg + logger.info("StreamableHTTP session manager started") + try: + yield # Let the application run + finally: + logger.info("StreamableHTTP session manager shutting down") + # Cancel task group to stop all spawned tasks + tg.cancel_scope.cancel() + self._task_group = None + # Clear any remaining server instances + self._server_instances.clear() + + async def handle_request( + self, + scope: Scope, + receive: Receive, + send: Send, + ) -> None: + """ + Process ASGI request with proper session handling and transport setup. + + Dispatches to the appropriate handler based on stateless mode. + + Args: + scope: ASGI scope + receive: ASGI receive function + send: ASGI send function + """ + if self._task_group is None: + raise RuntimeError("Task group is not initialized. Make sure to use run().") + + # Dispatch to the appropriate handler + if self.stateless: + await self._handle_stateless_request(scope, receive, send) + else: + await self._handle_stateful_request(scope, receive, send) + + async def _handle_stateless_request( + self, + scope: Scope, + receive: Receive, + send: Send, + ) -> None: + """ + Process request in stateless mode - creating a new transport for each request. + + Args: + scope: ASGI scope + receive: ASGI receive function + send: ASGI send function + """ + logger.debug("Stateless mode: Creating new transport for this request") + # No session ID needed in stateless mode + http_transport = StreamableHTTPServerTransport( + mcp_session_id=None, # No session tracking in stateless mode + is_json_response_enabled=self.json_response, + event_store=None, # No event store in stateless mode + ) + + # Start server in a new task + async def run_stateless_server( + *, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED + ): + async with http_transport.connect() as streams: + read_stream, write_stream = streams + task_status.started() + await self.app.run( + read_stream, + write_stream, + self.app.create_initialization_options(), + stateless=True, + ) + + # Assert task group is not None for type checking + assert self._task_group is not None + # Start the server task + await self._task_group.start(run_stateless_server) + + # Handle the HTTP request and return the response + await http_transport.handle_request(scope, receive, send) + + async def _handle_stateful_request( + self, + scope: Scope, + receive: Receive, + send: Send, + ) -> None: + """ + Process request in stateful mode - maintaining session state between requests. + + Args: + scope: ASGI scope + receive: ASGI receive function + send: ASGI send function + """ + request = Request(scope, receive) + request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER) + + # Existing session case + if ( + request_mcp_session_id is not None + and request_mcp_session_id in self._server_instances + ): + transport = self._server_instances[request_mcp_session_id] + logger.debug("Session already exists, handling request directly") + await transport.handle_request(scope, receive, send) + return + + if request_mcp_session_id is None: + # New session case + logger.debug("Creating new transport") + async with self._session_creation_lock: + new_session_id = uuid4().hex + http_transport = StreamableHTTPServerTransport( + mcp_session_id=new_session_id, + is_json_response_enabled=self.json_response, + event_store=self.event_store, # May be None (no resumability) + ) + + assert http_transport.mcp_session_id is not None + self._server_instances[http_transport.mcp_session_id] = http_transport + logger.info(f"Created new transport with session ID: {new_session_id}") + + # Define the server runner + async def run_server( + *, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED + ) -> None: + async with http_transport.connect() as streams: + read_stream, write_stream = streams + task_status.started() + await self.app.run( + read_stream, + write_stream, + self.app.create_initialization_options(), + stateless=False, # Stateful mode + ) + + # Assert task group is not None for type checking + assert self._task_group is not None + # Start the server task + await self._task_group.start(run_server) + + # Handle the HTTP request and return the response + await http_transport.handle_request(scope, receive, send) + else: + # Invalid session ID + response = Response( + "Bad Request: No valid session ID provided", + status_code=HTTPStatus.BAD_REQUEST, + ) + await response(scope, receive, send) diff --git a/src/mcp/server/websocket.py b/src/mcp/server/websocket.py index aee855cf1..9dc3f2a25 100644 --- a/src/mcp/server/websocket.py +++ b/src/mcp/server/websocket.py @@ -8,6 +8,7 @@ from starlette.websockets import WebSocket import mcp.types as types +from mcp.shared.message import SessionMessage logger = logging.getLogger(__name__) @@ -22,11 +23,11 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send): websocket = WebSocket(scope, receive, send) await websocket.accept(subprotocol="mcp") - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] + read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] - write_stream: MemoryObjectSendStream[types.JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] + write_stream: MemoryObjectSendStream[SessionMessage] + write_stream_reader: MemoryObjectReceiveStream[SessionMessage] read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) @@ -41,15 +42,18 @@ async def ws_reader(): await read_stream_writer.send(exc) continue - await read_stream_writer.send(client_message) + session_message = SessionMessage(client_message) + await read_stream_writer.send(session_message) except anyio.ClosedResourceError: await websocket.close() async def ws_writer(): try: async with write_stream_reader: - async for message in write_stream_reader: - obj = message.model_dump_json(by_alias=True, exclude_none=True) + async for session_message in write_stream_reader: + obj = session_message.message.model_dump_json( + by_alias=True, exclude_none=True + ) await websocket.send_text(obj) except anyio.ClosedResourceError: await websocket.close() diff --git a/src/mcp/shared/_httpx_utils.py b/src/mcp/shared/_httpx_utils.py new file mode 100644 index 000000000..95080bde1 --- /dev/null +++ b/src/mcp/shared/_httpx_utils.py @@ -0,0 +1,62 @@ +"""Utilities for creating standardized httpx AsyncClient instances.""" + +from typing import Any + +import httpx + +__all__ = ["create_mcp_http_client"] + + +def create_mcp_http_client( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, +) -> httpx.AsyncClient: + """Create a standardized httpx AsyncClient with MCP defaults. + + This function provides common defaults used throughout the MCP codebase: + - follow_redirects=True (always enabled) + - Default timeout of 30 seconds if not specified + + Args: + headers: Optional headers to include with all requests. + timeout: Request timeout as httpx.Timeout object. + Defaults to 30 seconds if not specified. + + Returns: + Configured httpx.AsyncClient instance with MCP defaults. + + Note: + The returned AsyncClient must be used as a context manager to ensure + proper cleanup of connections. + + Examples: + # Basic usage with MCP defaults + async with create_mcp_http_client() as client: + response = await client.get("https://api.example.com") + + # With custom headers + headers = {"Authorization": "Bearer token"} + async with create_mcp_http_client(headers) as client: + response = await client.get("/endpoint") + + # With both custom headers and timeout + timeout = httpx.Timeout(60.0, read=300.0) + async with create_mcp_http_client(headers, timeout) as client: + response = await client.get("/long-request") + """ + # Set MCP defaults + kwargs: dict[str, Any] = { + "follow_redirects": True, + } + + # Handle timeout + if timeout is None: + kwargs["timeout"] = httpx.Timeout(30.0) + else: + kwargs["timeout"] = timeout + + # Handle headers + if headers is not None: + kwargs["headers"] = headers + + return httpx.AsyncClient(**kwargs) diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index abf87a3aa..b53f8dd63 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -19,11 +19,11 @@ SamplingFnT, ) from mcp.server import Server -from mcp.types import JSONRPCMessage +from mcp.shared.message import SessionMessage MessageStream = tuple[ - MemoryObjectReceiveStream[JSONRPCMessage | Exception], - MemoryObjectSendStream[JSONRPCMessage], + MemoryObjectReceiveStream[SessionMessage | Exception], + MemoryObjectSendStream[SessionMessage], ] @@ -40,10 +40,10 @@ async def create_client_server_memory_streams() -> ( """ # Create streams for both directions server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - JSONRPCMessage | Exception + SessionMessage | Exception ](1) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - JSONRPCMessage | Exception + SessionMessage | Exception ](1) client_streams = (server_to_client_receive, client_to_server_send) diff --git a/src/mcp/shared/message.py b/src/mcp/shared/message.py new file mode 100644 index 000000000..5583f4795 --- /dev/null +++ b/src/mcp/shared/message.py @@ -0,0 +1,43 @@ +""" +Message wrapper with metadata support. + +This module defines a wrapper type that combines JSONRPCMessage with metadata +to support transport-specific features like resumability. +""" + +from collections.abc import Awaitable, Callable +from dataclasses import dataclass + +from mcp.types import JSONRPCMessage, RequestId + +ResumptionToken = str + +ResumptionTokenUpdateCallback = Callable[[ResumptionToken], Awaitable[None]] + + +@dataclass +class ClientMessageMetadata: + """Metadata specific to client messages.""" + + resumption_token: ResumptionToken | None = None + on_resumption_token_update: Callable[[ResumptionToken], Awaitable[None]] | None = ( + None + ) + + +@dataclass +class ServerMessageMetadata: + """Metadata specific to server messages.""" + + related_request_id: RequestId | None = None + + +MessageMetadata = ClientMessageMetadata | ServerMessageMetadata | None + + +@dataclass +class SessionMessage: + """A message with specific metadata for transport-specific features.""" + + message: JSONRPCMessage + metadata: MessageMetadata = None diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 11daedc98..cce8b1184 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -6,13 +6,13 @@ from typing import Any, Generic, TypeVar import anyio -import anyio.lowlevel import httpx from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import BaseModel from typing_extensions import Self from mcp.shared.exceptions import McpError +from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage from mcp.types import ( CancelledNotification, ClientNotification, @@ -172,8 +172,8 @@ class BaseSession( def __init__( self, - read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], - write_stream: MemoryObjectSendStream[JSONRPCMessage], + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], + write_stream: MemoryObjectSendStream[SessionMessage], receive_request_type: type[ReceiveRequestT], receive_notification_type: type[ReceiveNotificationT], # If none, reading will never time out @@ -213,6 +213,7 @@ async def send_request( request: SendRequestT, result_type: type[ReceiveResultT], request_read_timeout_seconds: timedelta | None = None, + metadata: MessageMetadata = None, ) -> ReceiveResultT: """ Sends a request and wait for a response. Raises an McpError if the @@ -240,7 +241,11 @@ async def send_request( # TODO: Support progress callbacks - await self._write_stream.send(JSONRPCMessage(jsonrpc_request)) + await self._write_stream.send( + SessionMessage( + message=JSONRPCMessage(jsonrpc_request), metadata=metadata + ) + ) # request read timeout takes precedence over session read timeout timeout = None @@ -274,24 +279,36 @@ async def send_request( await response_stream.aclose() await response_stream_reader.aclose() - async def send_notification(self, notification: SendNotificationT) -> None: + async def send_notification( + self, + notification: SendNotificationT, + related_request_id: RequestId | None = None, + ) -> None: """ Emits a notification, which is a one-way message that does not expect a response. """ + # Some transport implementations may need to set the related_request_id + # to attribute to the notifications to the request that triggered them. jsonrpc_notification = JSONRPCNotification( jsonrpc="2.0", **notification.model_dump(by_alias=True, mode="json", exclude_none=True), ) - - await self._write_stream.send(JSONRPCMessage(jsonrpc_notification)) + session_message = SessionMessage( + message=JSONRPCMessage(jsonrpc_notification), + metadata=ServerMessageMetadata(related_request_id=related_request_id) + if related_request_id + else None, + ) + await self._write_stream.send(session_message) async def _send_response( self, request_id: RequestId, response: SendResultT | ErrorData ) -> None: if isinstance(response, ErrorData): jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response) - await self._write_stream.send(JSONRPCMessage(jsonrpc_error)) + session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error)) + await self._write_stream.send(session_message) else: jsonrpc_response = JSONRPCResponse( jsonrpc="2.0", @@ -300,7 +317,8 @@ async def _send_response( by_alias=True, mode="json", exclude_none=True ), ) - await self._write_stream.send(JSONRPCMessage(jsonrpc_response)) + session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response)) + await self._write_stream.send(session_message) async def _receive_loop(self) -> None: async with ( @@ -310,15 +328,15 @@ async def _receive_loop(self) -> None: async for message in self._read_stream: if isinstance(message, Exception): await self._handle_incoming(message) - elif isinstance(message.root, JSONRPCRequest): + elif isinstance(message.message.root, JSONRPCRequest): validated_request = self._receive_request_type.model_validate( - message.root.model_dump( + message.message.root.model_dump( by_alias=True, mode="json", exclude_none=True ) ) responder = RequestResponder( - request_id=message.root.id, + request_id=message.message.root.id, request_meta=validated_request.root.params.meta if validated_request.root.params else None, @@ -333,10 +351,10 @@ async def _receive_loop(self) -> None: if not responder._completed: # type: ignore[reportPrivateUsage] await self._handle_incoming(responder) - elif isinstance(message.root, JSONRPCNotification): + elif isinstance(message.message.root, JSONRPCNotification): try: notification = self._receive_notification_type.model_validate( - message.root.model_dump( + message.message.root.model_dump( by_alias=True, mode="json", exclude_none=True ) ) @@ -352,12 +370,12 @@ async def _receive_loop(self) -> None: # For other validation errors, log and continue logging.warning( f"Failed to validate notification: {e}. " - f"Message was: {message.root}" + f"Message was: {message.message.root}" ) else: # Response or error - stream = self._response_streams.pop(message.root.id, None) + stream = self._response_streams.pop(message.message.root.id, None) if stream: - await stream.send(message.root) + await stream.send(message.message.root) else: await self._handle_incoming( RuntimeError( diff --git a/tests/client/test_config.py b/tests/client/test_config.py index 97030e069..6577d663c 100644 --- a/tests/client/test_config.py +++ b/tests/client/test_config.py @@ -48,3 +48,28 @@ def test_command_execution(mock_config_path: Path): assert result.returncode == 0 assert "usage" in result.stdout.lower() + + +def test_absolute_uv_path(mock_config_path: Path): + """Test that the absolute path to uv is used when available.""" + # Mock the shutil.which function to return a fake path + mock_uv_path = "/usr/local/bin/uv" + + with patch("mcp.cli.claude.get_uv_path", return_value=mock_uv_path): + # Setup + server_name = "test_server" + file_spec = "test_server.py:app" + + # Update config + success = update_claude_config(file_spec=file_spec, server_name=server_name) + assert success + + # Read the generated config + config_file = mock_config_path / "claude_desktop_config.json" + config = json.loads(config_file.read_text()) + + # Verify the command is the absolute path + server_config = config["mcpServers"][server_name] + command = server_config["command"] + + assert command == mock_uv_path \ No newline at end of file diff --git a/tests/client/test_logging_callback.py b/tests/client/test_logging_callback.py index 797f817e1..0c9eeb397 100644 --- a/tests/client/test_logging_callback.py +++ b/tests/client/test_logging_callback.py @@ -78,6 +78,8 @@ async def message_handler( ) assert log_result.isError is False assert len(logging_collector.log_messages) == 1 - assert logging_collector.log_messages[0] == LoggingMessageNotificationParams( - level="info", logger="test_logger", data="Test log message" - ) + # Create meta object with related_request_id added dynamically + log = logging_collector.log_messages[0] + assert log.level == "info" + assert log.logger == "test_logger" + assert log.data == "Test log message" diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 543ebb2f0..6abcf70cb 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -3,6 +3,7 @@ import mcp.types as types from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession +from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder from mcp.types import ( LATEST_PROTOCOL_VERSION, @@ -24,10 +25,10 @@ @pytest.mark.anyio async def test_client_session_initialize(): client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - JSONRPCMessage + SessionMessage ](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - JSONRPCMessage + SessionMessage ](1) initialized_notification = None @@ -35,7 +36,8 @@ async def test_client_session_initialize(): async def mock_server(): nonlocal initialized_notification - jsonrpc_request = await client_to_server_receive.receive() + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message assert isinstance(jsonrpc_request.root, JSONRPCRequest) request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) @@ -59,17 +61,20 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump( - by_alias=True, mode="json", exclude_none=True - ), + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump( + by_alias=True, mode="json", exclude_none=True + ), + ) ) ) ) - jsonrpc_notification = await client_to_server_receive.receive() + session_notification = await client_to_server_receive.receive() + jsonrpc_notification = session_notification.message assert isinstance(jsonrpc_notification.root, JSONRPCNotification) initialized_notification = ClientNotification.model_validate( jsonrpc_notification.model_dump( @@ -116,10 +121,10 @@ async def message_handler( @pytest.mark.anyio async def test_client_session_custom_client_info(): client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - JSONRPCMessage + SessionMessage ](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - JSONRPCMessage + SessionMessage ](1) custom_client_info = Implementation(name="test-client", version="1.2.3") @@ -128,7 +133,8 @@ async def test_client_session_custom_client_info(): async def mock_server(): nonlocal received_client_info - jsonrpc_request = await client_to_server_receive.receive() + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message assert isinstance(jsonrpc_request.root, JSONRPCRequest) request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) @@ -146,13 +152,15 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump( - by_alias=True, mode="json", exclude_none=True - ), + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump( + by_alias=True, mode="json", exclude_none=True + ), + ) ) ) ) @@ -181,10 +189,10 @@ async def mock_server(): @pytest.mark.anyio async def test_client_session_default_client_info(): client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - JSONRPCMessage + SessionMessage ](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - JSONRPCMessage + SessionMessage ](1) received_client_info = None @@ -192,7 +200,8 @@ async def test_client_session_default_client_info(): async def mock_server(): nonlocal received_client_info - jsonrpc_request = await client_to_server_receive.receive() + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message assert isinstance(jsonrpc_request.root, JSONRPCRequest) request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) @@ -210,13 +219,15 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump( - by_alias=True, mode="json", exclude_none=True - ), + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump( + by_alias=True, mode="json", exclude_none=True + ), + ) ) ) ) diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index 95747ffd1..523ba199a 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -3,6 +3,7 @@ import pytest from mcp.client.stdio import StdioServerParameters, stdio_client +from mcp.shared.message import SessionMessage from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse tee: str = shutil.which("tee") # type: ignore @@ -22,7 +23,8 @@ async def test_stdio_client(): async with write_stream: for message in messages: - await write_stream.send(message) + session_message = SessionMessage(message) + await write_stream.send(session_message) read_messages = [] async with read_stream: @@ -30,7 +32,7 @@ async def test_stdio_client(): if isinstance(message, Exception): raise message - read_messages.append(message) + read_messages.append(message.message) if len(read_messages) == 2: break diff --git a/tests/issues/test_188_concurrency.py b/tests/issues/test_188_concurrency.py index 2aa6c49cb..d0a86885f 100644 --- a/tests/issues/test_188_concurrency.py +++ b/tests/issues/test_188_concurrency.py @@ -35,7 +35,7 @@ async def slow_resource(): end_time = anyio.current_time() duration = end_time - start_time - assert duration < 3 * _sleep_time_seconds + assert duration < 6 * _sleep_time_seconds print(duration) diff --git a/tests/issues/test_192_request_id.py b/tests/issues/test_192_request_id.py index 00e187895..cf5eb6083 100644 --- a/tests/issues/test_192_request_id.py +++ b/tests/issues/test_192_request_id.py @@ -3,6 +3,7 @@ from mcp.server.lowlevel import NotificationOptions, Server from mcp.server.models import InitializationOptions +from mcp.shared.message import SessionMessage from mcp.types import ( LATEST_PROTOCOL_VERSION, ClientCapabilities, @@ -64,8 +65,10 @@ async def run_server(): jsonrpc="2.0", ) - await client_writer.send(JSONRPCMessage(root=init_req)) - await server_reader.receive() # Get init response but don't need to check it + await client_writer.send(SessionMessage(JSONRPCMessage(root=init_req))) + response = ( + await server_reader.receive() + ) # Get init response but don't need to check it # Send initialized notification initialized_notification = JSONRPCNotification( @@ -73,21 +76,23 @@ async def run_server(): params=NotificationParams().model_dump(by_alias=True, exclude_none=True), jsonrpc="2.0", ) - await client_writer.send(JSONRPCMessage(root=initialized_notification)) + await client_writer.send( + SessionMessage(JSONRPCMessage(root=initialized_notification)) + ) # Send ping request with custom ID ping_request = JSONRPCRequest( id=custom_request_id, method="ping", params={}, jsonrpc="2.0" ) - await client_writer.send(JSONRPCMessage(root=ping_request)) + await client_writer.send(SessionMessage(JSONRPCMessage(root=ping_request))) # Read response response = await server_reader.receive() # Verify response ID matches request ID assert ( - response.root.id == custom_request_id + response.message.root.id == custom_request_id ), "Response ID should match request ID" # Cancel server task diff --git a/tests/server/auth/middleware/test_bearer_auth.py b/tests/server/auth/middleware/test_bearer_auth.py index 9acb5ff09..e8c17a4c4 100644 --- a/tests/server/auth/middleware/test_bearer_auth.py +++ b/tests/server/auth/middleware/test_bearer_auth.py @@ -7,6 +7,7 @@ import pytest from starlette.authentication import AuthCredentials +from starlette.datastructures import Headers from starlette.exceptions import HTTPException from starlette.requests import Request from starlette.types import Message, Receive, Scope, Send @@ -221,6 +222,66 @@ async def test_token_without_expiry( assert user.access_token == no_expiry_access_token assert user.scopes == ["read", "write"] + async def test_lowercase_bearer_prefix( + self, + mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any], + valid_access_token: AccessToken, + ): + """Test with lowercase 'bearer' prefix in Authorization header""" + backend = BearerAuthBackend(provider=mock_oauth_provider) + add_token_to_provider(mock_oauth_provider, "valid_token", valid_access_token) + headers = Headers({"Authorization": "bearer valid_token"}) + scope = {"type": "http", "headers": headers.raw} + request = Request(scope) + result = await backend.authenticate(request) + assert result is not None + credentials, user = result + assert isinstance(credentials, AuthCredentials) + assert isinstance(user, AuthenticatedUser) + assert credentials.scopes == ["read", "write"] + assert user.display_name == "test_client" + assert user.access_token == valid_access_token + + async def test_mixed_case_bearer_prefix( + self, + mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any], + valid_access_token: AccessToken, + ): + """Test with mixed 'BeArEr' prefix in Authorization header""" + backend = BearerAuthBackend(provider=mock_oauth_provider) + add_token_to_provider(mock_oauth_provider, "valid_token", valid_access_token) + headers = Headers({"authorization": "BeArEr valid_token"}) + scope = {"type": "http", "headers": headers.raw} + request = Request(scope) + result = await backend.authenticate(request) + assert result is not None + credentials, user = result + assert isinstance(credentials, AuthCredentials) + assert isinstance(user, AuthenticatedUser) + assert credentials.scopes == ["read", "write"] + assert user.display_name == "test_client" + assert user.access_token == valid_access_token + + async def test_mixed_case_authorization_header( + self, + mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any], + valid_access_token: AccessToken, + ): + """Test authentication with mixed 'Authorization' header.""" + backend = BearerAuthBackend(provider=mock_oauth_provider) + add_token_to_provider(mock_oauth_provider, "valid_token", valid_access_token) + headers = Headers({"AuThOrIzAtIoN": "BeArEr valid_token"}) + scope = {"type": "http", "headers": headers.raw} + request = Request(scope) + result = await backend.authenticate(request) + assert result is not None + credentials, user = result + assert isinstance(credentials, AuthCredentials) + assert isinstance(user, AuthenticatedUser) + assert credentials.scopes == ["read", "write"] + assert user.display_name == "test_client" + assert user.access_token == valid_access_token + @pytest.mark.anyio class TestRequireAuthMiddleware: diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py new file mode 100644 index 000000000..67911e9e7 --- /dev/null +++ b/tests/server/fastmcp/test_integration.py @@ -0,0 +1,325 @@ +""" +Integration tests for FastMCP server functionality. + +These tests validate the proper functioning of FastMCP in various configurations, +including with and without authentication. +""" + +import multiprocessing +import socket +import time +from collections.abc import Generator + +import pytest +import uvicorn + +from mcp.client.session import ClientSession +from mcp.client.sse import sse_client +from mcp.client.streamable_http import streamablehttp_client +from mcp.server.fastmcp import FastMCP +from mcp.types import InitializeResult, TextContent + + +@pytest.fixture +def server_port() -> int: + """Get a free port for testing.""" + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@pytest.fixture +def server_url(server_port: int) -> str: + """Get the server URL for testing.""" + return f"http://127.0.0.1:{server_port}" + + +@pytest.fixture +def http_server_port() -> int: + """Get a free port for testing the StreamableHTTP server.""" + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@pytest.fixture +def http_server_url(http_server_port: int) -> str: + """Get the StreamableHTTP server URL for testing.""" + return f"http://127.0.0.1:{http_server_port}" + + +@pytest.fixture +def stateless_http_server_port() -> int: + """Get a free port for testing the stateless StreamableHTTP server.""" + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@pytest.fixture +def stateless_http_server_url(stateless_http_server_port: int) -> str: + """Get the stateless StreamableHTTP server URL for testing.""" + return f"http://127.0.0.1:{stateless_http_server_port}" + + +# Create a function to make the FastMCP server app +def make_fastmcp_app(): + """Create a FastMCP server without auth settings.""" + from starlette.applications import Starlette + + mcp = FastMCP(name="NoAuthServer") + + # Add a simple tool + @mcp.tool(description="A simple echo tool") + def echo(message: str) -> str: + return f"Echo: {message}" + + # Create the SSE app + app: Starlette = mcp.sse_app() + + return mcp, app + + +def make_fastmcp_streamable_http_app(): + """Create a FastMCP server with StreamableHTTP transport.""" + from starlette.applications import Starlette + + mcp = FastMCP(name="NoAuthServer") + + # Add a simple tool + @mcp.tool(description="A simple echo tool") + def echo(message: str) -> str: + return f"Echo: {message}" + + # Create the StreamableHTTP app + app: Starlette = mcp.streamable_http_app() + + return mcp, app + + +def make_fastmcp_stateless_http_app(): + """Create a FastMCP server with stateless StreamableHTTP transport.""" + from starlette.applications import Starlette + + mcp = FastMCP(name="StatelessServer", stateless_http=True) + + # Add a simple tool + @mcp.tool(description="A simple echo tool") + def echo(message: str) -> str: + return f"Echo: {message}" + + # Create the StreamableHTTP app + app: Starlette = mcp.streamable_http_app() + + return mcp, app + + +def run_server(server_port: int) -> None: + """Run the server.""" + _, app = make_fastmcp_app() + server = uvicorn.Server( + config=uvicorn.Config( + app=app, host="127.0.0.1", port=server_port, log_level="error" + ) + ) + print(f"Starting server on port {server_port}") + server.run() + + +def run_streamable_http_server(server_port: int) -> None: + """Run the StreamableHTTP server.""" + _, app = make_fastmcp_streamable_http_app() + server = uvicorn.Server( + config=uvicorn.Config( + app=app, host="127.0.0.1", port=server_port, log_level="error" + ) + ) + print(f"Starting StreamableHTTP server on port {server_port}") + server.run() + + +def run_stateless_http_server(server_port: int) -> None: + """Run the stateless StreamableHTTP server.""" + _, app = make_fastmcp_stateless_http_app() + server = uvicorn.Server( + config=uvicorn.Config( + app=app, host="127.0.0.1", port=server_port, log_level="error" + ) + ) + print(f"Starting stateless StreamableHTTP server on port {server_port}") + server.run() + + +@pytest.fixture() +def server(server_port: int) -> Generator[None, None, None]: + """Start the server in a separate process and clean up after the test.""" + proc = multiprocessing.Process(target=run_server, args=(server_port,), daemon=True) + print("Starting server process") + proc.start() + + # Wait for server to be running + max_attempts = 20 + attempt = 0 + print("Waiting for server to start") + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("127.0.0.1", server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError(f"Server failed to start after {max_attempts} attempts") + + yield + + print("Killing server") + proc.kill() + proc.join(timeout=2) + if proc.is_alive(): + print("Server process failed to terminate") + + +@pytest.fixture() +def streamable_http_server(http_server_port: int) -> Generator[None, None, None]: + """Start the StreamableHTTP server in a separate process.""" + proc = multiprocessing.Process( + target=run_streamable_http_server, args=(http_server_port,), daemon=True + ) + print("Starting StreamableHTTP server process") + proc.start() + + # Wait for server to be running + max_attempts = 20 + attempt = 0 + print("Waiting for StreamableHTTP server to start") + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("127.0.0.1", http_server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError( + f"StreamableHTTP server failed to start after {max_attempts} attempts" + ) + + yield + + print("Killing StreamableHTTP server") + proc.kill() + proc.join(timeout=2) + if proc.is_alive(): + print("StreamableHTTP server process failed to terminate") + + +@pytest.fixture() +def stateless_http_server( + stateless_http_server_port: int, +) -> Generator[None, None, None]: + """Start the stateless StreamableHTTP server in a separate process.""" + proc = multiprocessing.Process( + target=run_stateless_http_server, + args=(stateless_http_server_port,), + daemon=True, + ) + print("Starting stateless StreamableHTTP server process") + proc.start() + + # Wait for server to be running + max_attempts = 20 + attempt = 0 + print("Waiting for stateless StreamableHTTP server to start") + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("127.0.0.1", stateless_http_server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError( + f"Stateless server failed to start after {max_attempts} attempts" + ) + + yield + + print("Killing stateless StreamableHTTP server") + proc.kill() + proc.join(timeout=2) + if proc.is_alive(): + print("Stateless StreamableHTTP server process failed to terminate") + + +@pytest.mark.anyio +async def test_fastmcp_without_auth(server: None, server_url: str) -> None: + """Test that FastMCP works when auth settings are not provided.""" + # Connect to the server + async with sse_client(server_url + "/sse") as streams: + async with ClientSession(*streams) as session: + # Test initialization + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == "NoAuthServer" + + # Test that we can call tools without authentication + tool_result = await session.call_tool("echo", {"message": "hello"}) + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + assert tool_result.content[0].text == "Echo: hello" + + +@pytest.mark.anyio +async def test_fastmcp_streamable_http( + streamable_http_server: None, http_server_url: str +) -> None: + """Test that FastMCP works with StreamableHTTP transport.""" + # Connect to the server using StreamableHTTP + async with streamablehttp_client(http_server_url + "/mcp") as ( + read_stream, + write_stream, + _, + ): + # Create a session using the client streams + async with ClientSession(read_stream, write_stream) as session: + # Test initialization + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == "NoAuthServer" + + # Test that we can call tools without authentication + tool_result = await session.call_tool("echo", {"message": "hello"}) + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + assert tool_result.content[0].text == "Echo: hello" + + +@pytest.mark.anyio +async def test_fastmcp_stateless_streamable_http( + stateless_http_server: None, stateless_http_server_url: str +) -> None: + """Test that FastMCP works with stateless StreamableHTTP transport.""" + # Connect to the server using StreamableHTTP + async with streamablehttp_client(stateless_http_server_url + "/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == "StatelessServer" + tool_result = await session.call_tool("echo", {"message": "hello"}) + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + assert tool_result.content[0].text == "Echo: hello" + + for i in range(3): + tool_result = await session.call_tool("echo", {"message": f"test_{i}"}) + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + assert tool_result.content[0].text == f"Echo: test_{i}" diff --git a/tests/server/fastmcp/test_server.py b/tests/server/fastmcp/test_server.py index e76e59c52..64700d959 100644 --- a/tests/server/fastmcp/test_server.py +++ b/tests/server/fastmcp/test_server.py @@ -1,9 +1,11 @@ import base64 from pathlib import Path from typing import TYPE_CHECKING +from unittest.mock import patch import pytest from pydantic import AnyUrl +from starlette.routing import Mount, Route from mcp.server.fastmcp import Context, FastMCP from mcp.server.fastmcp.prompts.base import EmbeddedResource, Message, UserMessage @@ -31,6 +33,97 @@ async def test_create_server(self): assert mcp.name == "FastMCP" assert mcp.instructions == "Server instructions" + @pytest.mark.anyio + async def test_normalize_path(self): + """Test path normalization for mount paths.""" + mcp = FastMCP() + + # Test root path + assert mcp._normalize_path("/", "/messages/") == "/messages/" + + # Test path with trailing slash + assert mcp._normalize_path("/github/", "/messages/") == "/github/messages/" + + # Test path without trailing slash + assert mcp._normalize_path("/github", "/messages/") == "/github/messages/" + + # Test endpoint without leading slash + assert mcp._normalize_path("/github", "messages/") == "/github/messages/" + + # Test both with trailing/leading slashes + assert mcp._normalize_path("/api/", "/v1/") == "/api/v1/" + + @pytest.mark.anyio + async def test_sse_app_with_mount_path(self): + """Test SSE app creation with different mount paths.""" + # Test with default mount path + mcp = FastMCP() + with patch.object( + mcp, "_normalize_path", return_value="/messages/" + ) as mock_normalize: + mcp.sse_app() + # Verify _normalize_path was called with correct args + mock_normalize.assert_called_once_with("/", "/messages/") + + # Test with custom mount path in settings + mcp = FastMCP() + mcp.settings.mount_path = "/custom" + with patch.object( + mcp, "_normalize_path", return_value="/custom/messages/" + ) as mock_normalize: + mcp.sse_app() + # Verify _normalize_path was called with correct args + mock_normalize.assert_called_once_with("/custom", "/messages/") + + # Test with mount_path parameter + mcp = FastMCP() + with patch.object( + mcp, "_normalize_path", return_value="/param/messages/" + ) as mock_normalize: + mcp.sse_app(mount_path="/param") + # Verify _normalize_path was called with correct args + mock_normalize.assert_called_once_with("/param", "/messages/") + + @pytest.mark.anyio + async def test_starlette_routes_with_mount_path(self): + """Test that Starlette routes are correctly configured with mount path.""" + # Test with mount path in settings + mcp = FastMCP() + mcp.settings.mount_path = "/api" + app = mcp.sse_app() + + # Find routes by type + sse_routes = [r for r in app.routes if isinstance(r, Route)] + mount_routes = [r for r in app.routes if isinstance(r, Mount)] + + # Verify routes exist + assert len(sse_routes) == 1, "Should have one SSE route" + assert len(mount_routes) == 1, "Should have one mount route" + + # Verify path values + assert sse_routes[0].path == "/sse", "SSE route path should be /sse" + assert ( + mount_routes[0].path == "/messages" + ), "Mount route path should be /messages" + + # Test with mount path as parameter + mcp = FastMCP() + app = mcp.sse_app(mount_path="/param") + + # Find routes by type + sse_routes = [r for r in app.routes if isinstance(r, Route)] + mount_routes = [r for r in app.routes if isinstance(r, Mount)] + + # Verify routes exist + assert len(sse_routes) == 1, "Should have one SSE route" + assert len(mount_routes) == 1, "Should have one mount route" + + # Verify path values + assert sse_routes[0].path == "/sse", "SSE route path should be /sse" + assert ( + mount_routes[0].path == "/messages" + ), "Mount route path should be /messages" + @pytest.mark.anyio async def test_non_ascii_description(self): """Test that FastMCP handles non-ASCII characters in descriptions correctly""" @@ -518,8 +611,6 @@ async def async_tool(x: int, ctx: Context) -> str: @pytest.mark.anyio async def test_context_logging(self): - from unittest.mock import patch - import mcp.server.session """Test that context logging methods work.""" @@ -544,14 +635,28 @@ async def logging_tool(msg: str, ctx: Context) -> str: assert mock_log.call_count == 4 mock_log.assert_any_call( - level="debug", data="Debug message", logger=None + level="debug", + data="Debug message", + logger=None, + related_request_id="1", + ) + mock_log.assert_any_call( + level="info", + data="Info message", + logger=None, + related_request_id="1", ) - mock_log.assert_any_call(level="info", data="Info message", logger=None) mock_log.assert_any_call( - level="warning", data="Warning message", logger=None + level="warning", + data="Warning message", + logger=None, + related_request_id="1", ) mock_log.assert_any_call( - level="error", data="Error message", logger=None + level="error", + data="Error message", + logger=None, + related_request_id="1", ) @pytest.mark.anyio diff --git a/tests/server/test_lifespan.py b/tests/server/test_lifespan.py index 309a44b87..a3ff59bc1 100644 --- a/tests/server/test_lifespan.py +++ b/tests/server/test_lifespan.py @@ -10,6 +10,7 @@ from mcp.server.fastmcp import Context, FastMCP from mcp.server.lowlevel.server import NotificationOptions, Server from mcp.server.models import InitializationOptions +from mcp.shared.message import SessionMessage from mcp.types import ( ClientCapabilities, Implementation, @@ -82,41 +83,49 @@ async def run_server(): clientInfo=Implementation(name="test-client", version="0.1.0"), ) await send_stream1.send( - JSONRPCMessage( - root=JSONRPCRequest( - jsonrpc="2.0", - id=1, - method="initialize", - params=TypeAdapter(InitializeRequestParams).dump_python(params), + SessionMessage( + JSONRPCMessage( + root=JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params=TypeAdapter(InitializeRequestParams).dump_python(params), + ) ) ) ) response = await receive_stream2.receive() + response = response.message # Send initialized notification await send_stream1.send( - JSONRPCMessage( - root=JSONRPCNotification( - jsonrpc="2.0", - method="notifications/initialized", + SessionMessage( + JSONRPCMessage( + root=JSONRPCNotification( + jsonrpc="2.0", + method="notifications/initialized", + ) ) ) ) # Call the tool to verify lifespan context await send_stream1.send( - JSONRPCMessage( - root=JSONRPCRequest( - jsonrpc="2.0", - id=2, - method="tools/call", - params={"name": "check_lifespan", "arguments": {}}, + SessionMessage( + JSONRPCMessage( + root=JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="tools/call", + params={"name": "check_lifespan", "arguments": {}}, + ) ) ) ) # Get response and verify response = await receive_stream2.receive() + response = response.message assert response.root.result["content"][0]["text"] == "true" # Cancel server task @@ -178,41 +187,49 @@ async def run_server(): clientInfo=Implementation(name="test-client", version="0.1.0"), ) await send_stream1.send( - JSONRPCMessage( - root=JSONRPCRequest( - jsonrpc="2.0", - id=1, - method="initialize", - params=TypeAdapter(InitializeRequestParams).dump_python(params), + SessionMessage( + JSONRPCMessage( + root=JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params=TypeAdapter(InitializeRequestParams).dump_python(params), + ) ) ) ) response = await receive_stream2.receive() + response = response.message # Send initialized notification await send_stream1.send( - JSONRPCMessage( - root=JSONRPCNotification( - jsonrpc="2.0", - method="notifications/initialized", + SessionMessage( + JSONRPCMessage( + root=JSONRPCNotification( + jsonrpc="2.0", + method="notifications/initialized", + ) ) ) ) # Call the tool to verify lifespan context await send_stream1.send( - JSONRPCMessage( - root=JSONRPCRequest( - jsonrpc="2.0", - id=2, - method="tools/call", - params={"name": "check_lifespan", "arguments": {}}, + SessionMessage( + JSONRPCMessage( + root=JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="tools/call", + params={"name": "check_lifespan", "arguments": {}}, + ) ) ) ) # Get response and verify response = await receive_stream2.receive() + response = response.message assert response.root.result["content"][0]["text"] == "true" # Cancel server task diff --git a/tests/server/test_lowlevel_tool_annotations.py b/tests/server/test_lowlevel_tool_annotations.py index 47d03ad23..e9eff9ed0 100644 --- a/tests/server/test_lowlevel_tool_annotations.py +++ b/tests/server/test_lowlevel_tool_annotations.py @@ -8,10 +8,10 @@ from mcp.server.lowlevel import NotificationOptions from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession +from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder from mcp.types import ( ClientResult, - JSONRPCMessage, ServerNotification, ServerRequest, Tool, @@ -46,10 +46,10 @@ async def list_tools(): ] server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - JSONRPCMessage + SessionMessage ](10) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - JSONRPCMessage + SessionMessage ](10) # Message handler for client diff --git a/tests/server/test_session.py b/tests/server/test_session.py index 561a94b64..f2f033588 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -7,11 +7,11 @@ from mcp.server.lowlevel import NotificationOptions from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession +from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder from mcp.types import ( ClientNotification, InitializedNotification, - JSONRPCMessage, PromptsCapability, ResourcesCapability, ServerCapabilities, @@ -21,10 +21,10 @@ @pytest.mark.anyio async def test_server_session_initialize(): server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - JSONRPCMessage + SessionMessage ](1) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - JSONRPCMessage + SessionMessage ](1) # Create a message handler to catch exceptions diff --git a/tests/server/test_stdio.py b/tests/server/test_stdio.py index 85c5bf219..c546a7167 100644 --- a/tests/server/test_stdio.py +++ b/tests/server/test_stdio.py @@ -4,6 +4,7 @@ import pytest from mcp.server.stdio import stdio_server +from mcp.shared.message import SessionMessage from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse @@ -29,7 +30,7 @@ async def test_stdio_server(): async for message in read_stream: if isinstance(message, Exception): raise message - received_messages.append(message) + received_messages.append(message.message) if len(received_messages) == 2: break @@ -50,7 +51,8 @@ async def test_stdio_server(): async with write_stream: for response in responses: - await write_stream.send(response) + session_message = SessionMessage(response) + await write_stream.send(session_message) stdout.seek(0) output_lines = stdout.readlines() diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py new file mode 100644 index 000000000..32782e458 --- /dev/null +++ b/tests/server/test_streamable_http_manager.py @@ -0,0 +1,81 @@ +"""Tests for StreamableHTTPSessionManager.""" + +import anyio +import pytest + +from mcp.server.lowlevel import Server +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager + + +@pytest.mark.anyio +async def test_run_can_only_be_called_once(): + """Test that run() can only be called once per instance.""" + app = Server("test-server") + manager = StreamableHTTPSessionManager(app=app) + + # First call should succeed + async with manager.run(): + pass + + # Second call should raise RuntimeError + with pytest.raises(RuntimeError) as excinfo: + async with manager.run(): + pass + + assert ( + "StreamableHTTPSessionManager .run() can only be called once per instance" + in str(excinfo.value) + ) + + +@pytest.mark.anyio +async def test_run_prevents_concurrent_calls(): + """Test that concurrent calls to run() are prevented.""" + app = Server("test-server") + manager = StreamableHTTPSessionManager(app=app) + + errors = [] + + async def try_run(): + try: + async with manager.run(): + # Simulate some work + await anyio.sleep(0.1) + except RuntimeError as e: + errors.append(e) + + # Try to run concurrently + async with anyio.create_task_group() as tg: + tg.start_soon(try_run) + tg.start_soon(try_run) + + # One should succeed, one should fail + assert len(errors) == 1 + assert ( + "StreamableHTTPSessionManager .run() can only be called once per instance" + in str(errors[0]) + ) + + +@pytest.mark.anyio +async def test_handle_request_without_run_raises_error(): + """Test that handle_request raises error if run() hasn't been called.""" + app = Server("test-server") + manager = StreamableHTTPSessionManager(app=app) + + # Mock ASGI parameters + scope = {"type": "http", "method": "POST", "path": "/test"} + + async def receive(): + return {"type": "http.request", "body": b""} + + async def send(message): + pass + + # Should raise error because run() hasn't been called + with pytest.raises(RuntimeError) as excinfo: + await manager.handle_request(scope, receive, send) + + assert "Task group is not initialized. Make sure to use run()." in str( + excinfo.value + ) diff --git a/tests/shared/test_httpx_utils.py b/tests/shared/test_httpx_utils.py new file mode 100644 index 000000000..dcc6fd003 --- /dev/null +++ b/tests/shared/test_httpx_utils.py @@ -0,0 +1,24 @@ +"""Tests for httpx utility functions.""" + +import httpx + +from mcp.shared._httpx_utils import create_mcp_http_client + + +def test_default_settings(): + """Test that default settings are applied correctly.""" + client = create_mcp_http_client() + + assert client.follow_redirects is True + assert client.timeout.connect == 30.0 + + +def test_custom_parameters(): + """Test custom headers and timeout are set correctly.""" + headers = {"Authorization": "Bearer token"} + timeout = httpx.Timeout(60.0) + + client = create_mcp_http_client(headers, timeout) + + assert client.headers["Authorization"] == "Bearer token" + assert client.timeout.connect == 60.0 diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index f5158c3c3..4558bb88c 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -10,6 +10,7 @@ from pydantic import AnyUrl from starlette.applications import Starlette from starlette.requests import Request +from starlette.responses import Response from starlette.routing import Mount, Route from mcp.client.session import ClientSession @@ -83,13 +84,14 @@ def make_server_app() -> Starlette: sse = SseServerTransport("/messages/") server = ServerTest() - async def handle_sse(request: Request) -> None: + async def handle_sse(request: Request) -> Response: async with sse.connect_sse( request.scope, request.receive, request._send ) as streams: await server.run( streams[0], streams[1], server.create_initialization_options() ) + return Response() app = Starlette( routes=[ diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py new file mode 100644 index 000000000..28d29ac23 --- /dev/null +++ b/tests/shared/test_streamable_http.py @@ -0,0 +1,1056 @@ +""" +Tests for the StreamableHTTP server and client transport. + +Contains tests for both server and client sides of the StreamableHTTP transport. +""" + +import multiprocessing +import socket +import time +from collections.abc import Generator + +import anyio +import httpx +import pytest +import requests +import uvicorn +from pydantic import AnyUrl +from starlette.applications import Starlette +from starlette.routing import Mount + +import mcp.types as types +from mcp.client.session import ClientSession +from mcp.client.streamable_http import streamablehttp_client +from mcp.server import Server +from mcp.server.streamable_http import ( + MCP_SESSION_ID_HEADER, + SESSION_ID_PATTERN, + EventCallback, + EventId, + EventMessage, + EventStore, + StreamableHTTPServerTransport, + StreamId, +) +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from mcp.shared.exceptions import McpError +from mcp.shared.message import ( + ClientMessageMetadata, +) +from mcp.shared.session import RequestResponder +from mcp.types import ( + InitializeResult, + TextContent, + TextResourceContents, + Tool, +) + +# Test constants +SERVER_NAME = "test_streamable_http_server" +TEST_SESSION_ID = "test-session-id-12345" +INIT_REQUEST = { + "jsonrpc": "2.0", + "method": "initialize", + "params": { + "clientInfo": {"name": "test-client", "version": "1.0"}, + "protocolVersion": "2025-03-26", + "capabilities": {}, + }, + "id": "init-1", +} + + +# Simple in-memory event store for testing +class SimpleEventStore(EventStore): + """Simple in-memory event store for testing.""" + + def __init__(self): + self._events: list[tuple[StreamId, EventId, types.JSONRPCMessage]] = [] + self._event_id_counter = 0 + + async def store_event( + self, stream_id: StreamId, message: types.JSONRPCMessage + ) -> EventId: + """Store an event and return its ID.""" + self._event_id_counter += 1 + event_id = str(self._event_id_counter) + self._events.append((stream_id, event_id, message)) + return event_id + + async def replay_events_after( + self, + last_event_id: EventId, + send_callback: EventCallback, + ) -> StreamId | None: + """Replay events after the specified ID.""" + # Find the index of the last event ID + start_index = None + for i, (_, event_id, _) in enumerate(self._events): + if event_id == last_event_id: + start_index = i + 1 + break + + if start_index is None: + # If event ID not found, start from beginning + start_index = 0 + + stream_id = None + # Replay events + for _, event_id, message in self._events[start_index:]: + await send_callback(EventMessage(message, event_id)) + # Capture the stream ID from the first replayed event + if stream_id is None and len(self._events) > start_index: + stream_id = self._events[start_index][0] + + return stream_id + + +# Test server implementation that follows MCP protocol +class ServerTest(Server): + def __init__(self): + super().__init__(SERVER_NAME) + + @self.read_resource() + async def handle_read_resource(uri: AnyUrl) -> str | bytes: + if uri.scheme == "foobar": + return f"Read {uri.host}" + elif uri.scheme == "slow": + # Simulate a slow resource + await anyio.sleep(2.0) + return f"Slow response from {uri.host}" + + raise ValueError(f"Unknown resource: {uri}") + + @self.list_tools() + async def handle_list_tools() -> list[Tool]: + return [ + Tool( + name="test_tool", + description="A test tool", + inputSchema={"type": "object", "properties": {}}, + ), + Tool( + name="test_tool_with_standalone_notification", + description="A test tool that sends a notification", + inputSchema={"type": "object", "properties": {}}, + ), + Tool( + name="long_running_with_checkpoints", + description="A long-running tool that sends periodic notifications", + inputSchema={"type": "object", "properties": {}}, + ), + ] + + @self.call_tool() + async def handle_call_tool(name: str, args: dict) -> list[TextContent]: + ctx = self.request_context + + # When the tool is called, send a notification to test GET stream + if name == "test_tool_with_standalone_notification": + await ctx.session.send_resource_updated( + uri=AnyUrl("http://test_resource") + ) + return [TextContent(type="text", text=f"Called {name}")] + + elif name == "long_running_with_checkpoints": + # Send notifications that are part of the response stream + # This simulates a long-running tool that sends logs + + await ctx.session.send_log_message( + level="info", + data="Tool started", + logger="tool", + related_request_id=ctx.request_id, # need for stream association + ) + + await anyio.sleep(0.1) + + await ctx.session.send_log_message( + level="info", + data="Tool is almost done", + logger="tool", + related_request_id=ctx.request_id, + ) + + return [TextContent(type="text", text="Completed!")] + + return [TextContent(type="text", text=f"Called {name}")] + + +def create_app( + is_json_response_enabled=False, event_store: EventStore | None = None +) -> Starlette: + """Create a Starlette application for testing using the session manager. + + Args: + is_json_response_enabled: If True, use JSON responses instead of SSE streams. + event_store: Optional event store for testing resumability. + """ + # Create server instance + server = ServerTest() + + # Create the session manager + session_manager = StreamableHTTPSessionManager( + app=server, + event_store=event_store, + json_response=is_json_response_enabled, + ) + + # Create an ASGI application that uses the session manager + app = Starlette( + debug=True, + routes=[ + Mount("/mcp", app=session_manager.handle_request), + ], + lifespan=lambda app: session_manager.run(), + ) + + return app + + +def run_server( + port: int, is_json_response_enabled=False, event_store: EventStore | None = None +) -> None: + """Run the test server. + + Args: + port: Port to listen on. + is_json_response_enabled: If True, use JSON responses instead of SSE streams. + event_store: Optional event store for testing resumability. + """ + + app = create_app(is_json_response_enabled, event_store) + # Configure server + config = uvicorn.Config( + app=app, + host="127.0.0.1", + port=port, + log_level="info", + limit_concurrency=10, + timeout_keep_alive=5, + access_log=False, + ) + + # Start the server + server = uvicorn.Server(config=config) + + # This is important to catch exceptions and prevent test hangs + try: + server.run() + except Exception: + import traceback + + traceback.print_exc() + + +# Test fixtures - using same approach as SSE tests +@pytest.fixture +def basic_server_port() -> int: + """Find an available port for the basic server.""" + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@pytest.fixture +def json_server_port() -> int: + """Find an available port for the JSON response server.""" + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@pytest.fixture +def basic_server(basic_server_port: int) -> Generator[None, None, None]: + """Start a basic server.""" + proc = multiprocessing.Process( + target=run_server, kwargs={"port": basic_server_port}, daemon=True + ) + proc.start() + + # Wait for server to be running + max_attempts = 20 + attempt = 0 + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("127.0.0.1", basic_server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError(f"Server failed to start after {max_attempts} attempts") + + yield + + # Clean up + proc.kill() + proc.join(timeout=2) + + +@pytest.fixture +def event_store() -> SimpleEventStore: + """Create a test event store.""" + return SimpleEventStore() + + +@pytest.fixture +def event_server_port() -> int: + """Find an available port for the event store server.""" + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@pytest.fixture +def event_server( + event_server_port: int, event_store: SimpleEventStore +) -> Generator[tuple[SimpleEventStore, str], None, None]: + """Start a server with event store enabled.""" + proc = multiprocessing.Process( + target=run_server, + kwargs={"port": event_server_port, "event_store": event_store}, + daemon=True, + ) + proc.start() + + # Wait for server to be running + max_attempts = 20 + attempt = 0 + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("127.0.0.1", event_server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError(f"Server failed to start after {max_attempts} attempts") + + yield event_store, f"http://127.0.0.1:{event_server_port}" + + # Clean up + proc.kill() + proc.join(timeout=2) + + +@pytest.fixture +def json_response_server(json_server_port: int) -> Generator[None, None, None]: + """Start a server with JSON response enabled.""" + proc = multiprocessing.Process( + target=run_server, + kwargs={"port": json_server_port, "is_json_response_enabled": True}, + daemon=True, + ) + proc.start() + + # Wait for server to be running + max_attempts = 20 + attempt = 0 + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("127.0.0.1", json_server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError(f"Server failed to start after {max_attempts} attempts") + + yield + + # Clean up + proc.kill() + proc.join(timeout=2) + + +@pytest.fixture +def basic_server_url(basic_server_port: int) -> str: + """Get the URL for the basic test server.""" + return f"http://127.0.0.1:{basic_server_port}" + + +@pytest.fixture +def json_server_url(json_server_port: int) -> str: + """Get the URL for the JSON response test server.""" + return f"http://127.0.0.1:{json_server_port}" + + +# Basic request validation tests +def test_accept_header_validation(basic_server, basic_server_url): + """Test that Accept header is properly validated.""" + # Test without Accept header + response = requests.post( + f"{basic_server_url}/mcp", + headers={"Content-Type": "application/json"}, + json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text + + +def test_content_type_validation(basic_server, basic_server_url): + """Test that Content-Type header is properly validated.""" + # Test with incorrect Content-Type + response = requests.post( + f"{basic_server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "text/plain", + }, + data="This is not JSON", + ) + assert response.status_code == 415 + assert "Unsupported Media Type" in response.text + + +def test_json_validation(basic_server, basic_server_url): + """Test that JSON content is properly validated.""" + # Test with invalid JSON + response = requests.post( + f"{basic_server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + data="this is not valid json", + ) + assert response.status_code == 400 + assert "Parse error" in response.text + + +def test_json_parsing(basic_server, basic_server_url): + """Test that JSON content is properly parse.""" + # Test with valid JSON but invalid JSON-RPC + response = requests.post( + f"{basic_server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json={"foo": "bar"}, + ) + assert response.status_code == 400 + assert "Validation error" in response.text + + +def test_method_not_allowed(basic_server, basic_server_url): + """Test that unsupported HTTP methods are rejected.""" + # Test with unsupported method (PUT) + response = requests.put( + f"{basic_server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, + ) + assert response.status_code == 405 + assert "Method Not Allowed" in response.text + + +def test_session_validation(basic_server, basic_server_url): + """Test session ID validation.""" + # session_id not used directly in this test + + # Test without session ID + response = requests.post( + f"{basic_server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json={"jsonrpc": "2.0", "method": "list_tools", "id": 1}, + ) + assert response.status_code == 400 + assert "Missing session ID" in response.text + + +def test_session_id_pattern(): + """Test that SESSION_ID_PATTERN correctly validates session IDs.""" + # Valid session IDs (visible ASCII characters from 0x21 to 0x7E) + valid_session_ids = [ + "test-session-id", + "1234567890", + "session!@#$%^&*()_+-=[]{}|;:,.<>?/", + "~`", + ] + + for session_id in valid_session_ids: + assert SESSION_ID_PATTERN.match(session_id) is not None + # Ensure fullmatch matches too (whole string) + assert SESSION_ID_PATTERN.fullmatch(session_id) is not None + + # Invalid session IDs + invalid_session_ids = [ + "", # Empty string + " test", # Space (0x20) + "test\t", # Tab + "test\n", # Newline + "test\r", # Carriage return + "test" + chr(0x7F), # DEL character + "test" + chr(0x80), # Extended ASCII + "test" + chr(0x00), # Null character + "test" + chr(0x20), # Space (0x20) + ] + + for session_id in invalid_session_ids: + # For invalid IDs, either match will fail or fullmatch will fail + if SESSION_ID_PATTERN.match(session_id) is not None: + # If match succeeds, fullmatch should fail (partial match case) + assert SESSION_ID_PATTERN.fullmatch(session_id) is None + + +def test_streamable_http_transport_init_validation(): + """Test that StreamableHTTPServerTransport validates session ID on init.""" + # Valid session ID should initialize without errors + valid_transport = StreamableHTTPServerTransport(mcp_session_id="valid-id") + assert valid_transport.mcp_session_id == "valid-id" + + # None should be accepted + none_transport = StreamableHTTPServerTransport(mcp_session_id=None) + assert none_transport.mcp_session_id is None + + # Invalid session ID should raise ValueError + with pytest.raises(ValueError) as excinfo: + StreamableHTTPServerTransport(mcp_session_id="invalid id with space") + assert "Session ID must only contain visible ASCII characters" in str(excinfo.value) + + # Test with control characters + with pytest.raises(ValueError): + StreamableHTTPServerTransport(mcp_session_id="test\nid") + + with pytest.raises(ValueError): + StreamableHTTPServerTransport(mcp_session_id="test\n") + + +def test_session_termination(basic_server, basic_server_url): + """Test session termination via DELETE and subsequent request handling.""" + response = requests.post( + f"{basic_server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + + # Now terminate the session + session_id = response.headers.get(MCP_SESSION_ID_HEADER) + response = requests.delete( + f"{basic_server_url}/mcp", + headers={MCP_SESSION_ID_HEADER: session_id}, + ) + assert response.status_code == 200 + + # Try to use the terminated session + response = requests.post( + f"{basic_server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, + }, + json={"jsonrpc": "2.0", "method": "ping", "id": 2}, + ) + assert response.status_code == 404 + assert "Session has been terminated" in response.text + + +def test_response(basic_server, basic_server_url): + """Test response handling for a valid request.""" + mcp_url = f"{basic_server_url}/mcp" + response = requests.post( + mcp_url, + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + + # Now terminate the session + session_id = response.headers.get(MCP_SESSION_ID_HEADER) + + # Try to use the terminated session + tools_response = requests.post( + mcp_url, + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, # Use the session ID we got earlier + }, + json={"jsonrpc": "2.0", "method": "tools/list", "id": "tools-1"}, + stream=True, + ) + assert tools_response.status_code == 200 + assert tools_response.headers.get("Content-Type") == "text/event-stream" + + +def test_json_response(json_response_server, json_server_url): + """Test response handling when is_json_response_enabled is True.""" + mcp_url = f"{json_server_url}/mcp" + response = requests.post( + mcp_url, + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + assert response.headers.get("Content-Type") == "application/json" + + +def test_get_sse_stream(basic_server, basic_server_url): + """Test establishing an SSE stream via GET request.""" + # First, we need to initialize a session + mcp_url = f"{basic_server_url}/mcp" + init_response = requests.post( + mcp_url, + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert init_response.status_code == 200 + + # Get the session ID + session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) + assert session_id is not None + + # Now attempt to establish an SSE stream via GET + get_response = requests.get( + mcp_url, + headers={ + "Accept": "text/event-stream", + MCP_SESSION_ID_HEADER: session_id, + }, + stream=True, + ) + + # Verify we got a successful response with the right content type + assert get_response.status_code == 200 + assert get_response.headers.get("Content-Type") == "text/event-stream" + + # Test that a second GET request gets rejected (only one stream allowed) + second_get = requests.get( + mcp_url, + headers={ + "Accept": "text/event-stream", + MCP_SESSION_ID_HEADER: session_id, + }, + stream=True, + ) + + # Should get CONFLICT (409) since there's already a stream + # Note: This might fail if the first stream fully closed before this runs, + # but generally it should work in the test environment where it runs quickly + assert second_get.status_code == 409 + + +def test_get_validation(basic_server, basic_server_url): + """Test validation for GET requests.""" + # First, we need to initialize a session + mcp_url = f"{basic_server_url}/mcp" + init_response = requests.post( + mcp_url, + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert init_response.status_code == 200 + + # Get the session ID + session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) + assert session_id is not None + + # Test without Accept header + response = requests.get( + mcp_url, + headers={ + MCP_SESSION_ID_HEADER: session_id, + }, + stream=True, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text + + # Test with wrong Accept header + response = requests.get( + mcp_url, + headers={ + "Accept": "application/json", + MCP_SESSION_ID_HEADER: session_id, + }, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text + + +# Client-specific fixtures +@pytest.fixture +async def http_client(basic_server, basic_server_url): + """Create test client matching the SSE test pattern.""" + async with httpx.AsyncClient(base_url=basic_server_url) as client: + yield client + + +@pytest.fixture +async def initialized_client_session(basic_server, basic_server_url): + """Create initialized StreamableHTTP client session.""" + async with streamablehttp_client(f"{basic_server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession( + read_stream, + write_stream, + ) as session: + await session.initialize() + yield session + + +@pytest.mark.anyio +async def test_streamablehttp_client_basic_connection(basic_server, basic_server_url): + """Test basic client connection with initialization.""" + async with streamablehttp_client(f"{basic_server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession( + read_stream, + write_stream, + ) as session: + # Test initialization + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == SERVER_NAME + + +@pytest.mark.anyio +async def test_streamablehttp_client_resource_read(initialized_client_session): + """Test client resource read functionality.""" + response = await initialized_client_session.read_resource( + uri=AnyUrl("foobar://test-resource") + ) + assert len(response.contents) == 1 + assert response.contents[0].uri == AnyUrl("foobar://test-resource") + assert response.contents[0].text == "Read test-resource" + + +@pytest.mark.anyio +async def test_streamablehttp_client_tool_invocation(initialized_client_session): + """Test client tool invocation.""" + # First list tools + tools = await initialized_client_session.list_tools() + assert len(tools.tools) == 3 + assert tools.tools[0].name == "test_tool" + + # Call the tool + result = await initialized_client_session.call_tool("test_tool", {}) + assert len(result.content) == 1 + assert result.content[0].type == "text" + assert result.content[0].text == "Called test_tool" + + +@pytest.mark.anyio +async def test_streamablehttp_client_error_handling(initialized_client_session): + """Test error handling in client.""" + with pytest.raises(McpError) as exc_info: + await initialized_client_session.read_resource( + uri=AnyUrl("unknown://test-error") + ) + assert exc_info.value.error.code == 0 + assert "Unknown resource: unknown://test-error" in exc_info.value.error.message + + +@pytest.mark.anyio +async def test_streamablehttp_client_session_persistence( + basic_server, basic_server_url +): + """Test that session ID persists across requests.""" + async with streamablehttp_client(f"{basic_server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession( + read_stream, + write_stream, + ) as session: + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + + # Make multiple requests to verify session persistence + tools = await session.list_tools() + assert len(tools.tools) == 3 + + # Read a resource + resource = await session.read_resource(uri=AnyUrl("foobar://test-persist")) + assert isinstance(resource.contents[0], TextResourceContents) is True + content = resource.contents[0] + assert isinstance(content, TextResourceContents) + assert content.text == "Read test-persist" + + +@pytest.mark.anyio +async def test_streamablehttp_client_json_response( + json_response_server, json_server_url +): + """Test client with JSON response mode.""" + async with streamablehttp_client(f"{json_server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession( + read_stream, + write_stream, + ) as session: + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == SERVER_NAME + + # Check tool listing + tools = await session.list_tools() + assert len(tools.tools) == 3 + + # Call a tool and verify JSON response handling + result = await session.call_tool("test_tool", {}) + assert len(result.content) == 1 + assert result.content[0].type == "text" + assert result.content[0].text == "Called test_tool" + + +@pytest.mark.anyio +async def test_streamablehttp_client_get_stream(basic_server, basic_server_url): + """Test GET stream functionality for server-initiated messages.""" + import mcp.types as types + from mcp.shared.session import RequestResponder + + notifications_received = [] + + # Define message handler to capture notifications + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] + | types.ServerNotification + | Exception, + ) -> None: + if isinstance(message, types.ServerNotification): + notifications_received.append(message) + + async with streamablehttp_client(f"{basic_server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession( + read_stream, write_stream, message_handler=message_handler + ) as session: + # Initialize the session - this triggers the GET stream setup + result = await session.initialize() + assert isinstance(result, InitializeResult) + + # Call the special tool that sends a notification + await session.call_tool("test_tool_with_standalone_notification", {}) + + # Verify we received the notification + assert len(notifications_received) > 0 + + # Verify the notification is a ResourceUpdatedNotification + resource_update_found = False + for notif in notifications_received: + if isinstance(notif.root, types.ResourceUpdatedNotification): + assert str(notif.root.params.uri) == "http://test_resource/" + resource_update_found = True + + assert ( + resource_update_found + ), "ResourceUpdatedNotification not received via GET stream" + + +@pytest.mark.anyio +async def test_streamablehttp_client_session_termination( + basic_server, basic_server_url +): + """Test client session termination functionality.""" + + captured_session_id = None + + # Create the streamablehttp_client with a custom httpx client to capture headers + async with streamablehttp_client(f"{basic_server_url}/mcp") as ( + read_stream, + write_stream, + get_session_id, + ): + async with ClientSession(read_stream, write_stream) as session: + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + captured_session_id = get_session_id() + assert captured_session_id is not None + + # Make a request to confirm session is working + tools = await session.list_tools() + assert len(tools.tools) == 3 + + headers = {} + if captured_session_id: + headers[MCP_SESSION_ID_HEADER] = captured_session_id + + async with streamablehttp_client(f"{basic_server_url}/mcp", headers=headers) as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: + # Attempt to make a request after termination + with pytest.raises( + McpError, + match="Session terminated", + ): + await session.list_tools() + + +@pytest.mark.anyio +async def test_streamablehttp_client_resumption(event_server): + """Test client session to resume a long running tool.""" + _, server_url = event_server + + # Variables to track the state + captured_session_id = None + captured_resumption_token = None + captured_notifications = [] + tool_started = False + + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] + | types.ServerNotification + | Exception, + ) -> None: + if isinstance(message, types.ServerNotification): + captured_notifications.append(message) + # Look for our special notification that indicates the tool is running + if isinstance(message.root, types.LoggingMessageNotification): + if message.root.params.data == "Tool started": + nonlocal tool_started + tool_started = True + + async def on_resumption_token_update(token: str) -> None: + nonlocal captured_resumption_token + captured_resumption_token = token + + # First, start the client session and begin the long-running tool + async with streamablehttp_client(f"{server_url}/mcp", terminate_on_close=False) as ( + read_stream, + write_stream, + get_session_id, + ): + async with ClientSession( + read_stream, write_stream, message_handler=message_handler + ) as session: + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + captured_session_id = get_session_id() + assert captured_session_id is not None + + # Start a long-running tool in a task + async with anyio.create_task_group() as tg: + + async def run_tool(): + metadata = ClientMessageMetadata( + on_resumption_token_update=on_resumption_token_update, + ) + await session.send_request( + types.ClientRequest( + types.CallToolRequest( + method="tools/call", + params=types.CallToolRequestParams( + name="long_running_with_checkpoints", arguments={} + ), + ) + ), + types.CallToolResult, + metadata=metadata, + ) + + tg.start_soon(run_tool) + + # Wait for the tool to start and at least one notification + # and then kill the task group + while not tool_started or not captured_resumption_token: + await anyio.sleep(0.1) + tg.cancel_scope.cancel() + + # Store pre notifications and clear the captured notifications + # for the post-resumption check + captured_notifications_pre = captured_notifications.copy() + captured_notifications = [] + + # Now resume the session with the same mcp-session-id + headers = {} + if captured_session_id: + headers[MCP_SESSION_ID_HEADER] = captured_session_id + + async with streamablehttp_client(f"{server_url}/mcp", headers=headers) as ( + read_stream, + write_stream, + _, + ): + async with ClientSession( + read_stream, write_stream, message_handler=message_handler + ) as session: + # Don't initialize - just use the existing session + + # Resume the tool with the resumption token + assert captured_resumption_token is not None + + metadata = ClientMessageMetadata( + resumption_token=captured_resumption_token, + ) + result = await session.send_request( + types.ClientRequest( + types.CallToolRequest( + method="tools/call", + params=types.CallToolRequestParams( + name="long_running_with_checkpoints", arguments={} + ), + ) + ), + types.CallToolResult, + metadata=metadata, + ) + + # We should get a complete result + assert len(result.content) == 1 + assert result.content[0].type == "text" + assert "Completed" in result.content[0].text + + # We should have received the remaining notifications + assert len(captured_notifications) > 0 + + # Should not have the first notification + # Check that "Tool started" notification isn't repeated when resuming + assert not any( + isinstance(n.root, types.LoggingMessageNotification) + and n.root.params.data == "Tool started" + for n in captured_notifications + ) + # there is no intersection between pre and post notifications + assert not any( + n in captured_notifications_pre for n in captured_notifications + ) diff --git a/uv.lock b/uv.lock index fdb788a79..88869fa50 100644 --- a/uv.lock +++ b/uv.lock @@ -8,8 +8,11 @@ resolution-mode = "lowest-direct" [manifest] members = [ "mcp", + "mcp-simple-auth", "mcp-simple-prompt", "mcp-simple-resource", + "mcp-simple-streamablehttp", + "mcp-simple-streamablehttp-stateless", "mcp-simple-tool", ] @@ -566,6 +569,47 @@ docs = [ { name = "mkdocstrings-python", specifier = ">=1.12.2" }, ] +[[package]] +name = "mcp-simple-auth" +version = "0.1.0" +source = { editable = "examples/servers/simple-auth" } +dependencies = [ + { name = "anyio" }, + { name = "click" }, + { name = "httpx" }, + { name = "mcp" }, + { name = "pydantic" }, + { name = "pydantic-settings" }, + { name = "sse-starlette" }, + { name = "uvicorn", marker = "sys_platform != 'emscripten'" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pyright" }, + { name = "pytest" }, + { name = "ruff" }, +] + +[package.metadata] +requires-dist = [ + { name = "anyio", specifier = ">=4.5" }, + { name = "click", specifier = ">=8.1.0" }, + { name = "httpx", specifier = ">=0.27" }, + { name = "mcp", editable = "." }, + { name = "pydantic", specifier = ">=2.0" }, + { name = "pydantic-settings", specifier = ">=2.5.2" }, + { name = "sse-starlette", specifier = ">=1.6.1" }, + { name = "uvicorn", marker = "sys_platform != 'emscripten'", specifier = ">=0.23.1" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pyright", specifier = ">=1.1.391" }, + { name = "pytest", specifier = ">=8.3.4" }, + { name = "ruff", specifier = ">=0.8.5" }, +] + [[package]] name = "mcp-simple-prompt" version = "0.1.0" @@ -632,6 +676,80 @@ dev = [ { name = "ruff", specifier = ">=0.6.9" }, ] +[[package]] +name = "mcp-simple-streamablehttp" +version = "0.1.0" +source = { editable = "examples/servers/simple-streamablehttp" } +dependencies = [ + { name = "anyio" }, + { name = "click" }, + { name = "httpx" }, + { name = "mcp" }, + { name = "starlette" }, + { name = "uvicorn" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pyright" }, + { name = "pytest" }, + { name = "ruff" }, +] + +[package.metadata] +requires-dist = [ + { name = "anyio", specifier = ">=4.5" }, + { name = "click", specifier = ">=8.1.0" }, + { name = "httpx", specifier = ">=0.27" }, + { name = "mcp", editable = "." }, + { name = "starlette" }, + { name = "uvicorn" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pyright", specifier = ">=1.1.378" }, + { name = "pytest", specifier = ">=8.3.3" }, + { name = "ruff", specifier = ">=0.6.9" }, +] + +[[package]] +name = "mcp-simple-streamablehttp-stateless" +version = "0.1.0" +source = { editable = "examples/servers/simple-streamablehttp-stateless" } +dependencies = [ + { name = "anyio" }, + { name = "click" }, + { name = "httpx" }, + { name = "mcp" }, + { name = "starlette" }, + { name = "uvicorn" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pyright" }, + { name = "pytest" }, + { name = "ruff" }, +] + +[package.metadata] +requires-dist = [ + { name = "anyio", specifier = ">=4.5" }, + { name = "click", specifier = ">=8.1.0" }, + { name = "httpx", specifier = ">=0.27" }, + { name = "mcp", editable = "." }, + { name = "starlette" }, + { name = "uvicorn" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pyright", specifier = ">=1.1.378" }, + { name = "pytest", specifier = ">=8.3.3" }, + { name = "ruff", specifier = ">=0.6.9" }, +] + [[package]] name = "mcp-simple-tool" version = "0.1.0"