From 78f0b11a09755d86c7840e93b94b502de7592c4b Mon Sep 17 00:00:00 2001 From: ihrpr Date: Fri, 2 May 2025 11:58:54 +0100 Subject: [PATCH 01/21] StreamableHttp - Server transport with state management (#553) --- .../servers/simple-streamablehttp/README.md | 37 + .../mcp_simple_streamablehttp/__init__.py | 0 .../mcp_simple_streamablehttp/__main__.py | 4 + .../mcp_simple_streamablehttp/server.py | 201 ++++++ .../simple-streamablehttp/pyproject.toml | 36 + src/mcp/server/fastmcp/server.py | 5 +- src/mcp/server/session.py | 18 +- src/mcp/server/streamableHttp.py | 644 ++++++++++++++++++ src/mcp/shared/session.py | 22 +- tests/client/test_logging_callback.py | 12 +- tests/issues/test_188_concurrency.py | 2 +- tests/server/fastmcp/test_server.py | 22 +- tests/server/test_streamableHttp.py | 543 +++++++++++++++ uv.lock | 38 ++ 14 files changed, 1568 insertions(+), 16 deletions(-) create mode 100644 examples/servers/simple-streamablehttp/README.md create mode 100644 examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__init__.py create mode 100644 examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__main__.py create mode 100644 examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py create mode 100644 examples/servers/simple-streamablehttp/pyproject.toml create mode 100644 src/mcp/server/streamableHttp.py create mode 100644 tests/server/test_streamableHttp.py diff --git a/examples/servers/simple-streamablehttp/README.md b/examples/servers/simple-streamablehttp/README.md new file mode 100644 index 0000000000..e5aaa65250 --- /dev/null +++ b/examples/servers/simple-streamablehttp/README.md @@ -0,0 +1,37 @@ +# MCP Simple StreamableHttp Server Example + +A simple MCP server example demonstrating the StreamableHttp transport, which enables HTTP-based communication with MCP servers using streaming. + +## Features + +- Uses the StreamableHTTP transport for server-client communication +- Supports REST API operations (POST, GET, DELETE) for `/mcp` endpoint +- Task management with anyio task groups +- Ability to send multiple notifications over time to the client +- Proper resource cleanup and lifespan management + +## Usage + +Start the server on the default or custom port: + +```bash + +# Using custom port +uv run mcp-simple-streamablehttp --port 3000 + +# Custom logging level +uv run mcp-simple-streamablehttp --log-level DEBUG + +# Enable JSON responses instead of SSE streams +uv run mcp-simple-streamablehttp --json-response +``` + +The server exposes a tool named "start-notification-stream" that accepts three arguments: + +- `interval`: Time between notifications in seconds (e.g., 1.0) +- `count`: Number of notifications to send (e.g., 5) +- `caller`: Identifier string for the caller + +## Client + +You can connect to this server using an HTTP client, for now only Typescript SDK has streamable HTTP client examples or you can use (Inspector)[https://github.com/modelcontextprotocol/inspector] \ No newline at end of file diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__init__.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__main__.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__main__.py new file mode 100644 index 0000000000..f5f6e402df --- /dev/null +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__main__.py @@ -0,0 +1,4 @@ +from .server import main + +if __name__ == "__main__": + main() diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py new file mode 100644 index 0000000000..e7bc44306a --- /dev/null +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -0,0 +1,201 @@ +import contextlib +import logging +from http import HTTPStatus +from uuid import uuid4 + +import anyio +import click +import mcp.types as types +from mcp.server.lowlevel import Server +from mcp.server.streamableHttp import ( + MCP_SESSION_ID_HEADER, + StreamableHTTPServerTransport, +) +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import Response +from starlette.routing import Mount + +# Configure logging +logger = logging.getLogger(__name__) + +# Global task group that will be initialized in the lifespan +task_group = None + + +@contextlib.asynccontextmanager +async def lifespan(app): + """Application lifespan context manager for managing task group.""" + global task_group + + async with anyio.create_task_group() as tg: + task_group = tg + logger.info("Application started, task group initialized!") + try: + yield + finally: + logger.info("Application shutting down, cleaning up resources...") + if task_group: + tg.cancel_scope.cancel() + task_group = None + logger.info("Resources cleaned up successfully.") + + +@click.command() +@click.option("--port", default=3000, help="Port to listen on for HTTP") +@click.option( + "--log-level", + default="INFO", + help="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)", +) +@click.option( + "--json-response", + is_flag=True, + default=False, + help="Enable JSON responses instead of SSE streams", +) +def main( + port: int, + log_level: str, + json_response: bool, +) -> int: + # Configure logging + logging.basicConfig( + level=getattr(logging, log_level.upper()), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + + app = Server("mcp-streamable-http-demo") + + @app.call_tool() + async def call_tool( + name: str, arguments: dict + ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: + ctx = app.request_context + interval = arguments.get("interval", 1.0) + count = arguments.get("count", 5) + caller = arguments.get("caller", "unknown") + + # Send the specified number of notifications with the given interval + for i in range(count): + await ctx.session.send_log_message( + level="info", + data=f"Notification {i+1}/{count} from caller: {caller}", + logger="notification_stream", + # Associates this notification with the original request + # Ensures notifications are sent to the correct response stream + # Without this, notifications will either go to: + # - a standalone SSE stream (if GET request is supported) + # - nowhere (if GET request isn't supported) + related_request_id=ctx.request_id, + ) + if i < count - 1: # Don't wait after the last notification + await anyio.sleep(interval) + + return [ + types.TextContent( + type="text", + text=( + f"Sent {count} notifications with {interval}s interval" + f" for caller: {caller}" + ), + ) + ] + + @app.list_tools() + async def list_tools() -> list[types.Tool]: + return [ + types.Tool( + name="start-notification-stream", + description=( + "Sends a stream of notifications with configurable count" + " and interval" + ), + inputSchema={ + "type": "object", + "required": ["interval", "count", "caller"], + "properties": { + "interval": { + "type": "number", + "description": "Interval between notifications in seconds", + }, + "count": { + "type": "number", + "description": "Number of notifications to send", + }, + "caller": { + "type": "string", + "description": ( + "Identifier of the caller to include in notifications" + ), + }, + }, + }, + ) + ] + + # We need to store the server instances between requests + server_instances = {} + # Lock to prevent race conditions when creating new sessions + session_creation_lock = anyio.Lock() + + # ASGI handler for streamable HTTP connections + async def handle_streamable_http(scope, receive, send): + request = Request(scope, receive) + request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER) + if ( + request_mcp_session_id is not None + and request_mcp_session_id in server_instances + ): + transport = server_instances[request_mcp_session_id] + logger.debug("Session already exists, handling request directly") + await transport.handle_request(scope, receive, send) + elif request_mcp_session_id is None: + # try to establish new session + logger.debug("Creating new transport") + # Use lock to prevent race conditions when creating new sessions + async with session_creation_lock: + new_session_id = uuid4().hex + http_transport = StreamableHTTPServerTransport( + mcp_session_id=new_session_id, + is_json_response_enabled=json_response, + ) + server_instances[http_transport.mcp_session_id] = http_transport + async with http_transport.connect() as streams: + read_stream, write_stream = streams + + async def run_server(): + await app.run( + read_stream, + write_stream, + app.create_initialization_options(), + ) + + if not task_group: + raise RuntimeError("Task group is not initialized") + + task_group.start_soon(run_server) + + # Handle the HTTP request and return the response + await http_transport.handle_request(scope, receive, send) + else: + response = Response( + "Bad Request: No valid session ID provided", + status_code=HTTPStatus.BAD_REQUEST, + ) + await response(scope, receive, send) + + # Create an ASGI application using the transport + starlette_app = Starlette( + debug=True, + routes=[ + Mount("/mcp", app=handle_streamable_http), + ], + lifespan=lifespan, + ) + + import uvicorn + + uvicorn.run(starlette_app, host="0.0.0.0", port=port) + + return 0 diff --git a/examples/servers/simple-streamablehttp/pyproject.toml b/examples/servers/simple-streamablehttp/pyproject.toml new file mode 100644 index 0000000000..c35887d1fd --- /dev/null +++ b/examples/servers/simple-streamablehttp/pyproject.toml @@ -0,0 +1,36 @@ +[project] +name = "mcp-simple-streamablehttp" +version = "0.1.0" +description = "A simple MCP server exposing a StreamableHttp transport for testing" +readme = "README.md" +requires-python = ">=3.10" +authors = [{ name = "Anthropic, PBC." }] +keywords = ["mcp", "llm", "automation", "web", "fetch", "http", "streamable"] +license = { text = "MIT" } +dependencies = ["anyio>=4.5", "click>=8.1.0", "httpx>=0.27", "mcp", "starlette", "uvicorn"] + +[project.scripts] +mcp-simple-streamablehttp = "mcp_simple_streamablehttp.server:main" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["mcp_simple_streamablehttp"] + +[tool.pyright] +include = ["mcp_simple_streamablehttp"] +venvPath = "." +venv = ".venv" + +[tool.ruff.lint] +select = ["E", "F", "I"] +ignore = [] + +[tool.ruff] +line-length = 88 +target-version = "py310" + +[tool.uv] +dev-dependencies = ["pyright>=1.1.378", "pytest>=8.3.3", "ruff>=0.6.9"] \ No newline at end of file diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 65d342e1af..5b57eb13e9 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -814,7 +814,10 @@ async def log( **extra: Additional structured data to include """ await self.request_context.session.send_log_message( - level=level, data=message, logger=logger_name + level=level, + data=message, + logger=logger_name, + related_request_id=self.request_id, ) @property diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 568ecd4b95..3a1f210dde 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -179,7 +179,11 @@ async def _received_notification( ) async def send_log_message( - self, level: types.LoggingLevel, data: Any, logger: str | None = None + self, + level: types.LoggingLevel, + data: Any, + logger: str | None = None, + related_request_id: types.RequestId | None = None, ) -> None: """Send a log message notification.""" await self.send_notification( @@ -192,7 +196,8 @@ async def send_log_message( logger=logger, ), ) - ) + ), + related_request_id, ) async def send_resource_updated(self, uri: AnyUrl) -> None: @@ -261,7 +266,11 @@ async def send_ping(self) -> types.EmptyResult: ) async def send_progress_notification( - self, progress_token: str | int, progress: float, total: float | None = None + self, + progress_token: str | int, + progress: float, + total: float | None = None, + related_request_id: str | None = None, ) -> None: """Send a progress notification.""" await self.send_notification( @@ -274,7 +283,8 @@ async def send_progress_notification( total=total, ), ) - ) + ), + related_request_id, ) async def send_resource_list_changed(self) -> None: diff --git a/src/mcp/server/streamableHttp.py b/src/mcp/server/streamableHttp.py new file mode 100644 index 0000000000..2e0f709092 --- /dev/null +++ b/src/mcp/server/streamableHttp.py @@ -0,0 +1,644 @@ +""" +StreamableHTTP Server Transport Module + +This module implements an HTTP transport layer with Streamable HTTP. + +The transport handles bidirectional communication using HTTP requests and +responses, with streaming support for long-running operations. +""" + +import json +import logging +import re +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from http import HTTPStatus +from typing import Any + +import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from pydantic import ValidationError +from sse_starlette import EventSourceResponse +from starlette.requests import Request +from starlette.responses import Response +from starlette.types import Receive, Scope, Send + +from mcp.types import ( + INTERNAL_ERROR, + INVALID_PARAMS, + INVALID_REQUEST, + PARSE_ERROR, + ErrorData, + JSONRPCError, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + RequestId, +) + +logger = logging.getLogger(__name__) + +# Maximum size for incoming messages +MAXIMUM_MESSAGE_SIZE = 4 * 1024 * 1024 # 4MB + +# Header names +MCP_SESSION_ID_HEADER = "mcp-session-id" +LAST_EVENT_ID_HEADER = "last-event-id" + +# Content types +CONTENT_TYPE_JSON = "application/json" +CONTENT_TYPE_SSE = "text/event-stream" + +# Session ID validation pattern (visible ASCII characters ranging from 0x21 to 0x7E) +# Pattern ensures entire string contains only valid characters by using ^ and $ anchors +SESSION_ID_PATTERN = re.compile(r"^[\x21-\x7E]+$") + + +class StreamableHTTPServerTransport: + """ + HTTP server transport with event streaming support for MCP. + + Handles JSON-RPC messages in HTTP POST requests with SSE streaming. + Supports optional JSON responses and session management. + """ + + # Server notification streams for POST requests as well as standalone SSE stream + _read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] | None = ( + None + ) + _write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] | None = None + + def __init__( + self, + mcp_session_id: str | None, + is_json_response_enabled: bool = False, + ) -> None: + """ + Initialize a new StreamableHTTP server transport. + + Args: + mcp_session_id: Optional session identifier for this connection. + Must contain only visible ASCII characters (0x21-0x7E). + is_json_response_enabled: If True, return JSON responses for requests + instead of SSE streams. Default is False. + + Raises: + ValueError: If the session ID contains invalid characters. + """ + if mcp_session_id is not None and not SESSION_ID_PATTERN.fullmatch( + mcp_session_id + ): + raise ValueError( + "Session ID must only contain visible ASCII characters (0x21-0x7E)" + ) + + self.mcp_session_id = mcp_session_id + self.is_json_response_enabled = is_json_response_enabled + self._request_streams: dict[ + RequestId, MemoryObjectSendStream[JSONRPCMessage] + ] = {} + self._terminated = False + + def _create_error_response( + self, + error_message: str, + status_code: HTTPStatus, + error_code: int = INVALID_REQUEST, + headers: dict[str, str] | None = None, + ) -> Response: + """Create an error response with a simple string message.""" + response_headers = {"Content-Type": CONTENT_TYPE_JSON} + if headers: + response_headers.update(headers) + + if self.mcp_session_id: + response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id + + # Return a properly formatted JSON error response + error_response = JSONRPCError( + jsonrpc="2.0", + id="server-error", # We don't have a request ID for general errors + error=ErrorData( + code=error_code, + message=error_message, + ), + ) + + return Response( + error_response.model_dump_json(by_alias=True, exclude_none=True), + status_code=status_code, + headers=response_headers, + ) + + def _create_json_response( + self, + response_message: JSONRPCMessage | None, + status_code: HTTPStatus = HTTPStatus.OK, + headers: dict[str, str] | None = None, + ) -> Response: + """Create a JSON response from a JSONRPCMessage""" + response_headers = {"Content-Type": CONTENT_TYPE_JSON} + if headers: + response_headers.update(headers) + + if self.mcp_session_id: + response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id + + return Response( + response_message.model_dump_json(by_alias=True, exclude_none=True) + if response_message + else None, + status_code=status_code, + headers=response_headers, + ) + + def _get_session_id(self, request: Request) -> str | None: + """Extract the session ID from request headers.""" + return request.headers.get(MCP_SESSION_ID_HEADER) + + async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None: + """Application entry point that handles all HTTP requests""" + request = Request(scope, receive) + if self._terminated: + # If the session has been terminated, return 404 Not Found + response = self._create_error_response( + "Not Found: Session has been terminated", + HTTPStatus.NOT_FOUND, + ) + await response(scope, receive, send) + return + + if request.method == "POST": + await self._handle_post_request(scope, request, receive, send) + elif request.method == "GET": + await self._handle_get_request(request, send) + elif request.method == "DELETE": + await self._handle_delete_request(request, send) + else: + await self._handle_unsupported_request(request, send) + + def _check_accept_headers(self, request: Request) -> tuple[bool, bool]: + """Check if the request accepts the required media types.""" + accept_header = request.headers.get("accept", "") + accept_types = [media_type.strip() for media_type in accept_header.split(",")] + + has_json = any( + media_type.startswith(CONTENT_TYPE_JSON) for media_type in accept_types + ) + has_sse = any( + media_type.startswith(CONTENT_TYPE_SSE) for media_type in accept_types + ) + + return has_json, has_sse + + def _check_content_type(self, request: Request) -> bool: + """Check if the request has the correct Content-Type.""" + content_type = request.headers.get("content-type", "") + content_type_parts = [ + part.strip() for part in content_type.split(";")[0].split(",") + ] + + return any(part == CONTENT_TYPE_JSON for part in content_type_parts) + + async def _handle_post_request( + self, scope: Scope, request: Request, receive: Receive, send: Send + ) -> None: + """Handle POST requests containing JSON-RPC messages.""" + writer = self._read_stream_writer + if writer is None: + raise ValueError( + "No read stream writer available. Ensure connect() is called first." + ) + try: + # Check Accept headers + has_json, has_sse = self._check_accept_headers(request) + if not (has_json and has_sse): + response = self._create_error_response( + ( + "Not Acceptable: Client must accept both application/json and " + "text/event-stream" + ), + HTTPStatus.NOT_ACCEPTABLE, + ) + await response(scope, receive, send) + return + + # Validate Content-Type + if not self._check_content_type(request): + response = self._create_error_response( + "Unsupported Media Type: Content-Type must be application/json", + HTTPStatus.UNSUPPORTED_MEDIA_TYPE, + ) + await response(scope, receive, send) + return + + # Parse the body - only read it once + body = await request.body() + if len(body) > MAXIMUM_MESSAGE_SIZE: + response = self._create_error_response( + "Payload Too Large: Message exceeds maximum size", + HTTPStatus.REQUEST_ENTITY_TOO_LARGE, + ) + await response(scope, receive, send) + return + + try: + raw_message = json.loads(body) + except json.JSONDecodeError as e: + response = self._create_error_response( + f"Parse error: {str(e)}", HTTPStatus.BAD_REQUEST, PARSE_ERROR + ) + await response(scope, receive, send) + return + + try: + message = JSONRPCMessage.model_validate(raw_message) + except ValidationError as e: + response = self._create_error_response( + f"Validation error: {str(e)}", + HTTPStatus.BAD_REQUEST, + INVALID_PARAMS, + ) + await response(scope, receive, send) + return + + # Check if this is an initialization request + is_initialization_request = ( + isinstance(message.root, JSONRPCRequest) + and message.root.method == "initialize" + ) + + if is_initialization_request: + # Check if the server already has an established session + if self.mcp_session_id: + # Check if request has a session ID + request_session_id = self._get_session_id(request) + + # If request has a session ID but doesn't match, return 404 + if request_session_id and request_session_id != self.mcp_session_id: + response = self._create_error_response( + "Not Found: Invalid or expired session ID", + HTTPStatus.NOT_FOUND, + ) + await response(scope, receive, send) + return + # For non-initialization requests, validate the session + elif not await self._validate_session(request, send): + return + + # For notifications and responses only, return 202 Accepted + if not isinstance(message.root, JSONRPCRequest): + # Create response object and send it + response = self._create_json_response( + None, + HTTPStatus.ACCEPTED, + ) + await response(scope, receive, send) + + # Process the message after sending the response + await writer.send(message) + + return + + # Extract the request ID outside the try block for proper scope + request_id = str(message.root.id) + # Create promise stream for getting response + request_stream_writer, request_stream_reader = ( + anyio.create_memory_object_stream[JSONRPCMessage](0) + ) + + # Register this stream for the request ID + self._request_streams[request_id] = request_stream_writer + + if self.is_json_response_enabled: + # Process the message + await writer.send(message) + try: + # Process messages from the request-specific stream + # We need to collect all messages until we get a response + response_message = None + + # Use similar approach to SSE writer for consistency + async for received_message in request_stream_reader: + # If it's a response, this is what we're waiting for + if isinstance( + received_message.root, JSONRPCResponse | JSONRPCError + ): + response_message = received_message + break + # For notifications and request, keep waiting + else: + logger.debug(f"received: {received_message.root.method}") + + # At this point we should have a response + if response_message: + # Create JSON response + response = self._create_json_response(response_message) + await response(scope, receive, send) + else: + # This shouldn't happen in normal operation + logger.error( + "No response message received before stream closed" + ) + response = self._create_error_response( + "Error processing request: No response received", + HTTPStatus.INTERNAL_SERVER_ERROR, + ) + await response(scope, receive, send) + except Exception as e: + logger.exception(f"Error processing JSON response: {e}") + response = self._create_error_response( + f"Error processing request: {str(e)}", + HTTPStatus.INTERNAL_SERVER_ERROR, + INTERNAL_ERROR, + ) + await response(scope, receive, send) + finally: + # Clean up the request stream + if request_id in self._request_streams: + self._request_streams.pop(request_id, None) + await request_stream_reader.aclose() + await request_stream_writer.aclose() + else: + # Create SSE stream + sse_stream_writer, sse_stream_reader = ( + anyio.create_memory_object_stream[dict[str, Any]](0) + ) + + async def sse_writer(): + # Get the request ID from the incoming request message + try: + async with sse_stream_writer, request_stream_reader: + # Process messages from the request-specific stream + async for received_message in request_stream_reader: + # Build the event data + event_data = { + "event": "message", + "data": received_message.model_dump_json( + by_alias=True, exclude_none=True + ), + } + + await sse_stream_writer.send(event_data) + + # If response, remove from pending streams and close + if isinstance( + received_message.root, + JSONRPCResponse | JSONRPCError, + ): + if request_id: + self._request_streams.pop(request_id, None) + break + except Exception as e: + logger.exception(f"Error in SSE writer: {e}") + finally: + logger.debug("Closing SSE writer") + # Clean up the request-specific streams + if request_id and request_id in self._request_streams: + self._request_streams.pop(request_id, None) + + # Create and start EventSourceResponse + # SSE stream mode (original behavior) + # Set up headers + headers = { + "Cache-Control": "no-cache, no-transform", + "Connection": "keep-alive", + "Content-Type": CONTENT_TYPE_SSE, + **( + {MCP_SESSION_ID_HEADER: self.mcp_session_id} + if self.mcp_session_id + else {} + ), + } + response = EventSourceResponse( + content=sse_stream_reader, + data_sender_callable=sse_writer, + headers=headers, + ) + + # Start the SSE response (this will send headers immediately) + try: + # First send the response to establish the SSE connection + async with anyio.create_task_group() as tg: + tg.start_soon(response, scope, receive, send) + # Then send the message to be processed by the server + await writer.send(message) + except Exception: + logger.exception("SSE response error") + # Clean up the request stream if something goes wrong + if request_id and request_id in self._request_streams: + self._request_streams.pop(request_id, None) + + except Exception as err: + logger.exception("Error handling POST request") + response = self._create_error_response( + f"Error handling POST request: {err}", + HTTPStatus.INTERNAL_SERVER_ERROR, + INTERNAL_ERROR, + ) + await response(scope, receive, send) + if writer: + await writer.send(err) + return + + async def _handle_get_request(self, request: Request, send: Send) -> None: + """Handle GET requests for SSE stream establishment.""" + # Validate session ID if server has one + if not await self._validate_session(request, send): + return + # Validate Accept header - must include text/event-stream + _, has_sse = self._check_accept_headers(request) + + if not has_sse: + response = self._create_error_response( + "Not Acceptable: Client must accept text/event-stream", + HTTPStatus.NOT_ACCEPTABLE, + ) + await response(request.scope, request.receive, send) + return + + # TODO: Implement SSE stream for GET requests + # For now, return 405 Method Not Allowed + response = self._create_error_response( + "SSE stream from GET request not implemented yet", + HTTPStatus.METHOD_NOT_ALLOWED, + ) + await response(request.scope, request.receive, send) + + async def _handle_delete_request(self, request: Request, send: Send) -> None: + """Handle DELETE requests for explicit session termination.""" + # Validate session ID + if not self.mcp_session_id: + # If no session ID set, return Method Not Allowed + response = self._create_error_response( + "Method Not Allowed: Session termination not supported", + HTTPStatus.METHOD_NOT_ALLOWED, + ) + await response(request.scope, request.receive, send) + return + + if not await self._validate_session(request, send): + return + + self._terminate_session() + + response = self._create_json_response( + None, + HTTPStatus.OK, + ) + await response(request.scope, request.receive, send) + + def _terminate_session(self) -> None: + """Terminate the current session, closing all streams. + + Once terminated, all requests with this session ID will receive 404 Not Found. + """ + + self._terminated = True + logger.info(f"Terminating session: {self.mcp_session_id}") + + # We need a copy of the keys to avoid modification during iteration + request_stream_keys = list(self._request_streams.keys()) + + # Close all request streams (synchronously) + for key in request_stream_keys: + try: + # Get the stream + stream = self._request_streams.get(key) + if stream: + # We must use close() here, not aclose() since this is a sync method + stream.close() + except Exception as e: + logger.debug(f"Error closing stream {key} during termination: {e}") + + # Clear the request streams dictionary immediately + self._request_streams.clear() + + async def _handle_unsupported_request(self, request: Request, send: Send) -> None: + """Handle unsupported HTTP methods.""" + headers = { + "Content-Type": CONTENT_TYPE_JSON, + "Allow": "GET, POST, DELETE", + } + if self.mcp_session_id: + headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id + + response = self._create_error_response( + "Method Not Allowed", + HTTPStatus.METHOD_NOT_ALLOWED, + headers=headers, + ) + await response(request.scope, request.receive, send) + + async def _validate_session(self, request: Request, send: Send) -> bool: + """Validate the session ID in the request.""" + if not self.mcp_session_id: + # If we're not using session IDs, return True + return True + + # Get the session ID from the request headers + request_session_id = self._get_session_id(request) + + # If no session ID provided but required, return error + if not request_session_id: + response = self._create_error_response( + "Bad Request: Missing session ID", + HTTPStatus.BAD_REQUEST, + ) + await response(request.scope, request.receive, send) + return False + + # If session ID doesn't match, return error + if request_session_id != self.mcp_session_id: + response = self._create_error_response( + "Not Found: Invalid or expired session ID", + HTTPStatus.NOT_FOUND, + ) + await response(request.scope, request.receive, send) + return False + + return True + + @asynccontextmanager + async def connect( + self, + ) -> AsyncGenerator[ + tuple[ + MemoryObjectReceiveStream[JSONRPCMessage | Exception], + MemoryObjectSendStream[JSONRPCMessage], + ], + None, + ]: + """Context manager that provides read and write streams for a connection. + + Yields: + Tuple of (read_stream, write_stream) for bidirectional communication + """ + + # Create the memory streams for this connection + + read_stream_writer, read_stream = anyio.create_memory_object_stream[ + JSONRPCMessage | Exception + ](0) + write_stream, write_stream_reader = anyio.create_memory_object_stream[ + JSONRPCMessage + ](0) + + # Store the streams + self._read_stream_writer = read_stream_writer + self._write_stream_reader = write_stream_reader + + # Start a task group for message routing + async with anyio.create_task_group() as tg: + # Create a message router that distributes messages to request streams + async def message_router(): + try: + async for message in write_stream_reader: + # Determine which request stream(s) should receive this message + target_request_id = None + if isinstance( + message.root, JSONRPCNotification | JSONRPCRequest + ): + # Extract related_request_id from meta if it exists + if ( + (params := getattr(message.root, "params", None)) + and (meta := params.get("_meta")) + and (related_id := meta.get("related_request_id")) + is not None + ): + target_request_id = str(related_id) + else: + target_request_id = str(message.root.id) + + # Send to the specific request stream if available + if ( + target_request_id + and target_request_id in self._request_streams + ): + try: + await self._request_streams[target_request_id].send( + message + ) + except ( + anyio.BrokenResourceError, + anyio.ClosedResourceError, + ): + # Stream might be closed, remove from registry + self._request_streams.pop(target_request_id, None) + except Exception as e: + logger.exception(f"Error in message router: {e}") + + # Start the message router + tg.start_soon(message_router) + + try: + # Yield the streams for the caller to use + yield read_stream, write_stream + finally: + for stream in list(self._request_streams.values()): + try: + await stream.aclose() + except Exception: + pass + self._request_streams.clear() diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 11daedc988..c1259da742 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -6,7 +6,6 @@ from typing import Any, Generic, TypeVar import anyio -import anyio.lowlevel import httpx from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import BaseModel @@ -24,6 +23,7 @@ JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, + NotificationParams, RequestParams, ServerNotification, ServerRequest, @@ -274,16 +274,32 @@ async def send_request( await response_stream.aclose() await response_stream_reader.aclose() - async def send_notification(self, notification: SendNotificationT) -> None: + async def send_notification( + self, + notification: SendNotificationT, + related_request_id: RequestId | None = None, + ) -> None: """ Emits a notification, which is a one-way message that does not expect a response. """ + # Some transport implementations may need to set the related_request_id + # to attribute to the notifications to the request that triggered them. + if related_request_id is not None and notification.root.params is not None: + # Create meta if it doesn't exist + if notification.root.params.meta is None: + meta_dict = {"related_request_id": related_request_id} + + else: + meta_dict = notification.root.params.meta.model_dump( + by_alias=True, mode="json", exclude_none=True + ) + meta_dict["related_request_id"] = related_request_id + notification.root.params.meta = NotificationParams.Meta(**meta_dict) jsonrpc_notification = JSONRPCNotification( jsonrpc="2.0", **notification.model_dump(by_alias=True, mode="json", exclude_none=True), ) - await self._write_stream.send(JSONRPCMessage(jsonrpc_notification)) async def _send_response( diff --git a/tests/client/test_logging_callback.py b/tests/client/test_logging_callback.py index 797f817e1a..588fa649f8 100644 --- a/tests/client/test_logging_callback.py +++ b/tests/client/test_logging_callback.py @@ -9,6 +9,7 @@ from mcp.shared.session import RequestResponder from mcp.types import ( LoggingMessageNotificationParams, + NotificationParams, TextContent, ) @@ -78,6 +79,11 @@ async def message_handler( ) assert log_result.isError is False assert len(logging_collector.log_messages) == 1 - assert logging_collector.log_messages[0] == LoggingMessageNotificationParams( - level="info", logger="test_logger", data="Test log message" - ) + # Create meta object with related_request_id added dynamically + meta = NotificationParams.Meta() + setattr(meta, "related_request_id", "2") + log = logging_collector.log_messages[0] + assert log.level == "info" + assert log.logger == "test_logger" + assert log.data == "Test log message" + assert log.meta == meta diff --git a/tests/issues/test_188_concurrency.py b/tests/issues/test_188_concurrency.py index 2aa6c49cb3..d0a86885fd 100644 --- a/tests/issues/test_188_concurrency.py +++ b/tests/issues/test_188_concurrency.py @@ -35,7 +35,7 @@ async def slow_resource(): end_time = anyio.current_time() duration = end_time - start_time - assert duration < 3 * _sleep_time_seconds + assert duration < 6 * _sleep_time_seconds print(duration) diff --git a/tests/server/fastmcp/test_server.py b/tests/server/fastmcp/test_server.py index e76e59c52e..772c41529b 100644 --- a/tests/server/fastmcp/test_server.py +++ b/tests/server/fastmcp/test_server.py @@ -544,14 +544,28 @@ async def logging_tool(msg: str, ctx: Context) -> str: assert mock_log.call_count == 4 mock_log.assert_any_call( - level="debug", data="Debug message", logger=None + level="debug", + data="Debug message", + logger=None, + related_request_id="1", ) - mock_log.assert_any_call(level="info", data="Info message", logger=None) mock_log.assert_any_call( - level="warning", data="Warning message", logger=None + level="info", + data="Info message", + logger=None, + related_request_id="1", ) mock_log.assert_any_call( - level="error", data="Error message", logger=None + level="warning", + data="Warning message", + logger=None, + related_request_id="1", + ) + mock_log.assert_any_call( + level="error", + data="Error message", + logger=None, + related_request_id="1", ) @pytest.mark.anyio diff --git a/tests/server/test_streamableHttp.py b/tests/server/test_streamableHttp.py new file mode 100644 index 0000000000..8904bf4f5f --- /dev/null +++ b/tests/server/test_streamableHttp.py @@ -0,0 +1,543 @@ +""" +Tests for the StreamableHTTP server transport validation. + +This file contains tests for request validation in the StreamableHTTP transport. +""" + +import contextlib +import multiprocessing +import socket +import time +from collections.abc import Generator +from http import HTTPStatus +from uuid import uuid4 + +import anyio +import pytest +import requests +import uvicorn +from pydantic import AnyUrl +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import Response +from starlette.routing import Mount + +from mcp.server import Server +from mcp.server.streamableHttp import ( + MCP_SESSION_ID_HEADER, + SESSION_ID_PATTERN, + StreamableHTTPServerTransport, +) +from mcp.shared.exceptions import McpError +from mcp.types import ( + ErrorData, + TextContent, + Tool, +) + +# Test constants +SERVER_NAME = "test_streamable_http_server" +TEST_SESSION_ID = "test-session-id-12345" +INIT_REQUEST = { + "jsonrpc": "2.0", + "method": "initialize", + "params": { + "clientInfo": {"name": "test-client", "version": "1.0"}, + "protocolVersion": "2025-03-26", + "capabilities": {}, + }, + "id": "init-1", +} + + +# Test server implementation that follows MCP protocol +class ServerTest(Server): + def __init__(self): + super().__init__(SERVER_NAME) + + @self.read_resource() + async def handle_read_resource(uri: AnyUrl) -> str | bytes: + if uri.scheme == "foobar": + return f"Read {uri.host}" + elif uri.scheme == "slow": + # Simulate a slow resource + await anyio.sleep(2.0) + return f"Slow response from {uri.host}" + + raise McpError( + error=ErrorData( + code=404, message="OOPS! no resource with that URI was found" + ) + ) + + @self.list_tools() + async def handle_list_tools() -> list[Tool]: + return [ + Tool( + name="test_tool", + description="A test tool", + inputSchema={"type": "object", "properties": {}}, + ) + ] + + @self.call_tool() + async def handle_call_tool(name: str, args: dict) -> list[TextContent]: + return [TextContent(type="text", text=f"Called {name}")] + + +def create_app(is_json_response_enabled=False) -> Starlette: + """Create a Starlette application for testing that matches the example server. + + Args: + is_json_response_enabled: If True, use JSON responses instead of SSE streams. + """ + # Create server instance + server = ServerTest() + + server_instances = {} + # Lock to prevent race conditions when creating new sessions + session_creation_lock = anyio.Lock() + task_group = None + + @contextlib.asynccontextmanager + async def lifespan(app): + """Application lifespan context manager for managing task group.""" + nonlocal task_group + + async with anyio.create_task_group() as tg: + task_group = tg + print("Application started, task group initialized!") + try: + yield + finally: + print("Application shutting down, cleaning up resources...") + if task_group: + tg.cancel_scope.cancel() + task_group = None + print("Resources cleaned up successfully.") + + async def handle_streamable_http(scope, receive, send): + request = Request(scope, receive) + request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER) + + # Use existing transport if session ID matches + if ( + request_mcp_session_id is not None + and request_mcp_session_id in server_instances + ): + transport = server_instances[request_mcp_session_id] + + await transport.handle_request(scope, receive, send) + elif request_mcp_session_id is None: + async with session_creation_lock: + new_session_id = uuid4().hex + + http_transport = StreamableHTTPServerTransport( + mcp_session_id=new_session_id, + is_json_response_enabled=is_json_response_enabled, + ) + + async with http_transport.connect() as streams: + read_stream, write_stream = streams + + async def run_server(): + try: + await server.run( + read_stream, + write_stream, + server.create_initialization_options(), + ) + except Exception as e: + print(f"Server exception: {e}") + + if task_group is None: + response = Response( + "Internal Server Error: Task group is not initialized", + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + ) + await response(scope, receive, send) + return + + # Store the instance before starting the task to prevent races + server_instances[http_transport.mcp_session_id] = http_transport + task_group.start_soon(run_server) + + await http_transport.handle_request(scope, receive, send) + else: + response = Response( + "Bad Request: No valid session ID provided", + status_code=HTTPStatus.BAD_REQUEST, + ) + await response(scope, receive, send) + + # Create an ASGI application + app = Starlette( + debug=True, + routes=[ + Mount("/mcp", app=handle_streamable_http), + ], + lifespan=lifespan, + ) + + return app + + +def run_server(port: int, is_json_response_enabled=False) -> None: + """Run the test server. + + Args: + port: Port to listen on. + is_json_response_enabled: If True, use JSON responses instead of SSE streams. + """ + print( + f"Starting test server on port {port} with " + f"json_enabled={is_json_response_enabled}" + ) + + app = create_app(is_json_response_enabled) + # Configure server + config = uvicorn.Config( + app=app, + host="127.0.0.1", + port=port, + log_level="info", + limit_concurrency=10, + timeout_keep_alive=5, + access_log=False, + ) + + # Start the server + server = uvicorn.Server(config=config) + + # This is important to catch exceptions and prevent test hangs + try: + print("Server starting...") + server.run() + except Exception as e: + print(f"ERROR: Server failed to run: {e}") + import traceback + + traceback.print_exc() + + print("Server shutdown") + + +# Test fixtures - using same approach as SSE tests +@pytest.fixture +def basic_server_port() -> int: + """Find an available port for the basic server.""" + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@pytest.fixture +def json_server_port() -> int: + """Find an available port for the JSON response server.""" + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@pytest.fixture +def basic_server(basic_server_port: int) -> Generator[None, None, None]: + """Start a basic server.""" + proc = multiprocessing.Process( + target=run_server, kwargs={"port": basic_server_port}, daemon=True + ) + proc.start() + + # Wait for server to be running + max_attempts = 20 + attempt = 0 + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("127.0.0.1", basic_server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError(f"Server failed to start after {max_attempts} attempts") + + yield + + # Clean up + proc.kill() + proc.join(timeout=2) + if proc.is_alive(): + print("server process failed to terminate") + + +@pytest.fixture +def json_response_server(json_server_port: int) -> Generator[None, None, None]: + """Start a server with JSON response enabled.""" + proc = multiprocessing.Process( + target=run_server, + kwargs={"port": json_server_port, "is_json_response_enabled": True}, + daemon=True, + ) + proc.start() + + # Wait for server to be running + max_attempts = 20 + attempt = 0 + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("127.0.0.1", json_server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError(f"Server failed to start after {max_attempts} attempts") + + yield + + # Clean up + proc.kill() + proc.join(timeout=2) + if proc.is_alive(): + print("server process failed to terminate") + + +@pytest.fixture +def basic_server_url(basic_server_port: int) -> str: + """Get the URL for the basic test server.""" + return f"http://127.0.0.1:{basic_server_port}" + + +@pytest.fixture +def json_server_url(json_server_port: int) -> str: + """Get the URL for the JSON response test server.""" + return f"http://127.0.0.1:{json_server_port}" + + +# Basic request validation tests +def test_accept_header_validation(basic_server, basic_server_url): + """Test that Accept header is properly validated.""" + # Test without Accept header + response = requests.post( + f"{basic_server_url}/mcp", + headers={"Content-Type": "application/json"}, + json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text + + +def test_content_type_validation(basic_server, basic_server_url): + """Test that Content-Type header is properly validated.""" + # Test with incorrect Content-Type + response = requests.post( + f"{basic_server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "text/plain", + }, + data="This is not JSON", + ) + assert response.status_code == 415 + assert "Unsupported Media Type" in response.text + + +def test_json_validation(basic_server, basic_server_url): + """Test that JSON content is properly validated.""" + # Test with invalid JSON + response = requests.post( + f"{basic_server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + data="this is not valid json", + ) + assert response.status_code == 400 + assert "Parse error" in response.text + + +def test_json_parsing(basic_server, basic_server_url): + """Test that JSON content is properly parse.""" + # Test with valid JSON but invalid JSON-RPC + response = requests.post( + f"{basic_server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json={"foo": "bar"}, + ) + assert response.status_code == 400 + assert "Validation error" in response.text + + +def test_method_not_allowed(basic_server, basic_server_url): + """Test that unsupported HTTP methods are rejected.""" + # Test with unsupported method (PUT) + response = requests.put( + f"{basic_server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json={"jsonrpc": "2.0", "method": "initialize", "id": 1}, + ) + assert response.status_code == 405 + assert "Method Not Allowed" in response.text + + +def test_session_validation(basic_server, basic_server_url): + """Test session ID validation.""" + # session_id not used directly in this test + + # Test without session ID + response = requests.post( + f"{basic_server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json={"jsonrpc": "2.0", "method": "list_tools", "id": 1}, + ) + assert response.status_code == 400 + assert "Missing session ID" in response.text + + +def test_session_id_pattern(): + """Test that SESSION_ID_PATTERN correctly validates session IDs.""" + # Valid session IDs (visible ASCII characters from 0x21 to 0x7E) + valid_session_ids = [ + "test-session-id", + "1234567890", + "session!@#$%^&*()_+-=[]{}|;:,.<>?/", + "~`", + ] + + for session_id in valid_session_ids: + assert SESSION_ID_PATTERN.match(session_id) is not None + # Ensure fullmatch matches too (whole string) + assert SESSION_ID_PATTERN.fullmatch(session_id) is not None + + # Invalid session IDs + invalid_session_ids = [ + "", # Empty string + " test", # Space (0x20) + "test\t", # Tab + "test\n", # Newline + "test\r", # Carriage return + "test" + chr(0x7F), # DEL character + "test" + chr(0x80), # Extended ASCII + "test" + chr(0x00), # Null character + "test" + chr(0x20), # Space (0x20) + ] + + for session_id in invalid_session_ids: + # For invalid IDs, either match will fail or fullmatch will fail + if SESSION_ID_PATTERN.match(session_id) is not None: + # If match succeeds, fullmatch should fail (partial match case) + assert SESSION_ID_PATTERN.fullmatch(session_id) is None + + +def test_streamable_http_transport_init_validation(): + """Test that StreamableHTTPServerTransport validates session ID on init.""" + # Valid session ID should initialize without errors + valid_transport = StreamableHTTPServerTransport(mcp_session_id="valid-id") + assert valid_transport.mcp_session_id == "valid-id" + + # None should be accepted + none_transport = StreamableHTTPServerTransport(mcp_session_id=None) + assert none_transport.mcp_session_id is None + + # Invalid session ID should raise ValueError + with pytest.raises(ValueError) as excinfo: + StreamableHTTPServerTransport(mcp_session_id="invalid id with space") + assert "Session ID must only contain visible ASCII characters" in str(excinfo.value) + + # Test with control characters + with pytest.raises(ValueError): + StreamableHTTPServerTransport(mcp_session_id="test\nid") + + with pytest.raises(ValueError): + StreamableHTTPServerTransport(mcp_session_id="test\n") + + +def test_session_termination(basic_server, basic_server_url): + """Test session termination via DELETE and subsequent request handling.""" + response = requests.post( + f"{basic_server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + + # Now terminate the session + session_id = response.headers.get(MCP_SESSION_ID_HEADER) + response = requests.delete( + f"{basic_server_url}/mcp", + headers={MCP_SESSION_ID_HEADER: session_id}, + ) + assert response.status_code == 200 + + # Try to use the terminated session + response = requests.post( + f"{basic_server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, + }, + json={"jsonrpc": "2.0", "method": "ping", "id": 2}, + ) + assert response.status_code == 404 + assert "Session has been terminated" in response.text + + +def test_response(basic_server, basic_server_url): + """Test response handling for a valid request.""" + mcp_url = f"{basic_server_url}/mcp" + response = requests.post( + mcp_url, + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + + # Now terminate the session + session_id = response.headers.get(MCP_SESSION_ID_HEADER) + + # Try to use the terminated session + tools_response = requests.post( + mcp_url, + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, # Use the session ID we got earlier + }, + json={"jsonrpc": "2.0", "method": "tools/list", "id": "tools-1"}, + stream=True, + ) + assert tools_response.status_code == 200 + assert tools_response.headers.get("Content-Type") == "text/event-stream" + + +def test_json_response(json_response_server, json_server_url): + """Test response handling when is_json_response_enabled is True.""" + mcp_url = f"{json_server_url}/mcp" + response = requests.post( + mcp_url, + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert response.status_code == 200 + assert response.headers.get("Content-Type") == "application/json" diff --git a/uv.lock b/uv.lock index fdb788a797..01726cc960 100644 --- a/uv.lock +++ b/uv.lock @@ -10,6 +10,7 @@ members = [ "mcp", "mcp-simple-prompt", "mcp-simple-resource", + "mcp-simple-streamablehttp", "mcp-simple-tool", ] @@ -632,6 +633,43 @@ dev = [ { name = "ruff", specifier = ">=0.6.9" }, ] +[[package]] +name = "mcp-simple-streamablehttp" +version = "0.1.0" +source = { editable = "examples/servers/simple-streamablehttp" } +dependencies = [ + { name = "anyio" }, + { name = "click" }, + { name = "httpx" }, + { name = "mcp" }, + { name = "starlette" }, + { name = "uvicorn" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pyright" }, + { name = "pytest" }, + { name = "ruff" }, +] + +[package.metadata] +requires-dist = [ + { name = "anyio", specifier = ">=4.5" }, + { name = "click", specifier = ">=8.1.0" }, + { name = "httpx", specifier = ">=0.27" }, + { name = "mcp", editable = "." }, + { name = "starlette" }, + { name = "uvicorn" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pyright", specifier = ">=1.1.378" }, + { name = "pytest", specifier = ">=8.3.3" }, + { name = "ruff", specifier = ">=0.6.9" }, +] + [[package]] name = "mcp-simple-tool" version = "0.1.0" From 72b66a58b1c51dcd5f6f3dff8d597f46706ffe72 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Fri, 2 May 2025 13:40:21 +0100 Subject: [PATCH 02/21] StreamableHttp - stateless server support (#554) --- .../simple-streamablehttp-stateless/README.md | 41 +++++ .../__init__.py | 0 .../__main__.py | 4 + .../server.py | 168 ++++++++++++++++++ .../pyproject.toml | 36 ++++ src/mcp/server/lowlevel/server.py | 12 +- src/mcp/server/session.py | 8 +- uv.lock | 40 ++++- 8 files changed, 305 insertions(+), 4 deletions(-) create mode 100644 examples/servers/simple-streamablehttp-stateless/README.md create mode 100644 examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/__init__.py create mode 100644 examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/__main__.py create mode 100644 examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py create mode 100644 examples/servers/simple-streamablehttp-stateless/pyproject.toml diff --git a/examples/servers/simple-streamablehttp-stateless/README.md b/examples/servers/simple-streamablehttp-stateless/README.md new file mode 100644 index 0000000000..2abb60614c --- /dev/null +++ b/examples/servers/simple-streamablehttp-stateless/README.md @@ -0,0 +1,41 @@ +# MCP Simple StreamableHttp Stateless Server Example + +A stateless MCP server example demonstrating the StreamableHttp transport without maintaining session state. This example is ideal for understanding how to deploy MCP servers in multi-node environments where requests can be routed to any instance. + +## Features + +- Uses the StreamableHTTP transport in stateless mode (mcp_session_id=None) +- Each request creates a new ephemeral connection +- No session state maintained between requests +- Task lifecycle scoped to individual requests +- Suitable for deployment in multi-node environments + + +## Usage + +Start the server: + +```bash +# Using default port 3000 +uv run mcp-simple-streamablehttp-stateless + +# Using custom port +uv run mcp-simple-streamablehttp-stateless --port 3000 + +# Custom logging level +uv run mcp-simple-streamablehttp-stateless --log-level DEBUG + +# Enable JSON responses instead of SSE streams +uv run mcp-simple-streamablehttp-stateless --json-response +``` + +The server exposes a tool named "start-notification-stream" that accepts three arguments: + +- `interval`: Time between notifications in seconds (e.g., 1.0) +- `count`: Number of notifications to send (e.g., 5) +- `caller`: Identifier string for the caller + + +## Client + +You can connect to this server using an HTTP client. For now, only the TypeScript SDK has streamable HTTP client examples, or you can use [Inspector](https://github.com/modelcontextprotocol/inspector) for testing. \ No newline at end of file diff --git a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/__init__.py b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/__main__.py b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/__main__.py new file mode 100644 index 0000000000..f5f6e402df --- /dev/null +++ b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/__main__.py @@ -0,0 +1,4 @@ +from .server import main + +if __name__ == "__main__": + main() diff --git a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py new file mode 100644 index 0000000000..da8158a980 --- /dev/null +++ b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py @@ -0,0 +1,168 @@ +import contextlib +import logging + +import anyio +import click +import mcp.types as types +from mcp.server.lowlevel import Server +from mcp.server.streamableHttp import ( + StreamableHTTPServerTransport, +) +from starlette.applications import Starlette +from starlette.routing import Mount + +logger = logging.getLogger(__name__) +# Global task group that will be initialized in the lifespan +task_group = None + + +@contextlib.asynccontextmanager +async def lifespan(app): + """Application lifespan context manager for managing task group.""" + global task_group + + async with anyio.create_task_group() as tg: + task_group = tg + logger.info("Application started, task group initialized!") + try: + yield + finally: + logger.info("Application shutting down, cleaning up resources...") + if task_group: + tg.cancel_scope.cancel() + task_group = None + logger.info("Resources cleaned up successfully.") + + +@click.command() +@click.option("--port", default=3000, help="Port to listen on for HTTP") +@click.option( + "--log-level", + default="INFO", + help="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)", +) +@click.option( + "--json-response", + is_flag=True, + default=False, + help="Enable JSON responses instead of SSE streams", +) +def main( + port: int, + log_level: str, + json_response: bool, +) -> int: + # Configure logging + logging.basicConfig( + level=getattr(logging, log_level.upper()), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + + app = Server("mcp-streamable-http-stateless-demo") + + @app.call_tool() + async def call_tool( + name: str, arguments: dict + ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: + ctx = app.request_context + interval = arguments.get("interval", 1.0) + count = arguments.get("count", 5) + caller = arguments.get("caller", "unknown") + + # Send the specified number of notifications with the given interval + for i in range(count): + await ctx.session.send_log_message( + level="info", + data=f"Notification {i+1}/{count} from caller: {caller}", + logger="notification_stream", + related_request_id=ctx.request_id, + ) + if i < count - 1: # Don't wait after the last notification + await anyio.sleep(interval) + + return [ + types.TextContent( + type="text", + text=( + f"Sent {count} notifications with {interval}s interval" + f" for caller: {caller}" + ), + ) + ] + + @app.list_tools() + async def list_tools() -> list[types.Tool]: + return [ + types.Tool( + name="start-notification-stream", + description=( + "Sends a stream of notifications with configurable count" + " and interval" + ), + inputSchema={ + "type": "object", + "required": ["interval", "count", "caller"], + "properties": { + "interval": { + "type": "number", + "description": "Interval between notifications in seconds", + }, + "count": { + "type": "number", + "description": "Number of notifications to send", + }, + "caller": { + "type": "string", + "description": ( + "Identifier of the caller to include in notifications" + ), + }, + }, + }, + ) + ] + + # ASGI handler for stateless HTTP connections + async def handle_streamable_http(scope, receive, send): + logger.debug("Creating new transport") + # Use lock to prevent race conditions when creating new sessions + http_transport = StreamableHTTPServerTransport( + mcp_session_id=None, + is_json_response_enabled=json_response, + ) + async with http_transport.connect() as streams: + read_stream, write_stream = streams + + if not task_group: + raise RuntimeError("Task group is not initialized") + + async def run_server(): + await app.run( + read_stream, + write_stream, + app.create_initialization_options(), + # Runs in standalone mode for stateless deployments + # where clients perform initialization with any node + standalone_mode=True, + ) + + # Start server task + task_group.start_soon(run_server) + + # Handle the HTTP request and return the response + await http_transport.handle_request(scope, receive, send) + + # Create an ASGI application using the transport + starlette_app = Starlette( + debug=True, + routes=[ + Mount("/mcp", app=handle_streamable_http), + ], + lifespan=lifespan, + ) + + import uvicorn + + uvicorn.run(starlette_app, host="0.0.0.0", port=port) + + return 0 diff --git a/examples/servers/simple-streamablehttp-stateless/pyproject.toml b/examples/servers/simple-streamablehttp-stateless/pyproject.toml new file mode 100644 index 0000000000..d2b089451f --- /dev/null +++ b/examples/servers/simple-streamablehttp-stateless/pyproject.toml @@ -0,0 +1,36 @@ +[project] +name = "mcp-simple-streamablehttp-stateless" +version = "0.1.0" +description = "A simple MCP server exposing a StreamableHttp transport in stateless mode" +readme = "README.md" +requires-python = ">=3.10" +authors = [{ name = "Anthropic, PBC." }] +keywords = ["mcp", "llm", "automation", "web", "fetch", "http", "streamable", "stateless"] +license = { text = "MIT" } +dependencies = ["anyio>=4.5", "click>=8.1.0", "httpx>=0.27", "mcp", "starlette", "uvicorn"] + +[project.scripts] +mcp-simple-streamablehttp-stateless = "mcp_simple_streamablehttp_stateless.server:main" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["mcp_simple_streamablehttp_stateless"] + +[tool.pyright] +include = ["mcp_simple_streamablehttp_stateless"] +venvPath = "." +venv = ".venv" + +[tool.ruff.lint] +select = ["E", "F", "I"] +ignore = [] + +[tool.ruff] +line-length = 88 +target-version = "py310" + +[tool.uv] +dev-dependencies = ["pyright>=1.1.378", "pytest>=8.3.3", "ruff>=0.6.9"] \ No newline at end of file diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index b4f6330b51..1a01ff837f 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -479,11 +479,21 @@ async def run( # but also make tracing exceptions much easier during testing and when using # in-process servers. raise_exceptions: bool = False, + # When True, the server as stateless deployments where + # clients can perform initialization with any node. The client must still follow + # the initialization lifecycle, but can do so with any available node + # rather than requiring initialization for each connection. + stateless: bool = False, ): async with AsyncExitStack() as stack: lifespan_context = await stack.enter_async_context(self.lifespan(self)) session = await stack.enter_async_context( - ServerSession(read_stream, write_stream, initialization_options) + ServerSession( + read_stream, + write_stream, + initialization_options, + stateless=stateless, + ) ) async with anyio.create_task_group() as tg: diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 3a1f210dde..d1fe88553b 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -85,11 +85,17 @@ def __init__( read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], write_stream: MemoryObjectSendStream[types.JSONRPCMessage], init_options: InitializationOptions, + stateless: bool = False, ) -> None: super().__init__( read_stream, write_stream, types.ClientRequest, types.ClientNotification ) - self._initialization_state = InitializationState.NotInitialized + self._initialization_state = ( + InitializationState.Initialized + if stateless + else InitializationState.NotInitialized + ) + self._init_options = init_options self._incoming_message_stream_writer, self._incoming_message_stream_reader = ( anyio.create_memory_object_stream[ServerRequestResponder](0) diff --git a/uv.lock b/uv.lock index 01726cc960..cbdc33471e 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,4 @@ version = 1 -revision = 1 requires-python = ">=3.10" [options] @@ -11,6 +10,7 @@ members = [ "mcp-simple-prompt", "mcp-simple-resource", "mcp-simple-streamablehttp", + "mcp-simple-streamablehttp-stateless", "mcp-simple-tool", ] @@ -547,7 +547,6 @@ requires-dist = [ { name = "uvicorn", marker = "sys_platform != 'emscripten'", specifier = ">=0.23.1" }, { name = "websockets", marker = "extra == 'ws'", specifier = ">=15.0.1" }, ] -provides-extras = ["cli", "rich", "ws"] [package.metadata.requires-dev] dev = [ @@ -670,6 +669,43 @@ dev = [ { name = "ruff", specifier = ">=0.6.9" }, ] +[[package]] +name = "mcp-simple-streamablehttp-stateless" +version = "0.1.0" +source = { editable = "examples/servers/simple-streamablehttp-stateless" } +dependencies = [ + { name = "anyio" }, + { name = "click" }, + { name = "httpx" }, + { name = "mcp" }, + { name = "starlette" }, + { name = "uvicorn" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pyright" }, + { name = "pytest" }, + { name = "ruff" }, +] + +[package.metadata] +requires-dist = [ + { name = "anyio", specifier = ">=4.5" }, + { name = "click", specifier = ">=8.1.0" }, + { name = "httpx", specifier = ">=0.27" }, + { name = "mcp", editable = "." }, + { name = "starlette" }, + { name = "uvicorn" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pyright", specifier = ">=1.1.378" }, + { name = "pytest", specifier = ">=8.3.3" }, + { name = "ruff", specifier = ">=0.6.9" }, +] + [[package]] name = "mcp-simple-tool" version = "0.1.0" From 46523afe3048178c3482b2dc9ccae18ddb8bc762 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Fri, 2 May 2025 13:52:27 +0100 Subject: [PATCH 03/21] StreamableHttp - GET request standalone SSE (#561) --- .../mcp_simple_streamablehttp/server.py | 4 + src/mcp/server/streamableHttp.py | 110 +++++++++++++++--- tests/server/test_streamableHttp.py | 89 ++++++++++++++ 3 files changed, 186 insertions(+), 17 deletions(-) diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py index e7bc44306a..b5faffedb4 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -11,6 +11,7 @@ MCP_SESSION_ID_HEADER, StreamableHTTPServerTransport, ) +from pydantic import AnyUrl from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import Response @@ -92,6 +93,9 @@ async def call_tool( if i < count - 1: # Don't wait after the last notification await anyio.sleep(interval) + # This will send a resource notificaiton though standalone SSE + # established by GET request + await ctx.session.send_resource_updated(uri=AnyUrl("http:///test_resource")) return [ types.TextContent( type="text", diff --git a/src/mcp/server/streamableHttp.py b/src/mcp/server/streamableHttp.py index 2e0f709092..8faff0162f 100644 --- a/src/mcp/server/streamableHttp.py +++ b/src/mcp/server/streamableHttp.py @@ -50,6 +50,9 @@ CONTENT_TYPE_JSON = "application/json" CONTENT_TYPE_SSE = "text/event-stream" +# Special key for the standalone GET stream +GET_STREAM_KEY = "_GET_stream" + # Session ID validation pattern (visible ASCII characters ranging from 0x21 to 0x7E) # Pattern ensures entire string contains only valid characters by using ^ and $ anchors SESSION_ID_PATTERN = re.compile(r"^[\x21-\x7E]+$") @@ -443,10 +446,19 @@ async def sse_writer(): return async def _handle_get_request(self, request: Request, send: Send) -> None: - """Handle GET requests for SSE stream establishment.""" - # Validate session ID if server has one - if not await self._validate_session(request, send): - return + """ + Handle GET request to establish SSE. + + This allows the server to communicate to the client without the client + first sending data via HTTP POST. The server can send JSON-RPC requests + and notifications on this stream. + """ + writer = self._read_stream_writer + if writer is None: + raise ValueError( + "No read stream writer available. Ensure connect() is called first." + ) + # Validate Accept header - must include text/event-stream _, has_sse = self._check_accept_headers(request) @@ -458,13 +470,80 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: await response(request.scope, request.receive, send) return - # TODO: Implement SSE stream for GET requests - # For now, return 405 Method Not Allowed - response = self._create_error_response( - "SSE stream from GET request not implemented yet", - HTTPStatus.METHOD_NOT_ALLOWED, + if not await self._validate_session(request, send): + return + + headers = { + "Cache-Control": "no-cache, no-transform", + "Connection": "keep-alive", + "Content-Type": CONTENT_TYPE_SSE, + } + + if self.mcp_session_id: + headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id + + # Check if we already have an active GET stream + if GET_STREAM_KEY in self._request_streams: + response = self._create_error_response( + "Conflict: Only one SSE stream is allowed per session", + HTTPStatus.CONFLICT, + ) + await response(request.scope, request.receive, send) + return + + # Create SSE stream + sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[ + dict[str, Any] + ](0) + + async def standalone_sse_writer(): + try: + # Create a standalone message stream for server-initiated messages + standalone_stream_writer, standalone_stream_reader = ( + anyio.create_memory_object_stream[JSONRPCMessage](0) + ) + + # Register this stream using the special key + self._request_streams[GET_STREAM_KEY] = standalone_stream_writer + + async with sse_stream_writer, standalone_stream_reader: + # Process messages from the standalone stream + async for received_message in standalone_stream_reader: + # For the standalone stream, we handle: + # - JSONRPCNotification (server sends notifications to client) + # - JSONRPCRequest (server sends requests to client) + # We should NOT receive JSONRPCResponse + + # Send the message via SSE + event_data = { + "event": "message", + "data": received_message.model_dump_json( + by_alias=True, exclude_none=True + ), + } + + await sse_stream_writer.send(event_data) + except Exception as e: + logger.exception(f"Error in standalone SSE writer: {e}") + finally: + logger.debug("Closing standalone SSE writer") + # Remove the stream from request_streams + self._request_streams.pop(GET_STREAM_KEY, None) + + # Create and start EventSourceResponse + response = EventSourceResponse( + content=sse_stream_reader, + data_sender_callable=standalone_sse_writer, + headers=headers, ) - await response(request.scope, request.receive, send) + + try: + # This will send headers immediately and establish the SSE connection + await response(request.scope, request.receive, send) + except Exception as e: + logger.exception(f"Error in standalone SSE response: {e}") + # Clean up the request stream + self._request_streams.pop(GET_STREAM_KEY, None) async def _handle_delete_request(self, request: Request, send: Send) -> None: """Handle DELETE requests for explicit session termination.""" @@ -611,13 +690,10 @@ async def message_router(): else: target_request_id = str(message.root.id) - # Send to the specific request stream if available - if ( - target_request_id - and target_request_id in self._request_streams - ): + request_stream_id = target_request_id or GET_STREAM_KEY + if request_stream_id in self._request_streams: try: - await self._request_streams[target_request_id].send( + await self._request_streams[request_stream_id].send( message ) except ( @@ -625,7 +701,7 @@ async def message_router(): anyio.ClosedResourceError, ): # Stream might be closed, remove from registry - self._request_streams.pop(target_request_id, None) + self._request_streams.pop(request_stream_id, None) except Exception as e: logger.exception(f"Error in message router: {e}") diff --git a/tests/server/test_streamableHttp.py b/tests/server/test_streamableHttp.py index 8904bf4f5f..f612575c3c 100644 --- a/tests/server/test_streamableHttp.py +++ b/tests/server/test_streamableHttp.py @@ -541,3 +541,92 @@ def test_json_response(json_response_server, json_server_url): ) assert response.status_code == 200 assert response.headers.get("Content-Type") == "application/json" + + +def test_get_sse_stream(basic_server, basic_server_url): + """Test establishing an SSE stream via GET request.""" + # First, we need to initialize a session + mcp_url = f"{basic_server_url}/mcp" + init_response = requests.post( + mcp_url, + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert init_response.status_code == 200 + + # Get the session ID + session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) + assert session_id is not None + + # Now attempt to establish an SSE stream via GET + get_response = requests.get( + mcp_url, + headers={ + "Accept": "text/event-stream", + MCP_SESSION_ID_HEADER: session_id, + }, + stream=True, + ) + + # Verify we got a successful response with the right content type + assert get_response.status_code == 200 + assert get_response.headers.get("Content-Type") == "text/event-stream" + + # Test that a second GET request gets rejected (only one stream allowed) + second_get = requests.get( + mcp_url, + headers={ + "Accept": "text/event-stream", + MCP_SESSION_ID_HEADER: session_id, + }, + stream=True, + ) + + # Should get CONFLICT (409) since there's already a stream + # Note: This might fail if the first stream fully closed before this runs, + # but generally it should work in the test environment where it runs quickly + assert second_get.status_code == 409 + + +def test_get_validation(basic_server, basic_server_url): + """Test validation for GET requests.""" + # First, we need to initialize a session + mcp_url = f"{basic_server_url}/mcp" + init_response = requests.post( + mcp_url, + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert init_response.status_code == 200 + + # Get the session ID + session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) + assert session_id is not None + + # Test without Accept header + response = requests.get( + mcp_url, + headers={ + MCP_SESSION_ID_HEADER: session_id, + }, + stream=True, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text + + # Test with wrong Accept header + response = requests.get( + mcp_url, + headers={ + "Accept": "application/json", + MCP_SESSION_ID_HEADER: session_id, + }, + ) + assert response.status_code == 406 + assert "Not Acceptable" in response.text From 9dfc9250908c18ffc9c7ea25766fb0ffcdc1105b Mon Sep 17 00:00:00 2001 From: ihrpr Date: Fri, 2 May 2025 13:59:02 +0100 Subject: [PATCH 04/21] StreamableHttp client transport (#573) --- .../mcp_simple_streamablehttp/server.py | 2 +- src/mcp/client/streamable_http.py | 258 ++++++++++++++++++ .../{streamableHttp.py => streamable_http.py} | 0 .../test_streamable_http.py} | 251 ++++++++++++++++- uv.lock | 2 + 5 files changed, 498 insertions(+), 15 deletions(-) create mode 100644 src/mcp/client/streamable_http.py rename src/mcp/server/{streamableHttp.py => streamable_http.py} (100%) rename tests/{server/test_streamableHttp.py => shared/test_streamable_http.py} (69%) diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py index b5faffedb4..71d4e5a376 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -7,7 +7,7 @@ import click import mcp.types as types from mcp.server.lowlevel import Server -from mcp.server.streamableHttp import ( +from mcp.server.streamable_http import ( MCP_SESSION_ID_HEADER, StreamableHTTPServerTransport, ) diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py new file mode 100644 index 0000000000..1e00424284 --- /dev/null +++ b/src/mcp/client/streamable_http.py @@ -0,0 +1,258 @@ +""" +StreamableHTTP Client Transport Module + +This module implements the StreamableHTTP transport for MCP clients, +providing support for HTTP POST requests with optional SSE streaming responses +and session management. +""" + +import logging +from contextlib import asynccontextmanager +from datetime import timedelta +from typing import Any + +import anyio +import httpx +from httpx_sse import EventSource, aconnect_sse + +from mcp.types import ( + ErrorData, + JSONRPCError, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, +) + +logger = logging.getLogger(__name__) + +# Header names +MCP_SESSION_ID_HEADER = "mcp-session-id" +LAST_EVENT_ID_HEADER = "last-event-id" + +# Content types +CONTENT_TYPE_JSON = "application/json" +CONTENT_TYPE_SSE = "text/event-stream" + + +@asynccontextmanager +async def streamablehttp_client( + url: str, + headers: dict[str, Any] | None = None, + timeout: timedelta = timedelta(seconds=30), + sse_read_timeout: timedelta = timedelta(seconds=60 * 5), +): + """ + Client transport for StreamableHTTP. + + `sse_read_timeout` determines how long (in seconds) the client will wait for a new + event before disconnecting. All other HTTP operations are controlled by `timeout`. + + Yields: + Tuple of (read_stream, write_stream, terminate_callback) + """ + + read_stream_writer, read_stream = anyio.create_memory_object_stream[ + JSONRPCMessage | Exception + ](0) + write_stream, write_stream_reader = anyio.create_memory_object_stream[ + JSONRPCMessage + ](0) + + async def get_stream(): + """ + Optional GET stream for server-initiated messages + """ + nonlocal session_id + try: + # Only attempt GET if we have a session ID + if not session_id: + return + + get_headers = request_headers.copy() + get_headers[MCP_SESSION_ID_HEADER] = session_id + + async with aconnect_sse( + client, + "GET", + url, + headers=get_headers, + timeout=httpx.Timeout(timeout.seconds, read=sse_read_timeout.seconds), + ) as event_source: + event_source.response.raise_for_status() + logger.debug("GET SSE connection established") + + async for sse in event_source.aiter_sse(): + if sse.event == "message": + try: + message = JSONRPCMessage.model_validate_json(sse.data) + logger.debug(f"GET message: {message}") + await read_stream_writer.send(message) + except Exception as exc: + logger.error(f"Error parsing GET message: {exc}") + await read_stream_writer.send(exc) + else: + logger.warning(f"Unknown SSE event from GET: {sse.event}") + except Exception as exc: + # GET stream is optional, so don't propagate errors + logger.debug(f"GET stream error (non-fatal): {exc}") + + async def post_writer(client: httpx.AsyncClient): + nonlocal session_id + try: + async with write_stream_reader: + async for message in write_stream_reader: + # Add session ID to headers if we have one + post_headers = request_headers.copy() + if session_id: + post_headers[MCP_SESSION_ID_HEADER] = session_id + + logger.debug(f"Sending client message: {message}") + + # Handle initial initialization request + is_initialization = ( + isinstance(message.root, JSONRPCRequest) + and message.root.method == "initialize" + ) + if ( + isinstance(message.root, JSONRPCNotification) + and message.root.method == "notifications/initialized" + ): + tg.start_soon(get_stream) + + async with client.stream( + "POST", + url, + json=message.model_dump( + by_alias=True, mode="json", exclude_none=True + ), + headers=post_headers, + ) as response: + if response.status_code == 202: + logger.debug("Received 202 Accepted") + continue + # Check for 404 (session expired/invalid) + if response.status_code == 404: + if isinstance(message.root, JSONRPCRequest): + jsonrpc_error = JSONRPCError( + jsonrpc="2.0", + id=message.root.id, + error=ErrorData( + code=32600, + message="Session terminated", + ), + ) + await read_stream_writer.send( + JSONRPCMessage(jsonrpc_error) + ) + continue + response.raise_for_status() + + # Extract session ID from response headers + if is_initialization: + new_session_id = response.headers.get(MCP_SESSION_ID_HEADER) + if new_session_id: + session_id = new_session_id + logger.info(f"Received session ID: {session_id}") + + # Handle different response types + content_type = response.headers.get("content-type", "").lower() + + if content_type.startswith(CONTENT_TYPE_JSON): + try: + content = await response.aread() + json_message = JSONRPCMessage.model_validate_json( + content + ) + await read_stream_writer.send(json_message) + except Exception as exc: + logger.error(f"Error parsing JSON response: {exc}") + await read_stream_writer.send(exc) + + elif content_type.startswith(CONTENT_TYPE_SSE): + # Parse SSE events from the response + try: + event_source = EventSource(response) + async for sse in event_source.aiter_sse(): + if sse.event == "message": + try: + await read_stream_writer.send( + JSONRPCMessage.model_validate_json( + sse.data + ) + ) + except Exception as exc: + logger.exception("Error parsing message") + await read_stream_writer.send(exc) + else: + logger.warning(f"Unknown event: {sse.event}") + + except Exception as e: + logger.exception("Error reading SSE stream:") + await read_stream_writer.send(e) + + else: + # For 202 Accepted with no body + if response.status_code == 202: + logger.debug("Received 202 Accepted") + continue + + error_msg = f"Unexpected content type: {content_type}" + logger.error(error_msg) + await read_stream_writer.send(ValueError(error_msg)) + + except Exception as exc: + logger.error(f"Error in post_writer: {exc}") + finally: + await read_stream_writer.aclose() + await write_stream.aclose() + + async def terminate_session(): + """ + Terminate the session by sending a DELETE request. + """ + nonlocal session_id + if not session_id: + return # No session to terminate + + try: + delete_headers = request_headers.copy() + delete_headers[MCP_SESSION_ID_HEADER] = session_id + + response = await client.delete( + url, + headers=delete_headers, + ) + + if response.status_code == 405: + # Server doesn't allow client-initiated termination + logger.debug("Server does not allow session termination") + elif response.status_code != 200: + logger.warning(f"Session termination failed: {response.status_code}") + except Exception as exc: + logger.warning(f"Session termination failed: {exc}") + + async with anyio.create_task_group() as tg: + try: + logger.info(f"Connecting to StreamableHTTP endpoint: {url}") + # Set up headers with required Accept header + request_headers = { + "Accept": f"{CONTENT_TYPE_JSON}, {CONTENT_TYPE_SSE}", + "Content-Type": CONTENT_TYPE_JSON, + **(headers or {}), + } + # Track session ID if provided by server + session_id: str | None = None + + async with httpx.AsyncClient( + headers=request_headers, + timeout=httpx.Timeout(timeout.seconds, read=sse_read_timeout.seconds), + follow_redirects=True, + ) as client: + tg.start_soon(post_writer, client) + try: + yield read_stream, write_stream, terminate_session + finally: + tg.cancel_scope.cancel() + finally: + await read_stream_writer.aclose() + await write_stream.aclose() diff --git a/src/mcp/server/streamableHttp.py b/src/mcp/server/streamable_http.py similarity index 100% rename from src/mcp/server/streamableHttp.py rename to src/mcp/server/streamable_http.py diff --git a/tests/server/test_streamableHttp.py b/tests/shared/test_streamable_http.py similarity index 69% rename from tests/server/test_streamableHttp.py rename to tests/shared/test_streamable_http.py index f612575c3c..48af09536c 100644 --- a/tests/server/test_streamableHttp.py +++ b/tests/shared/test_streamable_http.py @@ -1,7 +1,7 @@ """ -Tests for the StreamableHTTP server transport validation. +Tests for the StreamableHTTP server and client transport. -This file contains tests for request validation in the StreamableHTTP transport. +Contains tests for both server and client sides of the StreamableHTTP transport. """ import contextlib @@ -13,6 +13,7 @@ from uuid import uuid4 import anyio +import httpx import pytest import requests import uvicorn @@ -22,18 +23,16 @@ from starlette.responses import Response from starlette.routing import Mount +from mcp.client.session import ClientSession +from mcp.client.streamable_http import streamablehttp_client from mcp.server import Server -from mcp.server.streamableHttp import ( +from mcp.server.streamable_http import ( MCP_SESSION_ID_HEADER, SESSION_ID_PATTERN, StreamableHTTPServerTransport, ) from mcp.shared.exceptions import McpError -from mcp.types import ( - ErrorData, - TextContent, - Tool, -) +from mcp.types import InitializeResult, TextContent, TextResourceContents, Tool # Test constants SERVER_NAME = "test_streamable_http_server" @@ -64,11 +63,7 @@ async def handle_read_resource(uri: AnyUrl) -> str | bytes: await anyio.sleep(2.0) return f"Slow response from {uri.host}" - raise McpError( - error=ErrorData( - code=404, message="OOPS! no resource with that URI was found" - ) - ) + raise ValueError(f"Unknown resource: {uri}") @self.list_tools() async def handle_list_tools() -> list[Tool]: @@ -77,11 +72,23 @@ async def handle_list_tools() -> list[Tool]: name="test_tool", description="A test tool", inputSchema={"type": "object", "properties": {}}, - ) + ), + Tool( + name="test_tool_with_standalone_notification", + description="A test tool that sends a notification", + inputSchema={"type": "object", "properties": {}}, + ), ] @self.call_tool() async def handle_call_tool(name: str, args: dict) -> list[TextContent]: + # When the tool is called, send a notification to test GET stream + if name == "test_tool_with_standalone_notification": + ctx = self.request_context + await ctx.session.send_resource_updated( + uri=AnyUrl("http://test_resource") + ) + return [TextContent(type="text", text=f"Called {name}")] @@ -630,3 +637,219 @@ def test_get_validation(basic_server, basic_server_url): ) assert response.status_code == 406 assert "Not Acceptable" in response.text + + +# Client-specific fixtures +@pytest.fixture +async def http_client(basic_server, basic_server_url): + """Create test client matching the SSE test pattern.""" + async with httpx.AsyncClient(base_url=basic_server_url) as client: + yield client + + +@pytest.fixture +async def initialized_client_session(basic_server, basic_server_url): + """Create initialized StreamableHTTP client session.""" + async with streamablehttp_client(f"{basic_server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession( + read_stream, + write_stream, + ) as session: + await session.initialize() + yield session + + +@pytest.mark.anyio +async def test_streamablehttp_client_basic_connection(basic_server, basic_server_url): + """Test basic client connection with initialization.""" + async with streamablehttp_client(f"{basic_server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession( + read_stream, + write_stream, + ) as session: + # Test initialization + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == SERVER_NAME + + +@pytest.mark.anyio +async def test_streamablehttp_client_resource_read(initialized_client_session): + """Test client resource read functionality.""" + response = await initialized_client_session.read_resource( + uri=AnyUrl("foobar://test-resource") + ) + assert len(response.contents) == 1 + assert response.contents[0].uri == AnyUrl("foobar://test-resource") + assert response.contents[0].text == "Read test-resource" + + +@pytest.mark.anyio +async def test_streamablehttp_client_tool_invocation(initialized_client_session): + """Test client tool invocation.""" + # First list tools + tools = await initialized_client_session.list_tools() + assert len(tools.tools) == 2 + assert tools.tools[0].name == "test_tool" + + # Call the tool + result = await initialized_client_session.call_tool("test_tool", {}) + assert len(result.content) == 1 + assert result.content[0].type == "text" + assert result.content[0].text == "Called test_tool" + + +@pytest.mark.anyio +async def test_streamablehttp_client_error_handling(initialized_client_session): + """Test error handling in client.""" + with pytest.raises(McpError) as exc_info: + await initialized_client_session.read_resource( + uri=AnyUrl("unknown://test-error") + ) + assert exc_info.value.error.code == 0 + assert "Unknown resource: unknown://test-error" in exc_info.value.error.message + + +@pytest.mark.anyio +async def test_streamablehttp_client_session_persistence( + basic_server, basic_server_url +): + """Test that session ID persists across requests.""" + async with streamablehttp_client(f"{basic_server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession( + read_stream, + write_stream, + ) as session: + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + + # Make multiple requests to verify session persistence + tools = await session.list_tools() + assert len(tools.tools) == 2 + + # Read a resource + resource = await session.read_resource(uri=AnyUrl("foobar://test-persist")) + assert isinstance(resource.contents[0], TextResourceContents) is True + content = resource.contents[0] + assert isinstance(content, TextResourceContents) + assert content.text == "Read test-persist" + + +@pytest.mark.anyio +async def test_streamablehttp_client_json_response( + json_response_server, json_server_url +): + """Test client with JSON response mode.""" + async with streamablehttp_client(f"{json_server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession( + read_stream, + write_stream, + ) as session: + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == SERVER_NAME + + # Check tool listing + tools = await session.list_tools() + assert len(tools.tools) == 2 + + # Call a tool and verify JSON response handling + result = await session.call_tool("test_tool", {}) + assert len(result.content) == 1 + assert result.content[0].type == "text" + assert result.content[0].text == "Called test_tool" + + +@pytest.mark.anyio +async def test_streamablehttp_client_get_stream(basic_server, basic_server_url): + """Test GET stream functionality for server-initiated messages.""" + import mcp.types as types + from mcp.shared.session import RequestResponder + + notifications_received = [] + + # Define message handler to capture notifications + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] + | types.ServerNotification + | Exception, + ) -> None: + if isinstance(message, types.ServerNotification): + notifications_received.append(message) + + async with streamablehttp_client(f"{basic_server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession( + read_stream, write_stream, message_handler=message_handler + ) as session: + # Initialize the session - this triggers the GET stream setup + result = await session.initialize() + assert isinstance(result, InitializeResult) + + # Call the special tool that sends a notification + await session.call_tool("test_tool_with_standalone_notification", {}) + + # Verify we received the notification + assert len(notifications_received) > 0 + + # Verify the notification is a ResourceUpdatedNotification + resource_update_found = False + for notif in notifications_received: + if isinstance(notif.root, types.ResourceUpdatedNotification): + assert str(notif.root.params.uri) == "http://test_resource/" + resource_update_found = True + + assert ( + resource_update_found + ), "ResourceUpdatedNotification not received via GET stream" + + +@pytest.mark.anyio +async def test_streamablehttp_client_session_termination( + basic_server, basic_server_url +): + """Test client session termination functionality.""" + + # Create the streamablehttp_client with a custom httpx client to capture headers + async with streamablehttp_client(f"{basic_server_url}/mcp") as ( + read_stream, + write_stream, + terminate_session, + ): + async with ClientSession(read_stream, write_stream) as session: + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + + # Make a request to confirm session is working + tools = await session.list_tools() + assert len(tools.tools) == 2 + + # After exiting ClientSession context, explicitly terminate the session + await terminate_session() + with pytest.raises( + McpError, + match="Session terminated", + ): + await session.list_tools() diff --git a/uv.lock b/uv.lock index cbdc33471e..06dd240b25 100644 --- a/uv.lock +++ b/uv.lock @@ -1,4 +1,5 @@ version = 1 +revision = 1 requires-python = ">=3.10" [options] @@ -547,6 +548,7 @@ requires-dist = [ { name = "uvicorn", marker = "sys_platform != 'emscripten'", specifier = ">=0.23.1" }, { name = "websockets", marker = "extra == 'ws'", specifier = ">=15.0.1" }, ] +provides-extras = ["cli", "rich", "ws"] [package.metadata.requires-dev] dev = [ From 3978c6e1b91e8830e82d97ab3c4e3b6559972021 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Fri, 2 May 2025 14:10:40 +0100 Subject: [PATCH 05/21] StreamableHttp -- resumability support for servers (#587) --- .../servers/simple-streamablehttp/README.md | 20 +- .../mcp_simple_streamablehttp/event_store.py | 105 +++++++++ .../mcp_simple_streamablehttp/server.py | 22 +- src/mcp/server/streamable_http.py | 218 +++++++++++++++--- tests/shared/test_streamable_http.py | 30 +-- 5 files changed, 340 insertions(+), 55 deletions(-) create mode 100644 examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/event_store.py diff --git a/examples/servers/simple-streamablehttp/README.md b/examples/servers/simple-streamablehttp/README.md index e5aaa65250..f850b72868 100644 --- a/examples/servers/simple-streamablehttp/README.md +++ b/examples/servers/simple-streamablehttp/README.md @@ -9,6 +9,7 @@ A simple MCP server example demonstrating the StreamableHttp transport, which en - Task management with anyio task groups - Ability to send multiple notifications over time to the client - Proper resource cleanup and lifespan management +- Resumability support via InMemoryEventStore ## Usage @@ -32,6 +33,23 @@ The server exposes a tool named "start-notification-stream" that accepts three a - `count`: Number of notifications to send (e.g., 5) - `caller`: Identifier string for the caller +## Resumability Support + +This server includes resumability support through the InMemoryEventStore. This enables clients to: + +- Reconnect to the server after a disconnection +- Resume event streaming from where they left off using the Last-Event-ID header + + +The server will: +- Generate unique event IDs for each SSE message +- Store events in memory for later replay +- Replay missed events when a client reconnects with a Last-Event-ID header + +Note: The InMemoryEventStore is designed for demonstration purposes only. For production use, consider implementing a persistent storage solution. + + + ## Client -You can connect to this server using an HTTP client, for now only Typescript SDK has streamable HTTP client examples or you can use (Inspector)[https://github.com/modelcontextprotocol/inspector] \ No newline at end of file +You can connect to this server using an HTTP client, for now only Typescript SDK has streamable HTTP client examples or you can use [Inspector](https://github.com/modelcontextprotocol/inspector) \ No newline at end of file diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/event_store.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/event_store.py new file mode 100644 index 0000000000..28c58149f5 --- /dev/null +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/event_store.py @@ -0,0 +1,105 @@ +""" +In-memory event store for demonstrating resumability functionality. + +This is a simple implementation intended for examples and testing, +not for production use where a persistent storage solution would be more appropriate. +""" + +import logging +from collections import deque +from dataclasses import dataclass +from uuid import uuid4 + +from mcp.server.streamable_http import ( + EventCallback, + EventId, + EventMessage, + EventStore, + StreamId, +) +from mcp.types import JSONRPCMessage + +logger = logging.getLogger(__name__) + + +@dataclass +class EventEntry: + """ + Represents an event entry in the event store. + """ + + event_id: EventId + stream_id: StreamId + message: JSONRPCMessage + + +class InMemoryEventStore(EventStore): + """ + Simple in-memory implementation of the EventStore interface for resumability. + This is primarily intended for examples and testing, not for production use + where a persistent storage solution would be more appropriate. + + This implementation keeps only the last N events per stream for memory efficiency. + """ + + def __init__(self, max_events_per_stream: int = 100): + """Initialize the event store. + + Args: + max_events_per_stream: Maximum number of events to keep per stream + """ + self.max_events_per_stream = max_events_per_stream + # for maintaining last N events per stream + self.streams: dict[StreamId, deque[EventEntry]] = {} + # event_id -> EventEntry for quick lookup + self.event_index: dict[EventId, EventEntry] = {} + + async def store_event( + self, stream_id: StreamId, message: JSONRPCMessage + ) -> EventId: + """Stores an event with a generated event ID.""" + event_id = str(uuid4()) + event_entry = EventEntry( + event_id=event_id, stream_id=stream_id, message=message + ) + + # Get or create deque for this stream + if stream_id not in self.streams: + self.streams[stream_id] = deque(maxlen=self.max_events_per_stream) + + # If deque is full, the oldest event will be automatically removed + # We need to remove it from the event_index as well + if len(self.streams[stream_id]) == self.max_events_per_stream: + oldest_event = self.streams[stream_id][0] + self.event_index.pop(oldest_event.event_id, None) + + # Add new event + self.streams[stream_id].append(event_entry) + self.event_index[event_id] = event_entry + + return event_id + + async def replay_events_after( + self, + last_event_id: EventId, + send_callback: EventCallback, + ) -> StreamId | None: + """Replays events that occurred after the specified event ID.""" + if last_event_id not in self.event_index: + logger.warning(f"Event ID {last_event_id} not found in store") + return None + + # Get the stream and find events after the last one + last_event = self.event_index[last_event_id] + stream_id = last_event.stream_id + stream_events = self.streams.get(last_event.stream_id, deque()) + + # Events in deque are already in chronological order + found_last = False + for event in stream_events: + if found_last: + await send_callback(EventMessage(event.message, event.event_id)) + elif event.event_id == last_event_id: + found_last = True + + return stream_id diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py index 71d4e5a376..b2079bb27e 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -17,12 +17,24 @@ from starlette.responses import Response from starlette.routing import Mount +from .event_store import InMemoryEventStore + # Configure logging logger = logging.getLogger(__name__) # Global task group that will be initialized in the lifespan task_group = None +# Event store for resumability +# The InMemoryEventStore enables resumability support for StreamableHTTP transport. +# It stores SSE events with unique IDs, allowing clients to: +# 1. Receive event IDs for each SSE message +# 2. Resume streams by sending Last-Event-ID in GET requests +# 3. Replay missed events after reconnection +# Note: This in-memory implementation is for demonstration ONLY. +# For production, use a persistent storage solution. +event_store = InMemoryEventStore() + @contextlib.asynccontextmanager async def lifespan(app): @@ -79,9 +91,14 @@ async def call_tool( # Send the specified number of notifications with the given interval for i in range(count): + # Include more detailed message for resumability demonstration + notification_msg = ( + f"[{i+1}/{count}] Event from '{caller}' - " + f"Use Last-Event-ID to resume if disconnected" + ) await ctx.session.send_log_message( level="info", - data=f"Notification {i+1}/{count} from caller: {caller}", + data=notification_msg, logger="notification_stream", # Associates this notification with the original request # Ensures notifications are sent to the correct response stream @@ -90,6 +107,7 @@ async def call_tool( # - nowhere (if GET request isn't supported) related_request_id=ctx.request_id, ) + logger.debug(f"Sent notification {i+1}/{count} for caller: {caller}") if i < count - 1: # Don't wait after the last notification await anyio.sleep(interval) @@ -163,8 +181,10 @@ async def handle_streamable_http(scope, receive, send): http_transport = StreamableHTTPServerTransport( mcp_session_id=new_session_id, is_json_response_enabled=json_response, + event_store=event_store, # Enable resumability ) server_instances[http_transport.mcp_session_id] = http_transport + logger.info(f"Created new transport with session ID: {new_session_id}") async with http_transport.connect() as streams: read_stream, write_stream = streams diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 8faff0162f..ca707e27ba 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -10,10 +10,11 @@ import json import logging import re -from collections.abc import AsyncGenerator +from abc import ABC, abstractmethod +from collections.abc import AsyncGenerator, Awaitable, Callable from contextlib import asynccontextmanager +from dataclasses import dataclass from http import HTTPStatus -from typing import Any import anyio from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream @@ -57,6 +58,63 @@ # Pattern ensures entire string contains only valid characters by using ^ and $ anchors SESSION_ID_PATTERN = re.compile(r"^[\x21-\x7E]+$") +# Type aliases +StreamId = str +EventId = str + + +@dataclass +class EventMessage: + """ + A JSONRPCMessage with an optional event ID for stream resumability. + """ + + message: JSONRPCMessage + event_id: str | None = None + + +EventCallback = Callable[[EventMessage], Awaitable[None]] + + +class EventStore(ABC): + """ + Interface for resumability support via event storage. + """ + + @abstractmethod + async def store_event( + self, stream_id: StreamId, message: JSONRPCMessage + ) -> EventId: + """ + Stores an event for later retrieval. + + Args: + stream_id: ID of the stream the event belongs to + message: The JSON-RPC message to store + + Returns: + The generated event ID for the stored event + """ + pass + + @abstractmethod + async def replay_events_after( + self, + last_event_id: EventId, + send_callback: EventCallback, + ) -> StreamId | None: + """ + Replays events that occurred after the specified event ID. + + Args: + last_event_id: The ID of the last event the client received + send_callback: A callback function to send events to the client + + Returns: + The stream ID of the replayed events + """ + pass + class StreamableHTTPServerTransport: """ @@ -76,6 +134,7 @@ def __init__( self, mcp_session_id: str | None, is_json_response_enabled: bool = False, + event_store: EventStore | None = None, ) -> None: """ Initialize a new StreamableHTTP server transport. @@ -85,6 +144,9 @@ def __init__( Must contain only visible ASCII characters (0x21-0x7E). is_json_response_enabled: If True, return JSON responses for requests instead of SSE streams. Default is False. + event_store: Event store for resumability support. If provided, + resumability will be enabled, allowing clients to + reconnect and resume messages. Raises: ValueError: If the session ID contains invalid characters. @@ -98,8 +160,9 @@ def __init__( self.mcp_session_id = mcp_session_id self.is_json_response_enabled = is_json_response_enabled + self._event_store = event_store self._request_streams: dict[ - RequestId, MemoryObjectSendStream[JSONRPCMessage] + RequestId, MemoryObjectSendStream[EventMessage] ] = {} self._terminated = False @@ -160,6 +223,21 @@ def _get_session_id(self, request: Request) -> str | None: """Extract the session ID from request headers.""" return request.headers.get(MCP_SESSION_ID_HEADER) + def _create_event_data(self, event_message: EventMessage) -> dict[str, str]: + """Create event data dictionary from an EventMessage.""" + event_data = { + "event": "message", + "data": event_message.message.model_dump_json( + by_alias=True, exclude_none=True + ), + } + + # If an event ID was provided, include it + if event_message.event_id: + event_data["id"] = event_message.event_id + + return event_data + async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None: """Application entry point that handles all HTTP requests""" request = Request(scope, receive) @@ -308,7 +386,7 @@ async def _handle_post_request( request_id = str(message.root.id) # Create promise stream for getting response request_stream_writer, request_stream_reader = ( - anyio.create_memory_object_stream[JSONRPCMessage](0) + anyio.create_memory_object_stream[EventMessage](0) ) # Register this stream for the request ID @@ -323,16 +401,18 @@ async def _handle_post_request( response_message = None # Use similar approach to SSE writer for consistency - async for received_message in request_stream_reader: + async for event_message in request_stream_reader: # If it's a response, this is what we're waiting for if isinstance( - received_message.root, JSONRPCResponse | JSONRPCError + event_message.message.root, JSONRPCResponse | JSONRPCError ): - response_message = received_message + response_message = event_message.message break # For notifications and request, keep waiting else: - logger.debug(f"received: {received_message.root.method}") + logger.debug( + f"received: {event_message.message.root.method}" + ) # At this point we should have a response if response_message: @@ -366,7 +446,7 @@ async def _handle_post_request( else: # Create SSE stream sse_stream_writer, sse_stream_reader = ( - anyio.create_memory_object_stream[dict[str, Any]](0) + anyio.create_memory_object_stream[dict[str, str]](0) ) async def sse_writer(): @@ -374,20 +454,14 @@ async def sse_writer(): try: async with sse_stream_writer, request_stream_reader: # Process messages from the request-specific stream - async for received_message in request_stream_reader: + async for event_message in request_stream_reader: # Build the event data - event_data = { - "event": "message", - "data": received_message.model_dump_json( - by_alias=True, exclude_none=True - ), - } - + event_data = self._create_event_data(event_message) await sse_stream_writer.send(event_data) # If response, remove from pending streams and close if isinstance( - received_message.root, + event_message.message.root, JSONRPCResponse | JSONRPCError, ): if request_id: @@ -472,6 +546,10 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: if not await self._validate_session(request, send): return + # Handle resumability: check for Last-Event-ID header + if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER): + await self._replay_events(last_event_id, request, send) + return headers = { "Cache-Control": "no-cache, no-transform", @@ -493,14 +571,14 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: # Create SSE stream sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[ - dict[str, Any] + dict[str, str] ](0) async def standalone_sse_writer(): try: # Create a standalone message stream for server-initiated messages standalone_stream_writer, standalone_stream_reader = ( - anyio.create_memory_object_stream[JSONRPCMessage](0) + anyio.create_memory_object_stream[EventMessage](0) ) # Register this stream using the special key @@ -508,20 +586,14 @@ async def standalone_sse_writer(): async with sse_stream_writer, standalone_stream_reader: # Process messages from the standalone stream - async for received_message in standalone_stream_reader: + async for event_message in standalone_stream_reader: # For the standalone stream, we handle: # - JSONRPCNotification (server sends notifications to client) # - JSONRPCRequest (server sends requests to client) # We should NOT receive JSONRPCResponse # Send the message via SSE - event_data = { - "event": "message", - "data": received_message.model_dump_json( - by_alias=True, exclude_none=True - ), - } - + event_data = self._create_event_data(event_message) await sse_stream_writer.send(event_data) except Exception as e: logger.exception(f"Error in standalone SSE writer: {e}") @@ -639,6 +711,82 @@ async def _validate_session(self, request: Request, send: Send) -> bool: return True + async def _replay_events( + self, last_event_id: str, request: Request, send: Send + ) -> None: + """ + Replays events that would have been sent after the specified event ID. + Only used when resumability is enabled. + """ + event_store = self._event_store + if not event_store: + return + + try: + headers = { + "Cache-Control": "no-cache, no-transform", + "Connection": "keep-alive", + "Content-Type": CONTENT_TYPE_SSE, + } + + if self.mcp_session_id: + headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id + + # Create SSE stream for replay + sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[ + dict[str, str] + ](0) + + async def replay_sender(): + try: + async with sse_stream_writer: + # Define an async callback for sending events + async def send_event(event_message: EventMessage) -> None: + event_data = self._create_event_data(event_message) + await sse_stream_writer.send(event_data) + + # Replay past events and get the stream ID + stream_id = await event_store.replay_events_after( + last_event_id, send_event + ) + + # If stream ID not in mapping, create it + if stream_id and stream_id not in self._request_streams: + msg_writer, msg_reader = anyio.create_memory_object_stream[ + EventMessage + ](0) + self._request_streams[stream_id] = msg_writer + + # Forward messages to SSE + async with msg_reader: + async for event_message in msg_reader: + event_data = self._create_event_data(event_message) + + await sse_stream_writer.send(event_data) + except Exception as e: + logger.exception(f"Error in replay sender: {e}") + + # Create and start EventSourceResponse + response = EventSourceResponse( + content=sse_stream_reader, + data_sender_callable=replay_sender, + headers=headers, + ) + + try: + await response(request.scope, request.receive, send) + except Exception as e: + logger.exception(f"Error in replay response: {e}") + + except Exception as e: + logger.exception(f"Error replaying events: {e}") + response = self._create_error_response( + f"Error replaying events: {str(e)}", + HTTPStatus.INTERNAL_SERVER_ERROR, + INTERNAL_ERROR, + ) + await response(request.scope, request.receive, send) + @asynccontextmanager async def connect( self, @@ -691,10 +839,22 @@ async def message_router(): target_request_id = str(message.root.id) request_stream_id = target_request_id or GET_STREAM_KEY + + # Store the event if we have an event store, + # regardless of whether a client is connected + # messages will be replayed on the re-connect + event_id = None + if self._event_store: + event_id = await self._event_store.store_event( + request_stream_id, message + ) + logger.debug(f"Stored {event_id} from {request_stream_id}") + if request_stream_id in self._request_streams: try: + # Send both the message and the event ID await self._request_streams[request_stream_id].send( - message + EventMessage(message, event_id) ) except ( anyio.BrokenResourceError, diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 48af09536c..7331b392bc 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -113,15 +113,12 @@ async def lifespan(app): async with anyio.create_task_group() as tg: task_group = tg - print("Application started, task group initialized!") try: yield finally: - print("Application shutting down, cleaning up resources...") if task_group: tg.cancel_scope.cancel() task_group = None - print("Resources cleaned up successfully.") async def handle_streamable_http(scope, receive, send): request = Request(scope, receive) @@ -148,14 +145,11 @@ async def handle_streamable_http(scope, receive, send): read_stream, write_stream = streams async def run_server(): - try: - await server.run( - read_stream, - write_stream, - server.create_initialization_options(), - ) - except Exception as e: - print(f"Server exception: {e}") + await server.run( + read_stream, + write_stream, + server.create_initialization_options(), + ) if task_group is None: response = Response( @@ -196,10 +190,6 @@ def run_server(port: int, is_json_response_enabled=False) -> None: port: Port to listen on. is_json_response_enabled: If True, use JSON responses instead of SSE streams. """ - print( - f"Starting test server on port {port} with " - f"json_enabled={is_json_response_enabled}" - ) app = create_app(is_json_response_enabled) # Configure server @@ -218,16 +208,12 @@ def run_server(port: int, is_json_response_enabled=False) -> None: # This is important to catch exceptions and prevent test hangs try: - print("Server starting...") server.run() - except Exception as e: - print(f"ERROR: Server failed to run: {e}") + except Exception: import traceback traceback.print_exc() - print("Server shutdown") - # Test fixtures - using same approach as SSE tests @pytest.fixture @@ -273,8 +259,6 @@ def basic_server(basic_server_port: int) -> Generator[None, None, None]: # Clean up proc.kill() proc.join(timeout=2) - if proc.is_alive(): - print("server process failed to terminate") @pytest.fixture @@ -306,8 +290,6 @@ def json_response_server(json_server_port: int) -> Generator[None, None, None]: # Clean up proc.kill() proc.join(timeout=2) - if proc.is_alive(): - print("server process failed to terminate") @pytest.fixture From da0cf223553d50e48fba7652b2ef0eca26550e77 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Fri, 2 May 2025 14:29:00 +0100 Subject: [PATCH 06/21] Wrap JSONRPC messages with SessionMessage for metadata support (#590) --- src/mcp/client/__main__.py | 6 +- src/mcp/client/session.py | 5 +- src/mcp/client/sse.py | 20 +++-- src/mcp/client/stdio/__init__.py | 18 +++-- src/mcp/client/streamable_http.py | 23 ++++-- src/mcp/client/websocket.py | 20 ++--- src/mcp/server/lowlevel/server.py | 5 +- src/mcp/server/session.py | 5 +- src/mcp/server/sse.py | 24 +++--- src/mcp/server/stdio.py | 18 +++-- src/mcp/server/streamable_http.py | 25 +++--- src/mcp/server/websocket.py | 18 +++-- src/mcp/shared/memory.py | 10 +-- src/mcp/shared/message.py | 35 ++++++++ src/mcp/shared/session.py | 34 ++++---- tests/client/test_session.py | 73 ++++++++++------- tests/client/test_stdio.py | 6 +- tests/issues/test_192_request_id.py | 15 ++-- tests/server/test_lifespan.py | 81 +++++++++++-------- .../server/test_lowlevel_tool_annotations.py | 6 +- tests/server/test_session.py | 6 +- tests/server/test_stdio.py | 6 +- 22 files changed, 286 insertions(+), 173 deletions(-) create mode 100644 src/mcp/shared/message.py diff --git a/src/mcp/client/__main__.py b/src/mcp/client/__main__.py index 84e15bd564..2ec68e56cd 100644 --- a/src/mcp/client/__main__.py +++ b/src/mcp/client/__main__.py @@ -11,8 +11,8 @@ from mcp.client.session import ClientSession from mcp.client.sse import sse_client from mcp.client.stdio import StdioServerParameters, stdio_client +from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder -from mcp.types import JSONRPCMessage if not sys.warnoptions: import warnings @@ -36,8 +36,8 @@ async def message_handler( async def run_session( - read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], - write_stream: MemoryObjectSendStream[JSONRPCMessage], + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], + write_stream: MemoryObjectSendStream[SessionMessage], client_info: types.Implementation | None = None, ): async with ClientSession( diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index fc86f0110e..4d65bbebb6 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -7,6 +7,7 @@ import mcp.types as types from mcp.shared.context import RequestContext +from mcp.shared.message import SessionMessage from mcp.shared.session import BaseSession, RequestResponder from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS @@ -92,8 +93,8 @@ class ClientSession( ): def __init__( self, - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], - write_stream: MemoryObjectSendStream[types.JSONRPCMessage], + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], + write_stream: MemoryObjectSendStream[SessionMessage], read_timeout_seconds: timedelta | None = None, sampling_callback: SamplingFnT | None = None, list_roots_callback: ListRootsFnT | None = None, diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 4f6241a720..ff04d2f961 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -10,6 +10,7 @@ from httpx_sse import aconnect_sse import mcp.types as types +from mcp.shared.message import SessionMessage logger = logging.getLogger(__name__) @@ -31,11 +32,11 @@ async def sse_client( `sse_read_timeout` determines how long (in seconds) the client will wait for a new event before disconnecting. All other HTTP operations are controlled by `timeout`. """ - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] + read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] - write_stream: MemoryObjectSendStream[types.JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] + write_stream: MemoryObjectSendStream[SessionMessage] + write_stream_reader: MemoryObjectReceiveStream[SessionMessage] read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) @@ -97,7 +98,8 @@ async def sse_reader( await read_stream_writer.send(exc) continue - await read_stream_writer.send(message) + session_message = SessionMessage(message) + await read_stream_writer.send(session_message) case _: logger.warning( f"Unknown SSE event: {sse.event}" @@ -111,11 +113,13 @@ async def sse_reader( async def post_writer(endpoint_url: str): try: async with write_stream_reader: - async for message in write_stream_reader: - logger.debug(f"Sending client message: {message}") + async for session_message in write_stream_reader: + logger.debug( + f"Sending client message: {session_message}" + ) response = await client.post( endpoint_url, - json=message.model_dump( + json=session_message.message.model_dump( by_alias=True, mode="json", exclude_none=True, diff --git a/src/mcp/client/stdio/__init__.py b/src/mcp/client/stdio/__init__.py index 83de57a2b9..e8be5aff5b 100644 --- a/src/mcp/client/stdio/__init__.py +++ b/src/mcp/client/stdio/__init__.py @@ -11,6 +11,7 @@ from pydantic import BaseModel, Field import mcp.types as types +from mcp.shared.message import SessionMessage from .win32 import ( create_windows_process, @@ -98,11 +99,11 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder Client transport for stdio: this will connect to a server by spawning a process and communicating with it over stdin/stdout. """ - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] + read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] - write_stream: MemoryObjectSendStream[types.JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] + write_stream: MemoryObjectSendStream[SessionMessage] + write_stream_reader: MemoryObjectReceiveStream[SessionMessage] read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) @@ -143,7 +144,8 @@ async def stdout_reader(): await read_stream_writer.send(exc) continue - await read_stream_writer.send(message) + session_message = SessionMessage(message) + await read_stream_writer.send(session_message) except anyio.ClosedResourceError: await anyio.lowlevel.checkpoint() @@ -152,8 +154,10 @@ async def stdin_writer(): try: async with write_stream_reader: - async for message in write_stream_reader: - json = message.model_dump_json(by_alias=True, exclude_none=True) + async for session_message in write_stream_reader: + json = session_message.message.model_dump_json( + by_alias=True, exclude_none=True + ) await process.stdin.send( (json + "\n").encode( encoding=server.encoding, diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 1e00424284..7a8887cd95 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -15,6 +15,7 @@ import httpx from httpx_sse import EventSource, aconnect_sse +from mcp.shared.message import SessionMessage from mcp.types import ( ErrorData, JSONRPCError, @@ -52,10 +53,10 @@ async def streamablehttp_client( """ read_stream_writer, read_stream = anyio.create_memory_object_stream[ - JSONRPCMessage | Exception + SessionMessage | Exception ](0) write_stream, write_stream_reader = anyio.create_memory_object_stream[ - JSONRPCMessage + SessionMessage ](0) async def get_stream(): @@ -86,7 +87,8 @@ async def get_stream(): try: message = JSONRPCMessage.model_validate_json(sse.data) logger.debug(f"GET message: {message}") - await read_stream_writer.send(message) + session_message = SessionMessage(message) + await read_stream_writer.send(session_message) except Exception as exc: logger.error(f"Error parsing GET message: {exc}") await read_stream_writer.send(exc) @@ -100,7 +102,8 @@ async def post_writer(client: httpx.AsyncClient): nonlocal session_id try: async with write_stream_reader: - async for message in write_stream_reader: + async for session_message in write_stream_reader: + message = session_message.message # Add session ID to headers if we have one post_headers = request_headers.copy() if session_id: @@ -141,9 +144,10 @@ async def post_writer(client: httpx.AsyncClient): message="Session terminated", ), ) - await read_stream_writer.send( + session_message = SessionMessage( JSONRPCMessage(jsonrpc_error) ) + await read_stream_writer.send(session_message) continue response.raise_for_status() @@ -163,7 +167,8 @@ async def post_writer(client: httpx.AsyncClient): json_message = JSONRPCMessage.model_validate_json( content ) - await read_stream_writer.send(json_message) + session_message = SessionMessage(json_message) + await read_stream_writer.send(session_message) except Exception as exc: logger.error(f"Error parsing JSON response: {exc}") await read_stream_writer.send(exc) @@ -175,11 +180,15 @@ async def post_writer(client: httpx.AsyncClient): async for sse in event_source.aiter_sse(): if sse.event == "message": try: - await read_stream_writer.send( + message = ( JSONRPCMessage.model_validate_json( sse.data ) ) + session_message = SessionMessage(message) + await read_stream_writer.send( + session_message + ) except Exception as exc: logger.exception("Error parsing message") await read_stream_writer.send(exc) diff --git a/src/mcp/client/websocket.py b/src/mcp/client/websocket.py index 2c2ed38b9e..ac542fb3f6 100644 --- a/src/mcp/client/websocket.py +++ b/src/mcp/client/websocket.py @@ -10,6 +10,7 @@ from websockets.typing import Subprotocol import mcp.types as types +from mcp.shared.message import SessionMessage logger = logging.getLogger(__name__) @@ -19,8 +20,8 @@ async def websocket_client( url: str, ) -> AsyncGenerator[ tuple[ - MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], - MemoryObjectSendStream[types.JSONRPCMessage], + MemoryObjectReceiveStream[SessionMessage | Exception], + MemoryObjectSendStream[SessionMessage], ], None, ]: @@ -39,10 +40,10 @@ async def websocket_client( # Create two in-memory streams: # - One for incoming messages (read_stream, written by ws_reader) # - One for outgoing messages (write_stream, read by ws_writer) - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] - write_stream: MemoryObjectSendStream[types.JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] + read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] + write_stream: MemoryObjectSendStream[SessionMessage] + write_stream_reader: MemoryObjectReceiveStream[SessionMessage] read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) @@ -59,7 +60,8 @@ async def ws_reader(): async for raw_text in ws: try: message = types.JSONRPCMessage.model_validate_json(raw_text) - await read_stream_writer.send(message) + session_message = SessionMessage(message) + await read_stream_writer.send(session_message) except ValidationError as exc: # If JSON parse or model validation fails, send the exception await read_stream_writer.send(exc) @@ -70,9 +72,9 @@ async def ws_writer(): sends them to the server. """ async with write_stream_reader: - async for message in write_stream_reader: + async for session_message in write_stream_reader: # Convert to a dict, then to JSON - msg_dict = message.model_dump( + msg_dict = session_message.message.model_dump( by_alias=True, mode="json", exclude_none=True ) await ws.send(json.dumps(msg_dict)) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 1a01ff837f..1cacd23b5e 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -84,6 +84,7 @@ async def main(): from mcp.server.stdio import stdio_server as stdio_server from mcp.shared.context import RequestContext from mcp.shared.exceptions import McpError +from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder logger = logging.getLogger(__name__) @@ -471,8 +472,8 @@ async def handler(req: types.CompleteRequest): async def run( self, - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], - write_stream: MemoryObjectSendStream[types.JSONRPCMessage], + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], + write_stream: MemoryObjectSendStream[SessionMessage], initialization_options: InitializationOptions, # When False, exceptions are returned as messages to the client. # When True, exceptions are raised, which will cause the server to shut down diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index d1fe88553b..6171dacc10 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -47,6 +47,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: import mcp.types as types from mcp.server.models import InitializationOptions +from mcp.shared.message import SessionMessage from mcp.shared.session import ( BaseSession, RequestResponder, @@ -82,8 +83,8 @@ class ServerSession( def __init__( self, - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], - write_stream: MemoryObjectSendStream[types.JSONRPCMessage], + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], + write_stream: MemoryObjectSendStream[SessionMessage], init_options: InitializationOptions, stateless: bool = False, ) -> None: diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index d051c25bf6..c781c64d5b 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -46,6 +46,7 @@ async def handle_sse(request): from starlette.types import Receive, Scope, Send import mcp.types as types +from mcp.shared.message import SessionMessage logger = logging.getLogger(__name__) @@ -63,9 +64,7 @@ class SseServerTransport: """ _endpoint: str - _read_stream_writers: dict[ - UUID, MemoryObjectSendStream[types.JSONRPCMessage | Exception] - ] + _read_stream_writers: dict[UUID, MemoryObjectSendStream[SessionMessage | Exception]] def __init__(self, endpoint: str) -> None: """ @@ -85,11 +84,11 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): raise ValueError("connect_sse can only handle HTTP requests") logger.debug("Setting up SSE connection") - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] + read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] - write_stream: MemoryObjectSendStream[types.JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] + write_stream: MemoryObjectSendStream[SessionMessage] + write_stream_reader: MemoryObjectReceiveStream[SessionMessage] read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) @@ -109,12 +108,12 @@ async def sse_writer(): await sse_stream_writer.send({"event": "endpoint", "data": session_uri}) logger.debug(f"Sent endpoint event: {session_uri}") - async for message in write_stream_reader: - logger.debug(f"Sending message via SSE: {message}") + async for session_message in write_stream_reader: + logger.debug(f"Sending message via SSE: {session_message}") await sse_stream_writer.send( { "event": "message", - "data": message.model_dump_json( + "data": session_message.message.model_dump_json( by_alias=True, exclude_none=True ), } @@ -169,7 +168,8 @@ async def handle_post_message( await writer.send(err) return - logger.debug(f"Sending message to writer: {message}") + session_message = SessionMessage(message) + logger.debug(f"Sending session message to writer: {session_message}") response = Response("Accepted", status_code=202) await response(scope, receive, send) - await writer.send(message) + await writer.send(session_message) diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index 0e0e491292..f0bbe5a316 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -27,6 +27,7 @@ async def run_server(): from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream import mcp.types as types +from mcp.shared.message import SessionMessage @asynccontextmanager @@ -47,11 +48,11 @@ async def stdio_server( if not stdout: stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8")) - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] + read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] - write_stream: MemoryObjectSendStream[types.JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] + write_stream: MemoryObjectSendStream[SessionMessage] + write_stream_reader: MemoryObjectReceiveStream[SessionMessage] read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) @@ -66,15 +67,18 @@ async def stdin_reader(): await read_stream_writer.send(exc) continue - await read_stream_writer.send(message) + session_message = SessionMessage(message) + await read_stream_writer.send(session_message) except anyio.ClosedResourceError: await anyio.lowlevel.checkpoint() async def stdout_writer(): try: async with write_stream_reader: - async for message in write_stream_reader: - json = message.model_dump_json(by_alias=True, exclude_none=True) + async for session_message in write_stream_reader: + json = session_message.message.model_dump_json( + by_alias=True, exclude_none=True + ) await stdout.write(json + "\n") await stdout.flush() except anyio.ClosedResourceError: diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index ca707e27ba..23ca43b4a3 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -24,6 +24,7 @@ from starlette.responses import Response from starlette.types import Receive, Scope, Send +from mcp.shared.message import SessionMessage from mcp.types import ( INTERNAL_ERROR, INVALID_PARAMS, @@ -125,10 +126,10 @@ class StreamableHTTPServerTransport: """ # Server notification streams for POST requests as well as standalone SSE stream - _read_stream_writer: MemoryObjectSendStream[JSONRPCMessage | Exception] | None = ( + _read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] | None = ( None ) - _write_stream_reader: MemoryObjectReceiveStream[JSONRPCMessage] | None = None + _write_stream_reader: MemoryObjectReceiveStream[SessionMessage] | None = None def __init__( self, @@ -378,7 +379,8 @@ async def _handle_post_request( await response(scope, receive, send) # Process the message after sending the response - await writer.send(message) + session_message = SessionMessage(message) + await writer.send(session_message) return @@ -394,7 +396,8 @@ async def _handle_post_request( if self.is_json_response_enabled: # Process the message - await writer.send(message) + session_message = SessionMessage(message) + await writer.send(session_message) try: # Process messages from the request-specific stream # We need to collect all messages until we get a response @@ -500,7 +503,8 @@ async def sse_writer(): async with anyio.create_task_group() as tg: tg.start_soon(response, scope, receive, send) # Then send the message to be processed by the server - await writer.send(message) + session_message = SessionMessage(message) + await writer.send(session_message) except Exception: logger.exception("SSE response error") # Clean up the request stream if something goes wrong @@ -792,8 +796,8 @@ async def connect( self, ) -> AsyncGenerator[ tuple[ - MemoryObjectReceiveStream[JSONRPCMessage | Exception], - MemoryObjectSendStream[JSONRPCMessage], + MemoryObjectReceiveStream[SessionMessage | Exception], + MemoryObjectSendStream[SessionMessage], ], None, ]: @@ -806,10 +810,10 @@ async def connect( # Create the memory streams for this connection read_stream_writer, read_stream = anyio.create_memory_object_stream[ - JSONRPCMessage | Exception + SessionMessage | Exception ](0) write_stream, write_stream_reader = anyio.create_memory_object_stream[ - JSONRPCMessage + SessionMessage ](0) # Store the streams @@ -821,8 +825,9 @@ async def connect( # Create a message router that distributes messages to request streams async def message_router(): try: - async for message in write_stream_reader: + async for session_message in write_stream_reader: # Determine which request stream(s) should receive this message + message = session_message.message target_request_id = None if isinstance( message.root, JSONRPCNotification | JSONRPCRequest diff --git a/src/mcp/server/websocket.py b/src/mcp/server/websocket.py index aee855cf11..9dc3f2a25e 100644 --- a/src/mcp/server/websocket.py +++ b/src/mcp/server/websocket.py @@ -8,6 +8,7 @@ from starlette.websockets import WebSocket import mcp.types as types +from mcp.shared.message import SessionMessage logger = logging.getLogger(__name__) @@ -22,11 +23,11 @@ async def websocket_server(scope: Scope, receive: Receive, send: Send): websocket = WebSocket(scope, receive, send) await websocket.accept(subprotocol="mcp") - read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] - read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] + read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] - write_stream: MemoryObjectSendStream[types.JSONRPCMessage] - write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] + write_stream: MemoryObjectSendStream[SessionMessage] + write_stream_reader: MemoryObjectReceiveStream[SessionMessage] read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) @@ -41,15 +42,18 @@ async def ws_reader(): await read_stream_writer.send(exc) continue - await read_stream_writer.send(client_message) + session_message = SessionMessage(client_message) + await read_stream_writer.send(session_message) except anyio.ClosedResourceError: await websocket.close() async def ws_writer(): try: async with write_stream_reader: - async for message in write_stream_reader: - obj = message.model_dump_json(by_alias=True, exclude_none=True) + async for session_message in write_stream_reader: + obj = session_message.message.model_dump_json( + by_alias=True, exclude_none=True + ) await websocket.send_text(obj) except anyio.ClosedResourceError: await websocket.close() diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index abf87a3aae..b53f8dd63c 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -19,11 +19,11 @@ SamplingFnT, ) from mcp.server import Server -from mcp.types import JSONRPCMessage +from mcp.shared.message import SessionMessage MessageStream = tuple[ - MemoryObjectReceiveStream[JSONRPCMessage | Exception], - MemoryObjectSendStream[JSONRPCMessage], + MemoryObjectReceiveStream[SessionMessage | Exception], + MemoryObjectSendStream[SessionMessage], ] @@ -40,10 +40,10 @@ async def create_client_server_memory_streams() -> ( """ # Create streams for both directions server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - JSONRPCMessage | Exception + SessionMessage | Exception ](1) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - JSONRPCMessage | Exception + SessionMessage | Exception ](1) client_streams = (server_to_client_receive, client_to_server_send) diff --git a/src/mcp/shared/message.py b/src/mcp/shared/message.py new file mode 100644 index 0000000000..c9341c364c --- /dev/null +++ b/src/mcp/shared/message.py @@ -0,0 +1,35 @@ +""" +Message wrapper with metadata support. + +This module defines a wrapper type that combines JSONRPCMessage with metadata +to support transport-specific features like resumability. +""" + +from dataclasses import dataclass + +from mcp.types import JSONRPCMessage, RequestId + + +@dataclass +class ClientMessageMetadata: + """Metadata specific to client messages.""" + + resumption_token: str | None = None + + +@dataclass +class ServerMessageMetadata: + """Metadata specific to server messages.""" + + related_request_id: RequestId | None = None + + +MessageMetadata = ClientMessageMetadata | ServerMessageMetadata | None + + +@dataclass +class SessionMessage: + """A message with specific metadata for transport-specific features.""" + + message: JSONRPCMessage + metadata: MessageMetadata = None diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index c1259da742..cbf47be5f2 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -12,6 +12,7 @@ from typing_extensions import Self from mcp.shared.exceptions import McpError +from mcp.shared.message import SessionMessage from mcp.types import ( CancelledNotification, ClientNotification, @@ -172,8 +173,8 @@ class BaseSession( def __init__( self, - read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception], - write_stream: MemoryObjectSendStream[JSONRPCMessage], + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], + write_stream: MemoryObjectSendStream[SessionMessage], receive_request_type: type[ReceiveRequestT], receive_notification_type: type[ReceiveNotificationT], # If none, reading will never time out @@ -240,7 +241,9 @@ async def send_request( # TODO: Support progress callbacks - await self._write_stream.send(JSONRPCMessage(jsonrpc_request)) + await self._write_stream.send( + SessionMessage(message=JSONRPCMessage(jsonrpc_request)) + ) # request read timeout takes precedence over session read timeout timeout = None @@ -300,14 +303,16 @@ async def send_notification( jsonrpc="2.0", **notification.model_dump(by_alias=True, mode="json", exclude_none=True), ) - await self._write_stream.send(JSONRPCMessage(jsonrpc_notification)) + session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_notification)) + await self._write_stream.send(session_message) async def _send_response( self, request_id: RequestId, response: SendResultT | ErrorData ) -> None: if isinstance(response, ErrorData): jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response) - await self._write_stream.send(JSONRPCMessage(jsonrpc_error)) + session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error)) + await self._write_stream.send(session_message) else: jsonrpc_response = JSONRPCResponse( jsonrpc="2.0", @@ -316,7 +321,8 @@ async def _send_response( by_alias=True, mode="json", exclude_none=True ), ) - await self._write_stream.send(JSONRPCMessage(jsonrpc_response)) + session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response)) + await self._write_stream.send(session_message) async def _receive_loop(self) -> None: async with ( @@ -326,15 +332,15 @@ async def _receive_loop(self) -> None: async for message in self._read_stream: if isinstance(message, Exception): await self._handle_incoming(message) - elif isinstance(message.root, JSONRPCRequest): + elif isinstance(message.message.root, JSONRPCRequest): validated_request = self._receive_request_type.model_validate( - message.root.model_dump( + message.message.root.model_dump( by_alias=True, mode="json", exclude_none=True ) ) responder = RequestResponder( - request_id=message.root.id, + request_id=message.message.root.id, request_meta=validated_request.root.params.meta if validated_request.root.params else None, @@ -349,10 +355,10 @@ async def _receive_loop(self) -> None: if not responder._completed: # type: ignore[reportPrivateUsage] await self._handle_incoming(responder) - elif isinstance(message.root, JSONRPCNotification): + elif isinstance(message.message.root, JSONRPCNotification): try: notification = self._receive_notification_type.model_validate( - message.root.model_dump( + message.message.root.model_dump( by_alias=True, mode="json", exclude_none=True ) ) @@ -368,12 +374,12 @@ async def _receive_loop(self) -> None: # For other validation errors, log and continue logging.warning( f"Failed to validate notification: {e}. " - f"Message was: {message.root}" + f"Message was: {message.message.root}" ) else: # Response or error - stream = self._response_streams.pop(message.root.id, None) + stream = self._response_streams.pop(message.message.root.id, None) if stream: - await stream.send(message.root) + await stream.send(message.message.root) else: await self._handle_incoming( RuntimeError( diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 543ebb2f00..6abcf70cbc 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -3,6 +3,7 @@ import mcp.types as types from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession +from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder from mcp.types import ( LATEST_PROTOCOL_VERSION, @@ -24,10 +25,10 @@ @pytest.mark.anyio async def test_client_session_initialize(): client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - JSONRPCMessage + SessionMessage ](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - JSONRPCMessage + SessionMessage ](1) initialized_notification = None @@ -35,7 +36,8 @@ async def test_client_session_initialize(): async def mock_server(): nonlocal initialized_notification - jsonrpc_request = await client_to_server_receive.receive() + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message assert isinstance(jsonrpc_request.root, JSONRPCRequest) request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) @@ -59,17 +61,20 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump( - by_alias=True, mode="json", exclude_none=True - ), + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump( + by_alias=True, mode="json", exclude_none=True + ), + ) ) ) ) - jsonrpc_notification = await client_to_server_receive.receive() + session_notification = await client_to_server_receive.receive() + jsonrpc_notification = session_notification.message assert isinstance(jsonrpc_notification.root, JSONRPCNotification) initialized_notification = ClientNotification.model_validate( jsonrpc_notification.model_dump( @@ -116,10 +121,10 @@ async def message_handler( @pytest.mark.anyio async def test_client_session_custom_client_info(): client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - JSONRPCMessage + SessionMessage ](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - JSONRPCMessage + SessionMessage ](1) custom_client_info = Implementation(name="test-client", version="1.2.3") @@ -128,7 +133,8 @@ async def test_client_session_custom_client_info(): async def mock_server(): nonlocal received_client_info - jsonrpc_request = await client_to_server_receive.receive() + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message assert isinstance(jsonrpc_request.root, JSONRPCRequest) request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) @@ -146,13 +152,15 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump( - by_alias=True, mode="json", exclude_none=True - ), + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump( + by_alias=True, mode="json", exclude_none=True + ), + ) ) ) ) @@ -181,10 +189,10 @@ async def mock_server(): @pytest.mark.anyio async def test_client_session_default_client_info(): client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - JSONRPCMessage + SessionMessage ](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - JSONRPCMessage + SessionMessage ](1) received_client_info = None @@ -192,7 +200,8 @@ async def test_client_session_default_client_info(): async def mock_server(): nonlocal received_client_info - jsonrpc_request = await client_to_server_receive.receive() + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message assert isinstance(jsonrpc_request.root, JSONRPCRequest) request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) @@ -210,13 +219,15 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump( - by_alias=True, mode="json", exclude_none=True - ), + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump( + by_alias=True, mode="json", exclude_none=True + ), + ) ) ) ) diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index 95747ffd19..523ba199a4 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -3,6 +3,7 @@ import pytest from mcp.client.stdio import StdioServerParameters, stdio_client +from mcp.shared.message import SessionMessage from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse tee: str = shutil.which("tee") # type: ignore @@ -22,7 +23,8 @@ async def test_stdio_client(): async with write_stream: for message in messages: - await write_stream.send(message) + session_message = SessionMessage(message) + await write_stream.send(session_message) read_messages = [] async with read_stream: @@ -30,7 +32,7 @@ async def test_stdio_client(): if isinstance(message, Exception): raise message - read_messages.append(message) + read_messages.append(message.message) if len(read_messages) == 2: break diff --git a/tests/issues/test_192_request_id.py b/tests/issues/test_192_request_id.py index 00e1878958..cf5eb6083e 100644 --- a/tests/issues/test_192_request_id.py +++ b/tests/issues/test_192_request_id.py @@ -3,6 +3,7 @@ from mcp.server.lowlevel import NotificationOptions, Server from mcp.server.models import InitializationOptions +from mcp.shared.message import SessionMessage from mcp.types import ( LATEST_PROTOCOL_VERSION, ClientCapabilities, @@ -64,8 +65,10 @@ async def run_server(): jsonrpc="2.0", ) - await client_writer.send(JSONRPCMessage(root=init_req)) - await server_reader.receive() # Get init response but don't need to check it + await client_writer.send(SessionMessage(JSONRPCMessage(root=init_req))) + response = ( + await server_reader.receive() + ) # Get init response but don't need to check it # Send initialized notification initialized_notification = JSONRPCNotification( @@ -73,21 +76,23 @@ async def run_server(): params=NotificationParams().model_dump(by_alias=True, exclude_none=True), jsonrpc="2.0", ) - await client_writer.send(JSONRPCMessage(root=initialized_notification)) + await client_writer.send( + SessionMessage(JSONRPCMessage(root=initialized_notification)) + ) # Send ping request with custom ID ping_request = JSONRPCRequest( id=custom_request_id, method="ping", params={}, jsonrpc="2.0" ) - await client_writer.send(JSONRPCMessage(root=ping_request)) + await client_writer.send(SessionMessage(JSONRPCMessage(root=ping_request))) # Read response response = await server_reader.receive() # Verify response ID matches request ID assert ( - response.root.id == custom_request_id + response.message.root.id == custom_request_id ), "Response ID should match request ID" # Cancel server task diff --git a/tests/server/test_lifespan.py b/tests/server/test_lifespan.py index 309a44b870..a3ff59bc1b 100644 --- a/tests/server/test_lifespan.py +++ b/tests/server/test_lifespan.py @@ -10,6 +10,7 @@ from mcp.server.fastmcp import Context, FastMCP from mcp.server.lowlevel.server import NotificationOptions, Server from mcp.server.models import InitializationOptions +from mcp.shared.message import SessionMessage from mcp.types import ( ClientCapabilities, Implementation, @@ -82,41 +83,49 @@ async def run_server(): clientInfo=Implementation(name="test-client", version="0.1.0"), ) await send_stream1.send( - JSONRPCMessage( - root=JSONRPCRequest( - jsonrpc="2.0", - id=1, - method="initialize", - params=TypeAdapter(InitializeRequestParams).dump_python(params), + SessionMessage( + JSONRPCMessage( + root=JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params=TypeAdapter(InitializeRequestParams).dump_python(params), + ) ) ) ) response = await receive_stream2.receive() + response = response.message # Send initialized notification await send_stream1.send( - JSONRPCMessage( - root=JSONRPCNotification( - jsonrpc="2.0", - method="notifications/initialized", + SessionMessage( + JSONRPCMessage( + root=JSONRPCNotification( + jsonrpc="2.0", + method="notifications/initialized", + ) ) ) ) # Call the tool to verify lifespan context await send_stream1.send( - JSONRPCMessage( - root=JSONRPCRequest( - jsonrpc="2.0", - id=2, - method="tools/call", - params={"name": "check_lifespan", "arguments": {}}, + SessionMessage( + JSONRPCMessage( + root=JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="tools/call", + params={"name": "check_lifespan", "arguments": {}}, + ) ) ) ) # Get response and verify response = await receive_stream2.receive() + response = response.message assert response.root.result["content"][0]["text"] == "true" # Cancel server task @@ -178,41 +187,49 @@ async def run_server(): clientInfo=Implementation(name="test-client", version="0.1.0"), ) await send_stream1.send( - JSONRPCMessage( - root=JSONRPCRequest( - jsonrpc="2.0", - id=1, - method="initialize", - params=TypeAdapter(InitializeRequestParams).dump_python(params), + SessionMessage( + JSONRPCMessage( + root=JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params=TypeAdapter(InitializeRequestParams).dump_python(params), + ) ) ) ) response = await receive_stream2.receive() + response = response.message # Send initialized notification await send_stream1.send( - JSONRPCMessage( - root=JSONRPCNotification( - jsonrpc="2.0", - method="notifications/initialized", + SessionMessage( + JSONRPCMessage( + root=JSONRPCNotification( + jsonrpc="2.0", + method="notifications/initialized", + ) ) ) ) # Call the tool to verify lifespan context await send_stream1.send( - JSONRPCMessage( - root=JSONRPCRequest( - jsonrpc="2.0", - id=2, - method="tools/call", - params={"name": "check_lifespan", "arguments": {}}, + SessionMessage( + JSONRPCMessage( + root=JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="tools/call", + params={"name": "check_lifespan", "arguments": {}}, + ) ) ) ) # Get response and verify response = await receive_stream2.receive() + response = response.message assert response.root.result["content"][0]["text"] == "true" # Cancel server task diff --git a/tests/server/test_lowlevel_tool_annotations.py b/tests/server/test_lowlevel_tool_annotations.py index 47d03ad233..e9eff9ed0b 100644 --- a/tests/server/test_lowlevel_tool_annotations.py +++ b/tests/server/test_lowlevel_tool_annotations.py @@ -8,10 +8,10 @@ from mcp.server.lowlevel import NotificationOptions from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession +from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder from mcp.types import ( ClientResult, - JSONRPCMessage, ServerNotification, ServerRequest, Tool, @@ -46,10 +46,10 @@ async def list_tools(): ] server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - JSONRPCMessage + SessionMessage ](10) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - JSONRPCMessage + SessionMessage ](10) # Message handler for client diff --git a/tests/server/test_session.py b/tests/server/test_session.py index 561a94b64b..f2f0335882 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -7,11 +7,11 @@ from mcp.server.lowlevel import NotificationOptions from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession +from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder from mcp.types import ( ClientNotification, InitializedNotification, - JSONRPCMessage, PromptsCapability, ResourcesCapability, ServerCapabilities, @@ -21,10 +21,10 @@ @pytest.mark.anyio async def test_server_session_initialize(): server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - JSONRPCMessage + SessionMessage ](1) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - JSONRPCMessage + SessionMessage ](1) # Create a message handler to catch exceptions diff --git a/tests/server/test_stdio.py b/tests/server/test_stdio.py index 85c5bf219b..c546a7167b 100644 --- a/tests/server/test_stdio.py +++ b/tests/server/test_stdio.py @@ -4,6 +4,7 @@ import pytest from mcp.server.stdio import stdio_server +from mcp.shared.message import SessionMessage from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse @@ -29,7 +30,7 @@ async def test_stdio_server(): async for message in read_stream: if isinstance(message, Exception): raise message - received_messages.append(message) + received_messages.append(message.message) if len(received_messages) == 2: break @@ -50,7 +51,8 @@ async def test_stdio_server(): async with write_stream: for response in responses: - await write_stream.send(response) + session_message = SessionMessage(response) + await write_stream.send(session_message) stdout.seek(0) output_lines = stdout.readlines() From cf8b66b82f2bd4cf0d67f909fc4ebc59c7bb63f2 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Fri, 2 May 2025 14:35:17 +0100 Subject: [PATCH 07/21] use metadata from SessionMessage to propagate related_request_id (#591) --- src/mcp/server/streamable_http.py | 17 +++++++++++------ src/mcp/shared/session.py | 21 +++++++-------------- tests/client/test_logging_callback.py | 4 ---- 3 files changed, 18 insertions(+), 24 deletions(-) diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 23ca43b4a3..53fff0d367 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -24,7 +24,7 @@ from starlette.responses import Response from starlette.types import Receive, Scope, Send -from mcp.shared.message import SessionMessage +from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.types import ( INTERNAL_ERROR, INVALID_PARAMS, @@ -520,7 +520,7 @@ async def sse_writer(): ) await response(scope, receive, send) if writer: - await writer.send(err) + await writer.send(Exception(err)) return async def _handle_get_request(self, request: Request, send: Send) -> None: @@ -834,12 +834,17 @@ async def message_router(): ): # Extract related_request_id from meta if it exists if ( - (params := getattr(message.root, "params", None)) - and (meta := params.get("_meta")) - and (related_id := meta.get("related_request_id")) + session_message.metadata is not None + and isinstance( + session_message.metadata, + ServerMessageMetadata, + ) + and session_message.metadata.related_request_id is not None ): - target_request_id = str(related_id) + target_request_id = str( + session_message.metadata.related_request_id + ) else: target_request_id = str(message.root.id) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index cbf47be5f2..d74c4d0664 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -12,7 +12,7 @@ from typing_extensions import Self from mcp.shared.exceptions import McpError -from mcp.shared.message import SessionMessage +from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.types import ( CancelledNotification, ClientNotification, @@ -24,7 +24,6 @@ JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, - NotificationParams, RequestParams, ServerNotification, ServerRequest, @@ -288,22 +287,16 @@ async def send_notification( """ # Some transport implementations may need to set the related_request_id # to attribute to the notifications to the request that triggered them. - if related_request_id is not None and notification.root.params is not None: - # Create meta if it doesn't exist - if notification.root.params.meta is None: - meta_dict = {"related_request_id": related_request_id} - - else: - meta_dict = notification.root.params.meta.model_dump( - by_alias=True, mode="json", exclude_none=True - ) - meta_dict["related_request_id"] = related_request_id - notification.root.params.meta = NotificationParams.Meta(**meta_dict) jsonrpc_notification = JSONRPCNotification( jsonrpc="2.0", **notification.model_dump(by_alias=True, mode="json", exclude_none=True), ) - session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_notification)) + session_message = SessionMessage( + message=JSONRPCMessage(jsonrpc_notification), + metadata=ServerMessageMetadata(related_request_id=related_request_id) + if related_request_id + else None, + ) await self._write_stream.send(session_message) async def _send_response( diff --git a/tests/client/test_logging_callback.py b/tests/client/test_logging_callback.py index 588fa649f8..0c9eeb3970 100644 --- a/tests/client/test_logging_callback.py +++ b/tests/client/test_logging_callback.py @@ -9,7 +9,6 @@ from mcp.shared.session import RequestResponder from mcp.types import ( LoggingMessageNotificationParams, - NotificationParams, TextContent, ) @@ -80,10 +79,7 @@ async def message_handler( assert log_result.isError is False assert len(logging_collector.log_messages) == 1 # Create meta object with related_request_id added dynamically - meta = NotificationParams.Meta() - setattr(meta, "related_request_id", "2") log = logging_collector.log_messages[0] assert log.level == "info" assert log.logger == "test_logger" assert log.data == "Test log message" - assert log.meta == meta From 74f5fcfa0d9181a079e6f684a98a3e4c3f794c04 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Fri, 2 May 2025 14:49:50 +0100 Subject: [PATCH 08/21] StreamableHttp - client refactoring and resumability support (#595) --- src/mcp/client/session.py | 1 + src/mcp/client/streamable_http.py | 576 ++++++++++++++++++--------- src/mcp/shared/message.py | 10 +- src/mcp/shared/session.py | 7 +- tests/shared/test_streamable_http.py | 311 ++++++++++++++- 5 files changed, 710 insertions(+), 195 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 4d65bbebb6..7bb8821f71 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -261,6 +261,7 @@ async def call_tool( read_timeout_seconds: timedelta | None = None, ) -> types.CallToolResult: """Send a tools/call request.""" + return await self.send_request( types.ClientRequest( types.CallToolRequest( diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 7a8887cd95..ef424e3b33 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -7,207 +7,377 @@ """ import logging +from collections.abc import AsyncGenerator, Awaitable, Callable from contextlib import asynccontextmanager +from dataclasses import dataclass from datetime import timedelta from typing import Any import anyio import httpx -from httpx_sse import EventSource, aconnect_sse +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from httpx_sse import EventSource, ServerSentEvent, aconnect_sse -from mcp.shared.message import SessionMessage +from mcp.shared.message import ClientMessageMetadata, SessionMessage from mcp.types import ( ErrorData, JSONRPCError, JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, + JSONRPCResponse, + RequestId, ) logger = logging.getLogger(__name__) -# Header names -MCP_SESSION_ID_HEADER = "mcp-session-id" -LAST_EVENT_ID_HEADER = "last-event-id" -# Content types -CONTENT_TYPE_JSON = "application/json" -CONTENT_TYPE_SSE = "text/event-stream" +SessionMessageOrError = SessionMessage | Exception +StreamWriter = MemoryObjectSendStream[SessionMessageOrError] +StreamReader = MemoryObjectReceiveStream[SessionMessage] +GetSessionIdCallback = Callable[[], str | None] +MCP_SESSION_ID = "mcp-session-id" +LAST_EVENT_ID = "last-event-id" +CONTENT_TYPE = "content-type" +ACCEPT = "Accept" -@asynccontextmanager -async def streamablehttp_client( - url: str, - headers: dict[str, Any] | None = None, - timeout: timedelta = timedelta(seconds=30), - sse_read_timeout: timedelta = timedelta(seconds=60 * 5), -): - """ - Client transport for StreamableHTTP. - `sse_read_timeout` determines how long (in seconds) the client will wait for a new - event before disconnecting. All other HTTP operations are controlled by `timeout`. +JSON = "application/json" +SSE = "text/event-stream" - Yields: - Tuple of (read_stream, write_stream, terminate_callback) - """ - read_stream_writer, read_stream = anyio.create_memory_object_stream[ - SessionMessage | Exception - ](0) - write_stream, write_stream_reader = anyio.create_memory_object_stream[ - SessionMessage - ](0) +class StreamableHTTPError(Exception): + """Base exception for StreamableHTTP transport errors.""" - async def get_stream(): - """ - Optional GET stream for server-initiated messages + pass + + +class ResumptionError(StreamableHTTPError): + """Raised when resumption request is invalid.""" + + pass + + +@dataclass +class RequestContext: + """Context for a request operation.""" + + client: httpx.AsyncClient + headers: dict[str, str] + session_id: str | None + session_message: SessionMessage + metadata: ClientMessageMetadata | None + read_stream_writer: StreamWriter + sse_read_timeout: timedelta + + +class StreamableHTTPTransport: + """StreamableHTTP client transport implementation.""" + + def __init__( + self, + url: str, + headers: dict[str, Any] | None = None, + timeout: timedelta = timedelta(seconds=30), + sse_read_timeout: timedelta = timedelta(seconds=60 * 5), + ) -> None: + """Initialize the StreamableHTTP transport. + + Args: + url: The endpoint URL. + headers: Optional headers to include in requests. + timeout: HTTP timeout for regular operations. + sse_read_timeout: Timeout for SSE read operations. """ - nonlocal session_id + self.url = url + self.headers = headers or {} + self.timeout = timeout + self.sse_read_timeout = sse_read_timeout + self.session_id: str | None = None + self.request_headers = { + ACCEPT: f"{JSON}, {SSE}", + CONTENT_TYPE: JSON, + **self.headers, + } + + def _update_headers_with_session( + self, base_headers: dict[str, str] + ) -> dict[str, str]: + """Update headers with session ID if available.""" + headers = base_headers.copy() + if self.session_id: + headers[MCP_SESSION_ID] = self.session_id + return headers + + def _is_initialization_request(self, message: JSONRPCMessage) -> bool: + """Check if the message is an initialization request.""" + return ( + isinstance(message.root, JSONRPCRequest) + and message.root.method == "initialize" + ) + + def _is_initialized_notification(self, message: JSONRPCMessage) -> bool: + """Check if the message is an initialized notification.""" + return ( + isinstance(message.root, JSONRPCNotification) + and message.root.method == "notifications/initialized" + ) + + def _maybe_extract_session_id_from_response( + self, + response: httpx.Response, + ) -> None: + """Extract and store session ID from response headers.""" + new_session_id = response.headers.get(MCP_SESSION_ID) + if new_session_id: + self.session_id = new_session_id + logger.info(f"Received session ID: {self.session_id}") + + async def _handle_sse_event( + self, + sse: ServerSentEvent, + read_stream_writer: StreamWriter, + original_request_id: RequestId | None = None, + resumption_callback: Callable[[str], Awaitable[None]] | None = None, + ) -> bool: + """Handle an SSE event, returning True if the response is complete.""" + if sse.event == "message": + try: + message = JSONRPCMessage.model_validate_json(sse.data) + logger.debug(f"SSE message: {message}") + + # If this is a response and we have original_request_id, replace it + if original_request_id is not None and isinstance( + message.root, JSONRPCResponse | JSONRPCError + ): + message.root.id = original_request_id + + session_message = SessionMessage(message) + await read_stream_writer.send(session_message) + + # Call resumption token callback if we have an ID + if sse.id and resumption_callback: + await resumption_callback(sse.id) + + # If this is a response or error return True indicating completion + # Otherwise, return False to continue listening + return isinstance(message.root, JSONRPCResponse | JSONRPCError) + + except Exception as exc: + logger.error(f"Error parsing SSE message: {exc}") + await read_stream_writer.send(exc) + return False + else: + logger.warning(f"Unknown SSE event: {sse.event}") + return False + + async def handle_get_stream( + self, + client: httpx.AsyncClient, + read_stream_writer: StreamWriter, + ) -> None: + """Handle GET stream for server-initiated messages.""" try: - # Only attempt GET if we have a session ID - if not session_id: + if not self.session_id: return - get_headers = request_headers.copy() - get_headers[MCP_SESSION_ID_HEADER] = session_id + headers = self._update_headers_with_session(self.request_headers) async with aconnect_sse( client, "GET", - url, - headers=get_headers, - timeout=httpx.Timeout(timeout.seconds, read=sse_read_timeout.seconds), + self.url, + headers=headers, + timeout=httpx.Timeout( + self.timeout.seconds, read=self.sse_read_timeout.seconds + ), ) as event_source: event_source.response.raise_for_status() logger.debug("GET SSE connection established") async for sse in event_source.aiter_sse(): - if sse.event == "message": - try: - message = JSONRPCMessage.model_validate_json(sse.data) - logger.debug(f"GET message: {message}") - session_message = SessionMessage(message) - await read_stream_writer.send(session_message) - except Exception as exc: - logger.error(f"Error parsing GET message: {exc}") - await read_stream_writer.send(exc) - else: - logger.warning(f"Unknown SSE event from GET: {sse.event}") + await self._handle_sse_event(sse, read_stream_writer) + except Exception as exc: - # GET stream is optional, so don't propagate errors logger.debug(f"GET stream error (non-fatal): {exc}") - async def post_writer(client: httpx.AsyncClient): - nonlocal session_id + async def _handle_resumption_request(self, ctx: RequestContext) -> None: + """Handle a resumption request using GET with SSE.""" + headers = self._update_headers_with_session(ctx.headers) + if ctx.metadata and ctx.metadata.resumption_token: + headers[LAST_EVENT_ID] = ctx.metadata.resumption_token + else: + raise ResumptionError("Resumption request requires a resumption token") + + # Extract original request ID to map responses + original_request_id = None + if isinstance(ctx.session_message.message.root, JSONRPCRequest): + original_request_id = ctx.session_message.message.root.id + + async with aconnect_sse( + ctx.client, + "GET", + self.url, + headers=headers, + timeout=httpx.Timeout( + self.timeout.seconds, read=ctx.sse_read_timeout.seconds + ), + ) as event_source: + event_source.response.raise_for_status() + logger.debug("Resumption GET SSE connection established") + + async for sse in event_source.aiter_sse(): + is_complete = await self._handle_sse_event( + sse, + ctx.read_stream_writer, + original_request_id, + ctx.metadata.on_resumption_token_update if ctx.metadata else None, + ) + if is_complete: + break + + async def _handle_post_request(self, ctx: RequestContext) -> None: + """Handle a POST request with response processing.""" + headers = self._update_headers_with_session(ctx.headers) + message = ctx.session_message.message + is_initialization = self._is_initialization_request(message) + + async with ctx.client.stream( + "POST", + self.url, + json=message.model_dump(by_alias=True, mode="json", exclude_none=True), + headers=headers, + ) as response: + if response.status_code == 202: + logger.debug("Received 202 Accepted") + return + + if response.status_code == 404: + if isinstance(message.root, JSONRPCRequest): + await self._send_session_terminated_error( + ctx.read_stream_writer, + message.root.id, + ) + return + + response.raise_for_status() + if is_initialization: + self._maybe_extract_session_id_from_response(response) + + content_type = response.headers.get(CONTENT_TYPE, "").lower() + + if content_type.startswith(JSON): + await self._handle_json_response(response, ctx.read_stream_writer) + elif content_type.startswith(SSE): + await self._handle_sse_response(response, ctx) + else: + await self._handle_unexpected_content_type( + content_type, + ctx.read_stream_writer, + ) + + async def _handle_json_response( + self, + response: httpx.Response, + read_stream_writer: StreamWriter, + ) -> None: + """Handle JSON response from the server.""" + try: + content = await response.aread() + message = JSONRPCMessage.model_validate_json(content) + session_message = SessionMessage(message) + await read_stream_writer.send(session_message) + except Exception as exc: + logger.error(f"Error parsing JSON response: {exc}") + await read_stream_writer.send(exc) + + async def _handle_sse_response( + self, response: httpx.Response, ctx: RequestContext + ) -> None: + """Handle SSE response from the server.""" + try: + event_source = EventSource(response) + async for sse in event_source.aiter_sse(): + await self._handle_sse_event( + sse, + ctx.read_stream_writer, + resumption_callback=( + ctx.metadata.on_resumption_token_update + if ctx.metadata + else None + ), + ) + except Exception as e: + logger.exception("Error reading SSE stream:") + await ctx.read_stream_writer.send(e) + + async def _handle_unexpected_content_type( + self, + content_type: str, + read_stream_writer: StreamWriter, + ) -> None: + """Handle unexpected content type in response.""" + error_msg = f"Unexpected content type: {content_type}" + logger.error(error_msg) + await read_stream_writer.send(ValueError(error_msg)) + + async def _send_session_terminated_error( + self, + read_stream_writer: StreamWriter, + request_id: RequestId, + ) -> None: + """Send a session terminated error response.""" + jsonrpc_error = JSONRPCError( + jsonrpc="2.0", + id=request_id, + error=ErrorData(code=32600, message="Session terminated"), + ) + session_message = SessionMessage(JSONRPCMessage(jsonrpc_error)) + await read_stream_writer.send(session_message) + + async def post_writer( + self, + client: httpx.AsyncClient, + write_stream_reader: StreamReader, + read_stream_writer: StreamWriter, + write_stream: MemoryObjectSendStream[SessionMessage], + start_get_stream: Callable[[], None], + ) -> None: + """Handle writing requests to the server.""" try: async with write_stream_reader: async for session_message in write_stream_reader: message = session_message.message - # Add session ID to headers if we have one - post_headers = request_headers.copy() - if session_id: - post_headers[MCP_SESSION_ID_HEADER] = session_id + metadata = ( + session_message.metadata + if isinstance(session_message.metadata, ClientMessageMetadata) + else None + ) + + # Check if this is a resumption request + is_resumption = bool(metadata and metadata.resumption_token) logger.debug(f"Sending client message: {message}") - # Handle initial initialization request - is_initialization = ( - isinstance(message.root, JSONRPCRequest) - and message.root.method == "initialize" + # Handle initialized notification + if self._is_initialized_notification(message): + start_get_stream() + + ctx = RequestContext( + client=client, + headers=self.request_headers, + session_id=self.session_id, + session_message=session_message, + metadata=metadata, + read_stream_writer=read_stream_writer, + sse_read_timeout=self.sse_read_timeout, ) - if ( - isinstance(message.root, JSONRPCNotification) - and message.root.method == "notifications/initialized" - ): - tg.start_soon(get_stream) - - async with client.stream( - "POST", - url, - json=message.model_dump( - by_alias=True, mode="json", exclude_none=True - ), - headers=post_headers, - ) as response: - if response.status_code == 202: - logger.debug("Received 202 Accepted") - continue - # Check for 404 (session expired/invalid) - if response.status_code == 404: - if isinstance(message.root, JSONRPCRequest): - jsonrpc_error = JSONRPCError( - jsonrpc="2.0", - id=message.root.id, - error=ErrorData( - code=32600, - message="Session terminated", - ), - ) - session_message = SessionMessage( - JSONRPCMessage(jsonrpc_error) - ) - await read_stream_writer.send(session_message) - continue - response.raise_for_status() - - # Extract session ID from response headers - if is_initialization: - new_session_id = response.headers.get(MCP_SESSION_ID_HEADER) - if new_session_id: - session_id = new_session_id - logger.info(f"Received session ID: {session_id}") - - # Handle different response types - content_type = response.headers.get("content-type", "").lower() - - if content_type.startswith(CONTENT_TYPE_JSON): - try: - content = await response.aread() - json_message = JSONRPCMessage.model_validate_json( - content - ) - session_message = SessionMessage(json_message) - await read_stream_writer.send(session_message) - except Exception as exc: - logger.error(f"Error parsing JSON response: {exc}") - await read_stream_writer.send(exc) - - elif content_type.startswith(CONTENT_TYPE_SSE): - # Parse SSE events from the response - try: - event_source = EventSource(response) - async for sse in event_source.aiter_sse(): - if sse.event == "message": - try: - message = ( - JSONRPCMessage.model_validate_json( - sse.data - ) - ) - session_message = SessionMessage(message) - await read_stream_writer.send( - session_message - ) - except Exception as exc: - logger.exception("Error parsing message") - await read_stream_writer.send(exc) - else: - logger.warning(f"Unknown event: {sse.event}") - - except Exception as e: - logger.exception("Error reading SSE stream:") - await read_stream_writer.send(e) - - else: - # For 202 Accepted with no body - if response.status_code == 202: - logger.debug("Received 202 Accepted") - continue - - error_msg = f"Unexpected content type: {content_type}" - logger.error(error_msg) - await read_stream_writer.send(ValueError(error_msg)) + + if is_resumption: + await self._handle_resumption_request(ctx) + else: + await self._handle_post_request(ctx) except Exception as exc: logger.error(f"Error in post_writer: {exc}") @@ -215,52 +385,98 @@ async def post_writer(client: httpx.AsyncClient): await read_stream_writer.aclose() await write_stream.aclose() - async def terminate_session(): - """ - Terminate the session by sending a DELETE request. - """ - nonlocal session_id - if not session_id: - return # No session to terminate + async def terminate_session(self, client: httpx.AsyncClient) -> None: + """Terminate the session by sending a DELETE request.""" + if not self.session_id: + return try: - delete_headers = request_headers.copy() - delete_headers[MCP_SESSION_ID_HEADER] = session_id - - response = await client.delete( - url, - headers=delete_headers, - ) + headers = self._update_headers_with_session(self.request_headers) + response = await client.delete(self.url, headers=headers) if response.status_code == 405: - # Server doesn't allow client-initiated termination logger.debug("Server does not allow session termination") elif response.status_code != 200: logger.warning(f"Session termination failed: {response.status_code}") except Exception as exc: logger.warning(f"Session termination failed: {exc}") + def get_session_id(self) -> str | None: + """Get the current session ID.""" + return self.session_id + + +@asynccontextmanager +async def streamablehttp_client( + url: str, + headers: dict[str, Any] | None = None, + timeout: timedelta = timedelta(seconds=30), + sse_read_timeout: timedelta = timedelta(seconds=60 * 5), + terminate_on_close: bool = True, +) -> AsyncGenerator[ + tuple[ + MemoryObjectReceiveStream[SessionMessage | Exception], + MemoryObjectSendStream[SessionMessage], + GetSessionIdCallback, + ], + None, +]: + """ + Client transport for StreamableHTTP. + + `sse_read_timeout` determines how long (in seconds) the client will wait for a new + event before disconnecting. All other HTTP operations are controlled by `timeout`. + + Yields: + Tuple containing: + - read_stream: Stream for reading messages from the server + - write_stream: Stream for sending messages to the server + - get_session_id_callback: Function to retrieve the current session ID + """ + transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout) + + read_stream_writer, read_stream = anyio.create_memory_object_stream[ + SessionMessage | Exception + ](0) + write_stream, write_stream_reader = anyio.create_memory_object_stream[ + SessionMessage + ](0) + async with anyio.create_task_group() as tg: try: logger.info(f"Connecting to StreamableHTTP endpoint: {url}") - # Set up headers with required Accept header - request_headers = { - "Accept": f"{CONTENT_TYPE_JSON}, {CONTENT_TYPE_SSE}", - "Content-Type": CONTENT_TYPE_JSON, - **(headers or {}), - } - # Track session ID if provided by server - session_id: str | None = None async with httpx.AsyncClient( - headers=request_headers, - timeout=httpx.Timeout(timeout.seconds, read=sse_read_timeout.seconds), + headers=transport.request_headers, + timeout=httpx.Timeout( + transport.timeout.seconds, read=transport.sse_read_timeout.seconds + ), follow_redirects=True, ) as client: - tg.start_soon(post_writer, client) + # Define callbacks that need access to tg + def start_get_stream() -> None: + tg.start_soon( + transport.handle_get_stream, client, read_stream_writer + ) + + tg.start_soon( + transport.post_writer, + client, + write_stream_reader, + read_stream_writer, + write_stream, + start_get_stream, + ) + try: - yield read_stream, write_stream, terminate_session + yield ( + read_stream, + write_stream, + transport.get_session_id, + ) finally: + if transport.session_id and terminate_on_close: + await transport.terminate_session(client) tg.cancel_scope.cancel() finally: await read_stream_writer.aclose() diff --git a/src/mcp/shared/message.py b/src/mcp/shared/message.py index c9341c364c..5583f47951 100644 --- a/src/mcp/shared/message.py +++ b/src/mcp/shared/message.py @@ -5,16 +5,24 @@ to support transport-specific features like resumability. """ +from collections.abc import Awaitable, Callable from dataclasses import dataclass from mcp.types import JSONRPCMessage, RequestId +ResumptionToken = str + +ResumptionTokenUpdateCallback = Callable[[ResumptionToken], Awaitable[None]] + @dataclass class ClientMessageMetadata: """Metadata specific to client messages.""" - resumption_token: str | None = None + resumption_token: ResumptionToken | None = None + on_resumption_token_update: Callable[[ResumptionToken], Awaitable[None]] | None = ( + None + ) @dataclass diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index d74c4d0664..cce8b1184e 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -12,7 +12,7 @@ from typing_extensions import Self from mcp.shared.exceptions import McpError -from mcp.shared.message import ServerMessageMetadata, SessionMessage +from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage from mcp.types import ( CancelledNotification, ClientNotification, @@ -213,6 +213,7 @@ async def send_request( request: SendRequestT, result_type: type[ReceiveResultT], request_read_timeout_seconds: timedelta | None = None, + metadata: MessageMetadata = None, ) -> ReceiveResultT: """ Sends a request and wait for a response. Raises an McpError if the @@ -241,7 +242,9 @@ async def send_request( # TODO: Support progress callbacks await self._write_stream.send( - SessionMessage(message=JSONRPCMessage(jsonrpc_request)) + SessionMessage( + message=JSONRPCMessage(jsonrpc_request), metadata=metadata + ) ) # request read timeout takes precedence over session read timeout diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 7331b392bc..f643602292 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -23,16 +23,31 @@ from starlette.responses import Response from starlette.routing import Mount +import mcp.types as types from mcp.client.session import ClientSession from mcp.client.streamable_http import streamablehttp_client from mcp.server import Server from mcp.server.streamable_http import ( MCP_SESSION_ID_HEADER, SESSION_ID_PATTERN, + EventCallback, + EventId, + EventMessage, + EventStore, StreamableHTTPServerTransport, + StreamId, ) from mcp.shared.exceptions import McpError -from mcp.types import InitializeResult, TextContent, TextResourceContents, Tool +from mcp.shared.message import ( + ClientMessageMetadata, +) +from mcp.shared.session import RequestResponder +from mcp.types import ( + InitializeResult, + TextContent, + TextResourceContents, + Tool, +) # Test constants SERVER_NAME = "test_streamable_http_server" @@ -49,6 +64,51 @@ } +# Simple in-memory event store for testing +class SimpleEventStore(EventStore): + """Simple in-memory event store for testing.""" + + def __init__(self): + self._events: list[tuple[StreamId, EventId, types.JSONRPCMessage]] = [] + self._event_id_counter = 0 + + async def store_event( + self, stream_id: StreamId, message: types.JSONRPCMessage + ) -> EventId: + """Store an event and return its ID.""" + self._event_id_counter += 1 + event_id = str(self._event_id_counter) + self._events.append((stream_id, event_id, message)) + return event_id + + async def replay_events_after( + self, + last_event_id: EventId, + send_callback: EventCallback, + ) -> StreamId | None: + """Replay events after the specified ID.""" + # Find the index of the last event ID + start_index = None + for i, (_, event_id, _) in enumerate(self._events): + if event_id == last_event_id: + start_index = i + 1 + break + + if start_index is None: + # If event ID not found, start from beginning + start_index = 0 + + stream_id = None + # Replay events + for _, event_id, message in self._events[start_index:]: + await send_callback(EventMessage(message, event_id)) + # Capture the stream ID from the first replayed event + if stream_id is None and len(self._events) > start_index: + stream_id = self._events[start_index][0] + + return stream_id + + # Test server implementation that follows MCP protocol class ServerTest(Server): def __init__(self): @@ -78,25 +138,57 @@ async def handle_list_tools() -> list[Tool]: description="A test tool that sends a notification", inputSchema={"type": "object", "properties": {}}, ), + Tool( + name="long_running_with_checkpoints", + description="A long-running tool that sends periodic notifications", + inputSchema={"type": "object", "properties": {}}, + ), ] @self.call_tool() async def handle_call_tool(name: str, args: dict) -> list[TextContent]: + ctx = self.request_context + # When the tool is called, send a notification to test GET stream if name == "test_tool_with_standalone_notification": - ctx = self.request_context await ctx.session.send_resource_updated( uri=AnyUrl("http://test_resource") ) + return [TextContent(type="text", text=f"Called {name}")] + + elif name == "long_running_with_checkpoints": + # Send notifications that are part of the response stream + # This simulates a long-running tool that sends logs + + await ctx.session.send_log_message( + level="info", + data="Tool started", + logger="tool", + related_request_id=ctx.request_id, # need for stream association + ) + + await anyio.sleep(0.1) + + await ctx.session.send_log_message( + level="info", + data="Tool is almost done", + logger="tool", + related_request_id=ctx.request_id, + ) + + return [TextContent(type="text", text="Completed!")] return [TextContent(type="text", text=f"Called {name}")] -def create_app(is_json_response_enabled=False) -> Starlette: +def create_app( + is_json_response_enabled=False, event_store: EventStore | None = None +) -> Starlette: """Create a Starlette application for testing that matches the example server. Args: is_json_response_enabled: If True, use JSON responses instead of SSE streams. + event_store: Optional event store for testing resumability. """ # Create server instance server = ServerTest() @@ -139,6 +231,7 @@ async def handle_streamable_http(scope, receive, send): http_transport = StreamableHTTPServerTransport( mcp_session_id=new_session_id, is_json_response_enabled=is_json_response_enabled, + event_store=event_store, ) async with http_transport.connect() as streams: @@ -183,15 +276,18 @@ async def run_server(): return app -def run_server(port: int, is_json_response_enabled=False) -> None: +def run_server( + port: int, is_json_response_enabled=False, event_store: EventStore | None = None +) -> None: """Run the test server. Args: port: Port to listen on. is_json_response_enabled: If True, use JSON responses instead of SSE streams. + event_store: Optional event store for testing resumability. """ - app = create_app(is_json_response_enabled) + app = create_app(is_json_response_enabled, event_store) # Configure server config = uvicorn.Config( app=app, @@ -261,6 +357,53 @@ def basic_server(basic_server_port: int) -> Generator[None, None, None]: proc.join(timeout=2) +@pytest.fixture +def event_store() -> SimpleEventStore: + """Create a test event store.""" + return SimpleEventStore() + + +@pytest.fixture +def event_server_port() -> int: + """Find an available port for the event store server.""" + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@pytest.fixture +def event_server( + event_server_port: int, event_store: SimpleEventStore +) -> Generator[tuple[SimpleEventStore, str], None, None]: + """Start a server with event store enabled.""" + proc = multiprocessing.Process( + target=run_server, + kwargs={"port": event_server_port, "event_store": event_store}, + daemon=True, + ) + proc.start() + + # Wait for server to be running + max_attempts = 20 + attempt = 0 + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("127.0.0.1", event_server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError(f"Server failed to start after {max_attempts} attempts") + + yield event_store, f"http://127.0.0.1:{event_server_port}" + + # Clean up + proc.kill() + proc.join(timeout=2) + + @pytest.fixture def json_response_server(json_server_port: int) -> Generator[None, None, None]: """Start a server with JSON response enabled.""" @@ -679,7 +822,7 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session) """Test client tool invocation.""" # First list tools tools = await initialized_client_session.list_tools() - assert len(tools.tools) == 2 + assert len(tools.tools) == 3 assert tools.tools[0].name == "test_tool" # Call the tool @@ -720,7 +863,7 @@ async def test_streamablehttp_client_session_persistence( # Make multiple requests to verify session persistence tools = await session.list_tools() - assert len(tools.tools) == 2 + assert len(tools.tools) == 3 # Read a resource resource = await session.read_resource(uri=AnyUrl("foobar://test-persist")) @@ -751,7 +894,7 @@ async def test_streamablehttp_client_json_response( # Check tool listing tools = await session.list_tools() - assert len(tools.tools) == 2 + assert len(tools.tools) == 3 # Call a tool and verify JSON response handling result = await session.call_tool("test_tool", {}) @@ -813,25 +956,169 @@ async def test_streamablehttp_client_session_termination( ): """Test client session termination functionality.""" + captured_session_id = None + # Create the streamablehttp_client with a custom httpx client to capture headers async with streamablehttp_client(f"{basic_server_url}/mcp") as ( read_stream, write_stream, - terminate_session, + get_session_id, ): async with ClientSession(read_stream, write_stream) as session: # Initialize the session result = await session.initialize() assert isinstance(result, InitializeResult) + captured_session_id = get_session_id() + assert captured_session_id is not None # Make a request to confirm session is working tools = await session.list_tools() - assert len(tools.tools) == 2 + assert len(tools.tools) == 3 + + headers = {} + if captured_session_id: + headers[MCP_SESSION_ID_HEADER] = captured_session_id - # After exiting ClientSession context, explicitly terminate the session - await terminate_session() + async with streamablehttp_client(f"{basic_server_url}/mcp", headers=headers) as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: + # Attempt to make a request after termination with pytest.raises( McpError, match="Session terminated", ): await session.list_tools() + + +@pytest.mark.anyio +async def test_streamablehttp_client_resumption(event_server): + """Test client session to resume a long running tool.""" + _, server_url = event_server + + # Variables to track the state + captured_session_id = None + captured_resumption_token = None + captured_notifications = [] + tool_started = False + + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] + | types.ServerNotification + | Exception, + ) -> None: + if isinstance(message, types.ServerNotification): + captured_notifications.append(message) + # Look for our special notification that indicates the tool is running + if isinstance(message.root, types.LoggingMessageNotification): + if message.root.params.data == "Tool started": + nonlocal tool_started + tool_started = True + + async def on_resumption_token_update(token: str) -> None: + nonlocal captured_resumption_token + captured_resumption_token = token + + # First, start the client session and begin the long-running tool + async with streamablehttp_client(f"{server_url}/mcp", terminate_on_close=False) as ( + read_stream, + write_stream, + get_session_id, + ): + async with ClientSession( + read_stream, write_stream, message_handler=message_handler + ) as session: + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + captured_session_id = get_session_id() + assert captured_session_id is not None + + # Start a long-running tool in a task + async with anyio.create_task_group() as tg: + + async def run_tool(): + metadata = ClientMessageMetadata( + on_resumption_token_update=on_resumption_token_update, + ) + await session.send_request( + types.ClientRequest( + types.CallToolRequest( + method="tools/call", + params=types.CallToolRequestParams( + name="long_running_with_checkpoints", arguments={} + ), + ) + ), + types.CallToolResult, + metadata=metadata, + ) + + tg.start_soon(run_tool) + + # Wait for the tool to start and at least one notification + # and then kill the task group + while not tool_started or not captured_resumption_token: + await anyio.sleep(0.1) + tg.cancel_scope.cancel() + + # Store pre notifications and clear the captured notifications + # for the post-resumption check + captured_notifications_pre = captured_notifications.copy() + captured_notifications = [] + + # Now resume the session with the same mcp-session-id + headers = {} + if captured_session_id: + headers[MCP_SESSION_ID_HEADER] = captured_session_id + + async with streamablehttp_client(f"{server_url}/mcp", headers=headers) as ( + read_stream, + write_stream, + _, + ): + async with ClientSession( + read_stream, write_stream, message_handler=message_handler + ) as session: + # Don't initialize - just use the existing session + + # Resume the tool with the resumption token + assert captured_resumption_token is not None + + metadata = ClientMessageMetadata( + resumption_token=captured_resumption_token, + ) + result = await session.send_request( + types.ClientRequest( + types.CallToolRequest( + method="tools/call", + params=types.CallToolRequestParams( + name="long_running_with_checkpoints", arguments={} + ), + ) + ), + types.CallToolResult, + metadata=metadata, + ) + + # We should get a complete result + assert len(result.content) == 1 + assert result.content[0].type == "text" + assert "Completed" in result.content[0].text + + # We should have received the remaining notifications + assert len(captured_notifications) > 0 + + # Should not have the first notification + # Check that "Tool started" notification isn't repeated when resuming + assert not any( + isinstance(n.root, types.LoggingMessageNotification) + and n.root.params.data == "Tool started" + for n in captured_notifications + ) + # there is no intersection between pre and post notifications + assert not any( + n in captured_notifications_pre for n in captured_notifications + ) From 5d8eaf77be00dbd9b33a7fe1e38cb0da77e49401 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Fri, 2 May 2025 14:59:17 +0100 Subject: [PATCH 09/21] Streamable Http - clean up server memory streams (#604) --- .../mcp_simple_streamablehttp/server.py | 20 +-- src/mcp/server/lowlevel/server.py | 2 +- src/mcp/server/streamable_http.py | 124 +++++++++++------- tests/shared/test_streamable_http.py | 31 ++--- 4 files changed, 108 insertions(+), 69 deletions(-) diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py index b2079bb27e..d36686720a 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -185,20 +185,22 @@ async def handle_streamable_http(scope, receive, send): ) server_instances[http_transport.mcp_session_id] = http_transport logger.info(f"Created new transport with session ID: {new_session_id}") - async with http_transport.connect() as streams: - read_stream, write_stream = streams - async def run_server(): - await app.run( - read_stream, - write_stream, - app.create_initialization_options(), - ) + async def run_server(task_status=None): + async with http_transport.connect() as streams: + read_stream, write_stream = streams + if task_status: + task_status.started() + await app.run( + read_stream, + write_stream, + app.create_initialization_options(), + ) if not task_group: raise RuntimeError("Task group is not initialized") - task_group.start_soon(run_server) + await task_group.start(run_server) # Handle the HTTP request and return the response await http_transport.handle_request(scope, receive, send) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 1cacd23b5e..4b97b33dad 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -480,7 +480,7 @@ async def run( # but also make tracing exceptions much easier during testing and when using # in-process servers. raise_exceptions: bool = False, - # When True, the server as stateless deployments where + # When True, the server is stateless and # clients can perform initialization with any node. The client must still follow # the initialization lifecycle, but can do so with any available node # rather than requiring initialization for each connection. diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 53fff0d367..ace74b33b4 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -129,6 +129,8 @@ class StreamableHTTPServerTransport: _read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] | None = ( None ) + _read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] | None = None + _write_stream: MemoryObjectSendStream[SessionMessage] | None = None _write_stream_reader: MemoryObjectReceiveStream[SessionMessage] | None = None def __init__( @@ -163,7 +165,11 @@ def __init__( self.is_json_response_enabled = is_json_response_enabled self._event_store = event_store self._request_streams: dict[ - RequestId, MemoryObjectSendStream[EventMessage] + RequestId, + tuple[ + MemoryObjectSendStream[EventMessage], + MemoryObjectReceiveStream[EventMessage], + ], ] = {} self._terminated = False @@ -239,6 +245,19 @@ def _create_event_data(self, event_message: EventMessage) -> dict[str, str]: return event_data + async def _clean_up_memory_streams(self, request_id: RequestId) -> None: + """Clean up memory streams for a given request ID.""" + if request_id in self._request_streams: + try: + # Close the request stream + await self._request_streams[request_id][0].aclose() + await self._request_streams[request_id][1].aclose() + except Exception as e: + logger.debug(f"Error closing memory streams: {e}") + finally: + # Remove the request stream from the mapping + self._request_streams.pop(request_id, None) + async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None: """Application entry point that handles all HTTP requests""" request = Request(scope, receive) @@ -386,13 +405,11 @@ async def _handle_post_request( # Extract the request ID outside the try block for proper scope request_id = str(message.root.id) - # Create promise stream for getting response - request_stream_writer, request_stream_reader = ( - anyio.create_memory_object_stream[EventMessage](0) - ) - # Register this stream for the request ID - self._request_streams[request_id] = request_stream_writer + self._request_streams[request_id] = anyio.create_memory_object_stream[ + EventMessage + ](0) + request_stream_reader = self._request_streams[request_id][1] if self.is_json_response_enabled: # Process the message @@ -441,11 +458,7 @@ async def _handle_post_request( ) await response(scope, receive, send) finally: - # Clean up the request stream - if request_id in self._request_streams: - self._request_streams.pop(request_id, None) - await request_stream_reader.aclose() - await request_stream_writer.aclose() + await self._clean_up_memory_streams(request_id) else: # Create SSE stream sse_stream_writer, sse_stream_reader = ( @@ -467,16 +480,12 @@ async def sse_writer(): event_message.message.root, JSONRPCResponse | JSONRPCError, ): - if request_id: - self._request_streams.pop(request_id, None) break except Exception as e: logger.exception(f"Error in SSE writer: {e}") finally: logger.debug("Closing SSE writer") - # Clean up the request-specific streams - if request_id and request_id in self._request_streams: - self._request_streams.pop(request_id, None) + await self._clean_up_memory_streams(request_id) # Create and start EventSourceResponse # SSE stream mode (original behavior) @@ -507,9 +516,9 @@ async def sse_writer(): await writer.send(session_message) except Exception: logger.exception("SSE response error") - # Clean up the request stream if something goes wrong - if request_id and request_id in self._request_streams: - self._request_streams.pop(request_id, None) + await sse_stream_writer.aclose() + await sse_stream_reader.aclose() + await self._clean_up_memory_streams(request_id) except Exception as err: logger.exception("Error handling POST request") @@ -581,12 +590,11 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: async def standalone_sse_writer(): try: # Create a standalone message stream for server-initiated messages - standalone_stream_writer, standalone_stream_reader = ( + + self._request_streams[GET_STREAM_KEY] = ( anyio.create_memory_object_stream[EventMessage](0) ) - - # Register this stream using the special key - self._request_streams[GET_STREAM_KEY] = standalone_stream_writer + standalone_stream_reader = self._request_streams[GET_STREAM_KEY][1] async with sse_stream_writer, standalone_stream_reader: # Process messages from the standalone stream @@ -603,8 +611,7 @@ async def standalone_sse_writer(): logger.exception(f"Error in standalone SSE writer: {e}") finally: logger.debug("Closing standalone SSE writer") - # Remove the stream from request_streams - self._request_streams.pop(GET_STREAM_KEY, None) + await self._clean_up_memory_streams(GET_STREAM_KEY) # Create and start EventSourceResponse response = EventSourceResponse( @@ -618,8 +625,9 @@ async def standalone_sse_writer(): await response(request.scope, request.receive, send) except Exception as e: logger.exception(f"Error in standalone SSE response: {e}") - # Clean up the request stream - self._request_streams.pop(GET_STREAM_KEY, None) + await sse_stream_writer.aclose() + await sse_stream_reader.aclose() + await self._clean_up_memory_streams(GET_STREAM_KEY) async def _handle_delete_request(self, request: Request, send: Send) -> None: """Handle DELETE requests for explicit session termination.""" @@ -636,7 +644,7 @@ async def _handle_delete_request(self, request: Request, send: Send) -> None: if not await self._validate_session(request, send): return - self._terminate_session() + await self._terminate_session() response = self._create_json_response( None, @@ -644,7 +652,7 @@ async def _handle_delete_request(self, request: Request, send: Send) -> None: ) await response(request.scope, request.receive, send) - def _terminate_session(self) -> None: + async def _terminate_session(self) -> None: """Terminate the current session, closing all streams. Once terminated, all requests with this session ID will receive 404 Not Found. @@ -656,19 +664,26 @@ def _terminate_session(self) -> None: # We need a copy of the keys to avoid modification during iteration request_stream_keys = list(self._request_streams.keys()) - # Close all request streams (synchronously) + # Close all request streams asynchronously for key in request_stream_keys: try: - # Get the stream - stream = self._request_streams.get(key) - if stream: - # We must use close() here, not aclose() since this is a sync method - stream.close() + await self._clean_up_memory_streams(key) except Exception as e: logger.debug(f"Error closing stream {key} during termination: {e}") # Clear the request streams dictionary immediately self._request_streams.clear() + try: + if self._read_stream_writer is not None: + await self._read_stream_writer.aclose() + if self._read_stream is not None: + await self._read_stream.aclose() + if self._write_stream_reader is not None: + await self._write_stream_reader.aclose() + if self._write_stream is not None: + await self._write_stream.aclose() + except Exception as e: + logger.debug(f"Error closing streams: {e}") async def _handle_unsupported_request(self, request: Request, send: Send) -> None: """Handle unsupported HTTP methods.""" @@ -756,10 +771,10 @@ async def send_event(event_message: EventMessage) -> None: # If stream ID not in mapping, create it if stream_id and stream_id not in self._request_streams: - msg_writer, msg_reader = anyio.create_memory_object_stream[ - EventMessage - ](0) - self._request_streams[stream_id] = msg_writer + self._request_streams[stream_id] = ( + anyio.create_memory_object_stream[EventMessage](0) + ) + msg_reader = self._request_streams[stream_id][1] # Forward messages to SSE async with msg_reader: @@ -781,6 +796,9 @@ async def send_event(event_message: EventMessage) -> None: await response(request.scope, request.receive, send) except Exception as e: logger.exception(f"Error in replay response: {e}") + finally: + await sse_stream_writer.aclose() + await sse_stream_reader.aclose() except Exception as e: logger.exception(f"Error replaying events: {e}") @@ -818,7 +836,9 @@ async def connect( # Store the streams self._read_stream_writer = read_stream_writer + self._read_stream = read_stream self._write_stream_reader = write_stream_reader + self._write_stream = write_stream # Start a task group for message routing async with anyio.create_task_group() as tg: @@ -863,7 +883,7 @@ async def message_router(): if request_stream_id in self._request_streams: try: # Send both the message and the event ID - await self._request_streams[request_stream_id].send( + await self._request_streams[request_stream_id][0].send( EventMessage(message, event_id) ) except ( @@ -872,6 +892,12 @@ async def message_router(): ): # Stream might be closed, remove from registry self._request_streams.pop(request_stream_id, None) + else: + logging.debug( + f"""Request stream {request_stream_id} not found + for message. Still processing message as the client + might reconnect and replay.""" + ) except Exception as e: logger.exception(f"Error in message router: {e}") @@ -882,9 +908,19 @@ async def message_router(): # Yield the streams for the caller to use yield read_stream, write_stream finally: - for stream in list(self._request_streams.values()): + for stream_id in list(self._request_streams.keys()): try: - await stream.aclose() - except Exception: + await self._clean_up_memory_streams(stream_id) + except Exception as e: + logger.debug(f"Error closing request stream: {e}") pass self._request_streams.clear() + + # Clean up the read and write streams + try: + await read_stream_writer.aclose() + await read_stream.aclose() + await write_stream_reader.aclose() + await write_stream.aclose() + except Exception as e: + logger.debug(f"Error closing streams: {e}") diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index f643602292..b1dc7ea338 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -234,29 +234,30 @@ async def handle_streamable_http(scope, receive, send): event_store=event_store, ) - async with http_transport.connect() as streams: - read_stream, write_stream = streams - - async def run_server(): + async def run_server(task_status=None): + async with http_transport.connect() as streams: + read_stream, write_stream = streams + if task_status: + task_status.started() await server.run( read_stream, write_stream, server.create_initialization_options(), ) - if task_group is None: - response = Response( - "Internal Server Error: Task group is not initialized", - status_code=HTTPStatus.INTERNAL_SERVER_ERROR, - ) - await response(scope, receive, send) - return + if task_group is None: + response = Response( + "Internal Server Error: Task group is not initialized", + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + ) + await response(scope, receive, send) + return - # Store the instance before starting the task to prevent races - server_instances[http_transport.mcp_session_id] = http_transport - task_group.start_soon(run_server) + # Store the instance before starting the task to prevent races + server_instances[http_transport.mcp_session_id] = http_transport + await task_group.start(run_server) - await http_transport.handle_request(scope, receive, send) + await http_transport.handle_request(scope, receive, send) else: response = Response( "Bad Request: No valid session ID provided", From 83968b5b2f424beb25310ea6087a3df8b8050cd7 Mon Sep 17 00:00:00 2001 From: Akash D <14977531+akash329d@users.noreply.github.com> Date: Fri, 2 May 2025 09:32:46 -0700 Subject: [PATCH 10/21] Handle SSE Disconnects Properly (#612) --- .../simple-prompt/mcp_simple_prompt/server.py | 2 ++ .../mcp_simple_resource/server.py | 4 ++- .../simple-tool/mcp_simple_tool/server.py | 4 ++- src/mcp/server/fastmcp/server.py | 1 + src/mcp/server/session.py | 7 ++--- src/mcp/server/sse.py | 27 +++++++++++++++---- tests/shared/test_sse.py | 4 ++- 7 files changed, 38 insertions(+), 11 deletions(-) diff --git a/examples/servers/simple-prompt/mcp_simple_prompt/server.py b/examples/servers/simple-prompt/mcp_simple_prompt/server.py index 0552f2770e..bc14b7cd0a 100644 --- a/examples/servers/simple-prompt/mcp_simple_prompt/server.py +++ b/examples/servers/simple-prompt/mcp_simple_prompt/server.py @@ -90,6 +90,7 @@ async def get_prompt( if transport == "sse": from mcp.server.sse import SseServerTransport from starlette.applications import Starlette + from starlette.responses import Response from starlette.routing import Mount, Route sse = SseServerTransport("/messages/") @@ -101,6 +102,7 @@ async def handle_sse(request): await app.run( streams[0], streams[1], app.create_initialization_options() ) + return Response() starlette_app = Starlette( debug=True, diff --git a/examples/servers/simple-resource/mcp_simple_resource/server.py b/examples/servers/simple-resource/mcp_simple_resource/server.py index 0ec1d926af..06f567fbeb 100644 --- a/examples/servers/simple-resource/mcp_simple_resource/server.py +++ b/examples/servers/simple-resource/mcp_simple_resource/server.py @@ -46,6 +46,7 @@ async def read_resource(uri: FileUrl) -> str | bytes: if transport == "sse": from mcp.server.sse import SseServerTransport from starlette.applications import Starlette + from starlette.responses import Response from starlette.routing import Mount, Route sse = SseServerTransport("/messages/") @@ -57,11 +58,12 @@ async def handle_sse(request): await app.run( streams[0], streams[1], app.create_initialization_options() ) + return Response() starlette_app = Starlette( debug=True, routes=[ - Route("/sse", endpoint=handle_sse), + Route("/sse", endpoint=handle_sse, methods=["GET"]), Mount("/messages/", app=sse.handle_post_message), ], ) diff --git a/examples/servers/simple-tool/mcp_simple_tool/server.py b/examples/servers/simple-tool/mcp_simple_tool/server.py index 3eace52eaf..04224af5d2 100644 --- a/examples/servers/simple-tool/mcp_simple_tool/server.py +++ b/examples/servers/simple-tool/mcp_simple_tool/server.py @@ -60,6 +60,7 @@ async def list_tools() -> list[types.Tool]: if transport == "sse": from mcp.server.sse import SseServerTransport from starlette.applications import Starlette + from starlette.responses import Response from starlette.routing import Mount, Route sse = SseServerTransport("/messages/") @@ -71,11 +72,12 @@ async def handle_sse(request): await app.run( streams[0], streams[1], app.create_initialization_options() ) + return Response() starlette_app = Starlette( debug=True, routes=[ - Route("/sse", endpoint=handle_sse), + Route("/sse", endpoint=handle_sse, methods=["GET"]), Mount("/messages/", app=sse.handle_post_message), ], ) diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 5b57eb13e9..a82880f1a5 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -589,6 +589,7 @@ async def handle_sse(scope: Scope, receive: Receive, send: Send): streams[1], self._mcp_server.create_initialization_options(), ) + return Response() # Create routes routes: list[Route | Mount] = [] diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 6171dacc10..c769d1aa3c 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -104,9 +104,6 @@ def __init__( self._exit_stack.push_async_callback( lambda: self._incoming_message_stream_reader.aclose() ) - self._exit_stack.push_async_callback( - lambda: self._incoming_message_stream_writer.aclose() - ) @property def client_params(self) -> types.InitializeRequestParams | None: @@ -144,6 +141,10 @@ def check_client_capability(self, capability: types.ClientCapabilities) -> bool: return True + async def _receive_loop(self) -> None: + async with self._incoming_message_stream_writer: + await super()._receive_loop() + async def _received_request( self, responder: RequestResponder[types.ClientRequest, types.ServerResult] ): diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index c781c64d5b..cc41a80d67 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -10,7 +10,7 @@ # Create Starlette routes for SSE and message handling routes = [ - Route("/sse", endpoint=handle_sse), + Route("/sse", endpoint=handle_sse, methods=["GET"]), Mount("/messages/", app=sse.handle_post_message), ] @@ -22,12 +22,18 @@ async def handle_sse(request): await app.run( streams[0], streams[1], app.create_initialization_options() ) + # Return empty response to avoid NoneType error + return Response() # Create and run Starlette app starlette_app = Starlette(routes=routes) uvicorn.run(starlette_app, host="0.0.0.0", port=port) ``` +Note: The handle_sse function must return a Response to avoid a "TypeError: 'NoneType' +object is not callable" error when client disconnects. The example above returns +an empty Response() after the SSE connection ends to fix this. + See SseServerTransport class documentation for more details. """ @@ -120,11 +126,22 @@ async def sse_writer(): ) async with anyio.create_task_group() as tg: - response = EventSourceResponse( - content=sse_stream_reader, data_sender_callable=sse_writer - ) + + async def response_wrapper(scope: Scope, receive: Receive, send: Send): + """ + The EventSourceResponse returning signals a client close / disconnect. + In this case we close our side of the streams to signal the client that + the connection has been closed. + """ + await EventSourceResponse( + content=sse_stream_reader, data_sender_callable=sse_writer + )(scope, receive, send) + await read_stream_writer.aclose() + await write_stream_reader.aclose() + logging.debug(f"Client session disconnected {session_id}") + logger.debug("Starting SSE response task") - tg.start_soon(response, scope, receive, send) + tg.start_soon(response_wrapper, scope, receive, send) logger.debug("Yielding read and write streams") yield (read_stream, write_stream) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index f5158c3c37..4558bb88c0 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -10,6 +10,7 @@ from pydantic import AnyUrl from starlette.applications import Starlette from starlette.requests import Request +from starlette.responses import Response from starlette.routing import Mount, Route from mcp.client.session import ClientSession @@ -83,13 +84,14 @@ def make_server_app() -> Starlette: sse = SseServerTransport("/messages/") server = ServerTest() - async def handle_sse(request: Request) -> None: + async def handle_sse(request: Request) -> Response: async with sse.connect_sse( request.scope, request.receive, request._send ) as streams: await server.run( streams[0], streams[1], server.create_initialization_options() ) + return Response() app = Starlette( routes=[ From 58c5e7223c40b2ec682fd7674545e8ceadd7cb20 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Fri, 2 May 2025 17:56:02 +0100 Subject: [PATCH 11/21] SSE FastMCP - do not go though auth when it's not needed (#619) --- src/mcp/server/fastmcp/server.py | 45 ++++++--- tests/server/fastmcp/test_integration.py | 112 +++++++++++++++++++++++ 2 files changed, 146 insertions(+), 11 deletions(-) create mode 100644 tests/server/fastmcp/test_integration.py diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index a82880f1a5..0e0b565c57 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -625,19 +625,42 @@ async def handle_sse(scope: Scope, receive: Receive, send: Send): ) ) - routes.append( - Route( - self.settings.sse_path, - endpoint=RequireAuthMiddleware(handle_sse, required_scopes), - methods=["GET"], + # When auth is not configured, we shouldn't require auth + if self._auth_server_provider: + # Auth is enabled, wrap the endpoints with RequireAuthMiddleware + routes.append( + Route( + self.settings.sse_path, + endpoint=RequireAuthMiddleware(handle_sse, required_scopes), + methods=["GET"], + ) ) - ) - routes.append( - Mount( - self.settings.message_path, - app=RequireAuthMiddleware(sse.handle_post_message, required_scopes), + routes.append( + Mount( + self.settings.message_path, + app=RequireAuthMiddleware(sse.handle_post_message, required_scopes), + ) + ) + else: + # Auth is disabled, no need for RequireAuthMiddleware + # Since handle_sse is an ASGI app, we need to create a compatible endpoint + async def sse_endpoint(request: Request) -> None: + # Convert the Starlette request to ASGI parameters + await handle_sse(request.scope, request.receive, request._send) # type: ignore[reportPrivateUsage] + + routes.append( + Route( + self.settings.sse_path, + endpoint=sse_endpoint, + methods=["GET"], + ) + ) + routes.append( + Mount( + self.settings.message_path, + app=sse.handle_post_message, + ) ) - ) # mount these routes last, so they have the lowest route matching precedence routes.extend(self._custom_starlette_routes) diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py new file mode 100644 index 0000000000..281db2dbc7 --- /dev/null +++ b/tests/server/fastmcp/test_integration.py @@ -0,0 +1,112 @@ +""" +Integration tests for FastMCP server functionality. + +These tests validate the proper functioning of FastMCP in various configurations, +including with and without authentication. +""" + +import multiprocessing +import socket +import time +from collections.abc import Generator + +import pytest +import uvicorn + +from mcp.client.session import ClientSession +from mcp.client.sse import sse_client +from mcp.server.fastmcp import FastMCP +from mcp.types import InitializeResult, TextContent + + +@pytest.fixture +def server_port() -> int: + """Get a free port for testing.""" + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@pytest.fixture +def server_url(server_port: int) -> str: + """Get the server URL for testing.""" + return f"http://127.0.0.1:{server_port}" + + +# Create a function to make the FastMCP server app +def make_fastmcp_app(): + """Create a FastMCP server without auth settings.""" + from starlette.applications import Starlette + + mcp = FastMCP(name="NoAuthServer") + + # Add a simple tool + @mcp.tool(description="A simple echo tool") + def echo(message: str) -> str: + return f"Echo: {message}" + + # Create the SSE app + app: Starlette = mcp.sse_app() + + return mcp, app + + +def run_server(server_port: int) -> None: + """Run the server.""" + _, app = make_fastmcp_app() + server = uvicorn.Server( + config=uvicorn.Config( + app=app, host="127.0.0.1", port=server_port, log_level="error" + ) + ) + print(f"Starting server on port {server_port}") + server.run() + + +@pytest.fixture() +def server(server_port: int) -> Generator[None, None, None]: + """Start the server in a separate process and clean up after the test.""" + proc = multiprocessing.Process(target=run_server, args=(server_port,), daemon=True) + print("Starting server process") + proc.start() + + # Wait for server to be running + max_attempts = 20 + attempt = 0 + print("Waiting for server to start") + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("127.0.0.1", server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError(f"Server failed to start after {max_attempts} attempts") + + yield + + print("Killing server") + proc.kill() + proc.join(timeout=2) + if proc.is_alive(): + print("Server process failed to terminate") + + +@pytest.mark.anyio +async def test_fastmcp_without_auth(server: None, server_url: str) -> None: + """Test that FastMCP works when auth settings are not provided.""" + # Connect to the server + async with sse_client(server_url + "/sse") as streams: + async with ClientSession(*streams) as session: + # Test initialization + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == "NoAuthServer" + + # Test that we can call tools without authentication + tool_result = await session.call_tool("echo", {"message": "hello"}) + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + assert tool_result.content[0].text == "Echo: hello" From 3b1b213a9669579ba04b3552e404f90a39baf8eb Mon Sep 17 00:00:00 2001 From: Akash D <14977531+akash329d@users.noreply.github.com> Date: Tue, 6 May 2025 17:10:43 -0700 Subject: [PATCH 12/21] Add message queue for SSE messages POST endpoint (#459) --- README.md | 24 ++ .../simple-prompt/mcp_simple_prompt/server.py | 5 +- pyproject.toml | 2 + src/mcp/client/sse.py | 6 +- src/mcp/client/stdio/__init__.py | 2 +- src/mcp/client/streamable_http.py | 6 +- src/mcp/client/websocket.py | 2 +- src/mcp/server/fastmcp/server.py | 31 +- src/mcp/server/message_queue/__init__.py | 16 + src/mcp/server/message_queue/base.py | 116 ++++++ src/mcp/server/message_queue/redis.py | 198 ++++++++++ src/mcp/server/sse.py | 42 ++- src/mcp/server/stdio.py | 2 +- src/mcp/server/streamable_http.py | 6 +- src/mcp/server/websocket.py | 2 +- src/mcp/shared/message.py | 14 +- tests/client/test_session.py | 6 +- tests/client/test_stdio.py | 2 +- tests/issues/test_192_request_id.py | 8 +- tests/server/message_dispatch/__init__.py | 1 + tests/server/message_dispatch/conftest.py | 28 ++ tests/server/message_dispatch/test_redis.py | 355 ++++++++++++++++++ .../test_redis_integration.py | 260 +++++++++++++ tests/server/test_lifespan.py | 12 +- tests/server/test_stdio.py | 2 +- uv.lock | 149 +++++++- 26 files changed, 1247 insertions(+), 50 deletions(-) create mode 100644 src/mcp/server/message_queue/__init__.py create mode 100644 src/mcp/server/message_queue/base.py create mode 100644 src/mcp/server/message_queue/redis.py create mode 100644 tests/server/message_dispatch/__init__.py create mode 100644 tests/server/message_dispatch/conftest.py create mode 100644 tests/server/message_dispatch/test_redis.py create mode 100644 tests/server/message_dispatch/test_redis_integration.py diff --git a/README.md b/README.md index 3889dc40b0..56aa3609ae 100644 --- a/README.md +++ b/README.md @@ -412,6 +412,30 @@ app.router.routes.append(Host('mcp.acme.corp', app=mcp.sse_app())) For more information on mounting applications in Starlette, see the [Starlette documentation](https://www.starlette.io/routing/#submounting-routes). +#### Message Dispatch Options + +By default, the SSE server uses an in-memory message dispatch system for incoming POST messages. For production deployments or distributed scenarios, you can use Redis or implement your own message dispatch system that conforms to the `MessageDispatch` protocol: + +```python +# Using the built-in Redis message dispatch +from mcp.server.fastmcp import FastMCP +from mcp.server.message_queue import RedisMessageDispatch + +# Create a Redis message dispatch +redis_dispatch = RedisMessageDispatch( + redis_url="redis://localhost:6379/0", prefix="mcp:pubsub:" +) + +# Pass the message dispatch instance to the server +mcp = FastMCP("My App", message_queue=redis_dispatch) +``` + +To use Redis, add the Redis dependency: + +```bash +uv add "mcp[redis]" +``` + ## Examples ### Echo Server diff --git a/examples/servers/simple-prompt/mcp_simple_prompt/server.py b/examples/servers/simple-prompt/mcp_simple_prompt/server.py index bc14b7cd0a..04b10ac75d 100644 --- a/examples/servers/simple-prompt/mcp_simple_prompt/server.py +++ b/examples/servers/simple-prompt/mcp_simple_prompt/server.py @@ -88,12 +88,15 @@ async def get_prompt( ) if transport == "sse": + from mcp.server.message_queue.redis import RedisMessageDispatch from mcp.server.sse import SseServerTransport from starlette.applications import Starlette from starlette.responses import Response from starlette.routing import Mount, Route - sse = SseServerTransport("/messages/") + message_dispatch = RedisMessageDispatch("redis://localhost:6379/0") + + sse = SseServerTransport("/messages/", message_dispatch=message_dispatch) async def handle_sse(request): async with sse.connect_sse( diff --git a/pyproject.toml b/pyproject.toml index 2b86fb3772..6ff2601e9d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ dependencies = [ rich = ["rich>=13.9.4"] cli = ["typer>=0.12.4", "python-dotenv>=1.0.0"] ws = ["websockets>=15.0.1"] +redis = ["redis>=5.2.1", "types-redis>=4.6.0.20241004"] [project.scripts] mcp = "mcp.cli:app [cli]" @@ -55,6 +56,7 @@ dev = [ "pytest-xdist>=3.6.1", "pytest-examples>=0.0.14", "pytest-pretty>=1.2.0", + "fakeredis==2.28.1", ] docs = [ "mkdocs>=1.6.1", diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index ff04d2f961..7df251f792 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -98,7 +98,9 @@ async def sse_reader( await read_stream_writer.send(exc) continue - session_message = SessionMessage(message) + session_message = SessionMessage( + message=message + ) await read_stream_writer.send(session_message) case _: logger.warning( @@ -148,3 +150,5 @@ async def post_writer(endpoint_url: str): finally: await read_stream_writer.aclose() await write_stream.aclose() + await read_stream.aclose() + await write_stream_reader.aclose() diff --git a/src/mcp/client/stdio/__init__.py b/src/mcp/client/stdio/__init__.py index e8be5aff5b..21c7764e75 100644 --- a/src/mcp/client/stdio/__init__.py +++ b/src/mcp/client/stdio/__init__.py @@ -144,7 +144,7 @@ async def stdout_reader(): await read_stream_writer.send(exc) continue - session_message = SessionMessage(message) + session_message = SessionMessage(message=message) await read_stream_writer.send(session_message) except anyio.ClosedResourceError: await anyio.lowlevel.checkpoint() diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index ef424e3b33..ca26046b93 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -153,7 +153,7 @@ async def _handle_sse_event( ): message.root.id = original_request_id - session_message = SessionMessage(message) + session_message = SessionMessage(message=message) await read_stream_writer.send(session_message) # Call resumption token callback if we have an ID @@ -286,7 +286,7 @@ async def _handle_json_response( try: content = await response.aread() message = JSONRPCMessage.model_validate_json(content) - session_message = SessionMessage(message) + session_message = SessionMessage(message=message) await read_stream_writer.send(session_message) except Exception as exc: logger.error(f"Error parsing JSON response: {exc}") @@ -333,7 +333,7 @@ async def _send_session_terminated_error( id=request_id, error=ErrorData(code=32600, message="Session terminated"), ) - session_message = SessionMessage(JSONRPCMessage(jsonrpc_error)) + session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error)) await read_stream_writer.send(session_message) async def post_writer( diff --git a/src/mcp/client/websocket.py b/src/mcp/client/websocket.py index ac542fb3f6..598fdaf252 100644 --- a/src/mcp/client/websocket.py +++ b/src/mcp/client/websocket.py @@ -60,7 +60,7 @@ async def ws_reader(): async for raw_text in ws: try: message = types.JSONRPCMessage.model_validate_json(raw_text) - session_message = SessionMessage(message) + session_message = SessionMessage(message=message) await read_stream_writer.send(session_message) except ValidationError as exc: # If JSON parse or model validation fails, send the exception diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 0e0b565c57..1e5b69eba0 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -44,6 +44,7 @@ from mcp.server.lowlevel.server import LifespanResultT from mcp.server.lowlevel.server import Server as MCPServer from mcp.server.lowlevel.server import lifespan as default_lifespan +from mcp.server.message_queue import MessageDispatch from mcp.server.session import ServerSession, ServerSessionT from mcp.server.sse import SseServerTransport from mcp.server.stdio import stdio_server @@ -90,6 +91,11 @@ class Settings(BaseSettings, Generic[LifespanResultT]): sse_path: str = "/sse" message_path: str = "/messages/" + # SSE message queue settings + message_dispatch: MessageDispatch | None = Field( + None, description="Custom message dispatch instance" + ) + # resource settings warn_on_duplicate_resources: bool = True @@ -569,12 +575,21 @@ async def run_sse_async(self) -> None: def sse_app(self) -> Starlette: """Return an instance of the SSE server app.""" + message_dispatch = self.settings.message_dispatch + if message_dispatch is None: + from mcp.server.message_queue import InMemoryMessageDispatch + + message_dispatch = InMemoryMessageDispatch() + logger.info("Using default in-memory message dispatch") + from starlette.middleware import Middleware from starlette.routing import Mount, Route # Set up auth context and dependencies - sse = SseServerTransport(self.settings.message_path) + sse = SseServerTransport( + self.settings.message_path, message_dispatch=message_dispatch + ) async def handle_sse(scope: Scope, receive: Receive, send: Send): # Add client ID from auth context into request context if available @@ -589,7 +604,14 @@ async def handle_sse(scope: Scope, receive: Receive, send: Send): streams[1], self._mcp_server.create_initialization_options(), ) - return Response() + return Response() + + @asynccontextmanager + async def lifespan(app: Starlette): + try: + yield + finally: + await message_dispatch.close() # Create routes routes: list[Route | Mount] = [] @@ -666,7 +688,10 @@ async def sse_endpoint(request: Request) -> None: # Create Starlette app with routes and middleware return Starlette( - debug=self.settings.debug, routes=routes, middleware=middleware + debug=self.settings.debug, + routes=routes, + middleware=middleware, + lifespan=lifespan, ) async def list_prompts(self) -> list[MCPPrompt]: diff --git a/src/mcp/server/message_queue/__init__.py b/src/mcp/server/message_queue/__init__.py new file mode 100644 index 0000000000..f4a8b9dfaf --- /dev/null +++ b/src/mcp/server/message_queue/__init__.py @@ -0,0 +1,16 @@ +""" +Message Dispatch Module for MCP Server + +This module implements dispatch interfaces for handling +messages between clients and servers. +""" + +from mcp.server.message_queue.base import InMemoryMessageDispatch, MessageDispatch + +# Try to import Redis implementation if available +try: + from mcp.server.message_queue.redis import RedisMessageDispatch +except ImportError: + RedisMessageDispatch = None + +__all__ = ["MessageDispatch", "InMemoryMessageDispatch", "RedisMessageDispatch"] diff --git a/src/mcp/server/message_queue/base.py b/src/mcp/server/message_queue/base.py new file mode 100644 index 0000000000..20c7145505 --- /dev/null +++ b/src/mcp/server/message_queue/base.py @@ -0,0 +1,116 @@ +import logging +from collections.abc import Awaitable, Callable +from contextlib import asynccontextmanager +from typing import Protocol, runtime_checkable +from uuid import UUID + +from pydantic import ValidationError + +from mcp.shared.message import SessionMessage + +logger = logging.getLogger(__name__) + +MessageCallback = Callable[[SessionMessage | Exception], Awaitable[None]] + + +@runtime_checkable +class MessageDispatch(Protocol): + """Abstract interface for SSE message dispatching. + + This interface allows messages to be published to sessions and callbacks to be + registered for message handling, enabling multiple servers to handle requests. + """ + + async def publish_message( + self, session_id: UUID, message: SessionMessage | str + ) -> bool: + """Publish a message for the specified session. + + Args: + session_id: The UUID of the session this message is for + message: The message to publish (SessionMessage or str for invalid JSON) + + Returns: + bool: True if message was published, False if session not found + """ + ... + + @asynccontextmanager + async def subscribe(self, session_id: UUID, callback: MessageCallback): + """Request-scoped context manager that subscribes to messages for a session. + + Args: + session_id: The UUID of the session to subscribe to + callback: Async callback function to handle messages for this session + """ + yield + + async def session_exists(self, session_id: UUID) -> bool: + """Check if a session exists. + + Args: + session_id: The UUID of the session to check + + Returns: + bool: True if the session is active, False otherwise + """ + ... + + async def close(self) -> None: + """Close the message dispatch.""" + ... + + +class InMemoryMessageDispatch: + """Default in-memory implementation of the MessageDispatch interface. + + This implementation immediately dispatches messages to registered callbacks when + messages are received without any queuing behavior. + """ + + def __init__(self) -> None: + self._callbacks: dict[UUID, MessageCallback] = {} + + async def publish_message( + self, session_id: UUID, message: SessionMessage | str + ) -> bool: + """Publish a message for the specified session.""" + if session_id not in self._callbacks: + logger.warning(f"Message dropped: unknown session {session_id}") + return False + + # Parse string messages or recreate original ValidationError + if isinstance(message, str): + try: + callback_argument = SessionMessage.model_validate_json(message) + except ValidationError as exc: + callback_argument = exc + else: + callback_argument = message + + # Call the callback with either valid message or recreated ValidationError + await self._callbacks[session_id](callback_argument) + + logger.debug(f"Message dispatched to session {session_id}") + return True + + @asynccontextmanager + async def subscribe(self, session_id: UUID, callback: MessageCallback): + """Request-scoped context manager that subscribes to messages for a session.""" + self._callbacks[session_id] = callback + logger.debug(f"Subscribing to messages for session {session_id}") + + try: + yield + finally: + if session_id in self._callbacks: + del self._callbacks[session_id] + logger.debug(f"Unsubscribed from session {session_id}") + + async def session_exists(self, session_id: UUID) -> bool: + """Check if a session exists.""" + return session_id in self._callbacks + + async def close(self) -> None: + """Close the message dispatch.""" + pass diff --git a/src/mcp/server/message_queue/redis.py b/src/mcp/server/message_queue/redis.py new file mode 100644 index 0000000000..628ce026c5 --- /dev/null +++ b/src/mcp/server/message_queue/redis.py @@ -0,0 +1,198 @@ +import logging +from contextlib import asynccontextmanager +from typing import Any, cast +from uuid import UUID + +import anyio +from anyio import CancelScope, CapacityLimiter, lowlevel +from anyio.abc import TaskGroup +from pydantic import ValidationError + +from mcp.server.message_queue.base import MessageCallback +from mcp.shared.message import SessionMessage + +try: + import redis.asyncio as redis +except ImportError: + raise ImportError( + "Redis support requires the 'redis' package. " + "Install it with: 'uv add redis' or 'uv add \"mcp[redis]\"'" + ) + +logger = logging.getLogger(__name__) + + +class RedisMessageDispatch: + """Redis implementation of the MessageDispatch interface using pubsub. + + This implementation uses Redis pubsub for real-time message distribution across + multiple servers handling the same sessions. + """ + + def __init__( + self, + redis_url: str = "redis://localhost:6379/0", + prefix: str = "mcp:pubsub:", + session_ttl: int = 3600, # 1 hour default TTL for sessions + ) -> None: + """Initialize Redis message dispatch. + + Args: + redis_url: Redis connection string + prefix: Key prefix for Redis channels to avoid collisions + session_ttl: TTL in seconds for session keys (default: 1 hour) + """ + self._redis = redis.from_url(redis_url, decode_responses=True) # type: ignore + self._pubsub = self._redis.pubsub(ignore_subscribe_messages=True) # type: ignore + self._prefix = prefix + self._session_ttl = session_ttl + # Maps session IDs to the callback and task group for that SSE session. + self._session_state: dict[UUID, tuple[MessageCallback, TaskGroup]] = {} + # Ensures only one polling task runs at a time for message handling + self._limiter = CapacityLimiter(1) + logger.debug(f"Redis message dispatch initialized: {redis_url}") + + async def close(self): + await self._pubsub.aclose() # type: ignore + await self._redis.aclose() # type: ignore + + def _session_channel(self, session_id: UUID) -> str: + """Get the Redis channel for a session.""" + return f"{self._prefix}session:{session_id.hex}" + + def _session_key(self, session_id: UUID) -> str: + """Get the Redis key for a session.""" + return f"{self._prefix}session_active:{session_id.hex}" + + @asynccontextmanager + async def subscribe(self, session_id: UUID, callback: MessageCallback): + """Request-scoped context manager that subscribes to messages for a session.""" + session_key = self._session_key(session_id) + await self._redis.setex(session_key, self._session_ttl, "1") # type: ignore + + channel = self._session_channel(session_id) + await self._pubsub.subscribe(channel) # type: ignore + + logger.debug(f"Subscribing to Redis channel for session {session_id}") + async with anyio.create_task_group() as tg: + self._session_state[session_id] = (callback, tg) + tg.start_soon(self._listen_for_messages) + # Start heartbeat for this session + tg.start_soon(self._session_heartbeat, session_id) + try: + yield + finally: + with anyio.CancelScope(shield=True): + tg.cancel_scope.cancel() + await self._pubsub.unsubscribe(channel) # type: ignore + await self._redis.delete(session_key) # type: ignore + del self._session_state[session_id] + logger.debug(f"Unsubscribed from Redis channel: {session_id}") + + async def _session_heartbeat(self, session_id: UUID) -> None: + """Periodically refresh the TTL for a session.""" + session_key = self._session_key(session_id) + while True: + await lowlevel.checkpoint() + try: + # Refresh TTL at half the TTL interval to avoid expiration + await anyio.sleep(self._session_ttl / 2) + with anyio.CancelScope(shield=True): + await self._redis.expire(session_key, self._session_ttl) # type: ignore + except anyio.get_cancelled_exc_class(): + break + except Exception as e: + logger.error(f"Error refreshing TTL for session {session_id}: {e}") + + def _extract_session_id(self, channel: str) -> UUID | None: + """Extract and validate session ID from channel.""" + expected_prefix = f"{self._prefix}session:" + if not channel.startswith(expected_prefix): + return None + + session_hex = channel[len(expected_prefix) :] + try: + session_id = UUID(hex=session_hex) + if channel != self._session_channel(session_id): + logger.error(f"Channel format mismatch: {channel}") + return None + return session_id + except ValueError: + logger.error(f"Invalid UUID in channel: {channel}") + return None + + async def _listen_for_messages(self) -> None: + """Background task that listens for messages on subscribed channels.""" + async with self._limiter: + while True: + await lowlevel.checkpoint() + with CancelScope(shield=True): + message: None | dict[str, Any] = await self._pubsub.get_message( # type: ignore + ignore_subscribe_messages=True, + timeout=0.1, # type: ignore + ) + if message is None: + continue + + channel: str = cast(str, message["channel"]) + session_id = self._extract_session_id(channel) + if session_id is None: + logger.debug( + f"Ignoring message from non-MCP channel: {channel}" + ) + continue + + data: str = cast(str, message["data"]) + try: + if session_state := self._session_state.get(session_id): + session_state[1].start_soon( + self._handle_message, session_id, data + ) + else: + logger.warning( + f"Message dropped: unknown session {session_id}" + ) + except Exception as e: + logger.error(f"Error processing message for {session_id}: {e}") + + async def _handle_message(self, session_id: UUID, data: str) -> None: + """Process a message from Redis in the session's task group.""" + if (session_state := self._session_state.get(session_id)) is None: + logger.warning(f"Message dropped: callback removed for {session_id}") + return + + try: + # Parse message or pass validation error to callback + msg_or_error = None + try: + msg_or_error = SessionMessage.model_validate_json(data) + except ValidationError as exc: + msg_or_error = exc + + await session_state[0](msg_or_error) + except Exception as e: + logger.error(f"Error in message handler for {session_id}: {e}") + + async def publish_message( + self, session_id: UUID, message: SessionMessage | str + ) -> bool: + """Publish a message for the specified session.""" + if not await self.session_exists(session_id): + logger.warning(f"Message dropped: unknown session {session_id}") + return False + + # Pass raw JSON strings directly, preserving validation errors + if isinstance(message, str): + data = message + else: + data = message.model_dump_json() + + channel = self._session_channel(session_id) + await self._redis.publish(channel, data) # type: ignore[attr-defined] + logger.debug(f"Message published to Redis channel for session {session_id}") + return True + + async def session_exists(self, session_id: UUID) -> bool: + """Check if a session exists.""" + session_key = self._session_key(session_id) + return bool(await self._redis.exists(session_key)) # type: ignore diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index cc41a80d67..98f32629ef 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -52,9 +52,11 @@ async def handle_sse(request): from starlette.types import Receive, Scope, Send import mcp.types as types +from mcp.server.message_queue import InMemoryMessageDispatch, MessageDispatch from mcp.shared.message import SessionMessage logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.DEBUG) class SseServerTransport: @@ -70,17 +72,24 @@ class SseServerTransport: """ _endpoint: str + _message_dispatch: MessageDispatch _read_stream_writers: dict[UUID, MemoryObjectSendStream[SessionMessage | Exception]] - def __init__(self, endpoint: str) -> None: + def __init__( + self, endpoint: str, message_dispatch: MessageDispatch | None = None + ) -> None: """ Creates a new SSE server transport, which will direct the client to POST messages to the relative or absolute URL given. + + Args: + endpoint: The endpoint URL for SSE connections + message_dispatch: Optional message dispatch to use """ super().__init__() self._endpoint = endpoint - self._read_stream_writers = {} + self._message_dispatch = message_dispatch or InMemoryMessageDispatch() logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}") @asynccontextmanager @@ -101,7 +110,12 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): session_id = uuid4() session_uri = f"{quote(self._endpoint)}?session_id={session_id.hex}" - self._read_stream_writers[session_id] = read_stream_writer + + async def message_callback(message: SessionMessage | Exception) -> None: + """Callback that receives messages from the message queue""" + logger.debug(f"Got message from queue for session {session_id}") + await read_stream_writer.send(message) + logger.debug(f"Created new session with ID: {session_id}") sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[ @@ -138,13 +152,16 @@ async def response_wrapper(scope: Scope, receive: Receive, send: Send): )(scope, receive, send) await read_stream_writer.aclose() await write_stream_reader.aclose() + await sse_stream_writer.aclose() + await sse_stream_reader.aclose() logging.debug(f"Client session disconnected {session_id}") logger.debug("Starting SSE response task") tg.start_soon(response_wrapper, scope, receive, send) - logger.debug("Yielding read and write streams") - yield (read_stream, write_stream) + async with self._message_dispatch.subscribe(session_id, message_callback): + logger.debug("Yielding read and write streams") + yield (read_stream, write_stream) async def handle_post_message( self, scope: Scope, receive: Receive, send: Send @@ -166,8 +183,7 @@ async def handle_post_message( response = Response("Invalid session ID", status_code=400) return await response(scope, receive, send) - writer = self._read_stream_writers.get(session_id) - if not writer: + if not await self._message_dispatch.session_exists(session_id): logger.warning(f"Could not find session for ID: {session_id}") response = Response("Could not find session", status_code=404) return await response(scope, receive, send) @@ -182,11 +198,15 @@ async def handle_post_message( logger.error(f"Failed to parse message: {err}") response = Response("Could not parse message", status_code=400) await response(scope, receive, send) - await writer.send(err) + # Pass raw JSON string; receiver will recreate identical ValidationError + # when parsing the same invalid JSON + await self._message_dispatch.publish_message(session_id, body.decode()) return - session_message = SessionMessage(message) - logger.debug(f"Sending session message to writer: {session_message}") + logger.debug(f"Publishing message for session {session_id}: {message}") response = Response("Accepted", status_code=202) await response(scope, receive, send) - await writer.send(session_message) + await self._message_dispatch.publish_message( + session_id, SessionMessage(message=message) + ) + logger.debug(f"Sending session message to writer: {message}") diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index f0bbe5a316..11c8f7ee4d 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -67,7 +67,7 @@ async def stdin_reader(): await read_stream_writer.send(exc) continue - session_message = SessionMessage(message) + session_message = SessionMessage(message=message) await read_stream_writer.send(session_message) except anyio.ClosedResourceError: await anyio.lowlevel.checkpoint() diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index ace74b33b4..79c8a89139 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -398,7 +398,7 @@ async def _handle_post_request( await response(scope, receive, send) # Process the message after sending the response - session_message = SessionMessage(message) + session_message = SessionMessage(message=message) await writer.send(session_message) return @@ -413,7 +413,7 @@ async def _handle_post_request( if self.is_json_response_enabled: # Process the message - session_message = SessionMessage(message) + session_message = SessionMessage(message=message) await writer.send(session_message) try: # Process messages from the request-specific stream @@ -512,7 +512,7 @@ async def sse_writer(): async with anyio.create_task_group() as tg: tg.start_soon(response, scope, receive, send) # Then send the message to be processed by the server - session_message = SessionMessage(message) + session_message = SessionMessage(message=message) await writer.send(session_message) except Exception: logger.exception("SSE response error") diff --git a/src/mcp/server/websocket.py b/src/mcp/server/websocket.py index 9dc3f2a25e..bb0b1ca6ea 100644 --- a/src/mcp/server/websocket.py +++ b/src/mcp/server/websocket.py @@ -42,7 +42,7 @@ async def ws_reader(): await read_stream_writer.send(exc) continue - session_message = SessionMessage(client_message) + session_message = SessionMessage(message=client_message) await read_stream_writer.send(session_message) except anyio.ClosedResourceError: await websocket.close() diff --git a/src/mcp/shared/message.py b/src/mcp/shared/message.py index 5583f47951..c96a0a1e63 100644 --- a/src/mcp/shared/message.py +++ b/src/mcp/shared/message.py @@ -6,7 +6,8 @@ """ from collections.abc import Awaitable, Callable -from dataclasses import dataclass + +from pydantic import BaseModel from mcp.types import JSONRPCMessage, RequestId @@ -15,8 +16,7 @@ ResumptionTokenUpdateCallback = Callable[[ResumptionToken], Awaitable[None]] -@dataclass -class ClientMessageMetadata: +class ClientMessageMetadata(BaseModel): """Metadata specific to client messages.""" resumption_token: ResumptionToken | None = None @@ -25,8 +25,7 @@ class ClientMessageMetadata: ) -@dataclass -class ServerMessageMetadata: +class ServerMessageMetadata(BaseModel): """Metadata specific to server messages.""" related_request_id: RequestId | None = None @@ -35,9 +34,8 @@ class ServerMessageMetadata: MessageMetadata = ClientMessageMetadata | ServerMessageMetadata | None -@dataclass -class SessionMessage: +class SessionMessage(BaseModel): """A message with specific metadata for transport-specific features.""" message: JSONRPCMessage - metadata: MessageMetadata = None + metadata: MessageMetadata | None = None diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 6abcf70cbc..cd3dae293d 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -62,7 +62,7 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( SessionMessage( - JSONRPCMessage( + message=JSONRPCMessage( JSONRPCResponse( jsonrpc="2.0", id=jsonrpc_request.root.id, @@ -153,7 +153,7 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( SessionMessage( - JSONRPCMessage( + message=JSONRPCMessage( JSONRPCResponse( jsonrpc="2.0", id=jsonrpc_request.root.id, @@ -220,7 +220,7 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( SessionMessage( - JSONRPCMessage( + message=JSONRPCMessage( JSONRPCResponse( jsonrpc="2.0", id=jsonrpc_request.root.id, diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index 523ba199a4..d93c63aefe 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -23,7 +23,7 @@ async def test_stdio_client(): async with write_stream: for message in messages: - session_message = SessionMessage(message) + session_message = SessionMessage(message=message) await write_stream.send(session_message) read_messages = [] diff --git a/tests/issues/test_192_request_id.py b/tests/issues/test_192_request_id.py index cf5eb6083e..c05f08f8cd 100644 --- a/tests/issues/test_192_request_id.py +++ b/tests/issues/test_192_request_id.py @@ -65,7 +65,7 @@ async def run_server(): jsonrpc="2.0", ) - await client_writer.send(SessionMessage(JSONRPCMessage(root=init_req))) + await client_writer.send(SessionMessage(message=JSONRPCMessage(root=init_req))) response = ( await server_reader.receive() ) # Get init response but don't need to check it @@ -77,7 +77,7 @@ async def run_server(): jsonrpc="2.0", ) await client_writer.send( - SessionMessage(JSONRPCMessage(root=initialized_notification)) + SessionMessage(message=JSONRPCMessage(root=initialized_notification)) ) # Send ping request with custom ID @@ -85,7 +85,9 @@ async def run_server(): id=custom_request_id, method="ping", params={}, jsonrpc="2.0" ) - await client_writer.send(SessionMessage(JSONRPCMessage(root=ping_request))) + await client_writer.send( + SessionMessage(message=JSONRPCMessage(root=ping_request)) + ) # Read response response = await server_reader.receive() diff --git a/tests/server/message_dispatch/__init__.py b/tests/server/message_dispatch/__init__.py new file mode 100644 index 0000000000..df0d26c3e6 --- /dev/null +++ b/tests/server/message_dispatch/__init__.py @@ -0,0 +1 @@ +# Message queue tests module diff --git a/tests/server/message_dispatch/conftest.py b/tests/server/message_dispatch/conftest.py new file mode 100644 index 0000000000..3422da2aab --- /dev/null +++ b/tests/server/message_dispatch/conftest.py @@ -0,0 +1,28 @@ +"""Shared fixtures for message queue tests.""" + +from collections.abc import AsyncGenerator +from unittest.mock import patch + +import pytest + +from mcp.server.message_queue.redis import RedisMessageDispatch + +# Set up fakeredis for testing +try: + from fakeredis import aioredis as fake_redis +except ImportError: + pytest.skip( + "fakeredis is required for testing Redis functionality", allow_module_level=True + ) + + +@pytest.fixture +async def message_dispatch() -> AsyncGenerator[RedisMessageDispatch, None]: + """Create a shared Redis message dispatch with a fake Redis client.""" + with patch("mcp.server.message_queue.redis.redis", fake_redis.FakeRedis): + # Shorter TTL for testing + message_dispatch = RedisMessageDispatch(session_ttl=5) + try: + yield message_dispatch + finally: + await message_dispatch.close() diff --git a/tests/server/message_dispatch/test_redis.py b/tests/server/message_dispatch/test_redis.py new file mode 100644 index 0000000000..d355f9e688 --- /dev/null +++ b/tests/server/message_dispatch/test_redis.py @@ -0,0 +1,355 @@ +from unittest.mock import AsyncMock +from uuid import uuid4 + +import anyio +import pytest +from pydantic import ValidationError + +import mcp.types as types +from mcp.server.message_queue.redis import RedisMessageDispatch +from mcp.shared.message import SessionMessage + + +@pytest.mark.anyio +async def test_session_heartbeat(message_dispatch): + """Test that session heartbeat refreshes TTL.""" + session_id = uuid4() + + async with message_dispatch.subscribe(session_id, AsyncMock()): + session_key = message_dispatch._session_key(session_id) + + # Initial TTL + initial_ttl = await message_dispatch._redis.ttl(session_key) # type: ignore + assert initial_ttl > 0 + + # Wait for heartbeat to run + await anyio.sleep(message_dispatch._session_ttl / 2 + 0.5) + + # TTL should be refreshed + refreshed_ttl = await message_dispatch._redis.ttl(session_key) # type: ignore + assert refreshed_ttl > 0 + assert refreshed_ttl <= message_dispatch._session_ttl + + +@pytest.mark.anyio +async def test_subscribe_unsubscribe(message_dispatch): + """Test subscribing and unsubscribing from a session.""" + session_id = uuid4() + callback = AsyncMock() + + # Subscribe + async with message_dispatch.subscribe(session_id, callback): + # Check that session is tracked + assert session_id in message_dispatch._session_state + assert await message_dispatch.session_exists(session_id) + + # After context exit, session should be cleaned up + assert session_id not in message_dispatch._session_state + assert not await message_dispatch.session_exists(session_id) + + +@pytest.mark.anyio +async def test_publish_message_valid_json(message_dispatch: RedisMessageDispatch): + """Test publishing a valid JSON-RPC message.""" + session_id = uuid4() + callback = AsyncMock() + message = types.JSONRPCMessage.model_validate( + {"jsonrpc": "2.0", "method": "test", "params": {}, "id": 1} + ) + + # Subscribe to messages + async with message_dispatch.subscribe(session_id, callback): + # Publish message + published = await message_dispatch.publish_message( + session_id, SessionMessage(message=message) + ) + assert published + + # Give some time for the message to be processed + await anyio.sleep(0.1) + + # Callback should have been called with the message + callback.assert_called_once() + call_args = callback.call_args[0][0] + assert isinstance(call_args, SessionMessage) + assert isinstance(call_args.message.root, types.JSONRPCRequest) + assert ( + call_args.message.root.method == "test" + ) # Access method through root attribute + + +@pytest.mark.anyio +async def test_publish_message_invalid_json(message_dispatch): + """Test publishing an invalid JSON string.""" + session_id = uuid4() + callback = AsyncMock() + invalid_json = '{"invalid": "json",,}' # Invalid JSON + + # Subscribe to messages + async with message_dispatch.subscribe(session_id, callback): + # Publish invalid message + published = await message_dispatch.publish_message(session_id, invalid_json) + assert published + + # Give some time for the message to be processed + await anyio.sleep(0.1) + + # Callback should have been called with a ValidationError + callback.assert_called_once() + error = callback.call_args[0][0] + assert isinstance(error, ValidationError) + + +@pytest.mark.anyio +async def test_publish_to_nonexistent_session(message_dispatch: RedisMessageDispatch): + """Test publishing to a session that doesn't exist.""" + session_id = uuid4() + message = SessionMessage( + message=types.JSONRPCMessage.model_validate( + {"jsonrpc": "2.0", "method": "test", "params": {}, "id": 1} + ) + ) + + published = await message_dispatch.publish_message(session_id, message) + assert not published + + +@pytest.mark.anyio +async def test_extract_session_id(message_dispatch): + """Test extracting session ID from channel name.""" + session_id = uuid4() + channel = message_dispatch._session_channel(session_id) + + # Valid channel + extracted_id = message_dispatch._extract_session_id(channel) + assert extracted_id == session_id + + # Invalid channel format + extracted_id = message_dispatch._extract_session_id("invalid_channel_name") + assert extracted_id is None + + # Invalid UUID in channel + invalid_channel = f"{message_dispatch._prefix}session:invalid_uuid" + extracted_id = message_dispatch._extract_session_id(invalid_channel) + assert extracted_id is None + + +@pytest.mark.anyio +async def test_multiple_sessions(message_dispatch: RedisMessageDispatch): + """Test handling multiple concurrent sessions.""" + session1 = uuid4() + session2 = uuid4() + callback1 = AsyncMock() + callback2 = AsyncMock() + + async with message_dispatch.subscribe(session1, callback1): + async with message_dispatch.subscribe(session2, callback2): + # Both sessions should exist + assert await message_dispatch.session_exists(session1) + assert await message_dispatch.session_exists(session2) + + # Publish to session1 + message1 = types.JSONRPCMessage.model_validate( + {"jsonrpc": "2.0", "method": "test1", "params": {}, "id": 1} + ) + await message_dispatch.publish_message( + session1, SessionMessage(message=message1) + ) + + # Publish to session2 + message2 = types.JSONRPCMessage.model_validate( + {"jsonrpc": "2.0", "method": "test2", "params": {}, "id": 2} + ) + await message_dispatch.publish_message( + session2, SessionMessage(message=message2) + ) + + # Give some time for messages to be processed + await anyio.sleep(0.1) + + # Check callbacks + callback1.assert_called_once() + callback2.assert_called_once() + + call1_args = callback1.call_args[0][0] + assert isinstance(call1_args, SessionMessage) + assert call1_args.message.root.method == "test1" # type: ignore + + call2_args = callback2.call_args[0][0] + assert isinstance(call2_args, SessionMessage) + assert call2_args.message.root.method == "test2" # type: ignore + + +@pytest.mark.anyio +async def test_task_group_cancellation(message_dispatch): + """Test that task group is properly cancelled when context exits.""" + session_id = uuid4() + callback = AsyncMock() + + async with message_dispatch.subscribe(session_id, callback): + # Check that task group is active + _, task_group = message_dispatch._session_state[session_id] + assert task_group.cancel_scope.cancel_called is False + + # After context exit, task group should be cancelled + # And session state should be cleaned up + assert session_id not in message_dispatch._session_state + + +@pytest.mark.anyio +async def test_session_cancellation_isolation(message_dispatch): + """Test that cancelling one session doesn't affect other sessions.""" + session1 = uuid4() + session2 = uuid4() + + # Create a blocking callback for session1 to ensure it's running when cancelled + session1_event = anyio.Event() + session1_started = anyio.Event() + session1_cancelled = False + + async def blocking_callback1(msg): + session1_started.set() + try: + await session1_event.wait() + except anyio.get_cancelled_exc_class(): + nonlocal session1_cancelled + session1_cancelled = True + raise + + callback2 = AsyncMock() + + # Start session2 first + async with message_dispatch.subscribe(session2, callback2): + # Start session1 with a blocking callback + async with anyio.create_task_group() as tg: + + async def session1_runner(): + async with message_dispatch.subscribe(session1, blocking_callback1): + # Publish a message to trigger the blocking callback + message = types.JSONRPCMessage.model_validate( + {"jsonrpc": "2.0", "method": "test", "params": {}, "id": 1} + ) + await message_dispatch.publish_message(session1, message) + + # Wait for the callback to start + await session1_started.wait() + + # Keep the context alive while we test cancellation + await anyio.sleep_forever() + + tg.start_soon(session1_runner) + + # Wait for session1's callback to start + await session1_started.wait() + + # Cancel session1 + tg.cancel_scope.cancel() + + # Give some time for cancellation to propagate + await anyio.sleep(0.1) + + # Verify session1 was cancelled + assert session1_cancelled + assert session1 not in message_dispatch._session_state + + # Verify session2 is still active and can receive messages + assert await message_dispatch.session_exists(session2) + message2 = types.JSONRPCMessage.model_validate( + {"jsonrpc": "2.0", "method": "test2", "params": {}, "id": 2} + ) + await message_dispatch.publish_message(session2, message2) + + # Give some time for the message to be processed + await anyio.sleep(0.1) + + # Verify session2 received the message + callback2.assert_called_once() + call_args = callback2.call_args[0][0] + assert call_args.root.method == "test2" + + +@pytest.mark.anyio +async def test_listener_task_handoff_on_cancellation(message_dispatch): + """ + Test that the single listening task is properly + handed off when a session is cancelled. + """ + session1 = uuid4() + session2 = uuid4() + + session1_messages_received = 0 + session2_messages_received = 0 + + async def callback1(msg): + nonlocal session1_messages_received + session1_messages_received += 1 + + async def callback2(msg): + nonlocal session2_messages_received + session2_messages_received += 1 + + # Create a cancel scope for session1 + async with anyio.create_task_group() as tg: + session1_cancel_scope: anyio.CancelScope | None = None + + async def session1_runner(): + nonlocal session1_cancel_scope + with anyio.CancelScope() as cancel_scope: + session1_cancel_scope = cancel_scope + async with message_dispatch.subscribe(session1, callback1): + # Keep session alive until cancelled + await anyio.sleep_forever() + + # Start session1 + tg.start_soon(session1_runner) + + # Wait for session1 to be established + await anyio.sleep(0.1) + assert session1 in message_dispatch._session_state + + # Send message to session1 to verify it's working + message1 = types.JSONRPCMessage.model_validate( + {"jsonrpc": "2.0", "method": "test1", "params": {}, "id": 1} + ) + await message_dispatch.publish_message(session1, message1) + await anyio.sleep(0.1) + assert session1_messages_received == 1 + + # Start session2 while session1 is still active + async with message_dispatch.subscribe(session2, callback2): + # Both sessions should be active + assert session1 in message_dispatch._session_state + assert session2 in message_dispatch._session_state + + # Cancel session1 + assert session1_cancel_scope is not None + session1_cancel_scope.cancel() + + # Wait for cancellation to complete + await anyio.sleep(0.1) + + # Session1 should be gone, session2 should remain + assert session1 not in message_dispatch._session_state + assert session2 in message_dispatch._session_state + + # Send message to session2 to verify the listener was handed off + message2 = types.JSONRPCMessage.model_validate( + {"jsonrpc": "2.0", "method": "test2", "params": {}, "id": 2} + ) + await message_dispatch.publish_message(session2, message2) + await anyio.sleep(0.1) + + # Session2 should have received the message + assert session2_messages_received == 1 + + # Session1 shouldn't receive any more messages + assert session1_messages_received == 1 + + # Send another message to verify the listener is still working + message3 = types.JSONRPCMessage.model_validate( + {"jsonrpc": "2.0", "method": "test3", "params": {}, "id": 3} + ) + await message_dispatch.publish_message(session2, message3) + await anyio.sleep(0.1) + + assert session2_messages_received == 2 diff --git a/tests/server/message_dispatch/test_redis_integration.py b/tests/server/message_dispatch/test_redis_integration.py new file mode 100644 index 0000000000..f01113872d --- /dev/null +++ b/tests/server/message_dispatch/test_redis_integration.py @@ -0,0 +1,260 @@ +""" +Integration tests for Redis message dispatch functionality. + +These tests validate Redis message dispatch by making actual HTTP calls and testing +that messages flow correctly through the Redis backend. + +This version runs the server in a task instead of a separate process to allow +access to the fakeredis instance for verification of Redis keys. +""" + +import asyncio +import socket +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager + +import anyio +import pytest +import uvicorn +from sse_starlette.sse import AppStatus +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import Response +from starlette.routing import Mount, Route + +from mcp.client.session import ClientSession +from mcp.client.sse import sse_client +from mcp.server import Server +from mcp.server.message_queue.redis import RedisMessageDispatch +from mcp.server.sse import SseServerTransport +from mcp.types import TextContent, Tool + +SERVER_NAME = "test_server_for_redis_integration_v3" + + +@pytest.fixture +def server_port() -> int: + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@pytest.fixture +def server_url(server_port: int) -> str: + return f"http://127.0.0.1:{server_port}" + + +class RedisTestServer(Server): + """Test server with basic tool functionality.""" + + def __init__(self): + super().__init__(SERVER_NAME) + + @self.list_tools() + async def handle_list_tools() -> list[Tool]: + return [ + Tool( + name="test_tool", + description="A test tool", + inputSchema={"type": "object", "properties": {}}, + ), + Tool( + name="echo_message", + description="Echo a message back", + inputSchema={ + "type": "object", + "properties": {"message": {"type": "string"}}, + "required": ["message"], + }, + ), + ] + + @self.call_tool() + async def handle_call_tool(name: str, args: dict) -> list[TextContent]: + if name == "echo_message": + message = args.get("message", "") + return [TextContent(type="text", text=f"Echo: {message}")] + return [TextContent(type="text", text=f"Called {name}")] + + +@pytest.fixture() +async def redis_server_and_app(message_dispatch: RedisMessageDispatch): + """Create a mock Redis instance and Starlette app for testing.""" + + # Create SSE transport with Redis message dispatch + sse = SseServerTransport("/messages/", message_dispatch=message_dispatch) + server = RedisTestServer() + + async def handle_sse(request: Request): + async with sse.connect_sse( + request.scope, request.receive, request._send + ) as streams: + await server.run( + streams[0], streams[1], server.create_initialization_options() + ) + return Response() + + @asynccontextmanager + async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: + """Manage the lifecycle of the application.""" + try: + yield + finally: + await message_dispatch.close() + + app = Starlette( + routes=[ + Route("/sse", endpoint=handle_sse), + Mount("/messages/", app=sse.handle_post_message), + ], + lifespan=lifespan, + ) + + return app, message_dispatch, message_dispatch._redis + + +@pytest.fixture() +async def server_and_redis(redis_server_and_app, server_port: int): + """Run the server in a task and return the Redis instance for inspection.""" + app, message_dispatch, mock_redis = redis_server_and_app + + # Create a server config + config = uvicorn.Config( + app=app, host="127.0.0.1", port=server_port, log_level="error" + ) + server = uvicorn.Server(config=config) + try: + async with anyio.create_task_group() as tg: + # Start server in background + tg.start_soon(server.serve) + + # Wait for server to be ready + max_attempts = 20 + attempt = 0 + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("127.0.0.1", server_port)) + break + except ConnectionRefusedError: + await anyio.sleep(0.1) + attempt += 1 + else: + raise RuntimeError( + f"Server failed to start after {max_attempts} attempts" + ) + + try: + yield mock_redis, message_dispatch + finally: + server.should_exit = True + finally: + # These class variables are set top-level in starlette-sse + # It isn't designed to be run multiple times in a single + # Python process so we need to manually reset them. + AppStatus.should_exit = False + AppStatus.should_exit_event = None + + +@pytest.fixture() +async def client_session(server_and_redis, server_url: str): + """Create a client session for testing.""" + async with sse_client(server_url + "/sse") as streams: + async with ClientSession(*streams) as session: + result = await session.initialize() + assert result.serverInfo.name == SERVER_NAME + yield session + + +@pytest.mark.anyio +async def test_redis_integration_key_verification( + server_and_redis, client_session +) -> None: + """Test that Redis keys are created correctly for sessions.""" + mock_redis, _ = server_and_redis + + all_keys = await mock_redis.keys("*") # type: ignore + + assert len(all_keys) > 0 + + session_key = None + for key in all_keys: + if key.startswith("mcp:pubsub:session_active:"): + session_key = key + break + + assert session_key is not None, f"No session key found. Keys: {all_keys}" + + ttl = await mock_redis.ttl(session_key) # type: ignore + assert ttl > 0, f"Session key should have TTL, got: {ttl}" + + +@pytest.mark.anyio +async def test_tool_calls(server_and_redis, client_session) -> None: + """Test that messages are properly published through Redis.""" + mock_redis, _ = server_and_redis + + for i in range(3): + tool_result = await client_session.call_tool( + "echo_message", {"message": f"Test {i}"} + ) + assert tool_result.content[0].text == f"Echo: Test {i}" # type: ignore + + +@pytest.mark.anyio +async def test_session_cleanup(server_and_redis, server_url: str) -> None: + """Test Redis key cleanup when sessions end.""" + mock_redis, _ = server_and_redis + session_keys_seen = set() + + for i in range(3): + async with sse_client(server_url + "/sse") as streams: + async with ClientSession(*streams) as session: + await session.initialize() + + all_keys = await mock_redis.keys("*") # type: ignore + for key in all_keys: + if key.startswith("mcp:pubsub:session_active:"): + session_keys_seen.add(key) + value = await mock_redis.get(key) # type: ignore + assert value == "1" + + await anyio.sleep(0.1) # Give time for cleanup + all_keys = await mock_redis.keys("*") # type: ignore + assert ( + len(all_keys) == 0 + ), f"Session keys should be cleaned up, found: {all_keys}" + + # Verify we saw different session keys for each session + assert len(session_keys_seen) == 3, "Should have seen 3 unique session keys" + + +@pytest.mark.anyio +async def concurrent_tool_call(server_and_redis, server_url: str) -> None: + """Test multiple clients and verify Redis key management.""" + mock_redis, _ = server_and_redis + + async def client_task(client_id: int) -> str: + async with sse_client(server_url + "/sse") as streams: + async with ClientSession(*streams) as session: + await session.initialize() + + result = await session.call_tool( + "echo_message", + {"message": f"Message from client {client_id}"}, + ) + return result.content[0].text # type: ignore + + # Run multiple clients concurrently + client_tasks = [client_task(i) for i in range(3)] + results = await asyncio.gather(*client_tasks) + + # Verify all clients received their respective messages + assert len(results) == 3 + for i, result in enumerate(results): + assert result == f"Echo: Message from client {i}" + + # After all clients disconnect, keys should be cleaned up + await anyio.sleep(0.1) # Give time for cleanup + all_keys = await mock_redis.keys("*") # type: ignore + assert len(all_keys) == 0, f"Session keys should be cleaned up, found: {all_keys}" diff --git a/tests/server/test_lifespan.py b/tests/server/test_lifespan.py index a3ff59bc1b..d8e76de1ac 100644 --- a/tests/server/test_lifespan.py +++ b/tests/server/test_lifespan.py @@ -84,7 +84,7 @@ async def run_server(): ) await send_stream1.send( SessionMessage( - JSONRPCMessage( + message=JSONRPCMessage( root=JSONRPCRequest( jsonrpc="2.0", id=1, @@ -100,7 +100,7 @@ async def run_server(): # Send initialized notification await send_stream1.send( SessionMessage( - JSONRPCMessage( + message=JSONRPCMessage( root=JSONRPCNotification( jsonrpc="2.0", method="notifications/initialized", @@ -112,7 +112,7 @@ async def run_server(): # Call the tool to verify lifespan context await send_stream1.send( SessionMessage( - JSONRPCMessage( + message=JSONRPCMessage( root=JSONRPCRequest( jsonrpc="2.0", id=2, @@ -188,7 +188,7 @@ async def run_server(): ) await send_stream1.send( SessionMessage( - JSONRPCMessage( + message=JSONRPCMessage( root=JSONRPCRequest( jsonrpc="2.0", id=1, @@ -204,7 +204,7 @@ async def run_server(): # Send initialized notification await send_stream1.send( SessionMessage( - JSONRPCMessage( + message=JSONRPCMessage( root=JSONRPCNotification( jsonrpc="2.0", method="notifications/initialized", @@ -216,7 +216,7 @@ async def run_server(): # Call the tool to verify lifespan context await send_stream1.send( SessionMessage( - JSONRPCMessage( + message=JSONRPCMessage( root=JSONRPCRequest( jsonrpc="2.0", id=2, diff --git a/tests/server/test_stdio.py b/tests/server/test_stdio.py index c546a7167b..570e4c1999 100644 --- a/tests/server/test_stdio.py +++ b/tests/server/test_stdio.py @@ -51,7 +51,7 @@ async def test_stdio_server(): async with write_stream: for response in responses: - session_message = SessionMessage(response) + session_message = SessionMessage(message=response) await write_stream.send(session_message) stdout.seek(0) diff --git a/uv.lock b/uv.lock index 06dd240b25..e819dbfe87 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,4 @@ version = 1 -revision = 1 requires-python = ">=3.10" [options] @@ -39,6 +38,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3b/68/f9e9bf6324c46e6b8396610aef90ad423ec3e18c9079547ceafea3dce0ec/anyio-4.5.0-py3-none-any.whl", hash = "sha256:fdeb095b7cc5a5563175eedd926ec4ae55413bb4be5770c424af0ba46ccb4a78", size = 89250 }, ] +[[package]] +name = "async-timeout" +version = "5.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a5/ae/136395dfbfe00dfc94da3f3e136d0b13f394cba8f4841120e34226265780/async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3", size = 9274 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fe/ba/e2081de779ca30d473f21f5b30e0e737c438205440784c7dfc81efc2b029/async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c", size = 6233 }, +] + [[package]] name = "attrs" version = "24.3.0" @@ -267,6 +275,51 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335 }, ] +[[package]] +name = "cryptography" +version = "44.0.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cffi", marker = "platform_python_implementation != 'PyPy'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/cd/25/4ce80c78963834b8a9fd1cc1266be5ed8d1840785c0f2e1b73b8d128d505/cryptography-44.0.2.tar.gz", hash = "sha256:c63454aa261a0cf0c5b4718349629793e9e634993538db841165b3df74f37ec0", size = 710807 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/92/ef/83e632cfa801b221570c5f58c0369db6fa6cef7d9ff859feab1aae1a8a0f/cryptography-44.0.2-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:efcfe97d1b3c79e486554efddeb8f6f53a4cdd4cf6086642784fa31fc384e1d7", size = 6676361 }, + { url = "https://files.pythonhosted.org/packages/30/ec/7ea7c1e4c8fc8329506b46c6c4a52e2f20318425d48e0fe597977c71dbce/cryptography-44.0.2-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:29ecec49f3ba3f3849362854b7253a9f59799e3763b0c9d0826259a88efa02f1", size = 3952350 }, + { url = "https://files.pythonhosted.org/packages/27/61/72e3afdb3c5ac510330feba4fc1faa0fe62e070592d6ad00c40bb69165e5/cryptography-44.0.2-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc821e161ae88bfe8088d11bb39caf2916562e0a2dc7b6d56714a48b784ef0bb", size = 4166572 }, + { url = "https://files.pythonhosted.org/packages/26/e4/ba680f0b35ed4a07d87f9e98f3ebccb05091f3bf6b5a478b943253b3bbd5/cryptography-44.0.2-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:3c00b6b757b32ce0f62c574b78b939afab9eecaf597c4d624caca4f9e71e7843", size = 3958124 }, + { url = "https://files.pythonhosted.org/packages/9c/e8/44ae3e68c8b6d1cbc59040288056df2ad7f7f03bbcaca6b503c737ab8e73/cryptography-44.0.2-cp37-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:7bdcd82189759aba3816d1f729ce42ffded1ac304c151d0a8e89b9996ab863d5", size = 3678122 }, + { url = "https://files.pythonhosted.org/packages/27/7b/664ea5e0d1eab511a10e480baf1c5d3e681c7d91718f60e149cec09edf01/cryptography-44.0.2-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:4973da6ca3db4405c54cd0b26d328be54c7747e89e284fcff166132eb7bccc9c", size = 4191831 }, + { url = "https://files.pythonhosted.org/packages/2a/07/79554a9c40eb11345e1861f46f845fa71c9e25bf66d132e123d9feb8e7f9/cryptography-44.0.2-cp37-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:4e389622b6927d8133f314949a9812972711a111d577a5d1f4bee5e58736b80a", size = 3960583 }, + { url = "https://files.pythonhosted.org/packages/bb/6d/858e356a49a4f0b591bd6789d821427de18432212e137290b6d8a817e9bf/cryptography-44.0.2-cp37-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:f514ef4cd14bb6fb484b4a60203e912cfcb64f2ab139e88c2274511514bf7308", size = 4191753 }, + { url = "https://files.pythonhosted.org/packages/b2/80/62df41ba4916067fa6b125aa8c14d7e9181773f0d5d0bd4dcef580d8b7c6/cryptography-44.0.2-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:1bc312dfb7a6e5d66082c87c34c8a62176e684b6fe3d90fcfe1568de675e6688", size = 4079550 }, + { url = "https://files.pythonhosted.org/packages/f3/cd/2558cc08f7b1bb40683f99ff4327f8dcfc7de3affc669e9065e14824511b/cryptography-44.0.2-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:3b721b8b4d948b218c88cb8c45a01793483821e709afe5f622861fc6182b20a7", size = 4298367 }, + { url = "https://files.pythonhosted.org/packages/71/59/94ccc74788945bc3bd4cf355d19867e8057ff5fdbcac781b1ff95b700fb1/cryptography-44.0.2-cp37-abi3-win32.whl", hash = "sha256:51e4de3af4ec3899d6d178a8c005226491c27c4ba84101bfb59c901e10ca9f79", size = 2772843 }, + { url = "https://files.pythonhosted.org/packages/ca/2c/0d0bbaf61ba05acb32f0841853cfa33ebb7a9ab3d9ed8bb004bd39f2da6a/cryptography-44.0.2-cp37-abi3-win_amd64.whl", hash = "sha256:c505d61b6176aaf982c5717ce04e87da5abc9a36a5b39ac03905c4aafe8de7aa", size = 3209057 }, + { url = "https://files.pythonhosted.org/packages/9e/be/7a26142e6d0f7683d8a382dd963745e65db895a79a280a30525ec92be890/cryptography-44.0.2-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:8e0ddd63e6bf1161800592c71ac794d3fb8001f2caebe0966e77c5234fa9efc3", size = 6677789 }, + { url = "https://files.pythonhosted.org/packages/06/88/638865be7198a84a7713950b1db7343391c6066a20e614f8fa286eb178ed/cryptography-44.0.2-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:81276f0ea79a208d961c433a947029e1a15948966658cf6710bbabb60fcc2639", size = 3951919 }, + { url = "https://files.pythonhosted.org/packages/d7/fc/99fe639bcdf58561dfad1faa8a7369d1dc13f20acd78371bb97a01613585/cryptography-44.0.2-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9a1e657c0f4ea2a23304ee3f964db058c9e9e635cc7019c4aa21c330755ef6fd", size = 4167812 }, + { url = "https://files.pythonhosted.org/packages/53/7b/aafe60210ec93d5d7f552592a28192e51d3c6b6be449e7fd0a91399b5d07/cryptography-44.0.2-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:6210c05941994290f3f7f175a4a57dbbb2afd9273657614c506d5976db061181", size = 3958571 }, + { url = "https://files.pythonhosted.org/packages/16/32/051f7ce79ad5a6ef5e26a92b37f172ee2d6e1cce09931646eef8de1e9827/cryptography-44.0.2-cp39-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:d1c3572526997b36f245a96a2b1713bf79ce99b271bbcf084beb6b9b075f29ea", size = 3679832 }, + { url = "https://files.pythonhosted.org/packages/78/2b/999b2a1e1ba2206f2d3bca267d68f350beb2b048a41ea827e08ce7260098/cryptography-44.0.2-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:b042d2a275c8cee83a4b7ae30c45a15e6a4baa65a179a0ec2d78ebb90e4f6699", size = 4193719 }, + { url = "https://files.pythonhosted.org/packages/72/97/430e56e39a1356e8e8f10f723211a0e256e11895ef1a135f30d7d40f2540/cryptography-44.0.2-cp39-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:d03806036b4f89e3b13b6218fefea8d5312e450935b1a2d55f0524e2ed7c59d9", size = 3960852 }, + { url = "https://files.pythonhosted.org/packages/89/33/c1cf182c152e1d262cac56850939530c05ca6c8d149aa0dcee490b417e99/cryptography-44.0.2-cp39-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:c7362add18b416b69d58c910caa217f980c5ef39b23a38a0880dfd87bdf8cd23", size = 4193906 }, + { url = "https://files.pythonhosted.org/packages/e1/99/87cf26d4f125380dc674233971069bc28d19b07f7755b29861570e513650/cryptography-44.0.2-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:8cadc6e3b5a1f144a039ea08a0bdb03a2a92e19c46be3285123d32029f40a922", size = 4081572 }, + { url = "https://files.pythonhosted.org/packages/b3/9f/6a3e0391957cc0c5f84aef9fbdd763035f2b52e998a53f99345e3ac69312/cryptography-44.0.2-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:6f101b1f780f7fc613d040ca4bdf835c6ef3b00e9bd7125a4255ec574c7916e4", size = 4298631 }, + { url = "https://files.pythonhosted.org/packages/e2/a5/5bc097adb4b6d22a24dea53c51f37e480aaec3465285c253098642696423/cryptography-44.0.2-cp39-abi3-win32.whl", hash = "sha256:3dc62975e31617badc19a906481deacdeb80b4bb454394b4098e3f2525a488c5", size = 2773792 }, + { url = "https://files.pythonhosted.org/packages/33/cf/1f7649b8b9a3543e042d3f348e398a061923ac05b507f3f4d95f11938aa9/cryptography-44.0.2-cp39-abi3-win_amd64.whl", hash = "sha256:5f6f90b72d8ccadb9c6e311c775c8305381db88374c65fa1a68250aa8a9cb3a6", size = 3210957 }, + { url = "https://files.pythonhosted.org/packages/99/10/173be140714d2ebaea8b641ff801cbcb3ef23101a2981cbf08057876f89e/cryptography-44.0.2-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:af4ff3e388f2fa7bff9f7f2b31b87d5651c45731d3e8cfa0944be43dff5cfbdb", size = 3396886 }, + { url = "https://files.pythonhosted.org/packages/2f/b4/424ea2d0fce08c24ede307cead3409ecbfc2f566725d4701b9754c0a1174/cryptography-44.0.2-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:0529b1d5a0105dd3731fa65680b45ce49da4d8115ea76e9da77a875396727b41", size = 3892387 }, + { url = "https://files.pythonhosted.org/packages/28/20/8eaa1a4f7c68a1cb15019dbaad59c812d4df4fac6fd5f7b0b9c5177f1edd/cryptography-44.0.2-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:7ca25849404be2f8e4b3c59483d9d3c51298a22c1c61a0e84415104dacaf5562", size = 4109922 }, + { url = "https://files.pythonhosted.org/packages/11/25/5ed9a17d532c32b3bc81cc294d21a36c772d053981c22bd678396bc4ae30/cryptography-44.0.2-pp310-pypy310_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:268e4e9b177c76d569e8a145a6939eca9a5fec658c932348598818acf31ae9a5", size = 3895715 }, + { url = "https://files.pythonhosted.org/packages/63/31/2aac03b19c6329b62c45ba4e091f9de0b8f687e1b0cd84f101401bece343/cryptography-44.0.2-pp310-pypy310_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:9eb9d22b0a5d8fd9925a7764a054dca914000607dff201a24c791ff5c799e1fa", size = 4109876 }, + { url = "https://files.pythonhosted.org/packages/99/ec/6e560908349843718db1a782673f36852952d52a55ab14e46c42c8a7690a/cryptography-44.0.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:2bf7bf75f7df9715f810d1b038870309342bff3069c5bd8c6b96128cb158668d", size = 3131719 }, + { url = "https://files.pythonhosted.org/packages/d6/d7/f30e75a6aa7d0f65031886fa4a1485c2fbfe25a1896953920f6a9cfe2d3b/cryptography-44.0.2-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:909c97ab43a9c0c0b0ada7a1281430e4e5ec0458e6d9244c0e821bbf152f061d", size = 3887513 }, + { url = "https://files.pythonhosted.org/packages/9c/b4/7a494ce1032323ca9db9a3661894c66e0d7142ad2079a4249303402d8c71/cryptography-44.0.2-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:96e7a5e9d6e71f9f4fca8eebfd603f8e86c5225bb18eb621b2c1e50b290a9471", size = 4107432 }, + { url = "https://files.pythonhosted.org/packages/45/f8/6b3ec0bc56123b344a8d2b3264a325646d2dcdbdd9848b5e6f3d37db90b3/cryptography-44.0.2-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:d1b3031093a366ac767b3feb8bcddb596671b3aaff82d4050f984da0c248b615", size = 3891421 }, + { url = "https://files.pythonhosted.org/packages/57/ff/f3b4b2d007c2a646b0f69440ab06224f9cf37a977a72cdb7b50632174e8a/cryptography-44.0.2-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:04abd71114848aa25edb28e225ab5f268096f44cf0127f3d36975bdf1bdf3390", size = 4107081 }, +] + [[package]] name = "cssselect2" version = "0.8.0" @@ -307,6 +360,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/43/09/2aea36ff60d16dd8879bdb2f5b3ee0ba8d08cbbdcdfe870e695ce3784385/execnet-2.1.1-py3-none-any.whl", hash = "sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc", size = 40612 }, ] +[[package]] +name = "fakeredis" +version = "2.28.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "redis" }, + { name = "sortedcontainers" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/60/32/8c1c215e50cb055e24a8d5a8981edab665d131ea9068c420bf81eb0fcb63/fakeredis-2.28.1.tar.gz", hash = "sha256:5e542200b945aa0a7afdc0396efefe3cdabab61bc0f41736cc45f68960255964", size = 161179 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e1/77/bca49c4960c22131da3acb647978983bea07f15c255fbef0a6559a774a7a/fakeredis-2.28.1-py3-none-any.whl", hash = "sha256:38c7c17fba5d5522af9d980a8f74a4da9900a3441e8f25c0fe93ea4205d695d1", size = 113685 }, +] + [[package]] name = "ghp-import" version = "2.1.0" @@ -507,6 +574,10 @@ cli = [ { name = "python-dotenv" }, { name = "typer" }, ] +redis = [ + { name = "redis" }, + { name = "types-redis" }, +] rich = [ { name = "rich" }, ] @@ -516,6 +587,7 @@ ws = [ [package.dev-dependencies] dev = [ + { name = "fakeredis" }, { name = "pyright" }, { name = "pytest" }, { name = "pytest-examples" }, @@ -541,17 +613,19 @@ requires-dist = [ { name = "pydantic-settings", specifier = ">=2.5.2" }, { name = "python-dotenv", marker = "extra == 'cli'", specifier = ">=1.0.0" }, { name = "python-multipart", specifier = ">=0.0.9" }, + { name = "redis", marker = "extra == 'redis'", specifier = ">=5.2.1" }, { name = "rich", marker = "extra == 'rich'", specifier = ">=13.9.4" }, { name = "sse-starlette", specifier = ">=1.6.1" }, { name = "starlette", specifier = ">=0.27" }, { name = "typer", marker = "extra == 'cli'", specifier = ">=0.12.4" }, + { name = "types-redis", marker = "extra == 'redis'", specifier = ">=4.6.0.20241004" }, { name = "uvicorn", marker = "sys_platform != 'emscripten'", specifier = ">=0.23.1" }, { name = "websockets", marker = "extra == 'ws'", specifier = ">=15.0.1" }, ] -provides-extras = ["cli", "rich", "ws"] [package.metadata.requires-dev] dev = [ + { name = "fakeredis", specifier = "==2.28.1" }, { name = "pyright", specifier = ">=1.1.391" }, { name = "pytest", specifier = ">=8.3.4" }, { name = "pytest-examples", specifier = ">=0.0.14" }, @@ -1323,6 +1397,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5a/66/bbb1dd374f5c870f59c5bb1db0e18cbe7fa739415a24cbd95b2d1f5ae0c4/pyyaml_env_tag-0.1-py3-none-any.whl", hash = "sha256:af31106dec8a4d68c60207c1886031cbf839b68aa7abccdb19868200532c2069", size = 3911 }, ] +[[package]] +name = "redis" +version = "5.2.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "async-timeout", marker = "python_full_version < '3.11.3'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/47/da/d283a37303a995cd36f8b92db85135153dc4f7a8e4441aa827721b442cfb/redis-5.2.1.tar.gz", hash = "sha256:16f2e22dff21d5125e8481515e386711a34cbec50f0e44413dd7d9c060a54e0f", size = 4608355 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3c/5f/fa26b9b2672cbe30e07d9a5bdf39cf16e3b80b42916757c5f92bca88e4ba/redis-5.2.1-py3-none-any.whl", hash = "sha256:ee7e1056b9aea0f04c6c2ed59452947f34c4940ee025f5dd83e6a6418b6989e4", size = 261502 }, +] + [[package]] name = "regex" version = "2024.11.6" @@ -1446,6 +1532,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/bf/5e/ffee22bf9f9e4b2669d1f0179ae8804584939fb6502b51f2401e26b1e028/ruff-0.8.5-py3-none-win_arm64.whl", hash = "sha256:134ae019ef13e1b060ab7136e7828a6d83ea727ba123381307eb37c6bd5e01cb", size = 9124741 }, ] +[[package]] +name = "setuptools" +version = "78.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a9/5a/0db4da3bc908df06e5efae42b44e75c81dd52716e10192ff36d0c1c8e379/setuptools-78.1.0.tar.gz", hash = "sha256:18fd474d4a82a5f83dac888df697af65afa82dec7323d09c3e37d1f14288da54", size = 1367827 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/54/21/f43f0a1fa8b06b32812e0975981f4677d28e0f3271601dc88ac5a5b83220/setuptools-78.1.0-py3-none-any.whl", hash = "sha256:3e386e96793c8702ae83d17b853fb93d3e09ef82ec62722e61da5cd22376dcd8", size = 1256108 }, +] + [[package]] name = "shellingham" version = "1.5.4" @@ -1590,6 +1685,56 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ae/cc/15083dcde1252a663398b1b2a173637a3ec65adadfb95137dc95df1e6adc/typer-0.12.4-py3-none-any.whl", hash = "sha256:819aa03699f438397e876aa12b0d63766864ecba1b579092cc9fe35d886e34b6", size = 47402 }, ] +[[package]] +name = "types-cffi" +version = "1.17.0.20250326" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "types-setuptools" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3f/3b/d29491d754b9e42edd4890648311ffa5d4d000b7d97b92ac4d04faad40d8/types_cffi-1.17.0.20250326.tar.gz", hash = "sha256:6c8fea2c2f34b55e5fb77b1184c8ad849d57cf0ddccbc67a62121ac4b8b32254", size = 16887 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/61/49/ce473d7fbc2c80931ef9f7530fd3ddf31b8a5bca56340590334ce6ffbfb1/types_cffi-1.17.0.20250326-py3-none-any.whl", hash = "sha256:5af4ecd7374ae0d5fa9e80864e8d4b31088cc32c51c544e3af7ed5b5ed681447", size = 20133 }, +] + +[[package]] +name = "types-pyopenssl" +version = "24.1.0.20240722" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cryptography" }, + { name = "types-cffi" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/93/29/47a346550fd2020dac9a7a6d033ea03fccb92fa47c726056618cc889745e/types-pyOpenSSL-24.1.0.20240722.tar.gz", hash = "sha256:47913b4678a01d879f503a12044468221ed8576263c1540dcb0484ca21b08c39", size = 8458 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/98/05/c868a850b6fbb79c26f5f299b768ee0adc1f9816d3461dcf4287916f655b/types_pyOpenSSL-24.1.0.20240722-py3-none-any.whl", hash = "sha256:6a7a5d2ec042537934cfb4c9d4deb0e16c4c6250b09358df1f083682fe6fda54", size = 7499 }, +] + +[[package]] +name = "types-redis" +version = "4.6.0.20241004" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cryptography" }, + { name = "types-pyopenssl" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3a/95/c054d3ac940e8bac4ca216470c80c26688a0e79e09f520a942bb27da3386/types-redis-4.6.0.20241004.tar.gz", hash = "sha256:5f17d2b3f9091ab75384153bfa276619ffa1cf6a38da60e10d5e6749cc5b902e", size = 49679 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/55/82/7d25dce10aad92d2226b269bce2f85cfd843b4477cd50245d7d40ecf8f89/types_redis-4.6.0.20241004-py3-none-any.whl", hash = "sha256:ef5da68cb827e5f606c8f9c0b49eeee4c2669d6d97122f301d3a55dc6a63f6ed", size = 58737 }, +] + +[[package]] +name = "types-setuptools" +version = "78.1.0.20250329" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "setuptools" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e9/6e/c54e6705e5fe67c3606e4c7c91123ecf10d7e1e6d7a9c11b52970cf2196c/types_setuptools-78.1.0.20250329.tar.gz", hash = "sha256:31e62950c38b8cc1c5114b077504e36426860a064287cac11b9666ab3a483234", size = 43942 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7d/31/85d0264705d8ef47680d28f4dc9bb1e27d8cace785fbe3f8d009fad6cb88/types_setuptools-78.1.0.20250329-py3-none-any.whl", hash = "sha256:ea47eab891afb506f470eee581dcde44d64dc99796665da794da6f83f50f6776", size = 66985 }, +] + [[package]] name = "typing-extensions" version = "4.12.2" From e0d443c95e2bb62d0ac670529b03b3dd3ec373bd Mon Sep 17 00:00:00 2001 From: tim-watcha <92134765+tim-watcha@users.noreply.github.com> Date: Wed, 7 May 2025 19:14:25 +0900 Subject: [PATCH 13/21] Add mount_path support for proper SSE endpoint routing with multiple FastMCP servers (#540) Co-authored-by: ihrpr --- README.md | 37 +++++++++++ src/mcp/server/fastmcp/server.py | 54 ++++++++++++++-- tests/server/fastmcp/test_server.py | 95 ++++++++++++++++++++++++++++- 3 files changed, 178 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 56aa3609ae..a0d47606ef 100644 --- a/README.md +++ b/README.md @@ -410,6 +410,43 @@ app = Starlette( app.router.routes.append(Host('mcp.acme.corp', app=mcp.sse_app())) ``` +When mounting multiple MCP servers under different paths, you can configure the mount path in several ways: + +```python +from starlette.applications import Starlette +from starlette.routing import Mount +from mcp.server.fastmcp import FastMCP + +# Create multiple MCP servers +github_mcp = FastMCP("GitHub API") +browser_mcp = FastMCP("Browser") +curl_mcp = FastMCP("Curl") +search_mcp = FastMCP("Search") + +# Method 1: Configure mount paths via settings (recommended for persistent configuration) +github_mcp.settings.mount_path = "/github" +browser_mcp.settings.mount_path = "/browser" + +# Method 2: Pass mount path directly to sse_app (preferred for ad-hoc mounting) +# This approach doesn't modify the server's settings permanently + +# Create Starlette app with multiple mounted servers +app = Starlette( + routes=[ + # Using settings-based configuration + Mount("/github", app=github_mcp.sse_app()), + Mount("/browser", app=browser_mcp.sse_app()), + # Using direct mount path parameter + Mount("/curl", app=curl_mcp.sse_app("/curl")), + Mount("/search", app=search_mcp.sse_app("/search")), + ] +) + +# Method 3: For direct execution, you can also pass the mount path to run() +if __name__ == "__main__": + search_mcp.run(transport="sse", mount_path="/search") +``` + For more information on mounting applications in Starlette, see the [Starlette documentation](https://www.starlette.io/routing/#submounting-routes). #### Message Dispatch Options diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 1e5b69eba0..24fe971057 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -88,6 +88,7 @@ class Settings(BaseSettings, Generic[LifespanResultT]): # HTTP settings host: str = "0.0.0.0" port: int = 8000 + mount_path: str = "/" # Mount path (e.g. "/github", defaults to root path) sse_path: str = "/sse" message_path: str = "/messages/" @@ -184,11 +185,16 @@ def name(self) -> str: def instructions(self) -> str | None: return self._mcp_server.instructions - def run(self, transport: Literal["stdio", "sse"] = "stdio") -> None: + def run( + self, + transport: Literal["stdio", "sse"] = "stdio", + mount_path: str | None = None, + ) -> None: """Run the FastMCP server. Note this is a synchronous function. Args: transport: Transport protocol to use ("stdio" or "sse") + mount_path: Optional mount path for SSE transport """ TRANSPORTS = Literal["stdio", "sse"] if transport not in TRANSPORTS.__args__: # type: ignore @@ -197,7 +203,7 @@ def run(self, transport: Literal["stdio", "sse"] = "stdio") -> None: if transport == "stdio": anyio.run(self.run_stdio_async) else: # transport == "sse" - anyio.run(self.run_sse_async) + anyio.run(lambda: self.run_sse_async(mount_path)) def _setup_handlers(self) -> None: """Set up core MCP protocol handlers.""" @@ -558,11 +564,11 @@ async def run_stdio_async(self) -> None: self._mcp_server.create_initialization_options(), ) - async def run_sse_async(self) -> None: + async def run_sse_async(self, mount_path: str | None = None) -> None: """Run the server using SSE transport.""" import uvicorn - starlette_app = self.sse_app() + starlette_app = self.sse_app(mount_path) config = uvicorn.Config( starlette_app, @@ -573,7 +579,33 @@ async def run_sse_async(self) -> None: server = uvicorn.Server(config) await server.serve() - def sse_app(self) -> Starlette: + def _normalize_path(self, mount_path: str, endpoint: str) -> str: + """ + Combine mount path and endpoint to return a normalized path. + + Args: + mount_path: The mount path (e.g. "/github" or "/") + endpoint: The endpoint path (e.g. "/messages/") + + Returns: + Normalized path (e.g. "/github/messages/") + """ + # Special case: root path + if mount_path == "/": + return endpoint + + # Remove trailing slash from mount path + if mount_path.endswith("/"): + mount_path = mount_path[:-1] + + # Ensure endpoint starts with slash + if not endpoint.startswith("/"): + endpoint = "/" + endpoint + + # Combine paths + return mount_path + endpoint + + def sse_app(self, mount_path: str | None = None) -> Starlette: """Return an instance of the SSE server app.""" message_dispatch = self.settings.message_dispatch if message_dispatch is None: @@ -585,10 +617,20 @@ def sse_app(self) -> Starlette: from starlette.middleware import Middleware from starlette.routing import Mount, Route + # Update mount_path in settings if provided + if mount_path is not None: + self.settings.mount_path = mount_path + + # Create normalized endpoint considering the mount path + normalized_message_endpoint = self._normalize_path( + self.settings.mount_path, self.settings.message_path + ) + # Set up auth context and dependencies sse = SseServerTransport( - self.settings.message_path, message_dispatch=message_dispatch + normalized_message_endpoint, + message_dispatch=message_dispatch ) async def handle_sse(scope: Scope, receive: Receive, send: Send): diff --git a/tests/server/fastmcp/test_server.py b/tests/server/fastmcp/test_server.py index 772c41529b..64700d9590 100644 --- a/tests/server/fastmcp/test_server.py +++ b/tests/server/fastmcp/test_server.py @@ -1,9 +1,11 @@ import base64 from pathlib import Path from typing import TYPE_CHECKING +from unittest.mock import patch import pytest from pydantic import AnyUrl +from starlette.routing import Mount, Route from mcp.server.fastmcp import Context, FastMCP from mcp.server.fastmcp.prompts.base import EmbeddedResource, Message, UserMessage @@ -31,6 +33,97 @@ async def test_create_server(self): assert mcp.name == "FastMCP" assert mcp.instructions == "Server instructions" + @pytest.mark.anyio + async def test_normalize_path(self): + """Test path normalization for mount paths.""" + mcp = FastMCP() + + # Test root path + assert mcp._normalize_path("/", "/messages/") == "/messages/" + + # Test path with trailing slash + assert mcp._normalize_path("/github/", "/messages/") == "/github/messages/" + + # Test path without trailing slash + assert mcp._normalize_path("/github", "/messages/") == "/github/messages/" + + # Test endpoint without leading slash + assert mcp._normalize_path("/github", "messages/") == "/github/messages/" + + # Test both with trailing/leading slashes + assert mcp._normalize_path("/api/", "/v1/") == "/api/v1/" + + @pytest.mark.anyio + async def test_sse_app_with_mount_path(self): + """Test SSE app creation with different mount paths.""" + # Test with default mount path + mcp = FastMCP() + with patch.object( + mcp, "_normalize_path", return_value="/messages/" + ) as mock_normalize: + mcp.sse_app() + # Verify _normalize_path was called with correct args + mock_normalize.assert_called_once_with("/", "/messages/") + + # Test with custom mount path in settings + mcp = FastMCP() + mcp.settings.mount_path = "/custom" + with patch.object( + mcp, "_normalize_path", return_value="/custom/messages/" + ) as mock_normalize: + mcp.sse_app() + # Verify _normalize_path was called with correct args + mock_normalize.assert_called_once_with("/custom", "/messages/") + + # Test with mount_path parameter + mcp = FastMCP() + with patch.object( + mcp, "_normalize_path", return_value="/param/messages/" + ) as mock_normalize: + mcp.sse_app(mount_path="/param") + # Verify _normalize_path was called with correct args + mock_normalize.assert_called_once_with("/param", "/messages/") + + @pytest.mark.anyio + async def test_starlette_routes_with_mount_path(self): + """Test that Starlette routes are correctly configured with mount path.""" + # Test with mount path in settings + mcp = FastMCP() + mcp.settings.mount_path = "/api" + app = mcp.sse_app() + + # Find routes by type + sse_routes = [r for r in app.routes if isinstance(r, Route)] + mount_routes = [r for r in app.routes if isinstance(r, Mount)] + + # Verify routes exist + assert len(sse_routes) == 1, "Should have one SSE route" + assert len(mount_routes) == 1, "Should have one mount route" + + # Verify path values + assert sse_routes[0].path == "/sse", "SSE route path should be /sse" + assert ( + mount_routes[0].path == "/messages" + ), "Mount route path should be /messages" + + # Test with mount path as parameter + mcp = FastMCP() + app = mcp.sse_app(mount_path="/param") + + # Find routes by type + sse_routes = [r for r in app.routes if isinstance(r, Route)] + mount_routes = [r for r in app.routes if isinstance(r, Mount)] + + # Verify routes exist + assert len(sse_routes) == 1, "Should have one SSE route" + assert len(mount_routes) == 1, "Should have one mount route" + + # Verify path values + assert sse_routes[0].path == "/sse", "SSE route path should be /sse" + assert ( + mount_routes[0].path == "/messages" + ), "Mount route path should be /messages" + @pytest.mark.anyio async def test_non_ascii_description(self): """Test that FastMCP handles non-ASCII characters in descriptions correctly""" @@ -518,8 +611,6 @@ async def async_tool(x: int, ctx: Context) -> str: @pytest.mark.anyio async def test_context_logging(self): - from unittest.mock import patch - import mcp.server.session """Test that context logging methods work.""" From c8a14c9dba444b3b4c394729d341d7a804e47232 Mon Sep 17 00:00:00 2001 From: Samad Yar Khan <70485812+samad-yar-khan@users.noreply.github.com> Date: Wed, 7 May 2025 20:47:11 +0530 Subject: [PATCH 14/21] docs: fix broken link to OAuthServerProvider in Authentication section of README (#651) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index a0d47606ef..bfe970888e 100644 --- a/README.md +++ b/README.md @@ -334,7 +334,7 @@ mcp = FastMCP("My App", ) ``` -See [OAuthServerProvider](mcp/server/auth/provider.py) for more details. +See [OAuthServerProvider](src/mcp/server/auth/provider.py) for more details. ## Running Your Server From 9d99aee0148817e3c18806ea926e7e57d419c3d4 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Wed, 7 May 2025 16:35:20 +0100 Subject: [PATCH 15/21] Revert "Add message queue for SSE messages POST endpoint (#459)" (#649) --- README.md | 24 -- .../simple-prompt/mcp_simple_prompt/server.py | 5 +- pyproject.toml | 2 - src/mcp/client/sse.py | 6 +- src/mcp/client/stdio/__init__.py | 2 +- src/mcp/client/streamable_http.py | 6 +- src/mcp/client/websocket.py | 2 +- src/mcp/server/fastmcp/server.py | 32 +- src/mcp/server/message_queue/__init__.py | 16 - src/mcp/server/message_queue/base.py | 116 ------ src/mcp/server/message_queue/redis.py | 198 ---------- src/mcp/server/sse.py | 42 +-- src/mcp/server/stdio.py | 2 +- src/mcp/server/streamable_http.py | 6 +- src/mcp/server/websocket.py | 2 +- src/mcp/shared/message.py | 14 +- tests/client/test_session.py | 6 +- tests/client/test_stdio.py | 2 +- tests/issues/test_192_request_id.py | 8 +- tests/server/message_dispatch/__init__.py | 1 - tests/server/message_dispatch/conftest.py | 28 -- tests/server/message_dispatch/test_redis.py | 355 ------------------ .../test_redis_integration.py | 260 ------------- tests/server/test_lifespan.py | 12 +- tests/server/test_stdio.py | 2 +- uv.lock | 149 +------- 26 files changed, 51 insertions(+), 1247 deletions(-) delete mode 100644 src/mcp/server/message_queue/__init__.py delete mode 100644 src/mcp/server/message_queue/base.py delete mode 100644 src/mcp/server/message_queue/redis.py delete mode 100644 tests/server/message_dispatch/__init__.py delete mode 100644 tests/server/message_dispatch/conftest.py delete mode 100644 tests/server/message_dispatch/test_redis.py delete mode 100644 tests/server/message_dispatch/test_redis_integration.py diff --git a/README.md b/README.md index bfe970888e..8f19aea1a8 100644 --- a/README.md +++ b/README.md @@ -449,30 +449,6 @@ if __name__ == "__main__": For more information on mounting applications in Starlette, see the [Starlette documentation](https://www.starlette.io/routing/#submounting-routes). -#### Message Dispatch Options - -By default, the SSE server uses an in-memory message dispatch system for incoming POST messages. For production deployments or distributed scenarios, you can use Redis or implement your own message dispatch system that conforms to the `MessageDispatch` protocol: - -```python -# Using the built-in Redis message dispatch -from mcp.server.fastmcp import FastMCP -from mcp.server.message_queue import RedisMessageDispatch - -# Create a Redis message dispatch -redis_dispatch = RedisMessageDispatch( - redis_url="redis://localhost:6379/0", prefix="mcp:pubsub:" -) - -# Pass the message dispatch instance to the server -mcp = FastMCP("My App", message_queue=redis_dispatch) -``` - -To use Redis, add the Redis dependency: - -```bash -uv add "mcp[redis]" -``` - ## Examples ### Echo Server diff --git a/examples/servers/simple-prompt/mcp_simple_prompt/server.py b/examples/servers/simple-prompt/mcp_simple_prompt/server.py index 04b10ac75d..bc14b7cd0a 100644 --- a/examples/servers/simple-prompt/mcp_simple_prompt/server.py +++ b/examples/servers/simple-prompt/mcp_simple_prompt/server.py @@ -88,15 +88,12 @@ async def get_prompt( ) if transport == "sse": - from mcp.server.message_queue.redis import RedisMessageDispatch from mcp.server.sse import SseServerTransport from starlette.applications import Starlette from starlette.responses import Response from starlette.routing import Mount, Route - message_dispatch = RedisMessageDispatch("redis://localhost:6379/0") - - sse = SseServerTransport("/messages/", message_dispatch=message_dispatch) + sse = SseServerTransport("/messages/") async def handle_sse(request): async with sse.connect_sse( diff --git a/pyproject.toml b/pyproject.toml index 6ff2601e9d..2b86fb3772 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,6 @@ dependencies = [ rich = ["rich>=13.9.4"] cli = ["typer>=0.12.4", "python-dotenv>=1.0.0"] ws = ["websockets>=15.0.1"] -redis = ["redis>=5.2.1", "types-redis>=4.6.0.20241004"] [project.scripts] mcp = "mcp.cli:app [cli]" @@ -56,7 +55,6 @@ dev = [ "pytest-xdist>=3.6.1", "pytest-examples>=0.0.14", "pytest-pretty>=1.2.0", - "fakeredis==2.28.1", ] docs = [ "mkdocs>=1.6.1", diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 7df251f792..ff04d2f961 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -98,9 +98,7 @@ async def sse_reader( await read_stream_writer.send(exc) continue - session_message = SessionMessage( - message=message - ) + session_message = SessionMessage(message) await read_stream_writer.send(session_message) case _: logger.warning( @@ -150,5 +148,3 @@ async def post_writer(endpoint_url: str): finally: await read_stream_writer.aclose() await write_stream.aclose() - await read_stream.aclose() - await write_stream_reader.aclose() diff --git a/src/mcp/client/stdio/__init__.py b/src/mcp/client/stdio/__init__.py index 21c7764e75..e8be5aff5b 100644 --- a/src/mcp/client/stdio/__init__.py +++ b/src/mcp/client/stdio/__init__.py @@ -144,7 +144,7 @@ async def stdout_reader(): await read_stream_writer.send(exc) continue - session_message = SessionMessage(message=message) + session_message = SessionMessage(message) await read_stream_writer.send(session_message) except anyio.ClosedResourceError: await anyio.lowlevel.checkpoint() diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index ca26046b93..ef424e3b33 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -153,7 +153,7 @@ async def _handle_sse_event( ): message.root.id = original_request_id - session_message = SessionMessage(message=message) + session_message = SessionMessage(message) await read_stream_writer.send(session_message) # Call resumption token callback if we have an ID @@ -286,7 +286,7 @@ async def _handle_json_response( try: content = await response.aread() message = JSONRPCMessage.model_validate_json(content) - session_message = SessionMessage(message=message) + session_message = SessionMessage(message) await read_stream_writer.send(session_message) except Exception as exc: logger.error(f"Error parsing JSON response: {exc}") @@ -333,7 +333,7 @@ async def _send_session_terminated_error( id=request_id, error=ErrorData(code=32600, message="Session terminated"), ) - session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error)) + session_message = SessionMessage(JSONRPCMessage(jsonrpc_error)) await read_stream_writer.send(session_message) async def post_writer( diff --git a/src/mcp/client/websocket.py b/src/mcp/client/websocket.py index 598fdaf252..ac542fb3f6 100644 --- a/src/mcp/client/websocket.py +++ b/src/mcp/client/websocket.py @@ -60,7 +60,7 @@ async def ws_reader(): async for raw_text in ws: try: message = types.JSONRPCMessage.model_validate_json(raw_text) - session_message = SessionMessage(message=message) + session_message = SessionMessage(message) await read_stream_writer.send(session_message) except ValidationError as exc: # If JSON parse or model validation fails, send the exception diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 24fe971057..ea0214f0fe 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -44,7 +44,6 @@ from mcp.server.lowlevel.server import LifespanResultT from mcp.server.lowlevel.server import Server as MCPServer from mcp.server.lowlevel.server import lifespan as default_lifespan -from mcp.server.message_queue import MessageDispatch from mcp.server.session import ServerSession, ServerSessionT from mcp.server.sse import SseServerTransport from mcp.server.stdio import stdio_server @@ -92,11 +91,6 @@ class Settings(BaseSettings, Generic[LifespanResultT]): sse_path: str = "/sse" message_path: str = "/messages/" - # SSE message queue settings - message_dispatch: MessageDispatch | None = Field( - None, description="Custom message dispatch instance" - ) - # resource settings warn_on_duplicate_resources: bool = True @@ -607,13 +601,6 @@ def _normalize_path(self, mount_path: str, endpoint: str) -> str: def sse_app(self, mount_path: str | None = None) -> Starlette: """Return an instance of the SSE server app.""" - message_dispatch = self.settings.message_dispatch - if message_dispatch is None: - from mcp.server.message_queue import InMemoryMessageDispatch - - message_dispatch = InMemoryMessageDispatch() - logger.info("Using default in-memory message dispatch") - from starlette.middleware import Middleware from starlette.routing import Mount, Route @@ -625,12 +612,11 @@ def sse_app(self, mount_path: str | None = None) -> Starlette: normalized_message_endpoint = self._normalize_path( self.settings.mount_path, self.settings.message_path ) - + # Set up auth context and dependencies sse = SseServerTransport( - normalized_message_endpoint, - message_dispatch=message_dispatch + normalized_message_endpoint, ) async def handle_sse(scope: Scope, receive: Receive, send: Send): @@ -646,14 +632,7 @@ async def handle_sse(scope: Scope, receive: Receive, send: Send): streams[1], self._mcp_server.create_initialization_options(), ) - return Response() - - @asynccontextmanager - async def lifespan(app: Starlette): - try: - yield - finally: - await message_dispatch.close() + return Response() # Create routes routes: list[Route | Mount] = [] @@ -730,10 +709,7 @@ async def sse_endpoint(request: Request) -> None: # Create Starlette app with routes and middleware return Starlette( - debug=self.settings.debug, - routes=routes, - middleware=middleware, - lifespan=lifespan, + debug=self.settings.debug, routes=routes, middleware=middleware ) async def list_prompts(self) -> list[MCPPrompt]: diff --git a/src/mcp/server/message_queue/__init__.py b/src/mcp/server/message_queue/__init__.py deleted file mode 100644 index f4a8b9dfaf..0000000000 --- a/src/mcp/server/message_queue/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -""" -Message Dispatch Module for MCP Server - -This module implements dispatch interfaces for handling -messages between clients and servers. -""" - -from mcp.server.message_queue.base import InMemoryMessageDispatch, MessageDispatch - -# Try to import Redis implementation if available -try: - from mcp.server.message_queue.redis import RedisMessageDispatch -except ImportError: - RedisMessageDispatch = None - -__all__ = ["MessageDispatch", "InMemoryMessageDispatch", "RedisMessageDispatch"] diff --git a/src/mcp/server/message_queue/base.py b/src/mcp/server/message_queue/base.py deleted file mode 100644 index 20c7145505..0000000000 --- a/src/mcp/server/message_queue/base.py +++ /dev/null @@ -1,116 +0,0 @@ -import logging -from collections.abc import Awaitable, Callable -from contextlib import asynccontextmanager -from typing import Protocol, runtime_checkable -from uuid import UUID - -from pydantic import ValidationError - -from mcp.shared.message import SessionMessage - -logger = logging.getLogger(__name__) - -MessageCallback = Callable[[SessionMessage | Exception], Awaitable[None]] - - -@runtime_checkable -class MessageDispatch(Protocol): - """Abstract interface for SSE message dispatching. - - This interface allows messages to be published to sessions and callbacks to be - registered for message handling, enabling multiple servers to handle requests. - """ - - async def publish_message( - self, session_id: UUID, message: SessionMessage | str - ) -> bool: - """Publish a message for the specified session. - - Args: - session_id: The UUID of the session this message is for - message: The message to publish (SessionMessage or str for invalid JSON) - - Returns: - bool: True if message was published, False if session not found - """ - ... - - @asynccontextmanager - async def subscribe(self, session_id: UUID, callback: MessageCallback): - """Request-scoped context manager that subscribes to messages for a session. - - Args: - session_id: The UUID of the session to subscribe to - callback: Async callback function to handle messages for this session - """ - yield - - async def session_exists(self, session_id: UUID) -> bool: - """Check if a session exists. - - Args: - session_id: The UUID of the session to check - - Returns: - bool: True if the session is active, False otherwise - """ - ... - - async def close(self) -> None: - """Close the message dispatch.""" - ... - - -class InMemoryMessageDispatch: - """Default in-memory implementation of the MessageDispatch interface. - - This implementation immediately dispatches messages to registered callbacks when - messages are received without any queuing behavior. - """ - - def __init__(self) -> None: - self._callbacks: dict[UUID, MessageCallback] = {} - - async def publish_message( - self, session_id: UUID, message: SessionMessage | str - ) -> bool: - """Publish a message for the specified session.""" - if session_id not in self._callbacks: - logger.warning(f"Message dropped: unknown session {session_id}") - return False - - # Parse string messages or recreate original ValidationError - if isinstance(message, str): - try: - callback_argument = SessionMessage.model_validate_json(message) - except ValidationError as exc: - callback_argument = exc - else: - callback_argument = message - - # Call the callback with either valid message or recreated ValidationError - await self._callbacks[session_id](callback_argument) - - logger.debug(f"Message dispatched to session {session_id}") - return True - - @asynccontextmanager - async def subscribe(self, session_id: UUID, callback: MessageCallback): - """Request-scoped context manager that subscribes to messages for a session.""" - self._callbacks[session_id] = callback - logger.debug(f"Subscribing to messages for session {session_id}") - - try: - yield - finally: - if session_id in self._callbacks: - del self._callbacks[session_id] - logger.debug(f"Unsubscribed from session {session_id}") - - async def session_exists(self, session_id: UUID) -> bool: - """Check if a session exists.""" - return session_id in self._callbacks - - async def close(self) -> None: - """Close the message dispatch.""" - pass diff --git a/src/mcp/server/message_queue/redis.py b/src/mcp/server/message_queue/redis.py deleted file mode 100644 index 628ce026c5..0000000000 --- a/src/mcp/server/message_queue/redis.py +++ /dev/null @@ -1,198 +0,0 @@ -import logging -from contextlib import asynccontextmanager -from typing import Any, cast -from uuid import UUID - -import anyio -from anyio import CancelScope, CapacityLimiter, lowlevel -from anyio.abc import TaskGroup -from pydantic import ValidationError - -from mcp.server.message_queue.base import MessageCallback -from mcp.shared.message import SessionMessage - -try: - import redis.asyncio as redis -except ImportError: - raise ImportError( - "Redis support requires the 'redis' package. " - "Install it with: 'uv add redis' or 'uv add \"mcp[redis]\"'" - ) - -logger = logging.getLogger(__name__) - - -class RedisMessageDispatch: - """Redis implementation of the MessageDispatch interface using pubsub. - - This implementation uses Redis pubsub for real-time message distribution across - multiple servers handling the same sessions. - """ - - def __init__( - self, - redis_url: str = "redis://localhost:6379/0", - prefix: str = "mcp:pubsub:", - session_ttl: int = 3600, # 1 hour default TTL for sessions - ) -> None: - """Initialize Redis message dispatch. - - Args: - redis_url: Redis connection string - prefix: Key prefix for Redis channels to avoid collisions - session_ttl: TTL in seconds for session keys (default: 1 hour) - """ - self._redis = redis.from_url(redis_url, decode_responses=True) # type: ignore - self._pubsub = self._redis.pubsub(ignore_subscribe_messages=True) # type: ignore - self._prefix = prefix - self._session_ttl = session_ttl - # Maps session IDs to the callback and task group for that SSE session. - self._session_state: dict[UUID, tuple[MessageCallback, TaskGroup]] = {} - # Ensures only one polling task runs at a time for message handling - self._limiter = CapacityLimiter(1) - logger.debug(f"Redis message dispatch initialized: {redis_url}") - - async def close(self): - await self._pubsub.aclose() # type: ignore - await self._redis.aclose() # type: ignore - - def _session_channel(self, session_id: UUID) -> str: - """Get the Redis channel for a session.""" - return f"{self._prefix}session:{session_id.hex}" - - def _session_key(self, session_id: UUID) -> str: - """Get the Redis key for a session.""" - return f"{self._prefix}session_active:{session_id.hex}" - - @asynccontextmanager - async def subscribe(self, session_id: UUID, callback: MessageCallback): - """Request-scoped context manager that subscribes to messages for a session.""" - session_key = self._session_key(session_id) - await self._redis.setex(session_key, self._session_ttl, "1") # type: ignore - - channel = self._session_channel(session_id) - await self._pubsub.subscribe(channel) # type: ignore - - logger.debug(f"Subscribing to Redis channel for session {session_id}") - async with anyio.create_task_group() as tg: - self._session_state[session_id] = (callback, tg) - tg.start_soon(self._listen_for_messages) - # Start heartbeat for this session - tg.start_soon(self._session_heartbeat, session_id) - try: - yield - finally: - with anyio.CancelScope(shield=True): - tg.cancel_scope.cancel() - await self._pubsub.unsubscribe(channel) # type: ignore - await self._redis.delete(session_key) # type: ignore - del self._session_state[session_id] - logger.debug(f"Unsubscribed from Redis channel: {session_id}") - - async def _session_heartbeat(self, session_id: UUID) -> None: - """Periodically refresh the TTL for a session.""" - session_key = self._session_key(session_id) - while True: - await lowlevel.checkpoint() - try: - # Refresh TTL at half the TTL interval to avoid expiration - await anyio.sleep(self._session_ttl / 2) - with anyio.CancelScope(shield=True): - await self._redis.expire(session_key, self._session_ttl) # type: ignore - except anyio.get_cancelled_exc_class(): - break - except Exception as e: - logger.error(f"Error refreshing TTL for session {session_id}: {e}") - - def _extract_session_id(self, channel: str) -> UUID | None: - """Extract and validate session ID from channel.""" - expected_prefix = f"{self._prefix}session:" - if not channel.startswith(expected_prefix): - return None - - session_hex = channel[len(expected_prefix) :] - try: - session_id = UUID(hex=session_hex) - if channel != self._session_channel(session_id): - logger.error(f"Channel format mismatch: {channel}") - return None - return session_id - except ValueError: - logger.error(f"Invalid UUID in channel: {channel}") - return None - - async def _listen_for_messages(self) -> None: - """Background task that listens for messages on subscribed channels.""" - async with self._limiter: - while True: - await lowlevel.checkpoint() - with CancelScope(shield=True): - message: None | dict[str, Any] = await self._pubsub.get_message( # type: ignore - ignore_subscribe_messages=True, - timeout=0.1, # type: ignore - ) - if message is None: - continue - - channel: str = cast(str, message["channel"]) - session_id = self._extract_session_id(channel) - if session_id is None: - logger.debug( - f"Ignoring message from non-MCP channel: {channel}" - ) - continue - - data: str = cast(str, message["data"]) - try: - if session_state := self._session_state.get(session_id): - session_state[1].start_soon( - self._handle_message, session_id, data - ) - else: - logger.warning( - f"Message dropped: unknown session {session_id}" - ) - except Exception as e: - logger.error(f"Error processing message for {session_id}: {e}") - - async def _handle_message(self, session_id: UUID, data: str) -> None: - """Process a message from Redis in the session's task group.""" - if (session_state := self._session_state.get(session_id)) is None: - logger.warning(f"Message dropped: callback removed for {session_id}") - return - - try: - # Parse message or pass validation error to callback - msg_or_error = None - try: - msg_or_error = SessionMessage.model_validate_json(data) - except ValidationError as exc: - msg_or_error = exc - - await session_state[0](msg_or_error) - except Exception as e: - logger.error(f"Error in message handler for {session_id}: {e}") - - async def publish_message( - self, session_id: UUID, message: SessionMessage | str - ) -> bool: - """Publish a message for the specified session.""" - if not await self.session_exists(session_id): - logger.warning(f"Message dropped: unknown session {session_id}") - return False - - # Pass raw JSON strings directly, preserving validation errors - if isinstance(message, str): - data = message - else: - data = message.model_dump_json() - - channel = self._session_channel(session_id) - await self._redis.publish(channel, data) # type: ignore[attr-defined] - logger.debug(f"Message published to Redis channel for session {session_id}") - return True - - async def session_exists(self, session_id: UUID) -> bool: - """Check if a session exists.""" - session_key = self._session_key(session_id) - return bool(await self._redis.exists(session_key)) # type: ignore diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 98f32629ef..cc41a80d67 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -52,11 +52,9 @@ async def handle_sse(request): from starlette.types import Receive, Scope, Send import mcp.types as types -from mcp.server.message_queue import InMemoryMessageDispatch, MessageDispatch from mcp.shared.message import SessionMessage logger = logging.getLogger(__name__) -logging.basicConfig(level=logging.DEBUG) class SseServerTransport: @@ -72,24 +70,17 @@ class SseServerTransport: """ _endpoint: str - _message_dispatch: MessageDispatch _read_stream_writers: dict[UUID, MemoryObjectSendStream[SessionMessage | Exception]] - def __init__( - self, endpoint: str, message_dispatch: MessageDispatch | None = None - ) -> None: + def __init__(self, endpoint: str) -> None: """ Creates a new SSE server transport, which will direct the client to POST messages to the relative or absolute URL given. - - Args: - endpoint: The endpoint URL for SSE connections - message_dispatch: Optional message dispatch to use """ super().__init__() self._endpoint = endpoint - self._message_dispatch = message_dispatch or InMemoryMessageDispatch() + self._read_stream_writers = {} logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}") @asynccontextmanager @@ -110,12 +101,7 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): session_id = uuid4() session_uri = f"{quote(self._endpoint)}?session_id={session_id.hex}" - - async def message_callback(message: SessionMessage | Exception) -> None: - """Callback that receives messages from the message queue""" - logger.debug(f"Got message from queue for session {session_id}") - await read_stream_writer.send(message) - + self._read_stream_writers[session_id] = read_stream_writer logger.debug(f"Created new session with ID: {session_id}") sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[ @@ -152,16 +138,13 @@ async def response_wrapper(scope: Scope, receive: Receive, send: Send): )(scope, receive, send) await read_stream_writer.aclose() await write_stream_reader.aclose() - await sse_stream_writer.aclose() - await sse_stream_reader.aclose() logging.debug(f"Client session disconnected {session_id}") logger.debug("Starting SSE response task") tg.start_soon(response_wrapper, scope, receive, send) - async with self._message_dispatch.subscribe(session_id, message_callback): - logger.debug("Yielding read and write streams") - yield (read_stream, write_stream) + logger.debug("Yielding read and write streams") + yield (read_stream, write_stream) async def handle_post_message( self, scope: Scope, receive: Receive, send: Send @@ -183,7 +166,8 @@ async def handle_post_message( response = Response("Invalid session ID", status_code=400) return await response(scope, receive, send) - if not await self._message_dispatch.session_exists(session_id): + writer = self._read_stream_writers.get(session_id) + if not writer: logger.warning(f"Could not find session for ID: {session_id}") response = Response("Could not find session", status_code=404) return await response(scope, receive, send) @@ -198,15 +182,11 @@ async def handle_post_message( logger.error(f"Failed to parse message: {err}") response = Response("Could not parse message", status_code=400) await response(scope, receive, send) - # Pass raw JSON string; receiver will recreate identical ValidationError - # when parsing the same invalid JSON - await self._message_dispatch.publish_message(session_id, body.decode()) + await writer.send(err) return - logger.debug(f"Publishing message for session {session_id}: {message}") + session_message = SessionMessage(message) + logger.debug(f"Sending session message to writer: {session_message}") response = Response("Accepted", status_code=202) await response(scope, receive, send) - await self._message_dispatch.publish_message( - session_id, SessionMessage(message=message) - ) - logger.debug(f"Sending session message to writer: {message}") + await writer.send(session_message) diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index 11c8f7ee4d..f0bbe5a316 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -67,7 +67,7 @@ async def stdin_reader(): await read_stream_writer.send(exc) continue - session_message = SessionMessage(message=message) + session_message = SessionMessage(message) await read_stream_writer.send(session_message) except anyio.ClosedResourceError: await anyio.lowlevel.checkpoint() diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 79c8a89139..ace74b33b4 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -398,7 +398,7 @@ async def _handle_post_request( await response(scope, receive, send) # Process the message after sending the response - session_message = SessionMessage(message=message) + session_message = SessionMessage(message) await writer.send(session_message) return @@ -413,7 +413,7 @@ async def _handle_post_request( if self.is_json_response_enabled: # Process the message - session_message = SessionMessage(message=message) + session_message = SessionMessage(message) await writer.send(session_message) try: # Process messages from the request-specific stream @@ -512,7 +512,7 @@ async def sse_writer(): async with anyio.create_task_group() as tg: tg.start_soon(response, scope, receive, send) # Then send the message to be processed by the server - session_message = SessionMessage(message=message) + session_message = SessionMessage(message) await writer.send(session_message) except Exception: logger.exception("SSE response error") diff --git a/src/mcp/server/websocket.py b/src/mcp/server/websocket.py index bb0b1ca6ea..9dc3f2a25e 100644 --- a/src/mcp/server/websocket.py +++ b/src/mcp/server/websocket.py @@ -42,7 +42,7 @@ async def ws_reader(): await read_stream_writer.send(exc) continue - session_message = SessionMessage(message=client_message) + session_message = SessionMessage(client_message) await read_stream_writer.send(session_message) except anyio.ClosedResourceError: await websocket.close() diff --git a/src/mcp/shared/message.py b/src/mcp/shared/message.py index c96a0a1e63..5583f47951 100644 --- a/src/mcp/shared/message.py +++ b/src/mcp/shared/message.py @@ -6,8 +6,7 @@ """ from collections.abc import Awaitable, Callable - -from pydantic import BaseModel +from dataclasses import dataclass from mcp.types import JSONRPCMessage, RequestId @@ -16,7 +15,8 @@ ResumptionTokenUpdateCallback = Callable[[ResumptionToken], Awaitable[None]] -class ClientMessageMetadata(BaseModel): +@dataclass +class ClientMessageMetadata: """Metadata specific to client messages.""" resumption_token: ResumptionToken | None = None @@ -25,7 +25,8 @@ class ClientMessageMetadata(BaseModel): ) -class ServerMessageMetadata(BaseModel): +@dataclass +class ServerMessageMetadata: """Metadata specific to server messages.""" related_request_id: RequestId | None = None @@ -34,8 +35,9 @@ class ServerMessageMetadata(BaseModel): MessageMetadata = ClientMessageMetadata | ServerMessageMetadata | None -class SessionMessage(BaseModel): +@dataclass +class SessionMessage: """A message with specific metadata for transport-specific features.""" message: JSONRPCMessage - metadata: MessageMetadata | None = None + metadata: MessageMetadata = None diff --git a/tests/client/test_session.py b/tests/client/test_session.py index cd3dae293d..6abcf70cbc 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -62,7 +62,7 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( SessionMessage( - message=JSONRPCMessage( + JSONRPCMessage( JSONRPCResponse( jsonrpc="2.0", id=jsonrpc_request.root.id, @@ -153,7 +153,7 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( SessionMessage( - message=JSONRPCMessage( + JSONRPCMessage( JSONRPCResponse( jsonrpc="2.0", id=jsonrpc_request.root.id, @@ -220,7 +220,7 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( SessionMessage( - message=JSONRPCMessage( + JSONRPCMessage( JSONRPCResponse( jsonrpc="2.0", id=jsonrpc_request.root.id, diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index d93c63aefe..523ba199a4 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -23,7 +23,7 @@ async def test_stdio_client(): async with write_stream: for message in messages: - session_message = SessionMessage(message=message) + session_message = SessionMessage(message) await write_stream.send(session_message) read_messages = [] diff --git a/tests/issues/test_192_request_id.py b/tests/issues/test_192_request_id.py index c05f08f8cd..cf5eb6083e 100644 --- a/tests/issues/test_192_request_id.py +++ b/tests/issues/test_192_request_id.py @@ -65,7 +65,7 @@ async def run_server(): jsonrpc="2.0", ) - await client_writer.send(SessionMessage(message=JSONRPCMessage(root=init_req))) + await client_writer.send(SessionMessage(JSONRPCMessage(root=init_req))) response = ( await server_reader.receive() ) # Get init response but don't need to check it @@ -77,7 +77,7 @@ async def run_server(): jsonrpc="2.0", ) await client_writer.send( - SessionMessage(message=JSONRPCMessage(root=initialized_notification)) + SessionMessage(JSONRPCMessage(root=initialized_notification)) ) # Send ping request with custom ID @@ -85,9 +85,7 @@ async def run_server(): id=custom_request_id, method="ping", params={}, jsonrpc="2.0" ) - await client_writer.send( - SessionMessage(message=JSONRPCMessage(root=ping_request)) - ) + await client_writer.send(SessionMessage(JSONRPCMessage(root=ping_request))) # Read response response = await server_reader.receive() diff --git a/tests/server/message_dispatch/__init__.py b/tests/server/message_dispatch/__init__.py deleted file mode 100644 index df0d26c3e6..0000000000 --- a/tests/server/message_dispatch/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Message queue tests module diff --git a/tests/server/message_dispatch/conftest.py b/tests/server/message_dispatch/conftest.py deleted file mode 100644 index 3422da2aab..0000000000 --- a/tests/server/message_dispatch/conftest.py +++ /dev/null @@ -1,28 +0,0 @@ -"""Shared fixtures for message queue tests.""" - -from collections.abc import AsyncGenerator -from unittest.mock import patch - -import pytest - -from mcp.server.message_queue.redis import RedisMessageDispatch - -# Set up fakeredis for testing -try: - from fakeredis import aioredis as fake_redis -except ImportError: - pytest.skip( - "fakeredis is required for testing Redis functionality", allow_module_level=True - ) - - -@pytest.fixture -async def message_dispatch() -> AsyncGenerator[RedisMessageDispatch, None]: - """Create a shared Redis message dispatch with a fake Redis client.""" - with patch("mcp.server.message_queue.redis.redis", fake_redis.FakeRedis): - # Shorter TTL for testing - message_dispatch = RedisMessageDispatch(session_ttl=5) - try: - yield message_dispatch - finally: - await message_dispatch.close() diff --git a/tests/server/message_dispatch/test_redis.py b/tests/server/message_dispatch/test_redis.py deleted file mode 100644 index d355f9e688..0000000000 --- a/tests/server/message_dispatch/test_redis.py +++ /dev/null @@ -1,355 +0,0 @@ -from unittest.mock import AsyncMock -from uuid import uuid4 - -import anyio -import pytest -from pydantic import ValidationError - -import mcp.types as types -from mcp.server.message_queue.redis import RedisMessageDispatch -from mcp.shared.message import SessionMessage - - -@pytest.mark.anyio -async def test_session_heartbeat(message_dispatch): - """Test that session heartbeat refreshes TTL.""" - session_id = uuid4() - - async with message_dispatch.subscribe(session_id, AsyncMock()): - session_key = message_dispatch._session_key(session_id) - - # Initial TTL - initial_ttl = await message_dispatch._redis.ttl(session_key) # type: ignore - assert initial_ttl > 0 - - # Wait for heartbeat to run - await anyio.sleep(message_dispatch._session_ttl / 2 + 0.5) - - # TTL should be refreshed - refreshed_ttl = await message_dispatch._redis.ttl(session_key) # type: ignore - assert refreshed_ttl > 0 - assert refreshed_ttl <= message_dispatch._session_ttl - - -@pytest.mark.anyio -async def test_subscribe_unsubscribe(message_dispatch): - """Test subscribing and unsubscribing from a session.""" - session_id = uuid4() - callback = AsyncMock() - - # Subscribe - async with message_dispatch.subscribe(session_id, callback): - # Check that session is tracked - assert session_id in message_dispatch._session_state - assert await message_dispatch.session_exists(session_id) - - # After context exit, session should be cleaned up - assert session_id not in message_dispatch._session_state - assert not await message_dispatch.session_exists(session_id) - - -@pytest.mark.anyio -async def test_publish_message_valid_json(message_dispatch: RedisMessageDispatch): - """Test publishing a valid JSON-RPC message.""" - session_id = uuid4() - callback = AsyncMock() - message = types.JSONRPCMessage.model_validate( - {"jsonrpc": "2.0", "method": "test", "params": {}, "id": 1} - ) - - # Subscribe to messages - async with message_dispatch.subscribe(session_id, callback): - # Publish message - published = await message_dispatch.publish_message( - session_id, SessionMessage(message=message) - ) - assert published - - # Give some time for the message to be processed - await anyio.sleep(0.1) - - # Callback should have been called with the message - callback.assert_called_once() - call_args = callback.call_args[0][0] - assert isinstance(call_args, SessionMessage) - assert isinstance(call_args.message.root, types.JSONRPCRequest) - assert ( - call_args.message.root.method == "test" - ) # Access method through root attribute - - -@pytest.mark.anyio -async def test_publish_message_invalid_json(message_dispatch): - """Test publishing an invalid JSON string.""" - session_id = uuid4() - callback = AsyncMock() - invalid_json = '{"invalid": "json",,}' # Invalid JSON - - # Subscribe to messages - async with message_dispatch.subscribe(session_id, callback): - # Publish invalid message - published = await message_dispatch.publish_message(session_id, invalid_json) - assert published - - # Give some time for the message to be processed - await anyio.sleep(0.1) - - # Callback should have been called with a ValidationError - callback.assert_called_once() - error = callback.call_args[0][0] - assert isinstance(error, ValidationError) - - -@pytest.mark.anyio -async def test_publish_to_nonexistent_session(message_dispatch: RedisMessageDispatch): - """Test publishing to a session that doesn't exist.""" - session_id = uuid4() - message = SessionMessage( - message=types.JSONRPCMessage.model_validate( - {"jsonrpc": "2.0", "method": "test", "params": {}, "id": 1} - ) - ) - - published = await message_dispatch.publish_message(session_id, message) - assert not published - - -@pytest.mark.anyio -async def test_extract_session_id(message_dispatch): - """Test extracting session ID from channel name.""" - session_id = uuid4() - channel = message_dispatch._session_channel(session_id) - - # Valid channel - extracted_id = message_dispatch._extract_session_id(channel) - assert extracted_id == session_id - - # Invalid channel format - extracted_id = message_dispatch._extract_session_id("invalid_channel_name") - assert extracted_id is None - - # Invalid UUID in channel - invalid_channel = f"{message_dispatch._prefix}session:invalid_uuid" - extracted_id = message_dispatch._extract_session_id(invalid_channel) - assert extracted_id is None - - -@pytest.mark.anyio -async def test_multiple_sessions(message_dispatch: RedisMessageDispatch): - """Test handling multiple concurrent sessions.""" - session1 = uuid4() - session2 = uuid4() - callback1 = AsyncMock() - callback2 = AsyncMock() - - async with message_dispatch.subscribe(session1, callback1): - async with message_dispatch.subscribe(session2, callback2): - # Both sessions should exist - assert await message_dispatch.session_exists(session1) - assert await message_dispatch.session_exists(session2) - - # Publish to session1 - message1 = types.JSONRPCMessage.model_validate( - {"jsonrpc": "2.0", "method": "test1", "params": {}, "id": 1} - ) - await message_dispatch.publish_message( - session1, SessionMessage(message=message1) - ) - - # Publish to session2 - message2 = types.JSONRPCMessage.model_validate( - {"jsonrpc": "2.0", "method": "test2", "params": {}, "id": 2} - ) - await message_dispatch.publish_message( - session2, SessionMessage(message=message2) - ) - - # Give some time for messages to be processed - await anyio.sleep(0.1) - - # Check callbacks - callback1.assert_called_once() - callback2.assert_called_once() - - call1_args = callback1.call_args[0][0] - assert isinstance(call1_args, SessionMessage) - assert call1_args.message.root.method == "test1" # type: ignore - - call2_args = callback2.call_args[0][0] - assert isinstance(call2_args, SessionMessage) - assert call2_args.message.root.method == "test2" # type: ignore - - -@pytest.mark.anyio -async def test_task_group_cancellation(message_dispatch): - """Test that task group is properly cancelled when context exits.""" - session_id = uuid4() - callback = AsyncMock() - - async with message_dispatch.subscribe(session_id, callback): - # Check that task group is active - _, task_group = message_dispatch._session_state[session_id] - assert task_group.cancel_scope.cancel_called is False - - # After context exit, task group should be cancelled - # And session state should be cleaned up - assert session_id not in message_dispatch._session_state - - -@pytest.mark.anyio -async def test_session_cancellation_isolation(message_dispatch): - """Test that cancelling one session doesn't affect other sessions.""" - session1 = uuid4() - session2 = uuid4() - - # Create a blocking callback for session1 to ensure it's running when cancelled - session1_event = anyio.Event() - session1_started = anyio.Event() - session1_cancelled = False - - async def blocking_callback1(msg): - session1_started.set() - try: - await session1_event.wait() - except anyio.get_cancelled_exc_class(): - nonlocal session1_cancelled - session1_cancelled = True - raise - - callback2 = AsyncMock() - - # Start session2 first - async with message_dispatch.subscribe(session2, callback2): - # Start session1 with a blocking callback - async with anyio.create_task_group() as tg: - - async def session1_runner(): - async with message_dispatch.subscribe(session1, blocking_callback1): - # Publish a message to trigger the blocking callback - message = types.JSONRPCMessage.model_validate( - {"jsonrpc": "2.0", "method": "test", "params": {}, "id": 1} - ) - await message_dispatch.publish_message(session1, message) - - # Wait for the callback to start - await session1_started.wait() - - # Keep the context alive while we test cancellation - await anyio.sleep_forever() - - tg.start_soon(session1_runner) - - # Wait for session1's callback to start - await session1_started.wait() - - # Cancel session1 - tg.cancel_scope.cancel() - - # Give some time for cancellation to propagate - await anyio.sleep(0.1) - - # Verify session1 was cancelled - assert session1_cancelled - assert session1 not in message_dispatch._session_state - - # Verify session2 is still active and can receive messages - assert await message_dispatch.session_exists(session2) - message2 = types.JSONRPCMessage.model_validate( - {"jsonrpc": "2.0", "method": "test2", "params": {}, "id": 2} - ) - await message_dispatch.publish_message(session2, message2) - - # Give some time for the message to be processed - await anyio.sleep(0.1) - - # Verify session2 received the message - callback2.assert_called_once() - call_args = callback2.call_args[0][0] - assert call_args.root.method == "test2" - - -@pytest.mark.anyio -async def test_listener_task_handoff_on_cancellation(message_dispatch): - """ - Test that the single listening task is properly - handed off when a session is cancelled. - """ - session1 = uuid4() - session2 = uuid4() - - session1_messages_received = 0 - session2_messages_received = 0 - - async def callback1(msg): - nonlocal session1_messages_received - session1_messages_received += 1 - - async def callback2(msg): - nonlocal session2_messages_received - session2_messages_received += 1 - - # Create a cancel scope for session1 - async with anyio.create_task_group() as tg: - session1_cancel_scope: anyio.CancelScope | None = None - - async def session1_runner(): - nonlocal session1_cancel_scope - with anyio.CancelScope() as cancel_scope: - session1_cancel_scope = cancel_scope - async with message_dispatch.subscribe(session1, callback1): - # Keep session alive until cancelled - await anyio.sleep_forever() - - # Start session1 - tg.start_soon(session1_runner) - - # Wait for session1 to be established - await anyio.sleep(0.1) - assert session1 in message_dispatch._session_state - - # Send message to session1 to verify it's working - message1 = types.JSONRPCMessage.model_validate( - {"jsonrpc": "2.0", "method": "test1", "params": {}, "id": 1} - ) - await message_dispatch.publish_message(session1, message1) - await anyio.sleep(0.1) - assert session1_messages_received == 1 - - # Start session2 while session1 is still active - async with message_dispatch.subscribe(session2, callback2): - # Both sessions should be active - assert session1 in message_dispatch._session_state - assert session2 in message_dispatch._session_state - - # Cancel session1 - assert session1_cancel_scope is not None - session1_cancel_scope.cancel() - - # Wait for cancellation to complete - await anyio.sleep(0.1) - - # Session1 should be gone, session2 should remain - assert session1 not in message_dispatch._session_state - assert session2 in message_dispatch._session_state - - # Send message to session2 to verify the listener was handed off - message2 = types.JSONRPCMessage.model_validate( - {"jsonrpc": "2.0", "method": "test2", "params": {}, "id": 2} - ) - await message_dispatch.publish_message(session2, message2) - await anyio.sleep(0.1) - - # Session2 should have received the message - assert session2_messages_received == 1 - - # Session1 shouldn't receive any more messages - assert session1_messages_received == 1 - - # Send another message to verify the listener is still working - message3 = types.JSONRPCMessage.model_validate( - {"jsonrpc": "2.0", "method": "test3", "params": {}, "id": 3} - ) - await message_dispatch.publish_message(session2, message3) - await anyio.sleep(0.1) - - assert session2_messages_received == 2 diff --git a/tests/server/message_dispatch/test_redis_integration.py b/tests/server/message_dispatch/test_redis_integration.py deleted file mode 100644 index f01113872d..0000000000 --- a/tests/server/message_dispatch/test_redis_integration.py +++ /dev/null @@ -1,260 +0,0 @@ -""" -Integration tests for Redis message dispatch functionality. - -These tests validate Redis message dispatch by making actual HTTP calls and testing -that messages flow correctly through the Redis backend. - -This version runs the server in a task instead of a separate process to allow -access to the fakeredis instance for verification of Redis keys. -""" - -import asyncio -import socket -from collections.abc import AsyncGenerator -from contextlib import asynccontextmanager - -import anyio -import pytest -import uvicorn -from sse_starlette.sse import AppStatus -from starlette.applications import Starlette -from starlette.requests import Request -from starlette.responses import Response -from starlette.routing import Mount, Route - -from mcp.client.session import ClientSession -from mcp.client.sse import sse_client -from mcp.server import Server -from mcp.server.message_queue.redis import RedisMessageDispatch -from mcp.server.sse import SseServerTransport -from mcp.types import TextContent, Tool - -SERVER_NAME = "test_server_for_redis_integration_v3" - - -@pytest.fixture -def server_port() -> int: - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def server_url(server_port: int) -> str: - return f"http://127.0.0.1:{server_port}" - - -class RedisTestServer(Server): - """Test server with basic tool functionality.""" - - def __init__(self): - super().__init__(SERVER_NAME) - - @self.list_tools() - async def handle_list_tools() -> list[Tool]: - return [ - Tool( - name="test_tool", - description="A test tool", - inputSchema={"type": "object", "properties": {}}, - ), - Tool( - name="echo_message", - description="Echo a message back", - inputSchema={ - "type": "object", - "properties": {"message": {"type": "string"}}, - "required": ["message"], - }, - ), - ] - - @self.call_tool() - async def handle_call_tool(name: str, args: dict) -> list[TextContent]: - if name == "echo_message": - message = args.get("message", "") - return [TextContent(type="text", text=f"Echo: {message}")] - return [TextContent(type="text", text=f"Called {name}")] - - -@pytest.fixture() -async def redis_server_and_app(message_dispatch: RedisMessageDispatch): - """Create a mock Redis instance and Starlette app for testing.""" - - # Create SSE transport with Redis message dispatch - sse = SseServerTransport("/messages/", message_dispatch=message_dispatch) - server = RedisTestServer() - - async def handle_sse(request: Request): - async with sse.connect_sse( - request.scope, request.receive, request._send - ) as streams: - await server.run( - streams[0], streams[1], server.create_initialization_options() - ) - return Response() - - @asynccontextmanager - async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: - """Manage the lifecycle of the application.""" - try: - yield - finally: - await message_dispatch.close() - - app = Starlette( - routes=[ - Route("/sse", endpoint=handle_sse), - Mount("/messages/", app=sse.handle_post_message), - ], - lifespan=lifespan, - ) - - return app, message_dispatch, message_dispatch._redis - - -@pytest.fixture() -async def server_and_redis(redis_server_and_app, server_port: int): - """Run the server in a task and return the Redis instance for inspection.""" - app, message_dispatch, mock_redis = redis_server_and_app - - # Create a server config - config = uvicorn.Config( - app=app, host="127.0.0.1", port=server_port, log_level="error" - ) - server = uvicorn.Server(config=config) - try: - async with anyio.create_task_group() as tg: - # Start server in background - tg.start_soon(server.serve) - - # Wait for server to be ready - max_attempts = 20 - attempt = 0 - while attempt < max_attempts: - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.connect(("127.0.0.1", server_port)) - break - except ConnectionRefusedError: - await anyio.sleep(0.1) - attempt += 1 - else: - raise RuntimeError( - f"Server failed to start after {max_attempts} attempts" - ) - - try: - yield mock_redis, message_dispatch - finally: - server.should_exit = True - finally: - # These class variables are set top-level in starlette-sse - # It isn't designed to be run multiple times in a single - # Python process so we need to manually reset them. - AppStatus.should_exit = False - AppStatus.should_exit_event = None - - -@pytest.fixture() -async def client_session(server_and_redis, server_url: str): - """Create a client session for testing.""" - async with sse_client(server_url + "/sse") as streams: - async with ClientSession(*streams) as session: - result = await session.initialize() - assert result.serverInfo.name == SERVER_NAME - yield session - - -@pytest.mark.anyio -async def test_redis_integration_key_verification( - server_and_redis, client_session -) -> None: - """Test that Redis keys are created correctly for sessions.""" - mock_redis, _ = server_and_redis - - all_keys = await mock_redis.keys("*") # type: ignore - - assert len(all_keys) > 0 - - session_key = None - for key in all_keys: - if key.startswith("mcp:pubsub:session_active:"): - session_key = key - break - - assert session_key is not None, f"No session key found. Keys: {all_keys}" - - ttl = await mock_redis.ttl(session_key) # type: ignore - assert ttl > 0, f"Session key should have TTL, got: {ttl}" - - -@pytest.mark.anyio -async def test_tool_calls(server_and_redis, client_session) -> None: - """Test that messages are properly published through Redis.""" - mock_redis, _ = server_and_redis - - for i in range(3): - tool_result = await client_session.call_tool( - "echo_message", {"message": f"Test {i}"} - ) - assert tool_result.content[0].text == f"Echo: Test {i}" # type: ignore - - -@pytest.mark.anyio -async def test_session_cleanup(server_and_redis, server_url: str) -> None: - """Test Redis key cleanup when sessions end.""" - mock_redis, _ = server_and_redis - session_keys_seen = set() - - for i in range(3): - async with sse_client(server_url + "/sse") as streams: - async with ClientSession(*streams) as session: - await session.initialize() - - all_keys = await mock_redis.keys("*") # type: ignore - for key in all_keys: - if key.startswith("mcp:pubsub:session_active:"): - session_keys_seen.add(key) - value = await mock_redis.get(key) # type: ignore - assert value == "1" - - await anyio.sleep(0.1) # Give time for cleanup - all_keys = await mock_redis.keys("*") # type: ignore - assert ( - len(all_keys) == 0 - ), f"Session keys should be cleaned up, found: {all_keys}" - - # Verify we saw different session keys for each session - assert len(session_keys_seen) == 3, "Should have seen 3 unique session keys" - - -@pytest.mark.anyio -async def concurrent_tool_call(server_and_redis, server_url: str) -> None: - """Test multiple clients and verify Redis key management.""" - mock_redis, _ = server_and_redis - - async def client_task(client_id: int) -> str: - async with sse_client(server_url + "/sse") as streams: - async with ClientSession(*streams) as session: - await session.initialize() - - result = await session.call_tool( - "echo_message", - {"message": f"Message from client {client_id}"}, - ) - return result.content[0].text # type: ignore - - # Run multiple clients concurrently - client_tasks = [client_task(i) for i in range(3)] - results = await asyncio.gather(*client_tasks) - - # Verify all clients received their respective messages - assert len(results) == 3 - for i, result in enumerate(results): - assert result == f"Echo: Message from client {i}" - - # After all clients disconnect, keys should be cleaned up - await anyio.sleep(0.1) # Give time for cleanup - all_keys = await mock_redis.keys("*") # type: ignore - assert len(all_keys) == 0, f"Session keys should be cleaned up, found: {all_keys}" diff --git a/tests/server/test_lifespan.py b/tests/server/test_lifespan.py index d8e76de1ac..a3ff59bc1b 100644 --- a/tests/server/test_lifespan.py +++ b/tests/server/test_lifespan.py @@ -84,7 +84,7 @@ async def run_server(): ) await send_stream1.send( SessionMessage( - message=JSONRPCMessage( + JSONRPCMessage( root=JSONRPCRequest( jsonrpc="2.0", id=1, @@ -100,7 +100,7 @@ async def run_server(): # Send initialized notification await send_stream1.send( SessionMessage( - message=JSONRPCMessage( + JSONRPCMessage( root=JSONRPCNotification( jsonrpc="2.0", method="notifications/initialized", @@ -112,7 +112,7 @@ async def run_server(): # Call the tool to verify lifespan context await send_stream1.send( SessionMessage( - message=JSONRPCMessage( + JSONRPCMessage( root=JSONRPCRequest( jsonrpc="2.0", id=2, @@ -188,7 +188,7 @@ async def run_server(): ) await send_stream1.send( SessionMessage( - message=JSONRPCMessage( + JSONRPCMessage( root=JSONRPCRequest( jsonrpc="2.0", id=1, @@ -204,7 +204,7 @@ async def run_server(): # Send initialized notification await send_stream1.send( SessionMessage( - message=JSONRPCMessage( + JSONRPCMessage( root=JSONRPCNotification( jsonrpc="2.0", method="notifications/initialized", @@ -216,7 +216,7 @@ async def run_server(): # Call the tool to verify lifespan context await send_stream1.send( SessionMessage( - message=JSONRPCMessage( + JSONRPCMessage( root=JSONRPCRequest( jsonrpc="2.0", id=2, diff --git a/tests/server/test_stdio.py b/tests/server/test_stdio.py index 570e4c1999..c546a7167b 100644 --- a/tests/server/test_stdio.py +++ b/tests/server/test_stdio.py @@ -51,7 +51,7 @@ async def test_stdio_server(): async with write_stream: for response in responses: - session_message = SessionMessage(message=response) + session_message = SessionMessage(response) await write_stream.send(session_message) stdout.seek(0) diff --git a/uv.lock b/uv.lock index e819dbfe87..06dd240b25 100644 --- a/uv.lock +++ b/uv.lock @@ -1,4 +1,5 @@ version = 1 +revision = 1 requires-python = ">=3.10" [options] @@ -38,15 +39,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3b/68/f9e9bf6324c46e6b8396610aef90ad423ec3e18c9079547ceafea3dce0ec/anyio-4.5.0-py3-none-any.whl", hash = "sha256:fdeb095b7cc5a5563175eedd926ec4ae55413bb4be5770c424af0ba46ccb4a78", size = 89250 }, ] -[[package]] -name = "async-timeout" -version = "5.0.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a5/ae/136395dfbfe00dfc94da3f3e136d0b13f394cba8f4841120e34226265780/async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3", size = 9274 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/fe/ba/e2081de779ca30d473f21f5b30e0e737c438205440784c7dfc81efc2b029/async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c", size = 6233 }, -] - [[package]] name = "attrs" version = "24.3.0" @@ -275,51 +267,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335 }, ] -[[package]] -name = "cryptography" -version = "44.0.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "cffi", marker = "platform_python_implementation != 'PyPy'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/cd/25/4ce80c78963834b8a9fd1cc1266be5ed8d1840785c0f2e1b73b8d128d505/cryptography-44.0.2.tar.gz", hash = "sha256:c63454aa261a0cf0c5b4718349629793e9e634993538db841165b3df74f37ec0", size = 710807 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/92/ef/83e632cfa801b221570c5f58c0369db6fa6cef7d9ff859feab1aae1a8a0f/cryptography-44.0.2-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:efcfe97d1b3c79e486554efddeb8f6f53a4cdd4cf6086642784fa31fc384e1d7", size = 6676361 }, - { url = "https://files.pythonhosted.org/packages/30/ec/7ea7c1e4c8fc8329506b46c6c4a52e2f20318425d48e0fe597977c71dbce/cryptography-44.0.2-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:29ecec49f3ba3f3849362854b7253a9f59799e3763b0c9d0826259a88efa02f1", size = 3952350 }, - { url = "https://files.pythonhosted.org/packages/27/61/72e3afdb3c5ac510330feba4fc1faa0fe62e070592d6ad00c40bb69165e5/cryptography-44.0.2-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc821e161ae88bfe8088d11bb39caf2916562e0a2dc7b6d56714a48b784ef0bb", size = 4166572 }, - { url = "https://files.pythonhosted.org/packages/26/e4/ba680f0b35ed4a07d87f9e98f3ebccb05091f3bf6b5a478b943253b3bbd5/cryptography-44.0.2-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:3c00b6b757b32ce0f62c574b78b939afab9eecaf597c4d624caca4f9e71e7843", size = 3958124 }, - { url = "https://files.pythonhosted.org/packages/9c/e8/44ae3e68c8b6d1cbc59040288056df2ad7f7f03bbcaca6b503c737ab8e73/cryptography-44.0.2-cp37-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:7bdcd82189759aba3816d1f729ce42ffded1ac304c151d0a8e89b9996ab863d5", size = 3678122 }, - { url = "https://files.pythonhosted.org/packages/27/7b/664ea5e0d1eab511a10e480baf1c5d3e681c7d91718f60e149cec09edf01/cryptography-44.0.2-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:4973da6ca3db4405c54cd0b26d328be54c7747e89e284fcff166132eb7bccc9c", size = 4191831 }, - { url = "https://files.pythonhosted.org/packages/2a/07/79554a9c40eb11345e1861f46f845fa71c9e25bf66d132e123d9feb8e7f9/cryptography-44.0.2-cp37-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:4e389622b6927d8133f314949a9812972711a111d577a5d1f4bee5e58736b80a", size = 3960583 }, - { url = "https://files.pythonhosted.org/packages/bb/6d/858e356a49a4f0b591bd6789d821427de18432212e137290b6d8a817e9bf/cryptography-44.0.2-cp37-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:f514ef4cd14bb6fb484b4a60203e912cfcb64f2ab139e88c2274511514bf7308", size = 4191753 }, - { url = "https://files.pythonhosted.org/packages/b2/80/62df41ba4916067fa6b125aa8c14d7e9181773f0d5d0bd4dcef580d8b7c6/cryptography-44.0.2-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:1bc312dfb7a6e5d66082c87c34c8a62176e684b6fe3d90fcfe1568de675e6688", size = 4079550 }, - { url = "https://files.pythonhosted.org/packages/f3/cd/2558cc08f7b1bb40683f99ff4327f8dcfc7de3affc669e9065e14824511b/cryptography-44.0.2-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:3b721b8b4d948b218c88cb8c45a01793483821e709afe5f622861fc6182b20a7", size = 4298367 }, - { url = "https://files.pythonhosted.org/packages/71/59/94ccc74788945bc3bd4cf355d19867e8057ff5fdbcac781b1ff95b700fb1/cryptography-44.0.2-cp37-abi3-win32.whl", hash = "sha256:51e4de3af4ec3899d6d178a8c005226491c27c4ba84101bfb59c901e10ca9f79", size = 2772843 }, - { url = "https://files.pythonhosted.org/packages/ca/2c/0d0bbaf61ba05acb32f0841853cfa33ebb7a9ab3d9ed8bb004bd39f2da6a/cryptography-44.0.2-cp37-abi3-win_amd64.whl", hash = "sha256:c505d61b6176aaf982c5717ce04e87da5abc9a36a5b39ac03905c4aafe8de7aa", size = 3209057 }, - { url = "https://files.pythonhosted.org/packages/9e/be/7a26142e6d0f7683d8a382dd963745e65db895a79a280a30525ec92be890/cryptography-44.0.2-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:8e0ddd63e6bf1161800592c71ac794d3fb8001f2caebe0966e77c5234fa9efc3", size = 6677789 }, - { url = "https://files.pythonhosted.org/packages/06/88/638865be7198a84a7713950b1db7343391c6066a20e614f8fa286eb178ed/cryptography-44.0.2-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:81276f0ea79a208d961c433a947029e1a15948966658cf6710bbabb60fcc2639", size = 3951919 }, - { url = "https://files.pythonhosted.org/packages/d7/fc/99fe639bcdf58561dfad1faa8a7369d1dc13f20acd78371bb97a01613585/cryptography-44.0.2-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9a1e657c0f4ea2a23304ee3f964db058c9e9e635cc7019c4aa21c330755ef6fd", size = 4167812 }, - { url = "https://files.pythonhosted.org/packages/53/7b/aafe60210ec93d5d7f552592a28192e51d3c6b6be449e7fd0a91399b5d07/cryptography-44.0.2-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:6210c05941994290f3f7f175a4a57dbbb2afd9273657614c506d5976db061181", size = 3958571 }, - { url = "https://files.pythonhosted.org/packages/16/32/051f7ce79ad5a6ef5e26a92b37f172ee2d6e1cce09931646eef8de1e9827/cryptography-44.0.2-cp39-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:d1c3572526997b36f245a96a2b1713bf79ce99b271bbcf084beb6b9b075f29ea", size = 3679832 }, - { url = "https://files.pythonhosted.org/packages/78/2b/999b2a1e1ba2206f2d3bca267d68f350beb2b048a41ea827e08ce7260098/cryptography-44.0.2-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:b042d2a275c8cee83a4b7ae30c45a15e6a4baa65a179a0ec2d78ebb90e4f6699", size = 4193719 }, - { url = "https://files.pythonhosted.org/packages/72/97/430e56e39a1356e8e8f10f723211a0e256e11895ef1a135f30d7d40f2540/cryptography-44.0.2-cp39-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:d03806036b4f89e3b13b6218fefea8d5312e450935b1a2d55f0524e2ed7c59d9", size = 3960852 }, - { url = "https://files.pythonhosted.org/packages/89/33/c1cf182c152e1d262cac56850939530c05ca6c8d149aa0dcee490b417e99/cryptography-44.0.2-cp39-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:c7362add18b416b69d58c910caa217f980c5ef39b23a38a0880dfd87bdf8cd23", size = 4193906 }, - { url = "https://files.pythonhosted.org/packages/e1/99/87cf26d4f125380dc674233971069bc28d19b07f7755b29861570e513650/cryptography-44.0.2-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:8cadc6e3b5a1f144a039ea08a0bdb03a2a92e19c46be3285123d32029f40a922", size = 4081572 }, - { url = "https://files.pythonhosted.org/packages/b3/9f/6a3e0391957cc0c5f84aef9fbdd763035f2b52e998a53f99345e3ac69312/cryptography-44.0.2-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:6f101b1f780f7fc613d040ca4bdf835c6ef3b00e9bd7125a4255ec574c7916e4", size = 4298631 }, - { url = "https://files.pythonhosted.org/packages/e2/a5/5bc097adb4b6d22a24dea53c51f37e480aaec3465285c253098642696423/cryptography-44.0.2-cp39-abi3-win32.whl", hash = "sha256:3dc62975e31617badc19a906481deacdeb80b4bb454394b4098e3f2525a488c5", size = 2773792 }, - { url = "https://files.pythonhosted.org/packages/33/cf/1f7649b8b9a3543e042d3f348e398a061923ac05b507f3f4d95f11938aa9/cryptography-44.0.2-cp39-abi3-win_amd64.whl", hash = "sha256:5f6f90b72d8ccadb9c6e311c775c8305381db88374c65fa1a68250aa8a9cb3a6", size = 3210957 }, - { url = "https://files.pythonhosted.org/packages/99/10/173be140714d2ebaea8b641ff801cbcb3ef23101a2981cbf08057876f89e/cryptography-44.0.2-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:af4ff3e388f2fa7bff9f7f2b31b87d5651c45731d3e8cfa0944be43dff5cfbdb", size = 3396886 }, - { url = "https://files.pythonhosted.org/packages/2f/b4/424ea2d0fce08c24ede307cead3409ecbfc2f566725d4701b9754c0a1174/cryptography-44.0.2-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:0529b1d5a0105dd3731fa65680b45ce49da4d8115ea76e9da77a875396727b41", size = 3892387 }, - { url = "https://files.pythonhosted.org/packages/28/20/8eaa1a4f7c68a1cb15019dbaad59c812d4df4fac6fd5f7b0b9c5177f1edd/cryptography-44.0.2-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:7ca25849404be2f8e4b3c59483d9d3c51298a22c1c61a0e84415104dacaf5562", size = 4109922 }, - { url = "https://files.pythonhosted.org/packages/11/25/5ed9a17d532c32b3bc81cc294d21a36c772d053981c22bd678396bc4ae30/cryptography-44.0.2-pp310-pypy310_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:268e4e9b177c76d569e8a145a6939eca9a5fec658c932348598818acf31ae9a5", size = 3895715 }, - { url = "https://files.pythonhosted.org/packages/63/31/2aac03b19c6329b62c45ba4e091f9de0b8f687e1b0cd84f101401bece343/cryptography-44.0.2-pp310-pypy310_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:9eb9d22b0a5d8fd9925a7764a054dca914000607dff201a24c791ff5c799e1fa", size = 4109876 }, - { url = "https://files.pythonhosted.org/packages/99/ec/6e560908349843718db1a782673f36852952d52a55ab14e46c42c8a7690a/cryptography-44.0.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:2bf7bf75f7df9715f810d1b038870309342bff3069c5bd8c6b96128cb158668d", size = 3131719 }, - { url = "https://files.pythonhosted.org/packages/d6/d7/f30e75a6aa7d0f65031886fa4a1485c2fbfe25a1896953920f6a9cfe2d3b/cryptography-44.0.2-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:909c97ab43a9c0c0b0ada7a1281430e4e5ec0458e6d9244c0e821bbf152f061d", size = 3887513 }, - { url = "https://files.pythonhosted.org/packages/9c/b4/7a494ce1032323ca9db9a3661894c66e0d7142ad2079a4249303402d8c71/cryptography-44.0.2-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:96e7a5e9d6e71f9f4fca8eebfd603f8e86c5225bb18eb621b2c1e50b290a9471", size = 4107432 }, - { url = "https://files.pythonhosted.org/packages/45/f8/6b3ec0bc56123b344a8d2b3264a325646d2dcdbdd9848b5e6f3d37db90b3/cryptography-44.0.2-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:d1b3031093a366ac767b3feb8bcddb596671b3aaff82d4050f984da0c248b615", size = 3891421 }, - { url = "https://files.pythonhosted.org/packages/57/ff/f3b4b2d007c2a646b0f69440ab06224f9cf37a977a72cdb7b50632174e8a/cryptography-44.0.2-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:04abd71114848aa25edb28e225ab5f268096f44cf0127f3d36975bdf1bdf3390", size = 4107081 }, -] - [[package]] name = "cssselect2" version = "0.8.0" @@ -360,20 +307,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/43/09/2aea36ff60d16dd8879bdb2f5b3ee0ba8d08cbbdcdfe870e695ce3784385/execnet-2.1.1-py3-none-any.whl", hash = "sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc", size = 40612 }, ] -[[package]] -name = "fakeredis" -version = "2.28.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "redis" }, - { name = "sortedcontainers" }, - { name = "typing-extensions", marker = "python_full_version < '3.11'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/60/32/8c1c215e50cb055e24a8d5a8981edab665d131ea9068c420bf81eb0fcb63/fakeredis-2.28.1.tar.gz", hash = "sha256:5e542200b945aa0a7afdc0396efefe3cdabab61bc0f41736cc45f68960255964", size = 161179 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e1/77/bca49c4960c22131da3acb647978983bea07f15c255fbef0a6559a774a7a/fakeredis-2.28.1-py3-none-any.whl", hash = "sha256:38c7c17fba5d5522af9d980a8f74a4da9900a3441e8f25c0fe93ea4205d695d1", size = 113685 }, -] - [[package]] name = "ghp-import" version = "2.1.0" @@ -574,10 +507,6 @@ cli = [ { name = "python-dotenv" }, { name = "typer" }, ] -redis = [ - { name = "redis" }, - { name = "types-redis" }, -] rich = [ { name = "rich" }, ] @@ -587,7 +516,6 @@ ws = [ [package.dev-dependencies] dev = [ - { name = "fakeredis" }, { name = "pyright" }, { name = "pytest" }, { name = "pytest-examples" }, @@ -613,19 +541,17 @@ requires-dist = [ { name = "pydantic-settings", specifier = ">=2.5.2" }, { name = "python-dotenv", marker = "extra == 'cli'", specifier = ">=1.0.0" }, { name = "python-multipart", specifier = ">=0.0.9" }, - { name = "redis", marker = "extra == 'redis'", specifier = ">=5.2.1" }, { name = "rich", marker = "extra == 'rich'", specifier = ">=13.9.4" }, { name = "sse-starlette", specifier = ">=1.6.1" }, { name = "starlette", specifier = ">=0.27" }, { name = "typer", marker = "extra == 'cli'", specifier = ">=0.12.4" }, - { name = "types-redis", marker = "extra == 'redis'", specifier = ">=4.6.0.20241004" }, { name = "uvicorn", marker = "sys_platform != 'emscripten'", specifier = ">=0.23.1" }, { name = "websockets", marker = "extra == 'ws'", specifier = ">=15.0.1" }, ] +provides-extras = ["cli", "rich", "ws"] [package.metadata.requires-dev] dev = [ - { name = "fakeredis", specifier = "==2.28.1" }, { name = "pyright", specifier = ">=1.1.391" }, { name = "pytest", specifier = ">=8.3.4" }, { name = "pytest-examples", specifier = ">=0.0.14" }, @@ -1397,18 +1323,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5a/66/bbb1dd374f5c870f59c5bb1db0e18cbe7fa739415a24cbd95b2d1f5ae0c4/pyyaml_env_tag-0.1-py3-none-any.whl", hash = "sha256:af31106dec8a4d68c60207c1886031cbf839b68aa7abccdb19868200532c2069", size = 3911 }, ] -[[package]] -name = "redis" -version = "5.2.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "async-timeout", marker = "python_full_version < '3.11.3'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/47/da/d283a37303a995cd36f8b92db85135153dc4f7a8e4441aa827721b442cfb/redis-5.2.1.tar.gz", hash = "sha256:16f2e22dff21d5125e8481515e386711a34cbec50f0e44413dd7d9c060a54e0f", size = 4608355 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/3c/5f/fa26b9b2672cbe30e07d9a5bdf39cf16e3b80b42916757c5f92bca88e4ba/redis-5.2.1-py3-none-any.whl", hash = "sha256:ee7e1056b9aea0f04c6c2ed59452947f34c4940ee025f5dd83e6a6418b6989e4", size = 261502 }, -] - [[package]] name = "regex" version = "2024.11.6" @@ -1532,15 +1446,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/bf/5e/ffee22bf9f9e4b2669d1f0179ae8804584939fb6502b51f2401e26b1e028/ruff-0.8.5-py3-none-win_arm64.whl", hash = "sha256:134ae019ef13e1b060ab7136e7828a6d83ea727ba123381307eb37c6bd5e01cb", size = 9124741 }, ] -[[package]] -name = "setuptools" -version = "78.1.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a9/5a/0db4da3bc908df06e5efae42b44e75c81dd52716e10192ff36d0c1c8e379/setuptools-78.1.0.tar.gz", hash = "sha256:18fd474d4a82a5f83dac888df697af65afa82dec7323d09c3e37d1f14288da54", size = 1367827 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/54/21/f43f0a1fa8b06b32812e0975981f4677d28e0f3271601dc88ac5a5b83220/setuptools-78.1.0-py3-none-any.whl", hash = "sha256:3e386e96793c8702ae83d17b853fb93d3e09ef82ec62722e61da5cd22376dcd8", size = 1256108 }, -] - [[package]] name = "shellingham" version = "1.5.4" @@ -1685,56 +1590,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ae/cc/15083dcde1252a663398b1b2a173637a3ec65adadfb95137dc95df1e6adc/typer-0.12.4-py3-none-any.whl", hash = "sha256:819aa03699f438397e876aa12b0d63766864ecba1b579092cc9fe35d886e34b6", size = 47402 }, ] -[[package]] -name = "types-cffi" -version = "1.17.0.20250326" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "types-setuptools" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/3f/3b/d29491d754b9e42edd4890648311ffa5d4d000b7d97b92ac4d04faad40d8/types_cffi-1.17.0.20250326.tar.gz", hash = "sha256:6c8fea2c2f34b55e5fb77b1184c8ad849d57cf0ddccbc67a62121ac4b8b32254", size = 16887 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/61/49/ce473d7fbc2c80931ef9f7530fd3ddf31b8a5bca56340590334ce6ffbfb1/types_cffi-1.17.0.20250326-py3-none-any.whl", hash = "sha256:5af4ecd7374ae0d5fa9e80864e8d4b31088cc32c51c544e3af7ed5b5ed681447", size = 20133 }, -] - -[[package]] -name = "types-pyopenssl" -version = "24.1.0.20240722" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "cryptography" }, - { name = "types-cffi" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/93/29/47a346550fd2020dac9a7a6d033ea03fccb92fa47c726056618cc889745e/types-pyOpenSSL-24.1.0.20240722.tar.gz", hash = "sha256:47913b4678a01d879f503a12044468221ed8576263c1540dcb0484ca21b08c39", size = 8458 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/98/05/c868a850b6fbb79c26f5f299b768ee0adc1f9816d3461dcf4287916f655b/types_pyOpenSSL-24.1.0.20240722-py3-none-any.whl", hash = "sha256:6a7a5d2ec042537934cfb4c9d4deb0e16c4c6250b09358df1f083682fe6fda54", size = 7499 }, -] - -[[package]] -name = "types-redis" -version = "4.6.0.20241004" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "cryptography" }, - { name = "types-pyopenssl" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/3a/95/c054d3ac940e8bac4ca216470c80c26688a0e79e09f520a942bb27da3386/types-redis-4.6.0.20241004.tar.gz", hash = "sha256:5f17d2b3f9091ab75384153bfa276619ffa1cf6a38da60e10d5e6749cc5b902e", size = 49679 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/55/82/7d25dce10aad92d2226b269bce2f85cfd843b4477cd50245d7d40ecf8f89/types_redis-4.6.0.20241004-py3-none-any.whl", hash = "sha256:ef5da68cb827e5f606c8f9c0b49eeee4c2669d6d97122f301d3a55dc6a63f6ed", size = 58737 }, -] - -[[package]] -name = "types-setuptools" -version = "78.1.0.20250329" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "setuptools" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/e9/6e/c54e6705e5fe67c3606e4c7c91123ecf10d7e1e6d7a9c11b52970cf2196c/types_setuptools-78.1.0.20250329.tar.gz", hash = "sha256:31e62950c38b8cc1c5114b077504e36426860a064287cac11b9666ab3a483234", size = 43942 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/7d/31/85d0264705d8ef47680d28f4dc9bb1e27d8cace785fbe3f8d009fad6cb88/types_setuptools-78.1.0.20250329-py3-none-any.whl", hash = "sha256:ea47eab891afb506f470eee581dcde44d64dc99796665da794da6f83f50f6776", size = 66985 }, -] - [[package]] name = "typing-extensions" version = "4.12.2" From a1307abdedf8bb724485b559343752621d4b4cb9 Mon Sep 17 00:00:00 2001 From: yabea <847511885@qq.com> Date: Thu, 8 May 2025 00:42:02 +0800 Subject: [PATCH 16/21] Fix the issue of get Authorization header fails during bearer auth (#637) Co-authored-by: yangben --- src/mcp/server/auth/middleware/bearer_auth.py | 11 +++- .../auth/middleware/test_bearer_auth.py | 61 +++++++++++++++++++ 2 files changed, 70 insertions(+), 2 deletions(-) diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py index 295605af7b..30b5e2ba65 100644 --- a/src/mcp/server/auth/middleware/bearer_auth.py +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -34,8 +34,15 @@ def __init__( self.provider = provider async def authenticate(self, conn: HTTPConnection): - auth_header = conn.headers.get("Authorization") - if not auth_header or not auth_header.startswith("Bearer "): + auth_header = next( + ( + conn.headers.get(key) + for key in conn.headers + if key.lower() == "authorization" + ), + None, + ) + if not auth_header or not auth_header.lower().startswith("bearer "): return None token = auth_header[7:] # Remove "Bearer " prefix diff --git a/tests/server/auth/middleware/test_bearer_auth.py b/tests/server/auth/middleware/test_bearer_auth.py index 9acb5ff095..e8c17a4c42 100644 --- a/tests/server/auth/middleware/test_bearer_auth.py +++ b/tests/server/auth/middleware/test_bearer_auth.py @@ -7,6 +7,7 @@ import pytest from starlette.authentication import AuthCredentials +from starlette.datastructures import Headers from starlette.exceptions import HTTPException from starlette.requests import Request from starlette.types import Message, Receive, Scope, Send @@ -221,6 +222,66 @@ async def test_token_without_expiry( assert user.access_token == no_expiry_access_token assert user.scopes == ["read", "write"] + async def test_lowercase_bearer_prefix( + self, + mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any], + valid_access_token: AccessToken, + ): + """Test with lowercase 'bearer' prefix in Authorization header""" + backend = BearerAuthBackend(provider=mock_oauth_provider) + add_token_to_provider(mock_oauth_provider, "valid_token", valid_access_token) + headers = Headers({"Authorization": "bearer valid_token"}) + scope = {"type": "http", "headers": headers.raw} + request = Request(scope) + result = await backend.authenticate(request) + assert result is not None + credentials, user = result + assert isinstance(credentials, AuthCredentials) + assert isinstance(user, AuthenticatedUser) + assert credentials.scopes == ["read", "write"] + assert user.display_name == "test_client" + assert user.access_token == valid_access_token + + async def test_mixed_case_bearer_prefix( + self, + mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any], + valid_access_token: AccessToken, + ): + """Test with mixed 'BeArEr' prefix in Authorization header""" + backend = BearerAuthBackend(provider=mock_oauth_provider) + add_token_to_provider(mock_oauth_provider, "valid_token", valid_access_token) + headers = Headers({"authorization": "BeArEr valid_token"}) + scope = {"type": "http", "headers": headers.raw} + request = Request(scope) + result = await backend.authenticate(request) + assert result is not None + credentials, user = result + assert isinstance(credentials, AuthCredentials) + assert isinstance(user, AuthenticatedUser) + assert credentials.scopes == ["read", "write"] + assert user.display_name == "test_client" + assert user.access_token == valid_access_token + + async def test_mixed_case_authorization_header( + self, + mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any], + valid_access_token: AccessToken, + ): + """Test authentication with mixed 'Authorization' header.""" + backend = BearerAuthBackend(provider=mock_oauth_provider) + add_token_to_provider(mock_oauth_provider, "valid_token", valid_access_token) + headers = Headers({"AuThOrIzAtIoN": "BeArEr valid_token"}) + scope = {"type": "http", "headers": headers.raw} + request = Request(scope) + result = await backend.authenticate(request) + assert result is not None + credentials, user = result + assert isinstance(credentials, AuthCredentials) + assert isinstance(user, AuthenticatedUser) + assert credentials.scopes == ["read", "write"] + assert user.display_name == "test_client" + assert user.access_token == valid_access_token + @pytest.mark.anyio class TestRequireAuthMiddleware: From a027d75f609000378522c5873c2a16aa1963d487 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Wed, 7 May 2025 17:52:29 +0100 Subject: [PATCH 17/21] Auth SSE simple server example (#610) Co-authored-by: Peter Raboud Co-authored-by: David Soria Parra Co-authored-by: Basil Hosmer Co-authored-by: Paul Carleton Co-authored-by: Paul Carleton --- examples/servers/simple-auth/README.md | 65 ++++ .../simple-auth/mcp_simple_auth/__init__.py | 1 + .../simple-auth/mcp_simple_auth/__main__.py | 7 + .../simple-auth/mcp_simple_auth/server.py | 368 ++++++++++++++++++ examples/servers/simple-auth/pyproject.toml | 31 ++ src/mcp/server/auth/routes.py | 2 +- uv.lock | 42 ++ 7 files changed, 515 insertions(+), 1 deletion(-) create mode 100644 examples/servers/simple-auth/README.md create mode 100644 examples/servers/simple-auth/mcp_simple_auth/__init__.py create mode 100644 examples/servers/simple-auth/mcp_simple_auth/__main__.py create mode 100644 examples/servers/simple-auth/mcp_simple_auth/server.py create mode 100644 examples/servers/simple-auth/pyproject.toml diff --git a/examples/servers/simple-auth/README.md b/examples/servers/simple-auth/README.md new file mode 100644 index 0000000000..1d0979d97d --- /dev/null +++ b/examples/servers/simple-auth/README.md @@ -0,0 +1,65 @@ +# Simple MCP Server with GitHub OAuth Authentication + +This is a simple example of an MCP server with GitHub OAuth authentication. It demonstrates the essential components needed for OAuth integration with just a single tool. + +This is just an example of a server that uses auth, an official GitHub mcp server is [here](https://github.com/github/github-mcp-server) + +## Overview + +This simple demo to show to set up a server with: +- GitHub OAuth2 authorization flow +- Single tool: `get_user_profile` to retrieve GitHub user information + + +## Prerequisites + +1. Create a GitHub OAuth App: + - Go to GitHub Settings > Developer settings > OAuth Apps > New OAuth App + - Application name: Any name (e.g., "Simple MCP Auth Demo") + - Homepage URL: `http://localhost:8000` + - Authorization callback URL: `http://localhost:8000/github/callback` + - Click "Register application" + - Note down your Client ID and Client Secret + +## Required Environment Variables + +You MUST set these environment variables before running the server: + +```bash +export MCP_GITHUB_GITHUB_CLIENT_ID="your_client_id_here" +export MCP_GITHUB_GITHUB_CLIENT_SECRET="your_client_secret_here" +``` + +The server will not start without these environment variables properly set. + + +## Running the Server + +```bash +# Set environment variables first (see above) + +# Run the server +uv run mcp-simple-auth +``` + +The server will start on `http://localhost:8000`. + +## Available Tool + +### get_user_profile + +The only tool in this simple example. Returns the authenticated user's GitHub profile information. + +**Required scope**: `user` + +**Returns**: GitHub user profile data including username, email, bio, etc. + + +## Troubleshooting + +If the server fails to start, check: +1. Environment variables `MCP_GITHUB_GITHUB_CLIENT_ID` and `MCP_GITHUB_GITHUB_CLIENT_SECRET` are set +2. The GitHub OAuth app callback URL matches `http://localhost:8000/github/callback` +3. No other service is using port 8000 + +You can use [Inspector](https://github.com/modelcontextprotocol/inspector) to test Auth \ No newline at end of file diff --git a/examples/servers/simple-auth/mcp_simple_auth/__init__.py b/examples/servers/simple-auth/mcp_simple_auth/__init__.py new file mode 100644 index 0000000000..3e12b31832 --- /dev/null +++ b/examples/servers/simple-auth/mcp_simple_auth/__init__.py @@ -0,0 +1 @@ +"""Simple MCP server with GitHub OAuth authentication.""" diff --git a/examples/servers/simple-auth/mcp_simple_auth/__main__.py b/examples/servers/simple-auth/mcp_simple_auth/__main__.py new file mode 100644 index 0000000000..a8840780b8 --- /dev/null +++ b/examples/servers/simple-auth/mcp_simple_auth/__main__.py @@ -0,0 +1,7 @@ +"""Main entry point for simple MCP server with GitHub OAuth authentication.""" + +import sys + +from mcp_simple_auth.server import main + +sys.exit(main()) diff --git a/examples/servers/simple-auth/mcp_simple_auth/server.py b/examples/servers/simple-auth/mcp_simple_auth/server.py new file mode 100644 index 0000000000..7cd92aa799 --- /dev/null +++ b/examples/servers/simple-auth/mcp_simple_auth/server.py @@ -0,0 +1,368 @@ +"""Simple MCP Server with GitHub OAuth Authentication.""" + +import logging +import secrets +import time +from typing import Any + +import click +import httpx +from pydantic import AnyHttpUrl +from pydantic_settings import BaseSettings, SettingsConfigDict +from starlette.exceptions import HTTPException +from starlette.requests import Request +from starlette.responses import JSONResponse, RedirectResponse, Response + +from mcp.server.auth.middleware.auth_context import get_access_token +from mcp.server.auth.provider import ( + AccessToken, + AuthorizationCode, + AuthorizationParams, + OAuthAuthorizationServerProvider, + RefreshToken, + construct_redirect_uri, +) +from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions +from mcp.server.fastmcp.server import FastMCP +from mcp.shared.auth import OAuthClientInformationFull, OAuthToken + +logger = logging.getLogger(__name__) + + +class ServerSettings(BaseSettings): + """Settings for the simple GitHub MCP server.""" + + model_config = SettingsConfigDict(env_prefix="MCP_GITHUB_") + + # Server settings + host: str = "localhost" + port: int = 8000 + server_url: AnyHttpUrl = AnyHttpUrl("http://localhost:8000") + + # GitHub OAuth settings - MUST be provided via environment variables + github_client_id: str # Type: MCP_GITHUB_GITHUB_CLIENT_ID env var + github_client_secret: str # Type: MCP_GITHUB_GITHUB_CLIENT_SECRET env var + github_callback_path: str = "http://localhost:8000/github/callback" + + # GitHub OAuth URLs + github_auth_url: str = "https://github.com/login/oauth/authorize" + github_token_url: str = "https://github.com/login/oauth/access_token" + + mcp_scope: str = "user" + github_scope: str = "read:user" + + def __init__(self, **data): + """Initialize settings with values from environment variables. + + Note: github_client_id and github_client_secret are required but can be + loaded automatically from environment variables (MCP_GITHUB_GITHUB_CLIENT_ID + and MCP_GITHUB_GITHUB_CLIENT_SECRET) and don't need to be passed explicitly. + """ + super().__init__(**data) + + +class SimpleGitHubOAuthProvider(OAuthAuthorizationServerProvider): + """Simple GitHub OAuth provider with essential functionality.""" + + def __init__(self, settings: ServerSettings): + self.settings = settings + self.clients: dict[str, OAuthClientInformationFull] = {} + self.auth_codes: dict[str, AuthorizationCode] = {} + self.tokens: dict[str, AccessToken] = {} + self.state_mapping: dict[str, dict[str, str]] = {} + # Store GitHub tokens with MCP tokens using the format: + # {"mcp_token": "github_token"} + self.token_mapping: dict[str, str] = {} + + async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: + """Get OAuth client information.""" + return self.clients.get(client_id) + + async def register_client(self, client_info: OAuthClientInformationFull): + """Register a new OAuth client.""" + self.clients[client_info.client_id] = client_info + + async def authorize( + self, client: OAuthClientInformationFull, params: AuthorizationParams + ) -> str: + """Generate an authorization URL for GitHub OAuth flow.""" + state = params.state or secrets.token_hex(16) + + # Store the state mapping + self.state_mapping[state] = { + "redirect_uri": str(params.redirect_uri), + "code_challenge": params.code_challenge, + "redirect_uri_provided_explicitly": str( + params.redirect_uri_provided_explicitly + ), + "client_id": client.client_id, + } + + # Build GitHub authorization URL + auth_url = ( + f"{self.settings.github_auth_url}" + f"?client_id={self.settings.github_client_id}" + f"&redirect_uri={self.settings.github_callback_path}" + f"&scope={self.settings.github_scope}" + f"&state={state}" + ) + + return auth_url + + async def handle_github_callback(self, code: str, state: str) -> str: + """Handle GitHub OAuth callback.""" + state_data = self.state_mapping.get(state) + if not state_data: + raise HTTPException(400, "Invalid state parameter") + + redirect_uri = state_data["redirect_uri"] + code_challenge = state_data["code_challenge"] + redirect_uri_provided_explicitly = ( + state_data["redirect_uri_provided_explicitly"] == "True" + ) + client_id = state_data["client_id"] + + # Exchange code for token with GitHub + async with httpx.AsyncClient() as client: + response = await client.post( + self.settings.github_token_url, + data={ + "client_id": self.settings.github_client_id, + "client_secret": self.settings.github_client_secret, + "code": code, + "redirect_uri": self.settings.github_callback_path, + }, + headers={"Accept": "application/json"}, + ) + + if response.status_code != 200: + raise HTTPException(400, "Failed to exchange code for token") + + data = response.json() + + if "error" in data: + raise HTTPException(400, data.get("error_description", data["error"])) + + github_token = data["access_token"] + + # Create MCP authorization code + new_code = f"mcp_{secrets.token_hex(16)}" + auth_code = AuthorizationCode( + code=new_code, + client_id=client_id, + redirect_uri=AnyHttpUrl(redirect_uri), + redirect_uri_provided_explicitly=redirect_uri_provided_explicitly, + expires_at=time.time() + 300, + scopes=[self.settings.mcp_scope], + code_challenge=code_challenge, + ) + self.auth_codes[new_code] = auth_code + + # Store GitHub token - we'll map the MCP token to this later + self.tokens[github_token] = AccessToken( + token=github_token, + client_id=client_id, + scopes=[self.settings.github_scope], + expires_at=None, + ) + + del self.state_mapping[state] + return construct_redirect_uri(redirect_uri, code=new_code, state=state) + + async def load_authorization_code( + self, client: OAuthClientInformationFull, authorization_code: str + ) -> AuthorizationCode | None: + """Load an authorization code.""" + return self.auth_codes.get(authorization_code) + + async def exchange_authorization_code( + self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode + ) -> OAuthToken: + """Exchange authorization code for tokens.""" + if authorization_code.code not in self.auth_codes: + raise ValueError("Invalid authorization code") + + # Generate MCP access token + mcp_token = f"mcp_{secrets.token_hex(32)}" + + # Store MCP token + self.tokens[mcp_token] = AccessToken( + token=mcp_token, + client_id=client.client_id, + scopes=authorization_code.scopes, + expires_at=int(time.time()) + 3600, + ) + + # Find GitHub token for this client + github_token = next( + ( + token + for token, data in self.tokens.items() + # see https://github.blog/engineering/platform-security/behind-githubs-new-authentication-token-formats/ + # which you get depends on your GH app setup. + if (token.startswith("ghu_") or token.startswith("gho_")) + and data.client_id == client.client_id + ), + None, + ) + + # Store mapping between MCP token and GitHub token + if github_token: + self.token_mapping[mcp_token] = github_token + + del self.auth_codes[authorization_code.code] + + return OAuthToken( + access_token=mcp_token, + token_type="bearer", + expires_in=3600, + scope=" ".join(authorization_code.scopes), + ) + + async def load_access_token(self, token: str) -> AccessToken | None: + """Load and validate an access token.""" + access_token = self.tokens.get(token) + if not access_token: + return None + + # Check if expired + if access_token.expires_at and access_token.expires_at < time.time(): + del self.tokens[token] + return None + + return access_token + + async def load_refresh_token( + self, client: OAuthClientInformationFull, refresh_token: str + ) -> RefreshToken | None: + """Load a refresh token - not supported.""" + return None + + async def exchange_refresh_token( + self, + client: OAuthClientInformationFull, + refresh_token: RefreshToken, + scopes: list[str], + ) -> OAuthToken: + """Exchange refresh token""" + raise NotImplementedError("Not supported") + + async def revoke_token( + self, token: str, token_type_hint: str | None = None + ) -> None: + """Revoke a token.""" + if token in self.tokens: + del self.tokens[token] + + +def create_simple_mcp_server(settings: ServerSettings) -> FastMCP: + """Create a simple FastMCP server with GitHub OAuth.""" + oauth_provider = SimpleGitHubOAuthProvider(settings) + + auth_settings = AuthSettings( + issuer_url=settings.server_url, + client_registration_options=ClientRegistrationOptions( + enabled=True, + valid_scopes=[settings.mcp_scope], + default_scopes=[settings.mcp_scope], + ), + required_scopes=[settings.mcp_scope], + ) + + app = FastMCP( + name="Simple GitHub MCP Server", + instructions="A simple MCP server with GitHub OAuth authentication", + auth_server_provider=oauth_provider, + host=settings.host, + port=settings.port, + debug=True, + auth=auth_settings, + ) + + @app.custom_route("/github/callback", methods=["GET"]) + async def github_callback_handler(request: Request) -> Response: + """Handle GitHub OAuth callback.""" + code = request.query_params.get("code") + state = request.query_params.get("state") + + if not code or not state: + raise HTTPException(400, "Missing code or state parameter") + + try: + redirect_uri = await oauth_provider.handle_github_callback(code, state) + return RedirectResponse(status_code=302, url=redirect_uri) + except HTTPException: + raise + except Exception as e: + logger.error("Unexpected error", exc_info=e) + return JSONResponse( + status_code=500, + content={ + "error": "server_error", + "error_description": "Unexpected error", + }, + ) + + def get_github_token() -> str: + """Get the GitHub token for the authenticated user.""" + access_token = get_access_token() + if not access_token: + raise ValueError("Not authenticated") + + # Get GitHub token from mapping + github_token = oauth_provider.token_mapping.get(access_token.token) + + if not github_token: + raise ValueError("No GitHub token found for user") + + return github_token + + @app.tool() + async def get_user_profile() -> dict[str, Any]: + """Get the authenticated user's GitHub profile information. + + This is the only tool in our simple example. It requires the 'user' scope. + """ + github_token = get_github_token() + + async with httpx.AsyncClient() as client: + response = await client.get( + "https://api.github.com/user", + headers={ + "Authorization": f"Bearer {github_token}", + "Accept": "application/vnd.github.v3+json", + }, + ) + + if response.status_code != 200: + raise ValueError( + f"GitHub API error: {response.status_code} - {response.text}" + ) + + return response.json() + + return app + + +@click.command() +@click.option("--port", default=8000, help="Port to listen on") +@click.option("--host", default="localhost", help="Host to bind to") +def main(port: int, host: str) -> int: + """Run the simple GitHub MCP server.""" + logging.basicConfig(level=logging.INFO) + + try: + # No hardcoded credentials - all from environment variables + settings = ServerSettings(host=host, port=port) + except ValueError as e: + logger.error( + "Failed to load settings. Make sure environment variables are set:" + ) + logger.error(" MCP_GITHUB_GITHUB_CLIENT_ID=") + logger.error(" MCP_GITHUB_GITHUB_CLIENT_SECRET=") + logger.error(f"Error: {e}") + return 1 + + mcp_server = create_simple_mcp_server(settings) + mcp_server.run(transport="sse") + return 0 diff --git a/examples/servers/simple-auth/pyproject.toml b/examples/servers/simple-auth/pyproject.toml new file mode 100644 index 0000000000..40ae278a43 --- /dev/null +++ b/examples/servers/simple-auth/pyproject.toml @@ -0,0 +1,31 @@ +[project] +name = "mcp-simple-auth" +version = "0.1.0" +description = "A simple MCP server demonstrating OAuth authentication" +readme = "README.md" +requires-python = ">=3.10" +authors = [{ name = "Anthropic, PBC." }] +license = { text = "MIT" } +dependencies = [ + "anyio>=4.5", + "click>=8.1.0", + "httpx>=0.27", + "mcp", + "pydantic>=2.0", + "pydantic-settings>=2.5.2", + "sse-starlette>=1.6.1", + "uvicorn>=0.23.1; sys_platform != 'emscripten'", +] + +[project.scripts] +mcp-simple-auth = "mcp_simple_auth.server:main" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["mcp_simple_auth"] + +[tool.uv] +dev-dependencies = ["pyright>=1.1.391", "pytest>=8.3.4", "ruff>=0.8.5"] \ No newline at end of file diff --git a/src/mcp/server/auth/routes.py b/src/mcp/server/auth/routes.py index 29dd6a43a1..4c56ca2478 100644 --- a/src/mcp/server/auth/routes.py +++ b/src/mcp/server/auth/routes.py @@ -177,7 +177,7 @@ def build_metadata( issuer=issuer_url, authorization_endpoint=authorization_url, token_endpoint=token_url, - scopes_supported=None, + scopes_supported=client_registration_options.valid_scopes, response_types_supported=["code"], response_modes_supported=None, grant_types_supported=["authorization_code", "refresh_token"], diff --git a/uv.lock b/uv.lock index 06dd240b25..88869fa508 100644 --- a/uv.lock +++ b/uv.lock @@ -8,6 +8,7 @@ resolution-mode = "lowest-direct" [manifest] members = [ "mcp", + "mcp-simple-auth", "mcp-simple-prompt", "mcp-simple-resource", "mcp-simple-streamablehttp", @@ -568,6 +569,47 @@ docs = [ { name = "mkdocstrings-python", specifier = ">=1.12.2" }, ] +[[package]] +name = "mcp-simple-auth" +version = "0.1.0" +source = { editable = "examples/servers/simple-auth" } +dependencies = [ + { name = "anyio" }, + { name = "click" }, + { name = "httpx" }, + { name = "mcp" }, + { name = "pydantic" }, + { name = "pydantic-settings" }, + { name = "sse-starlette" }, + { name = "uvicorn", marker = "sys_platform != 'emscripten'" }, +] + +[package.dev-dependencies] +dev = [ + { name = "pyright" }, + { name = "pytest" }, + { name = "ruff" }, +] + +[package.metadata] +requires-dist = [ + { name = "anyio", specifier = ">=4.5" }, + { name = "click", specifier = ">=8.1.0" }, + { name = "httpx", specifier = ">=0.27" }, + { name = "mcp", editable = "." }, + { name = "pydantic", specifier = ">=2.0" }, + { name = "pydantic-settings", specifier = ">=2.5.2" }, + { name = "sse-starlette", specifier = ">=1.6.1" }, + { name = "uvicorn", marker = "sys_platform != 'emscripten'", specifier = ">=0.23.1" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "pyright", specifier = ">=1.1.391" }, + { name = "pytest", specifier = ">=8.3.4" }, + { name = "ruff", specifier = ">=0.8.5" }, +] + [[package]] name = "mcp-simple-prompt" version = "0.1.0" From 280bab36f4557235842c25d3ccb7333a2aa13b13 Mon Sep 17 00:00:00 2001 From: inceptmyth <76823502+arcAman07@users.noreply.github.com> Date: Thu, 8 May 2025 20:27:41 +0530 Subject: [PATCH 18/21] Fix: Use absolute path to uv executable in Claude Desktop config (#440) --- src/mcp/cli/claude.py | 14 +++++++++++++- tests/client/test_config.py | 25 +++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/src/mcp/cli/claude.py b/src/mcp/cli/claude.py index 5a0ce0ab4f..17c957df29 100644 --- a/src/mcp/cli/claude.py +++ b/src/mcp/cli/claude.py @@ -2,6 +2,7 @@ import json import os +import shutil import sys from pathlib import Path from typing import Any @@ -30,6 +31,16 @@ def get_claude_config_path() -> Path | None: return path return None +def get_uv_path() -> str: + """Get the full path to the uv executable.""" + uv_path = shutil.which("uv") + if not uv_path: + logger.error( + "uv executable not found in PATH, falling back to 'uv'. " + "Please ensure uv is installed and in your PATH" + ) + return "uv" # Fall back to just "uv" if not found + return uv_path def update_claude_config( file_spec: str, @@ -54,6 +65,7 @@ def update_claude_config( Claude Desktop may not be installed or properly set up. """ config_dir = get_claude_config_path() + uv_path = get_uv_path() if not config_dir: raise RuntimeError( "Claude Desktop config directory not found. Please ensure Claude Desktop" @@ -117,7 +129,7 @@ def update_claude_config( # Add fastmcp run command args.extend(["mcp", "run", file_spec]) - server_config: dict[str, Any] = {"command": "uv", "args": args} + server_config: dict[str, Any] = {"command": uv_path, "args": args} # Add environment variables if specified if env_vars: diff --git a/tests/client/test_config.py b/tests/client/test_config.py index 97030e0691..6577d663c4 100644 --- a/tests/client/test_config.py +++ b/tests/client/test_config.py @@ -48,3 +48,28 @@ def test_command_execution(mock_config_path: Path): assert result.returncode == 0 assert "usage" in result.stdout.lower() + + +def test_absolute_uv_path(mock_config_path: Path): + """Test that the absolute path to uv is used when available.""" + # Mock the shutil.which function to return a fake path + mock_uv_path = "/usr/local/bin/uv" + + with patch("mcp.cli.claude.get_uv_path", return_value=mock_uv_path): + # Setup + server_name = "test_server" + file_spec = "test_server.py:app" + + # Update config + success = update_claude_config(file_spec=file_spec, server_name=server_name) + assert success + + # Read the generated config + config_file = mock_config_path / "claude_desktop_config.json" + config = json.loads(config_file.read_text()) + + # Verify the command is the absolute path + server_config = config["mcpServers"][server_name] + command = server_config["command"] + + assert command == mock_uv_path \ No newline at end of file From e4e119b32454b4bf7d72de28c1eb0559875caaf5 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Thu, 8 May 2025 20:43:25 +0100 Subject: [PATCH 19/21] Streamable HTTP - improve usability, fast mcp and auth (#641) --- .../server.py | 77 ++---- .../mcp_simple_streamablehttp/server.py | 124 +++------ src/mcp/server/fastmcp/server.py | 141 +++++++++- src/mcp/server/streamable_http_manager.py | 258 ++++++++++++++++++ tests/server/fastmcp/test_integration.py | 213 +++++++++++++++ tests/server/test_streamable_http_manager.py | 81 ++++++ tests/shared/test_streamable_http.py | 91 +----- 7 files changed, 753 insertions(+), 232 deletions(-) create mode 100644 src/mcp/server/streamable_http_manager.py create mode 100644 tests/server/test_streamable_http_manager.py diff --git a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py index da8158a980..f718df8010 100644 --- a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py +++ b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py @@ -1,37 +1,17 @@ import contextlib import logging +from collections.abc import AsyncIterator import anyio import click import mcp.types as types from mcp.server.lowlevel import Server -from mcp.server.streamableHttp import ( - StreamableHTTPServerTransport, -) +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from starlette.applications import Starlette from starlette.routing import Mount +from starlette.types import Receive, Scope, Send logger = logging.getLogger(__name__) -# Global task group that will be initialized in the lifespan -task_group = None - - -@contextlib.asynccontextmanager -async def lifespan(app): - """Application lifespan context manager for managing task group.""" - global task_group - - async with anyio.create_task_group() as tg: - task_group = tg - logger.info("Application started, task group initialized!") - try: - yield - finally: - logger.info("Application shutting down, cleaning up resources...") - if task_group: - tg.cancel_scope.cancel() - task_group = None - logger.info("Resources cleaned up successfully.") @click.command() @@ -122,35 +102,28 @@ async def list_tools() -> list[types.Tool]: ) ] - # ASGI handler for stateless HTTP connections - async def handle_streamable_http(scope, receive, send): - logger.debug("Creating new transport") - # Use lock to prevent race conditions when creating new sessions - http_transport = StreamableHTTPServerTransport( - mcp_session_id=None, - is_json_response_enabled=json_response, - ) - async with http_transport.connect() as streams: - read_stream, write_stream = streams - - if not task_group: - raise RuntimeError("Task group is not initialized") - - async def run_server(): - await app.run( - read_stream, - write_stream, - app.create_initialization_options(), - # Runs in standalone mode for stateless deployments - # where clients perform initialization with any node - standalone_mode=True, - ) - - # Start server task - task_group.start_soon(run_server) - - # Handle the HTTP request and return the response - await http_transport.handle_request(scope, receive, send) + # Create the session manager with true stateless mode + session_manager = StreamableHTTPSessionManager( + app=app, + event_store=None, + json_response=json_response, + stateless=True, + ) + + async def handle_streamable_http( + scope: Scope, receive: Receive, send: Send + ) -> None: + await session_manager.handle_request(scope, receive, send) + + @contextlib.asynccontextmanager + async def lifespan(app: Starlette) -> AsyncIterator[None]: + """Context manager for session manager.""" + async with session_manager.run(): + logger.info("Application started with StreamableHTTP session manager!") + try: + yield + finally: + logger.info("Application shutting down...") # Create an ASGI application using the transport starlette_app = Starlette( diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py index d36686720a..1a76097b52 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -1,58 +1,22 @@ import contextlib import logging -from http import HTTPStatus -from uuid import uuid4 +from collections.abc import AsyncIterator import anyio import click import mcp.types as types from mcp.server.lowlevel import Server -from mcp.server.streamable_http import ( - MCP_SESSION_ID_HEADER, - StreamableHTTPServerTransport, -) +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from pydantic import AnyUrl from starlette.applications import Starlette -from starlette.requests import Request -from starlette.responses import Response from starlette.routing import Mount +from starlette.types import Receive, Scope, Send from .event_store import InMemoryEventStore # Configure logging logger = logging.getLogger(__name__) -# Global task group that will be initialized in the lifespan -task_group = None - -# Event store for resumability -# The InMemoryEventStore enables resumability support for StreamableHTTP transport. -# It stores SSE events with unique IDs, allowing clients to: -# 1. Receive event IDs for each SSE message -# 2. Resume streams by sending Last-Event-ID in GET requests -# 3. Replay missed events after reconnection -# Note: This in-memory implementation is for demonstration ONLY. -# For production, use a persistent storage solution. -event_store = InMemoryEventStore() - - -@contextlib.asynccontextmanager -async def lifespan(app): - """Application lifespan context manager for managing task group.""" - global task_group - - async with anyio.create_task_group() as tg: - task_group = tg - logger.info("Application started, task group initialized!") - try: - yield - finally: - logger.info("Application shutting down, cleaning up resources...") - if task_group: - tg.cancel_scope.cancel() - task_group = None - logger.info("Resources cleaned up successfully.") - @click.command() @click.option("--port", default=3000, help="Port to listen on for HTTP") @@ -156,60 +120,38 @@ async def list_tools() -> list[types.Tool]: ) ] - # We need to store the server instances between requests - server_instances = {} - # Lock to prevent race conditions when creating new sessions - session_creation_lock = anyio.Lock() + # Create event store for resumability + # The InMemoryEventStore enables resumability support for StreamableHTTP transport. + # It stores SSE events with unique IDs, allowing clients to: + # 1. Receive event IDs for each SSE message + # 2. Resume streams by sending Last-Event-ID in GET requests + # 3. Replay missed events after reconnection + # Note: This in-memory implementation is for demonstration ONLY. + # For production, use a persistent storage solution. + event_store = InMemoryEventStore() + + # Create the session manager with our app and event store + session_manager = StreamableHTTPSessionManager( + app=app, + event_store=event_store, # Enable resumability + json_response=json_response, + ) # ASGI handler for streamable HTTP connections - async def handle_streamable_http(scope, receive, send): - request = Request(scope, receive) - request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER) - if ( - request_mcp_session_id is not None - and request_mcp_session_id in server_instances - ): - transport = server_instances[request_mcp_session_id] - logger.debug("Session already exists, handling request directly") - await transport.handle_request(scope, receive, send) - elif request_mcp_session_id is None: - # try to establish new session - logger.debug("Creating new transport") - # Use lock to prevent race conditions when creating new sessions - async with session_creation_lock: - new_session_id = uuid4().hex - http_transport = StreamableHTTPServerTransport( - mcp_session_id=new_session_id, - is_json_response_enabled=json_response, - event_store=event_store, # Enable resumability - ) - server_instances[http_transport.mcp_session_id] = http_transport - logger.info(f"Created new transport with session ID: {new_session_id}") - - async def run_server(task_status=None): - async with http_transport.connect() as streams: - read_stream, write_stream = streams - if task_status: - task_status.started() - await app.run( - read_stream, - write_stream, - app.create_initialization_options(), - ) - - if not task_group: - raise RuntimeError("Task group is not initialized") - - await task_group.start(run_server) - - # Handle the HTTP request and return the response - await http_transport.handle_request(scope, receive, send) - else: - response = Response( - "Bad Request: No valid session ID provided", - status_code=HTTPStatus.BAD_REQUEST, - ) - await response(scope, receive, send) + async def handle_streamable_http( + scope: Scope, receive: Receive, send: Send + ) -> None: + await session_manager.handle_request(scope, receive, send) + + @contextlib.asynccontextmanager + async def lifespan(app: Starlette) -> AsyncIterator[None]: + """Context manager for managing session manager lifecycle.""" + async with session_manager.run(): + logger.info("Application started with StreamableHTTP session manager!") + try: + yield + finally: + logger.info("Application shutting down...") # Create an ASGI application using the transport starlette_app = Starlette( diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index ea0214f0fe..c31f29d4c3 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -47,6 +47,8 @@ from mcp.server.session import ServerSession, ServerSessionT from mcp.server.sse import SseServerTransport from mcp.server.stdio import stdio_server +from mcp.server.streamable_http import EventStore +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.shared.context import LifespanContextT, RequestContext from mcp.types import ( AnyFunction, @@ -90,6 +92,13 @@ class Settings(BaseSettings, Generic[LifespanResultT]): mount_path: str = "/" # Mount path (e.g. "/github", defaults to root path) sse_path: str = "/sse" message_path: str = "/messages/" + streamable_http_path: str = "/mcp" + + # StreamableHTTP settings + json_response: bool = False + stateless_http: bool = ( + False # If True, uses true stateless mode (new transport per request) + ) # resource settings warn_on_duplicate_resources: bool = True @@ -131,6 +140,7 @@ def __init__( instructions: str | None = None, auth_server_provider: OAuthAuthorizationServerProvider[Any, Any, Any] | None = None, + event_store: EventStore | None = None, **settings: Any, ): self.settings = Settings(**settings) @@ -162,8 +172,10 @@ def __init__( "is specified" ) self._auth_server_provider = auth_server_provider + self._event_store = event_store self._custom_starlette_routes: list[Route] = [] self.dependencies = self.settings.dependencies + self._session_manager: StreamableHTTPSessionManager | None = None # Set up MCP protocol handlers self._setup_handlers() @@ -179,25 +191,47 @@ def name(self) -> str: def instructions(self) -> str | None: return self._mcp_server.instructions + @property + def session_manager(self) -> StreamableHTTPSessionManager: + """Get the StreamableHTTP session manager. + + This is exposed to enable advanced use cases like mounting multiple + FastMCP servers in a single FastAPI application. + + Raises: + RuntimeError: If called before streamable_http_app() has been called. + """ + if self._session_manager is None: + raise RuntimeError( + "Session manager can only be accessed after" + "calling streamable_http_app()." + "The session manager is created lazily" + "to avoid unnecessary initialization." + ) + return self._session_manager + def run( self, - transport: Literal["stdio", "sse"] = "stdio", + transport: Literal["stdio", "sse", "streamable-http"] = "stdio", mount_path: str | None = None, ) -> None: """Run the FastMCP server. Note this is a synchronous function. Args: - transport: Transport protocol to use ("stdio" or "sse") + transport: Transport protocol to use ("stdio", "sse", or "streamable-http") mount_path: Optional mount path for SSE transport """ - TRANSPORTS = Literal["stdio", "sse"] + TRANSPORTS = Literal["stdio", "sse", "streamable-http"] if transport not in TRANSPORTS.__args__: # type: ignore raise ValueError(f"Unknown transport: {transport}") - if transport == "stdio": - anyio.run(self.run_stdio_async) - else: # transport == "sse" - anyio.run(lambda: self.run_sse_async(mount_path)) + match transport: + case "stdio": + anyio.run(self.run_stdio_async) + case "sse": + anyio.run(lambda: self.run_sse_async(mount_path)) + case "streamable-http": + anyio.run(self.run_streamable_http_async) def _setup_handlers(self) -> None: """Set up core MCP protocol handlers.""" @@ -573,6 +607,21 @@ async def run_sse_async(self, mount_path: str | None = None) -> None: server = uvicorn.Server(config) await server.serve() + async def run_streamable_http_async(self) -> None: + """Run the server using StreamableHTTP transport.""" + import uvicorn + + starlette_app = self.streamable_http_app() + + config = uvicorn.Config( + starlette_app, + host=self.settings.host, + port=self.settings.port, + log_level=self.settings.log_level.lower(), + ) + server = uvicorn.Server(config) + await server.serve() + def _normalize_path(self, mount_path: str, endpoint: str) -> str: """ Combine mount path and endpoint to return a normalized path. @@ -687,9 +736,9 @@ async def handle_sse(scope: Scope, receive: Receive, send: Send): else: # Auth is disabled, no need for RequireAuthMiddleware # Since handle_sse is an ASGI app, we need to create a compatible endpoint - async def sse_endpoint(request: Request) -> None: + async def sse_endpoint(request: Request) -> Response: # Convert the Starlette request to ASGI parameters - await handle_sse(request.scope, request.receive, request._send) # type: ignore[reportPrivateUsage] + return await handle_sse(request.scope, request.receive, request._send) # type: ignore[reportPrivateUsage] routes.append( Route( @@ -712,6 +761,80 @@ async def sse_endpoint(request: Request) -> None: debug=self.settings.debug, routes=routes, middleware=middleware ) + def streamable_http_app(self) -> Starlette: + """Return an instance of the StreamableHTTP server app.""" + from starlette.middleware import Middleware + from starlette.routing import Mount + + # Create session manager on first call (lazy initialization) + if self._session_manager is None: + self._session_manager = StreamableHTTPSessionManager( + app=self._mcp_server, + event_store=self._event_store, + json_response=self.settings.json_response, + stateless=self.settings.stateless_http, # Use the stateless setting + ) + + # Create the ASGI handler + async def handle_streamable_http( + scope: Scope, receive: Receive, send: Send + ) -> None: + await self.session_manager.handle_request(scope, receive, send) + + # Create routes + routes: list[Route | Mount] = [] + middleware: list[Middleware] = [] + required_scopes = [] + + # Add auth endpoints if auth provider is configured + if self._auth_server_provider: + assert self.settings.auth + from mcp.server.auth.routes import create_auth_routes + + required_scopes = self.settings.auth.required_scopes or [] + + middleware = [ + Middleware( + AuthenticationMiddleware, + backend=BearerAuthBackend( + provider=self._auth_server_provider, + ), + ), + Middleware(AuthContextMiddleware), + ] + routes.extend( + create_auth_routes( + provider=self._auth_server_provider, + issuer_url=self.settings.auth.issuer_url, + service_documentation_url=self.settings.auth.service_documentation_url, + client_registration_options=self.settings.auth.client_registration_options, + revocation_options=self.settings.auth.revocation_options, + ) + ) + routes.append( + Mount( + self.settings.streamable_http_path, + app=RequireAuthMiddleware(handle_streamable_http, required_scopes), + ) + ) + else: + # Auth is disabled, no wrapper needed + routes.append( + Mount( + self.settings.streamable_http_path, + app=handle_streamable_http, + ) + ) + + routes.extend(self._custom_starlette_routes) + + return Starlette( + debug=self.settings.debug, + routes=routes, + middleware=middleware, + lifespan=lambda app: self.session_manager.run(), + ) + async def list_prompts(self) -> list[MCPPrompt]: """List all available prompts.""" prompts = self._prompt_manager.list_prompts() diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py new file mode 100644 index 0000000000..e5ef8b4aa1 --- /dev/null +++ b/src/mcp/server/streamable_http_manager.py @@ -0,0 +1,258 @@ +"""StreamableHTTP Session Manager for MCP servers.""" + +from __future__ import annotations + +import contextlib +import logging +import threading +from collections.abc import AsyncIterator +from http import HTTPStatus +from typing import Any +from uuid import uuid4 + +import anyio +from anyio.abc import TaskStatus +from starlette.requests import Request +from starlette.responses import Response +from starlette.types import Receive, Scope, Send + +from mcp.server.lowlevel.server import Server as MCPServer +from mcp.server.streamable_http import ( + MCP_SESSION_ID_HEADER, + EventStore, + StreamableHTTPServerTransport, +) + +logger = logging.getLogger(__name__) + + +class StreamableHTTPSessionManager: + """ + Manages StreamableHTTP sessions with optional resumability via event store. + + This class abstracts away the complexity of session management, event storage, + and request handling for StreamableHTTP transports. It handles: + + 1. Session tracking for clients + 2. Resumability via an optional event store + 3. Connection management and lifecycle + 4. Request handling and transport setup + + Important: Only one StreamableHTTPSessionManager instance should be created + per application. The instance cannot be reused after its run() context has + completed. If you need to restart the manager, create a new instance. + + Args: + app: The MCP server instance + event_store: Optional event store for resumability support. + If provided, enables resumable connections where clients + can reconnect and receive missed events. + If None, sessions are still tracked but not resumable. + json_response: Whether to use JSON responses instead of SSE streams + stateless: If True, creates a completely fresh transport for each request + with no session tracking or state persistence between requests. + + """ + + def __init__( + self, + app: MCPServer[Any], + event_store: EventStore | None = None, + json_response: bool = False, + stateless: bool = False, + ): + self.app = app + self.event_store = event_store + self.json_response = json_response + self.stateless = stateless + + # Session tracking (only used if not stateless) + self._session_creation_lock = anyio.Lock() + self._server_instances: dict[str, StreamableHTTPServerTransport] = {} + + # The task group will be set during lifespan + self._task_group = None + # Thread-safe tracking of run() calls + self._run_lock = threading.Lock() + self._has_started = False + + @contextlib.asynccontextmanager + async def run(self) -> AsyncIterator[None]: + """ + Run the session manager with proper lifecycle management. + + This creates and manages the task group for all session operations. + + Important: This method can only be called once per instance. The same + StreamableHTTPSessionManager instance cannot be reused after this + context manager exits. Create a new instance if you need to restart. + + Use this in the lifespan context manager of your Starlette app: + + @contextlib.asynccontextmanager + async def lifespan(app: Starlette) -> AsyncIterator[None]: + async with session_manager.run(): + yield + """ + # Thread-safe check to ensure run() is only called once + with self._run_lock: + if self._has_started: + raise RuntimeError( + "StreamableHTTPSessionManager .run() can only be called " + "once per instance. Create a new instance if you need to run again." + ) + self._has_started = True + + async with anyio.create_task_group() as tg: + # Store the task group for later use + self._task_group = tg + logger.info("StreamableHTTP session manager started") + try: + yield # Let the application run + finally: + logger.info("StreamableHTTP session manager shutting down") + # Cancel task group to stop all spawned tasks + tg.cancel_scope.cancel() + self._task_group = None + # Clear any remaining server instances + self._server_instances.clear() + + async def handle_request( + self, + scope: Scope, + receive: Receive, + send: Send, + ) -> None: + """ + Process ASGI request with proper session handling and transport setup. + + Dispatches to the appropriate handler based on stateless mode. + + Args: + scope: ASGI scope + receive: ASGI receive function + send: ASGI send function + """ + if self._task_group is None: + raise RuntimeError("Task group is not initialized. Make sure to use run().") + + # Dispatch to the appropriate handler + if self.stateless: + await self._handle_stateless_request(scope, receive, send) + else: + await self._handle_stateful_request(scope, receive, send) + + async def _handle_stateless_request( + self, + scope: Scope, + receive: Receive, + send: Send, + ) -> None: + """ + Process request in stateless mode - creating a new transport for each request. + + Args: + scope: ASGI scope + receive: ASGI receive function + send: ASGI send function + """ + logger.debug("Stateless mode: Creating new transport for this request") + # No session ID needed in stateless mode + http_transport = StreamableHTTPServerTransport( + mcp_session_id=None, # No session tracking in stateless mode + is_json_response_enabled=self.json_response, + event_store=None, # No event store in stateless mode + ) + + # Start server in a new task + async def run_stateless_server( + *, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED + ): + async with http_transport.connect() as streams: + read_stream, write_stream = streams + task_status.started() + await self.app.run( + read_stream, + write_stream, + self.app.create_initialization_options(), + stateless=True, + ) + + # Assert task group is not None for type checking + assert self._task_group is not None + # Start the server task + await self._task_group.start(run_stateless_server) + + # Handle the HTTP request and return the response + await http_transport.handle_request(scope, receive, send) + + async def _handle_stateful_request( + self, + scope: Scope, + receive: Receive, + send: Send, + ) -> None: + """ + Process request in stateful mode - maintaining session state between requests. + + Args: + scope: ASGI scope + receive: ASGI receive function + send: ASGI send function + """ + request = Request(scope, receive) + request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER) + + # Existing session case + if ( + request_mcp_session_id is not None + and request_mcp_session_id in self._server_instances + ): + transport = self._server_instances[request_mcp_session_id] + logger.debug("Session already exists, handling request directly") + await transport.handle_request(scope, receive, send) + return + + if request_mcp_session_id is None: + # New session case + logger.debug("Creating new transport") + async with self._session_creation_lock: + new_session_id = uuid4().hex + http_transport = StreamableHTTPServerTransport( + mcp_session_id=new_session_id, + is_json_response_enabled=self.json_response, + event_store=self.event_store, # May be None (no resumability) + ) + + assert http_transport.mcp_session_id is not None + self._server_instances[http_transport.mcp_session_id] = http_transport + logger.info(f"Created new transport with session ID: {new_session_id}") + + # Define the server runner + async def run_server( + *, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED + ) -> None: + async with http_transport.connect() as streams: + read_stream, write_stream = streams + task_status.started() + await self.app.run( + read_stream, + write_stream, + self.app.create_initialization_options(), + stateless=False, # Stateful mode + ) + + # Assert task group is not None for type checking + assert self._task_group is not None + # Start the server task + await self._task_group.start(run_server) + + # Handle the HTTP request and return the response + await http_transport.handle_request(scope, receive, send) + else: + # Invalid session ID + response = Response( + "Bad Request: No valid session ID provided", + status_code=HTTPStatus.BAD_REQUEST, + ) + await response(scope, receive, send) diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index 281db2dbc7..67911e9e7d 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -15,6 +15,7 @@ from mcp.client.session import ClientSession from mcp.client.sse import sse_client +from mcp.client.streamable_http import streamablehttp_client from mcp.server.fastmcp import FastMCP from mcp.types import InitializeResult, TextContent @@ -33,6 +34,34 @@ def server_url(server_port: int) -> str: return f"http://127.0.0.1:{server_port}" +@pytest.fixture +def http_server_port() -> int: + """Get a free port for testing the StreamableHTTP server.""" + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@pytest.fixture +def http_server_url(http_server_port: int) -> str: + """Get the StreamableHTTP server URL for testing.""" + return f"http://127.0.0.1:{http_server_port}" + + +@pytest.fixture +def stateless_http_server_port() -> int: + """Get a free port for testing the stateless StreamableHTTP server.""" + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@pytest.fixture +def stateless_http_server_url(stateless_http_server_port: int) -> str: + """Get the stateless StreamableHTTP server URL for testing.""" + return f"http://127.0.0.1:{stateless_http_server_port}" + + # Create a function to make the FastMCP server app def make_fastmcp_app(): """Create a FastMCP server without auth settings.""" @@ -51,6 +80,40 @@ def echo(message: str) -> str: return mcp, app +def make_fastmcp_streamable_http_app(): + """Create a FastMCP server with StreamableHTTP transport.""" + from starlette.applications import Starlette + + mcp = FastMCP(name="NoAuthServer") + + # Add a simple tool + @mcp.tool(description="A simple echo tool") + def echo(message: str) -> str: + return f"Echo: {message}" + + # Create the StreamableHTTP app + app: Starlette = mcp.streamable_http_app() + + return mcp, app + + +def make_fastmcp_stateless_http_app(): + """Create a FastMCP server with stateless StreamableHTTP transport.""" + from starlette.applications import Starlette + + mcp = FastMCP(name="StatelessServer", stateless_http=True) + + # Add a simple tool + @mcp.tool(description="A simple echo tool") + def echo(message: str) -> str: + return f"Echo: {message}" + + # Create the StreamableHTTP app + app: Starlette = mcp.streamable_http_app() + + return mcp, app + + def run_server(server_port: int) -> None: """Run the server.""" _, app = make_fastmcp_app() @@ -63,6 +126,30 @@ def run_server(server_port: int) -> None: server.run() +def run_streamable_http_server(server_port: int) -> None: + """Run the StreamableHTTP server.""" + _, app = make_fastmcp_streamable_http_app() + server = uvicorn.Server( + config=uvicorn.Config( + app=app, host="127.0.0.1", port=server_port, log_level="error" + ) + ) + print(f"Starting StreamableHTTP server on port {server_port}") + server.run() + + +def run_stateless_http_server(server_port: int) -> None: + """Run the stateless StreamableHTTP server.""" + _, app = make_fastmcp_stateless_http_app() + server = uvicorn.Server( + config=uvicorn.Config( + app=app, host="127.0.0.1", port=server_port, log_level="error" + ) + ) + print(f"Starting stateless StreamableHTTP server on port {server_port}") + server.run() + + @pytest.fixture() def server(server_port: int) -> Generator[None, None, None]: """Start the server in a separate process and clean up after the test.""" @@ -94,6 +181,80 @@ def server(server_port: int) -> Generator[None, None, None]: print("Server process failed to terminate") +@pytest.fixture() +def streamable_http_server(http_server_port: int) -> Generator[None, None, None]: + """Start the StreamableHTTP server in a separate process.""" + proc = multiprocessing.Process( + target=run_streamable_http_server, args=(http_server_port,), daemon=True + ) + print("Starting StreamableHTTP server process") + proc.start() + + # Wait for server to be running + max_attempts = 20 + attempt = 0 + print("Waiting for StreamableHTTP server to start") + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("127.0.0.1", http_server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError( + f"StreamableHTTP server failed to start after {max_attempts} attempts" + ) + + yield + + print("Killing StreamableHTTP server") + proc.kill() + proc.join(timeout=2) + if proc.is_alive(): + print("StreamableHTTP server process failed to terminate") + + +@pytest.fixture() +def stateless_http_server( + stateless_http_server_port: int, +) -> Generator[None, None, None]: + """Start the stateless StreamableHTTP server in a separate process.""" + proc = multiprocessing.Process( + target=run_stateless_http_server, + args=(stateless_http_server_port,), + daemon=True, + ) + print("Starting stateless StreamableHTTP server process") + proc.start() + + # Wait for server to be running + max_attempts = 20 + attempt = 0 + print("Waiting for stateless StreamableHTTP server to start") + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("127.0.0.1", stateless_http_server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError( + f"Stateless server failed to start after {max_attempts} attempts" + ) + + yield + + print("Killing stateless StreamableHTTP server") + proc.kill() + proc.join(timeout=2) + if proc.is_alive(): + print("Stateless StreamableHTTP server process failed to terminate") + + @pytest.mark.anyio async def test_fastmcp_without_auth(server: None, server_url: str) -> None: """Test that FastMCP works when auth settings are not provided.""" @@ -110,3 +271,55 @@ async def test_fastmcp_without_auth(server: None, server_url: str) -> None: assert len(tool_result.content) == 1 assert isinstance(tool_result.content[0], TextContent) assert tool_result.content[0].text == "Echo: hello" + + +@pytest.mark.anyio +async def test_fastmcp_streamable_http( + streamable_http_server: None, http_server_url: str +) -> None: + """Test that FastMCP works with StreamableHTTP transport.""" + # Connect to the server using StreamableHTTP + async with streamablehttp_client(http_server_url + "/mcp") as ( + read_stream, + write_stream, + _, + ): + # Create a session using the client streams + async with ClientSession(read_stream, write_stream) as session: + # Test initialization + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == "NoAuthServer" + + # Test that we can call tools without authentication + tool_result = await session.call_tool("echo", {"message": "hello"}) + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + assert tool_result.content[0].text == "Echo: hello" + + +@pytest.mark.anyio +async def test_fastmcp_stateless_streamable_http( + stateless_http_server: None, stateless_http_server_url: str +) -> None: + """Test that FastMCP works with stateless StreamableHTTP transport.""" + # Connect to the server using StreamableHTTP + async with streamablehttp_client(stateless_http_server_url + "/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == "StatelessServer" + tool_result = await session.call_tool("echo", {"message": "hello"}) + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + assert tool_result.content[0].text == "Echo: hello" + + for i in range(3): + tool_result = await session.call_tool("echo", {"message": f"test_{i}"}) + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + assert tool_result.content[0].text == f"Echo: test_{i}" diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py new file mode 100644 index 0000000000..32782e458c --- /dev/null +++ b/tests/server/test_streamable_http_manager.py @@ -0,0 +1,81 @@ +"""Tests for StreamableHTTPSessionManager.""" + +import anyio +import pytest + +from mcp.server.lowlevel import Server +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager + + +@pytest.mark.anyio +async def test_run_can_only_be_called_once(): + """Test that run() can only be called once per instance.""" + app = Server("test-server") + manager = StreamableHTTPSessionManager(app=app) + + # First call should succeed + async with manager.run(): + pass + + # Second call should raise RuntimeError + with pytest.raises(RuntimeError) as excinfo: + async with manager.run(): + pass + + assert ( + "StreamableHTTPSessionManager .run() can only be called once per instance" + in str(excinfo.value) + ) + + +@pytest.mark.anyio +async def test_run_prevents_concurrent_calls(): + """Test that concurrent calls to run() are prevented.""" + app = Server("test-server") + manager = StreamableHTTPSessionManager(app=app) + + errors = [] + + async def try_run(): + try: + async with manager.run(): + # Simulate some work + await anyio.sleep(0.1) + except RuntimeError as e: + errors.append(e) + + # Try to run concurrently + async with anyio.create_task_group() as tg: + tg.start_soon(try_run) + tg.start_soon(try_run) + + # One should succeed, one should fail + assert len(errors) == 1 + assert ( + "StreamableHTTPSessionManager .run() can only be called once per instance" + in str(errors[0]) + ) + + +@pytest.mark.anyio +async def test_handle_request_without_run_raises_error(): + """Test that handle_request raises error if run() hasn't been called.""" + app = Server("test-server") + manager = StreamableHTTPSessionManager(app=app) + + # Mock ASGI parameters + scope = {"type": "http", "method": "POST", "path": "/test"} + + async def receive(): + return {"type": "http.request", "body": b""} + + async def send(message): + pass + + # Should raise error because run() hasn't been called + with pytest.raises(RuntimeError) as excinfo: + await manager.handle_request(scope, receive, send) + + assert "Task group is not initialized. Make sure to use run()." in str( + excinfo.value + ) diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index b1dc7ea338..28d29ac23f 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -4,13 +4,10 @@ Contains tests for both server and client sides of the StreamableHTTP transport. """ -import contextlib import multiprocessing import socket import time from collections.abc import Generator -from http import HTTPStatus -from uuid import uuid4 import anyio import httpx @@ -19,8 +16,6 @@ import uvicorn from pydantic import AnyUrl from starlette.applications import Starlette -from starlette.requests import Request -from starlette.responses import Response from starlette.routing import Mount import mcp.types as types @@ -37,6 +32,7 @@ StreamableHTTPServerTransport, StreamId, ) +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.shared.exceptions import McpError from mcp.shared.message import ( ClientMessageMetadata, @@ -184,7 +180,7 @@ async def handle_call_tool(name: str, args: dict) -> list[TextContent]: def create_app( is_json_response_enabled=False, event_store: EventStore | None = None ) -> Starlette: - """Create a Starlette application for testing that matches the example server. + """Create a Starlette application for testing using the session manager. Args: is_json_response_enabled: If True, use JSON responses instead of SSE streams. @@ -193,85 +189,20 @@ def create_app( # Create server instance server = ServerTest() - server_instances = {} - # Lock to prevent race conditions when creating new sessions - session_creation_lock = anyio.Lock() - task_group = None - - @contextlib.asynccontextmanager - async def lifespan(app): - """Application lifespan context manager for managing task group.""" - nonlocal task_group - - async with anyio.create_task_group() as tg: - task_group = tg - try: - yield - finally: - if task_group: - tg.cancel_scope.cancel() - task_group = None - - async def handle_streamable_http(scope, receive, send): - request = Request(scope, receive) - request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER) - - # Use existing transport if session ID matches - if ( - request_mcp_session_id is not None - and request_mcp_session_id in server_instances - ): - transport = server_instances[request_mcp_session_id] - - await transport.handle_request(scope, receive, send) - elif request_mcp_session_id is None: - async with session_creation_lock: - new_session_id = uuid4().hex - - http_transport = StreamableHTTPServerTransport( - mcp_session_id=new_session_id, - is_json_response_enabled=is_json_response_enabled, - event_store=event_store, - ) - - async def run_server(task_status=None): - async with http_transport.connect() as streams: - read_stream, write_stream = streams - if task_status: - task_status.started() - await server.run( - read_stream, - write_stream, - server.create_initialization_options(), - ) - - if task_group is None: - response = Response( - "Internal Server Error: Task group is not initialized", - status_code=HTTPStatus.INTERNAL_SERVER_ERROR, - ) - await response(scope, receive, send) - return - - # Store the instance before starting the task to prevent races - server_instances[http_transport.mcp_session_id] = http_transport - await task_group.start(run_server) - - await http_transport.handle_request(scope, receive, send) - else: - response = Response( - "Bad Request: No valid session ID provided", - status_code=HTTPStatus.BAD_REQUEST, - ) - await response(scope, receive, send) + # Create the session manager + session_manager = StreamableHTTPSessionManager( + app=server, + event_store=event_store, + json_response=is_json_response_enabled, + ) - # Create an ASGI application + # Create an ASGI application that uses the session manager app = Starlette( debug=True, routes=[ - Mount("/mcp", app=handle_streamable_http), + Mount("/mcp", app=session_manager.handle_request), ], - lifespan=lifespan, + lifespan=lambda app: session_manager.run(), ) return app From 72003d9cc0969799efbe7bf41989ab0036b36918 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Thu, 8 May 2025 20:49:55 +0100 Subject: [PATCH 20/21] StreamableHttp - update docs (#664) --- README.md | 99 +++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 97 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 8f19aea1a8..a63cb4056e 100644 --- a/README.md +++ b/README.md @@ -66,7 +66,7 @@ The Model Context Protocol allows applications to provide context for LLMs in a - Build MCP clients that can connect to any MCP server - Create MCP servers that expose resources, prompts and tools -- Use standard transports like stdio and SSE +- Use standard transports like stdio, SSE, and Streamable HTTP - Handle all MCP protocol messages and lifecycle events ## Installation @@ -387,8 +387,81 @@ python server.py mcp run server.py ``` +### Streamable HTTP Transport + +> **Note**: Streamable HTTP transport is superseding SSE transport for production deployments. + +```python +from mcp.server.fastmcp import FastMCP + +# Stateful server (maintains session state) +mcp = FastMCP("StatefulServer") + +# Stateless server (no session persistence) +mcp = FastMCP("StatelessServer", stateless_http=True) + +# Run server with streamable_http transport +mcp.run(transport="streamable-http") +``` + +You can mount multiple FastMCP servers in a FastAPI application: + +```python +# echo.py +from mcp.server.fastmcp import FastMCP + +mcp = FastMCP(name="EchoServer", stateless_http=True) + + +@mcp.tool(description="A simple echo tool") +def echo(message: str) -> str: + return f"Echo: {message}" +``` + +```python +# math.py +from mcp.server.fastmcp import FastMCP + +mcp = FastMCP(name="MathServer", stateless_http=True) + + +@mcp.tool(description="A simple add tool") +def add_two(n: int) -> str: + return n + 2 +``` + +```python +# main.py +from fastapi import FastAPI +from mcp.echo import echo +from mcp.math import math + + +app = FastAPI() + +# Use the session manager's lifespan +app = FastAPI(lifespan=lambda app: echo.mcp.session_manager.run()) +app.mount("/echo", echo.mcp.streamable_http_app()) +app.mount("/math", math.mcp.streamable_http_app()) +``` + +For low level server with Streamable HTTP implementations, see: +- Stateful server: [`examples/servers/simple-streamablehttp/`](examples/servers/simple-streamablehttp/) +- Stateless server: [`examples/servers/simple-streamablehttp-stateless/`](examples/servers/simple-streamablehttp-stateless/) + + + +The streamable HTTP transport supports: +- Stateful and stateless operation modes +- Resumability with event stores +- JSON or SSE response formats +- Better scalability for multi-node deployments + + ### Mounting to an Existing ASGI Server +> **Note**: SSE transport is being superseded by [Streamable HTTP transport](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http). + You can mount the SSE server to an existing ASGI server using the `sse_app` method. This allows you to integrate the SSE server with other ASGI applications. ```python @@ -621,7 +694,7 @@ if __name__ == "__main__": ### Writing MCP Clients -The SDK provides a high-level client interface for connecting to MCP servers: +The SDK provides a high-level client interface for connecting to MCP servers using various [transports](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports): ```python from mcp import ClientSession, StdioServerParameters, types @@ -685,6 +758,28 @@ if __name__ == "__main__": asyncio.run(run()) ``` +Clients can also connect using [Streamable HTTP transport](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http): + +```python +from mcp.client.streamable_http import streamablehttp_client +from mcp import ClientSession + + +async def main(): + # Connect to a streamable HTTP server + async with streamablehttp_client("example/mcp") as ( + read_stream, + write_stream, + _, + ): + # Create a session using the client streams + async with ClientSession(read_stream, write_stream) as session: + # Initialize the connection + await session.initialize() + # Call a tool + tool_result = await session.call_tool("echo", {"message": "hello"}) +``` + ### MCP Primitives The MCP protocol defines three core primitives that servers can implement: From ed25167fa5d715733437996682e20c24470e8177 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Thu, 8 May 2025 20:53:21 +0100 Subject: [PATCH 21/21] Introduce a function to create a standard AsyncClient with options (#655) --- .../simple-auth/mcp_simple_auth/server.py | 6 +- .../simple-tool/mcp_simple_tool/server.py | 4 +- src/mcp/client/sse.py | 3 +- src/mcp/client/streamable_http.py | 4 +- src/mcp/shared/_httpx_utils.py | 62 +++++++++++++++++++ tests/shared/test_httpx_utils.py | 24 +++++++ 6 files changed, 95 insertions(+), 8 deletions(-) create mode 100644 src/mcp/shared/_httpx_utils.py create mode 100644 tests/shared/test_httpx_utils.py diff --git a/examples/servers/simple-auth/mcp_simple_auth/server.py b/examples/servers/simple-auth/mcp_simple_auth/server.py index 7cd92aa799..2f1e4086ff 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/server.py +++ b/examples/servers/simple-auth/mcp_simple_auth/server.py @@ -6,7 +6,6 @@ from typing import Any import click -import httpx from pydantic import AnyHttpUrl from pydantic_settings import BaseSettings, SettingsConfigDict from starlette.exceptions import HTTPException @@ -24,6 +23,7 @@ ) from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions from mcp.server.fastmcp.server import FastMCP +from mcp.shared._httpx_utils import create_mcp_http_client from mcp.shared.auth import OAuthClientInformationFull, OAuthToken logger = logging.getLogger(__name__) @@ -123,7 +123,7 @@ async def handle_github_callback(self, code: str, state: str) -> str: client_id = state_data["client_id"] # Exchange code for token with GitHub - async with httpx.AsyncClient() as client: + async with create_mcp_http_client() as client: response = await client.post( self.settings.github_token_url, data={ @@ -325,7 +325,7 @@ async def get_user_profile() -> dict[str, Any]: """ github_token = get_github_token() - async with httpx.AsyncClient() as client: + async with create_mcp_http_client() as client: response = await client.get( "https://api.github.com/user", headers={ diff --git a/examples/servers/simple-tool/mcp_simple_tool/server.py b/examples/servers/simple-tool/mcp_simple_tool/server.py index 04224af5d2..5f4e28bb73 100644 --- a/examples/servers/simple-tool/mcp_simple_tool/server.py +++ b/examples/servers/simple-tool/mcp_simple_tool/server.py @@ -1,8 +1,8 @@ import anyio import click -import httpx import mcp.types as types from mcp.server.lowlevel import Server +from mcp.shared._httpx_utils import create_mcp_http_client async def fetch_website( @@ -11,7 +11,7 @@ async def fetch_website( headers = { "User-Agent": "MCP Test Server (github.com/modelcontextprotocol/python-sdk)" } - async with httpx.AsyncClient(follow_redirects=True, headers=headers) as client: + async with create_mcp_http_client(headers=headers) as client: response = await client.get(url) response.raise_for_status() return [types.TextContent(type="text", text=response.text)] diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index ff04d2f961..29195cbd98 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -10,6 +10,7 @@ from httpx_sse import aconnect_sse import mcp.types as types +from mcp.shared._httpx_utils import create_mcp_http_client from mcp.shared.message import SessionMessage logger = logging.getLogger(__name__) @@ -44,7 +45,7 @@ async def sse_client( async with anyio.create_task_group() as tg: try: logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}") - async with httpx.AsyncClient(headers=headers) as client: + async with create_mcp_http_client(headers=headers) as client: async with aconnect_sse( client, "GET", diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index ef424e3b33..183653b9ab 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -18,6 +18,7 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from httpx_sse import EventSource, ServerSentEvent, aconnect_sse +from mcp.shared._httpx_utils import create_mcp_http_client from mcp.shared.message import ClientMessageMetadata, SessionMessage from mcp.types import ( ErrorData, @@ -446,12 +447,11 @@ async def streamablehttp_client( try: logger.info(f"Connecting to StreamableHTTP endpoint: {url}") - async with httpx.AsyncClient( + async with create_mcp_http_client( headers=transport.request_headers, timeout=httpx.Timeout( transport.timeout.seconds, read=transport.sse_read_timeout.seconds ), - follow_redirects=True, ) as client: # Define callbacks that need access to tg def start_get_stream() -> None: diff --git a/src/mcp/shared/_httpx_utils.py b/src/mcp/shared/_httpx_utils.py new file mode 100644 index 0000000000..95080bde1c --- /dev/null +++ b/src/mcp/shared/_httpx_utils.py @@ -0,0 +1,62 @@ +"""Utilities for creating standardized httpx AsyncClient instances.""" + +from typing import Any + +import httpx + +__all__ = ["create_mcp_http_client"] + + +def create_mcp_http_client( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, +) -> httpx.AsyncClient: + """Create a standardized httpx AsyncClient with MCP defaults. + + This function provides common defaults used throughout the MCP codebase: + - follow_redirects=True (always enabled) + - Default timeout of 30 seconds if not specified + + Args: + headers: Optional headers to include with all requests. + timeout: Request timeout as httpx.Timeout object. + Defaults to 30 seconds if not specified. + + Returns: + Configured httpx.AsyncClient instance with MCP defaults. + + Note: + The returned AsyncClient must be used as a context manager to ensure + proper cleanup of connections. + + Examples: + # Basic usage with MCP defaults + async with create_mcp_http_client() as client: + response = await client.get("https://api.example.com") + + # With custom headers + headers = {"Authorization": "Bearer token"} + async with create_mcp_http_client(headers) as client: + response = await client.get("/endpoint") + + # With both custom headers and timeout + timeout = httpx.Timeout(60.0, read=300.0) + async with create_mcp_http_client(headers, timeout) as client: + response = await client.get("/long-request") + """ + # Set MCP defaults + kwargs: dict[str, Any] = { + "follow_redirects": True, + } + + # Handle timeout + if timeout is None: + kwargs["timeout"] = httpx.Timeout(30.0) + else: + kwargs["timeout"] = timeout + + # Handle headers + if headers is not None: + kwargs["headers"] = headers + + return httpx.AsyncClient(**kwargs) diff --git a/tests/shared/test_httpx_utils.py b/tests/shared/test_httpx_utils.py new file mode 100644 index 0000000000..dcc6fd003c --- /dev/null +++ b/tests/shared/test_httpx_utils.py @@ -0,0 +1,24 @@ +"""Tests for httpx utility functions.""" + +import httpx + +from mcp.shared._httpx_utils import create_mcp_http_client + + +def test_default_settings(): + """Test that default settings are applied correctly.""" + client = create_mcp_http_client() + + assert client.follow_redirects is True + assert client.timeout.connect == 30.0 + + +def test_custom_parameters(): + """Test custom headers and timeout are set correctly.""" + headers = {"Authorization": "Bearer token"} + timeout = httpx.Timeout(60.0) + + client = create_mcp_http_client(headers, timeout) + + assert client.headers["Authorization"] == "Bearer token" + assert client.timeout.connect == 60.0