diff --git a/.github/workflows/shared.yml b/.github/workflows/shared.yml index 03c36a691..4c9023ae9 100644 --- a/.github/workflows/shared.yml +++ b/.github/workflows/shared.yml @@ -37,10 +37,11 @@ jobs: run: uv run --no-sync pyright test: - runs-on: ubuntu-latest + runs-on: ${{ matrix.os }} strategy: matrix: python-version: ["3.10", "3.11", "3.12", "3.13"] + os: [ubuntu-latest, windows-latest] steps: - uses: actions/checkout@v4 @@ -55,3 +56,4 @@ jobs: - name: Run pytest run: uv run --no-sync pytest + continue-on-error: true diff --git a/README.md b/README.md index a63cb4056..2611e25f0 100644 --- a/README.md +++ b/README.md @@ -318,7 +318,7 @@ providing an implementation of the `OAuthServerProvider` protocol. ``` mcp = FastMCP("My App", - auth_provider=MyOAuthServerProvider(), + auth_server_provider=MyOAuthServerProvider(), auth=AuthSettings( issuer_url="https://myapp.com", revocation_options=RevocationOptions( @@ -387,6 +387,8 @@ python server.py mcp run server.py ``` +Note that `mcp run` or `mcp dev` only supports server using FastMCP and not the low-level server variant. + ### Streamable HTTP Transport > **Note**: Streamable HTTP transport is superseding SSE transport for production deployments. @@ -400,6 +402,9 @@ mcp = FastMCP("StatefulServer") # Stateless server (no session persistence) mcp = FastMCP("StatelessServer", stateless_http=True) +# Stateless server (no session persistence, no sse stream with supported client) +mcp = FastMCP("StatelessServer", stateless_http=True, json_response=True) + # Run server with streamable_http transport mcp.run(transport="streamable-http") ``` @@ -426,21 +431,28 @@ mcp = FastMCP(name="MathServer", stateless_http=True) @mcp.tool(description="A simple add tool") -def add_two(n: int) -> str: +def add_two(n: int) -> int: return n + 2 ``` ```python # main.py +import contextlib from fastapi import FastAPI from mcp.echo import echo from mcp.math import math -app = FastAPI() +# Create a combined lifespan to manage both session managers +@contextlib.asynccontextmanager +async def lifespan(app: FastAPI): + async with contextlib.AsyncExitStack() as stack: + await stack.enter_async_context(echo.mcp.session_manager.run()) + await stack.enter_async_context(math.mcp.session_manager.run()) + yield + -# Use the session manager's lifespan -app = FastAPI(lifespan=lambda app: echo.mcp.session_manager.run()) +app = FastAPI(lifespan=lifespan) app.mount("/echo", echo.mcp.streamable_http_app()) app.mount("/math", math.mcp.streamable_http_app()) ``` @@ -462,6 +474,8 @@ The streamable HTTP transport supports: > **Note**: SSE transport is being superseded by [Streamable HTTP transport](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http). +By default, SSE servers are mounted at `/sse` and Streamable HTTP servers are mounted at `/mcp`. You can customize these paths using the methods described below. + You can mount the SSE server to an existing ASGI server using the `sse_app` method. This allows you to integrate the SSE server with other ASGI applications. ```python @@ -617,7 +631,7 @@ server = Server("example-server", lifespan=server_lifespan) # Access lifespan context in handlers @server.call_tool() async def query_db(name: str, arguments: dict) -> list: - ctx = server.request_context + ctx = server.get_context() db = ctx.lifespan_context["db"] return await db.query(arguments["query"]) ``` @@ -692,6 +706,8 @@ if __name__ == "__main__": asyncio.run(run()) ``` +Caution: The `mcp run` and `mcp dev` tool doesn't support low-level server. + ### Writing MCP Clients The SDK provides a high-level client interface for connecting to MCP servers using various [transports](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports): diff --git a/examples/clients/simple-chatbot/README.MD b/examples/clients/simple-chatbot/README.MD index 683e4f3f5..22996d962 100644 --- a/examples/clients/simple-chatbot/README.MD +++ b/examples/clients/simple-chatbot/README.MD @@ -25,6 +25,7 @@ This example demonstrates how to integrate the Model Context Protocol (MCP) into ```plaintext LLM_API_KEY=your_api_key_here ``` + **Note:** The current implementation is configured to use the Groq API endpoint (`https://api.groq.com/openai/v1/chat/completions`) with the `llama-3.2-90b-vision-preview` model. If you plan to use a different LLM provider, you'll need to modify the `LLMClient` class in `main.py` to use the appropriate endpoint URL and model parameters. 3. **Configure servers:** diff --git a/examples/servers/simple-auth/README.md b/examples/servers/simple-auth/README.md index 1d0979d97..9906c4d36 100644 --- a/examples/servers/simple-auth/README.md +++ b/examples/servers/simple-auth/README.md @@ -44,6 +44,31 @@ uv run mcp-simple-auth The server will start on `http://localhost:8000`. +### Transport Options + +This server supports multiple transport protocols that can run on the same port: + +#### SSE (Server-Sent Events) - Default +```bash +uv run mcp-simple-auth +# or explicitly: +uv run mcp-simple-auth --transport sse +``` + +SSE transport provides endpoint: +- `/sse` + +#### Streamable HTTP +```bash +uv run mcp-simple-auth --transport streamable-http +``` + +Streamable HTTP transport provides endpoint: +- `/mcp` + + +This ensures backward compatibility without needing multiple server instances. When using SSE transport (`--transport sse`), only the `/sse` endpoint is available. + ## Available Tool ### get_user_profile @@ -61,5 +86,6 @@ If the server fails to start, check: 1. Environment variables `MCP_GITHUB_GITHUB_CLIENT_ID` and `MCP_GITHUB_GITHUB_CLIENT_SECRET` are set 2. The GitHub OAuth app callback URL matches `http://localhost:8000/github/callback` 3. No other service is using port 8000 +4. The transport specified is valid (`sse` or `streamable-http`) You can use [Inspector](https://github.com/modelcontextprotocol/inspector) to test Auth \ No newline at end of file diff --git a/examples/servers/simple-auth/mcp_simple_auth/__main__.py b/examples/servers/simple-auth/mcp_simple_auth/__main__.py index a8840780b..2365ff5a1 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/__main__.py +++ b/examples/servers/simple-auth/mcp_simple_auth/__main__.py @@ -4,4 +4,4 @@ from mcp_simple_auth.server import main -sys.exit(main()) +sys.exit(main()) # type: ignore[call-arg] diff --git a/examples/servers/simple-auth/mcp_simple_auth/server.py b/examples/servers/simple-auth/mcp_simple_auth/server.py index 2f1e4086f..51f449113 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/server.py +++ b/examples/servers/simple-auth/mcp_simple_auth/server.py @@ -3,7 +3,7 @@ import logging import secrets import time -from typing import Any +from typing import Any, Literal import click from pydantic import AnyHttpUrl @@ -347,7 +347,13 @@ async def get_user_profile() -> dict[str, Any]: @click.command() @click.option("--port", default=8000, help="Port to listen on") @click.option("--host", default="localhost", help="Host to bind to") -def main(port: int, host: str) -> int: +@click.option( + "--transport", + default="sse", + type=click.Choice(["sse", "streamable-http"]), + help="Transport protocol to use ('sse' or 'streamable-http')", +) +def main(port: int, host: str, transport: Literal["sse", "streamable-http"]) -> int: """Run the simple GitHub MCP server.""" logging.basicConfig(level=logging.INFO) @@ -364,5 +370,6 @@ def main(port: int, host: str) -> int: return 1 mcp_server = create_simple_mcp_server(settings) - mcp_server.run(transport="sse") + logger.info(f"Starting server with {transport} transport") + mcp_server.run(transport=transport) return 0 diff --git a/examples/servers/simple-prompt/mcp_simple_prompt/__main__.py b/examples/servers/simple-prompt/mcp_simple_prompt/__main__.py index 8b345fa2e..e7ef16530 100644 --- a/examples/servers/simple-prompt/mcp_simple_prompt/__main__.py +++ b/examples/servers/simple-prompt/mcp_simple_prompt/__main__.py @@ -2,4 +2,4 @@ from .server import main -sys.exit(main()) +sys.exit(main()) # type: ignore[call-arg] diff --git a/examples/servers/simple-resource/mcp_simple_resource/__main__.py b/examples/servers/simple-resource/mcp_simple_resource/__main__.py index 8b345fa2e..e7ef16530 100644 --- a/examples/servers/simple-resource/mcp_simple_resource/__main__.py +++ b/examples/servers/simple-resource/mcp_simple_resource/__main__.py @@ -2,4 +2,4 @@ from .server import main -sys.exit(main()) +sys.exit(main()) # type: ignore[call-arg] diff --git a/examples/servers/simple-resource/mcp_simple_resource/server.py b/examples/servers/simple-resource/mcp_simple_resource/server.py index 06f567fbe..3e3adf108 100644 --- a/examples/servers/simple-resource/mcp_simple_resource/server.py +++ b/examples/servers/simple-resource/mcp_simple_resource/server.py @@ -2,7 +2,7 @@ import click import mcp.types as types from mcp.server.lowlevel import Server -from pydantic import FileUrl +from pydantic import AnyUrl SAMPLE_RESOURCES = { "greeting": "Hello! This is a sample text resource.", @@ -26,7 +26,7 @@ def main(port: int, transport: str) -> int: async def list_resources() -> list[types.Resource]: return [ types.Resource( - uri=FileUrl(f"file:///{name}.txt"), + uri=AnyUrl(f"file:///{name}.txt"), name=name, description=f"A sample text resource named {name}", mimeType="text/plain", @@ -35,7 +35,9 @@ async def list_resources() -> list[types.Resource]: ] @app.read_resource() - async def read_resource(uri: FileUrl) -> str | bytes: + async def read_resource(uri: AnyUrl) -> str | bytes: + if uri.path is None: + raise ValueError(f"Invalid resource path: {uri}") name = uri.path.replace(".txt", "").lstrip("/") if name not in SAMPLE_RESOURCES: diff --git a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/__main__.py b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/__main__.py index f5f6e402d..1664737e3 100644 --- a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/__main__.py +++ b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/__main__.py @@ -1,4 +1,7 @@ from .server import main if __name__ == "__main__": - main() + # Click will handle CLI arguments + import sys + + sys.exit(main()) # type: ignore[call-arg] diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__main__.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__main__.py index f5f6e402d..21862e45f 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__main__.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/__main__.py @@ -1,4 +1,4 @@ from .server import main if __name__ == "__main__": - main() + main() # type: ignore[call-arg] diff --git a/examples/servers/simple-tool/mcp_simple_tool/__main__.py b/examples/servers/simple-tool/mcp_simple_tool/__main__.py index 8b345fa2e..e7ef16530 100644 --- a/examples/servers/simple-tool/mcp_simple_tool/__main__.py +++ b/examples/servers/simple-tool/mcp_simple_tool/__main__.py @@ -2,4 +2,4 @@ from .server import main -sys.exit(main()) +sys.exit(main()) # type: ignore[call-arg] diff --git a/pyproject.toml b/pyproject.toml index 2b86fb377..1b3f86994 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,7 +85,7 @@ Issues = "https://github.com/modelcontextprotocol/python-sdk/issues" packages = ["src/mcp"] [tool.pyright] -include = ["src/mcp", "tests"] +include = ["src/mcp", "tests", "examples/servers"] venvPath = "." venv = ".venv" strict = ["src/mcp/**/*.py"] diff --git a/src/mcp/__init__.py b/src/mcp/__init__.py index 0d3c372ce..e93b95c90 100644 --- a/src/mcp/__init__.py +++ b/src/mcp/__init__.py @@ -1,4 +1,5 @@ from .client.session import ClientSession +from .client.session_group import ClientSessionGroup from .client.stdio import StdioServerParameters, stdio_client from .server.session import ServerSession from .server.stdio import stdio_server @@ -63,6 +64,7 @@ "ClientRequest", "ClientResult", "ClientSession", + "ClientSessionGroup", "CreateMessageRequest", "CreateMessageResult", "ErrorData", diff --git a/src/mcp/cli/cli.py b/src/mcp/cli/cli.py index cb0830600..b2632f1d9 100644 --- a/src/mcp/cli/cli.py +++ b/src/mcp/cli/cli.py @@ -6,7 +6,10 @@ import subprocess import sys from pathlib import Path -from typing import Annotated +from typing import Annotated, Any + +from mcp.server import FastMCP +from mcp.server import Server as LowLevelServer try: import typer @@ -141,17 +144,48 @@ def _import_server(file: Path, server_object: str | None = None): module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) + def _check_server_object(server_object: Any, object_name: str): + """Helper function to check that the server object is supported + + Args: + server_object: The server object to check. + + Returns: + True if it's supported. + """ + if not isinstance(server_object, FastMCP): + logger.error( + f"The server object {object_name} is of type " + f"{type(server_object)} (expecting {FastMCP})." + ) + if isinstance(server_object, LowLevelServer): + logger.warning( + "Note that only FastMCP server is supported. Low level " + "Server class is not yet supported." + ) + return False + return True + # If no object specified, try common server names if not server_object: # Look for the most common server object names for name in ["mcp", "server", "app"]: if hasattr(module, name): + if not _check_server_object(getattr(module, name), f"{file}:{name}"): + logger.error( + f"Ignoring object '{file}:{name}' as it's not a valid " + "server object" + ) + continue return getattr(module, name) logger.error( f"No server object found in {file}. Please either:\n" "1. Use a standard variable name (mcp, server, or app)\n" - "2. Specify the object name with file:object syntax", + "2. Specify the object name with file:object syntax" + "3. If the server creates the FastMCP object within main() " + " or another function, refactor the FastMCP object to be a " + " global variable named mcp, server, or app.", extra={"file": str(file)}, ) sys.exit(1) @@ -179,6 +213,9 @@ def _import_server(file: Path, server_object: str | None = None): ) sys.exit(1) + if not _check_server_object(server, server_object): + sys.exit(1) + return server diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 7bb8821f7..c714c44bb 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -8,7 +8,7 @@ import mcp.types as types from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage -from mcp.shared.session import BaseSession, RequestResponder +from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0") @@ -168,7 +168,11 @@ async def send_ping(self) -> types.EmptyResult: ) async def send_progress_notification( - self, progress_token: str | int, progress: float, total: float | None = None + self, + progress_token: str | int, + progress: float, + total: float | None = None, + message: str | None = None, ) -> None: """Send a progress notification.""" await self.send_notification( @@ -179,6 +183,7 @@ async def send_progress_notification( progressToken=progress_token, progress=progress, total=total, + message=message, ), ), ) @@ -196,23 +201,29 @@ async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResul types.EmptyResult, ) - async def list_resources(self) -> types.ListResourcesResult: + async def list_resources( + self, cursor: str | None = None + ) -> types.ListResourcesResult: """Send a resources/list request.""" return await self.send_request( types.ClientRequest( types.ListResourcesRequest( method="resources/list", + cursor=cursor, ) ), types.ListResourcesResult, ) - async def list_resource_templates(self) -> types.ListResourceTemplatesResult: + async def list_resource_templates( + self, cursor: str | None = None + ) -> types.ListResourceTemplatesResult: """Send a resources/templates/list request.""" return await self.send_request( types.ClientRequest( types.ListResourceTemplatesRequest( method="resources/templates/list", + cursor=cursor, ) ), types.ListResourceTemplatesResult, @@ -259,26 +270,32 @@ async def call_tool( name: str, arguments: dict[str, Any] | None = None, read_timeout_seconds: timedelta | None = None, + progress_callback: ProgressFnT | None = None, ) -> types.CallToolResult: - """Send a tools/call request.""" + """Send a tools/call request with optional progress callback support.""" return await self.send_request( types.ClientRequest( types.CallToolRequest( method="tools/call", - params=types.CallToolRequestParams(name=name, arguments=arguments), + params=types.CallToolRequestParams( + name=name, + arguments=arguments, + ), ) ), types.CallToolResult, request_read_timeout_seconds=read_timeout_seconds, + progress_callback=progress_callback, ) - async def list_prompts(self) -> types.ListPromptsResult: + async def list_prompts(self, cursor: str | None = None) -> types.ListPromptsResult: """Send a prompts/list request.""" return await self.send_request( types.ClientRequest( types.ListPromptsRequest( method="prompts/list", + cursor=cursor, ) ), types.ListPromptsResult, @@ -317,12 +334,13 @@ async def complete( types.CompleteResult, ) - async def list_tools(self) -> types.ListToolsResult: + async def list_tools(self, cursor: str | None = None) -> types.ListToolsResult: """Send a tools/list request.""" return await self.send_request( types.ClientRequest( types.ListToolsRequest( method="tools/list", + cursor=cursor, ) ), types.ListToolsResult, diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py new file mode 100644 index 000000000..c23f2523e --- /dev/null +++ b/src/mcp/client/session_group.py @@ -0,0 +1,372 @@ +""" +SessionGroup concurrently manages multiple MCP session connections. + +Tools, resources, and prompts are aggregated across servers. Servers may +be connected to or disconnected from at any point after initialization. + +This abstractions can handle naming collisions using a custom user-provided +hook. +""" + +import contextlib +import logging +from collections.abc import Callable +from datetime import timedelta +from types import TracebackType +from typing import Any, TypeAlias + +import anyio +from pydantic import BaseModel +from typing_extensions import Self + +import mcp +from mcp import types +from mcp.client.sse import sse_client +from mcp.client.stdio import StdioServerParameters +from mcp.client.streamable_http import streamablehttp_client +from mcp.shared.exceptions import McpError + + +class SseServerParameters(BaseModel): + """Parameters for intializing a sse_client.""" + + # The endpoint URL. + url: str + + # Optional headers to include in requests. + headers: dict[str, Any] | None = None + + # HTTP timeout for regular operations. + timeout: float = 5 + + # Timeout for SSE read operations. + sse_read_timeout: float = 60 * 5 + + +class StreamableHttpParameters(BaseModel): + """Parameters for intializing a streamablehttp_client.""" + + # The endpoint URL. + url: str + + # Optional headers to include in requests. + headers: dict[str, Any] | None = None + + # HTTP timeout for regular operations. + timeout: timedelta = timedelta(seconds=30) + + # Timeout for SSE read operations. + sse_read_timeout: timedelta = timedelta(seconds=60 * 5) + + # Close the client session when the transport closes. + terminate_on_close: bool = True + + +ServerParameters: TypeAlias = ( + StdioServerParameters | SseServerParameters | StreamableHttpParameters +) + + +class ClientSessionGroup: + """Client for managing connections to multiple MCP servers. + + This class is responsible for encapsulating management of server connections. + It aggregates tools, resources, and prompts from all connected servers. + + For auxiliary handlers, such as resource subscription, this is delegated to + the client and can be accessed via the session. + + Example Usage: + name_fn = lambda name, server_info: f"{(server_info.name)}-{name}" + async with ClientSessionGroup(component_name_hook=name_fn) as group: + for server_params in server_params: + group.connect_to_server(server_param) + ... + + """ + + class _ComponentNames(BaseModel): + """Used for reverse index to find components.""" + + prompts: set[str] = set() + resources: set[str] = set() + tools: set[str] = set() + + # Standard MCP components. + _prompts: dict[str, types.Prompt] + _resources: dict[str, types.Resource] + _tools: dict[str, types.Tool] + + # Client-server connection management. + _sessions: dict[mcp.ClientSession, _ComponentNames] + _tool_to_session: dict[str, mcp.ClientSession] + _exit_stack: contextlib.AsyncExitStack + _session_exit_stacks: dict[mcp.ClientSession, contextlib.AsyncExitStack] + + # Optional fn consuming (component_name, serverInfo) for custom names. + # This is provide a means to mitigate naming conflicts across servers. + # Example: (tool_name, serverInfo) => "{result.serverInfo.name}.{tool_name}" + _ComponentNameHook: TypeAlias = Callable[[str, types.Implementation], str] + _component_name_hook: _ComponentNameHook | None + + def __init__( + self, + exit_stack: contextlib.AsyncExitStack | None = None, + component_name_hook: _ComponentNameHook | None = None, + ) -> None: + """Initializes the MCP client.""" + + self._tools = {} + self._resources = {} + self._prompts = {} + + self._sessions = {} + self._tool_to_session = {} + if exit_stack is None: + self._exit_stack = contextlib.AsyncExitStack() + self._owns_exit_stack = True + else: + self._exit_stack = exit_stack + self._owns_exit_stack = False + self._session_exit_stacks = {} + self._component_name_hook = component_name_hook + + async def __aenter__(self) -> Self: + # Enter the exit stack only if we created it ourselves + if self._owns_exit_stack: + await self._exit_stack.__aenter__() + return self + + async def __aexit__( + self, + _exc_type: type[BaseException] | None, + _exc_val: BaseException | None, + _exc_tb: TracebackType | None, + ) -> bool | None: + """Closes session exit stacks and main exit stack upon completion.""" + + # Concurrently close session stacks. + async with anyio.create_task_group() as tg: + for exit_stack in self._session_exit_stacks.values(): + tg.start_soon(exit_stack.aclose) + + # Only close the main exit stack if we created it + if self._owns_exit_stack: + await self._exit_stack.aclose() + + @property + def sessions(self) -> list[mcp.ClientSession]: + """Returns the list of sessions being managed.""" + return list(self._sessions.keys()) + + @property + def prompts(self) -> dict[str, types.Prompt]: + """Returns the prompts as a dictionary of names to prompts.""" + return self._prompts + + @property + def resources(self) -> dict[str, types.Resource]: + """Returns the resources as a dictionary of names to resources.""" + return self._resources + + @property + def tools(self) -> dict[str, types.Tool]: + """Returns the tools as a dictionary of names to tools.""" + return self._tools + + async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResult: + """Executes a tool given its name and arguments.""" + session = self._tool_to_session[name] + session_tool_name = self.tools[name].name + return await session.call_tool(session_tool_name, args) + + async def disconnect_from_server(self, session: mcp.ClientSession) -> None: + """Disconnects from a single MCP server.""" + + session_known_for_components = session in self._sessions + session_known_for_stack = session in self._session_exit_stacks + + if not session_known_for_components and not session_known_for_stack: + raise McpError( + types.ErrorData( + code=types.INVALID_PARAMS, + message="Provided session is not managed or already disconnected.", + ) + ) + + if session_known_for_components: + component_names = self._sessions.pop(session) # Pop from _sessions tracking + + # Remove prompts associated with the session. + for name in component_names.prompts: + if name in self._prompts: + del self._prompts[name] + # Remove resources associated with the session. + for name in component_names.resources: + if name in self._resources: + del self._resources[name] + # Remove tools associated with the session. + for name in component_names.tools: + if name in self._tools: + del self._tools[name] + if name in self._tool_to_session: + del self._tool_to_session[name] + + # Clean up the session's resources via its dedicated exit stack + if session_known_for_stack: + session_stack_to_close = self._session_exit_stacks.pop(session) + await session_stack_to_close.aclose() + + async def connect_with_session( + self, server_info: types.Implementation, session: mcp.ClientSession + ) -> mcp.ClientSession: + """Connects to a single MCP server.""" + await self._aggregate_components(server_info, session) + return session + + async def connect_to_server( + self, + server_params: ServerParameters, + ) -> mcp.ClientSession: + """Connects to a single MCP server.""" + server_info, session = await self._establish_session(server_params) + return await self.connect_with_session(server_info, session) + + async def _establish_session( + self, server_params: ServerParameters + ) -> tuple[types.Implementation, mcp.ClientSession]: + """Establish a client session to an MCP server.""" + + session_stack = contextlib.AsyncExitStack() + try: + # Create read and write streams that facilitate io with the server. + if isinstance(server_params, StdioServerParameters): + client = mcp.stdio_client(server_params) + read, write = await session_stack.enter_async_context(client) + elif isinstance(server_params, SseServerParameters): + client = sse_client( + url=server_params.url, + headers=server_params.headers, + timeout=server_params.timeout, + sse_read_timeout=server_params.sse_read_timeout, + ) + read, write = await session_stack.enter_async_context(client) + else: + client = streamablehttp_client( + url=server_params.url, + headers=server_params.headers, + timeout=server_params.timeout, + sse_read_timeout=server_params.sse_read_timeout, + terminate_on_close=server_params.terminate_on_close, + ) + read, write, _ = await session_stack.enter_async_context(client) + + session = await session_stack.enter_async_context( + mcp.ClientSession(read, write) + ) + result = await session.initialize() + + # Session successfully initialized. + # Store its stack and register the stack with the main group stack. + self._session_exit_stacks[session] = session_stack + # session_stack itself becomes a resource managed by the + # main _exit_stack. + await self._exit_stack.enter_async_context(session_stack) + + return result.serverInfo, session + except Exception: + # If anything during this setup fails, ensure the session-specific + # stack is closed. + await session_stack.aclose() + raise + + async def _aggregate_components( + self, server_info: types.Implementation, session: mcp.ClientSession + ) -> None: + """Aggregates prompts, resources, and tools from a given session.""" + + # Create a reverse index so we can find all prompts, resources, and + # tools belonging to this session. Used for removing components from + # the session group via self.disconnect_from_server. + component_names = self._ComponentNames() + + # Temporary components dicts. We do not want to modify the aggregate + # lists in case of an intermediate failure. + prompts_temp: dict[str, types.Prompt] = {} + resources_temp: dict[str, types.Resource] = {} + tools_temp: dict[str, types.Tool] = {} + tool_to_session_temp: dict[str, mcp.ClientSession] = {} + + # Query the server for its prompts and aggregate to list. + try: + prompts = (await session.list_prompts()).prompts + for prompt in prompts: + name = self._component_name(prompt.name, server_info) + prompts_temp[name] = prompt + component_names.prompts.add(name) + except McpError as err: + logging.warning(f"Could not fetch prompts: {err}") + + # Query the server for its resources and aggregate to list. + try: + resources = (await session.list_resources()).resources + for resource in resources: + name = self._component_name(resource.name, server_info) + resources_temp[name] = resource + component_names.resources.add(name) + except McpError as err: + logging.warning(f"Could not fetch resources: {err}") + + # Query the server for its tools and aggregate to list. + try: + tools = (await session.list_tools()).tools + for tool in tools: + name = self._component_name(tool.name, server_info) + tools_temp[name] = tool + tool_to_session_temp[name] = session + component_names.tools.add(name) + except McpError as err: + logging.warning(f"Could not fetch tools: {err}") + + # Clean up exit stack for session if we couldn't retrieve anything + # from the server. + if not any((prompts_temp, resources_temp, tools_temp)): + del self._session_exit_stacks[session] + + # Check for duplicates. + matching_prompts = prompts_temp.keys() & self._prompts.keys() + if matching_prompts: + raise McpError( + types.ErrorData( + code=types.INVALID_PARAMS, + message=f"{matching_prompts} already exist in group prompts.", + ) + ) + matching_resources = resources_temp.keys() & self._resources.keys() + if matching_resources: + raise McpError( + types.ErrorData( + code=types.INVALID_PARAMS, + message=f"{matching_resources} already exist in group resources.", + ) + ) + matching_tools = tools_temp.keys() & self._tools.keys() + if matching_tools: + raise McpError( + types.ErrorData( + code=types.INVALID_PARAMS, + message=f"{matching_tools} already exist in group tools.", + ) + ) + + # Aggregate components. + self._sessions[session] = component_names + self._prompts.update(prompts_temp) + self._resources.update(resources_temp) + self._tools.update(tools_temp) + self._tool_to_session.update(tool_to_session_temp) + + def _component_name(self, name: str, server_info: types.Implementation) -> str: + if self._component_name_hook: + return self._component_name_hook(name, server_info) + return name diff --git a/src/mcp/client/stdio/__init__.py b/src/mcp/client/stdio/__init__.py index e8be5aff5..6d815b43a 100644 --- a/src/mcp/client/stdio/__init__.py +++ b/src/mcp/client/stdio/__init__.py @@ -181,6 +181,8 @@ async def stdin_writer(): await terminate_windows_process(process) else: process.terminate() + await read_stream.aclose() + await write_stream.aclose() def _get_executable_command(command: str) -> str: diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 893aeb84a..3324dab5a 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -410,7 +410,7 @@ async def terminate_session(self, client: httpx.AsyncClient) -> None: if response.status_code == 405: logger.debug("Server does not allow session termination") - elif response.status_code != 200: + elif response.status_code not in (200, 204): logger.warning(f"Session termination failed: {response.status_code}") except Exception as exc: logger.warning(f"Session termination failed: {exc}") diff --git a/src/mcp/server/fastmcp/resources/types.py b/src/mcp/server/fastmcp/resources/types.py index 2ab39b078..d3f10211d 100644 --- a/src/mcp/server/fastmcp/resources/types.py +++ b/src/mcp/server/fastmcp/resources/types.py @@ -11,7 +11,7 @@ import httpx import pydantic import pydantic_core -from pydantic import Field, ValidationInfo +from pydantic import AnyUrl, Field, ValidationInfo, validate_call from mcp.server.fastmcp.resources.base import Resource @@ -68,6 +68,31 @@ async def read(self) -> str | bytes: except Exception as e: raise ValueError(f"Error reading resource {self.uri}: {e}") + @classmethod + def from_function( + cls, + fn: Callable[..., Any], + uri: str, + name: str | None = None, + description: str | None = None, + mime_type: str | None = None, + ) -> "FunctionResource": + """Create a FunctionResource from a function.""" + func_name = name or fn.__name__ + if func_name == "": + raise ValueError("You must provide a name for lambda functions") + + # ensure the arguments are properly cast + fn = validate_call(fn) + + return cls( + uri=AnyUrl(uri), + name=func_name, + description=description or fn.__doc__ or "", + mime_type=mime_type or "text/plain", + fn=fn, + ) + class FileResource(Resource): """A resource that reads from a file. diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index c31f29d4c..21c31b0b3 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -148,9 +148,11 @@ def __init__( self._mcp_server = MCPServer( name=name or "FastMCP", instructions=instructions, - lifespan=lifespan_wrapper(self, self.settings.lifespan) - if self.settings.lifespan - else default_lifespan, + lifespan=( + lifespan_wrapper(self, self.settings.lifespan) + if self.settings.lifespan + else default_lifespan + ), ) self._tool_manager = ToolManager( warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools @@ -465,16 +467,16 @@ def decorator(fn: AnyFunction) -> AnyFunction: uri_template=uri, name=name, description=description, - mime_type=mime_type or "text/plain", + mime_type=mime_type, ) else: # Register as regular resource - resource = FunctionResource( - uri=AnyUrl(uri), + resource = FunctionResource.from_function( + fn=fn, + uri=uri, name=name, description=description, - mime_type=mime_type or "text/plain", - fn=fn, + mime_type=mime_type, ) self.add_resource(resource) return fn @@ -952,15 +954,15 @@ def request_context(self) -> RequestContext[ServerSessionT, LifespanContextT]: return self._request_context async def report_progress( - self, progress: float, total: float | None = None + self, progress: float, total: float | None = None, message: str | None = None ) -> None: """Report progress for the current operation. Args: progress: Current progress value e.g. 24 total: Optional total value e.g. 100 + message: Optional message e.g. Starting render... """ - progress_token = ( self.request_context.meta.progressToken if self.request_context.meta @@ -971,7 +973,10 @@ async def report_progress( return await self.request_context.session.send_progress_notification( - progress_token=progress_token, progress=progress, total=total + progress_token=progress_token, + progress=progress, + total=total, + message=message, ) async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContents]: diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 4b97b33da..876aef817 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -37,7 +37,8 @@ async def handle_list_resource_templates() -> list[types.ResourceTemplate]: 3. Define notification handlers if needed: @server.progress_notification() async def handle_progress( - progress_token: str | int, progress: float, total: float | None + progress_token: str | int, progress: float, total: float | None, + message: str | None ) -> None: # Implementation @@ -427,13 +428,18 @@ async def handler(req: types.CallToolRequest): def progress_notification(self): def decorator( - func: Callable[[str | int, float, float | None], Awaitable[None]], + func: Callable[ + [str | int, float, float | None, str | None], Awaitable[None] + ], ): logger.debug("Registering handler for ProgressNotification") async def handler(req: types.ProgressNotification): await func( - req.params.progressToken, req.params.progress, req.params.total + req.params.progressToken, + req.params.progress, + req.params.total, + req.params.message, ) self.notification_handlers[types.ProgressNotification] = handler diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index f4e72eac1..ef5c5a3c3 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -52,6 +52,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: BaseSession, RequestResponder, ) +from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS class InitializationState(Enum): @@ -150,13 +151,16 @@ async def _received_request( ): match responder.request.root: case types.InitializeRequest(params=params): + requested_version = params.protocolVersion self._initialization_state = InitializationState.Initializing self._client_params = params with responder: await responder.respond( types.ServerResult( types.InitializeResult( - protocolVersion=types.LATEST_PROTOCOL_VERSION, + protocolVersion=requested_version + if requested_version in SUPPORTED_PROTOCOL_VERSIONS + else types.LATEST_PROTOCOL_VERSION, capabilities=self._init_options.capabilities, serverInfo=types.Implementation( name=self._init_options.server_name, @@ -282,6 +286,7 @@ async def send_progress_notification( progress_token: str | int, progress: float, total: float | None = None, + message: str | None = None, related_request_id: str | None = None, ) -> None: """Send a progress notification.""" @@ -293,6 +298,7 @@ async def send_progress_notification( progressToken=progress_token, progress=progress, total=total, + message=message, ), ) ), diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index cc41a80d6..a6350a39b 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -100,10 +100,26 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): write_stream, write_stream_reader = anyio.create_memory_object_stream(0) session_id = uuid4() - session_uri = f"{quote(self._endpoint)}?session_id={session_id.hex}" self._read_stream_writers[session_id] = read_stream_writer logger.debug(f"Created new session with ID: {session_id}") + # Determine the full path for the message endpoint to be sent to the client. + # scope['root_path'] is the prefix where the current Starlette app + # instance is mounted. + # e.g., "" if top-level, or "/api_prefix" if mounted under "/api_prefix". + root_path = scope.get("root_path", "") + + # self._endpoint is the path *within* this app, e.g., "/messages". + # Concatenating them gives the full absolute path from the server root. + # e.g., "" + "/messages" -> "/messages" + # e.g., "/api_prefix" + "/messages" -> "/api_prefix/messages" + full_message_path_for_client = root_path.rstrip("/") + self._endpoint + + # This is the URI (path + query) the client will use to POST messages. + client_post_uri_data = ( + f"{quote(full_message_path_for_client)}?session_id={session_id.hex}" + ) + sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[ dict[str, Any] ](0) @@ -111,8 +127,10 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): async def sse_writer(): logger.debug("Starting SSE writer") async with sse_stream_writer, write_stream_reader: - await sse_stream_writer.send({"event": "endpoint", "data": session_uri}) - logger.debug(f"Sent endpoint event: {session_uri}") + await sse_stream_writer.send( + {"event": "endpoint", "data": client_post_uri_data} + ) + logger.debug(f"Sent endpoint event: {client_post_uri_data}") async for session_message in write_stream_reader: logger.debug(f"Sending message via SSE: {session_message}") diff --git a/src/mcp/shared/progress.py b/src/mcp/shared/progress.py index 52e0017d0..856a8d3b6 100644 --- a/src/mcp/shared/progress.py +++ b/src/mcp/shared/progress.py @@ -43,11 +43,11 @@ class ProgressContext( total: float | None current: float = field(default=0.0, init=False) - async def progress(self, amount: float) -> None: + async def progress(self, amount: float, message: str | None = None) -> None: self.current += amount await self.session.send_progress_notification( - self.progress_token, self.current, total=self.total + self.progress_token, self.current, total=self.total, message=message ) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index c390386a9..90b4eb27c 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -3,7 +3,7 @@ from contextlib import AsyncExitStack from datetime import timedelta from types import TracebackType -from typing import Any, Generic, TypeVar +from typing import Any, Generic, Protocol, TypeVar import anyio import httpx @@ -24,6 +24,7 @@ JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, + ProgressNotification, RequestParams, ServerNotification, ServerRequest, @@ -42,6 +43,14 @@ RequestId = str | int +class ProgressFnT(Protocol): + """Protocol for progress notification callbacks.""" + + async def __call__( + self, progress: float, total: float | None, message: str | None + ) -> None: ... + + class RequestResponder(Generic[ReceiveRequestT, SendResultT]): """Handles responding to MCP requests and manages request lifecycle. @@ -169,6 +178,7 @@ class BaseSession( ] _request_id: int _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]] + _progress_callbacks: dict[RequestId, ProgressFnT] def __init__( self, @@ -187,6 +197,7 @@ def __init__( self._receive_notification_type = receive_notification_type self._session_read_timeout_seconds = read_timeout_seconds self._in_flight = {} + self._progress_callbacks = {} self._exit_stack = AsyncExitStack() async def __aenter__(self) -> Self: @@ -214,6 +225,7 @@ async def send_request( result_type: type[ReceiveResultT], request_read_timeout_seconds: timedelta | None = None, metadata: MessageMetadata = None, + progress_callback: ProgressFnT | None = None, ) -> ReceiveResultT: """ Sends a request and wait for a response. Raises an McpError if the @@ -231,15 +243,25 @@ async def send_request( ](1) self._response_streams[request_id] = response_stream + # Set up progress token if progress callback is provided + request_data = request.model_dump(by_alias=True, mode="json", exclude_none=True) + if progress_callback is not None: + # Use request_id as progress token + if "params" not in request_data: + request_data["params"] = {} + if "_meta" not in request_data["params"]: + request_data["params"]["_meta"] = {} + request_data["params"]["_meta"]["progressToken"] = request_id + # Store the callback for this request + self._progress_callbacks[request_id] = progress_callback + try: jsonrpc_request = JSONRPCRequest( jsonrpc="2.0", id=request_id, - **request.model_dump(by_alias=True, mode="json", exclude_none=True), + **request_data, ) - # TODO: Support progress callbacks - await self._write_stream.send( SessionMessage( message=JSONRPCMessage(jsonrpc_request), metadata=metadata @@ -275,6 +297,7 @@ async def send_request( finally: self._response_streams.pop(request_id, None) + self._progress_callbacks.pop(request_id, None) await response_stream.aclose() await response_stream_reader.aclose() @@ -333,7 +356,6 @@ async def _receive_loop(self) -> None: by_alias=True, mode="json", exclude_none=True ) ) - responder = RequestResponder( request_id=message.message.root.id, request_meta=validated_request.root.params.meta @@ -363,6 +385,18 @@ async def _receive_loop(self) -> None: if cancelled_id in self._in_flight: await self._in_flight[cancelled_id].cancel() else: + # Handle progress notifications callback + if isinstance(notification.root, ProgressNotification): + progress_token = notification.root.params.progressToken + # If there is a progress callback for this token, + # call it with the progress information + if progress_token in self._progress_callbacks: + callback = self._progress_callbacks[progress_token] + await callback( + notification.root.params.progress, + notification.root.params.total, + notification.root.params.message, + ) await self._received_notification(notification) await self._handle_incoming(notification) except Exception as e: @@ -401,7 +435,11 @@ async def _received_notification(self, notification: ReceiveNotificationT) -> No """ async def send_progress_notification( - self, progress_token: str | int, progress: float, total: float | None = None + self, + progress_token: str | int, + progress: float, + total: float | None = None, + message: str | None = None, ) -> None: """ Sends a progress notification for a request that is currently being diff --git a/src/mcp/shared/version.py b/src/mcp/shared/version.py index 8fd13b992..d00077705 100644 --- a/src/mcp/shared/version.py +++ b/src/mcp/shared/version.py @@ -1,3 +1,3 @@ from mcp.types import LATEST_PROTOCOL_VERSION -SUPPORTED_PROTOCOL_VERSIONS: tuple[int, str] = (1, LATEST_PROTOCOL_VERSION) +SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", LATEST_PROTOCOL_VERSION] diff --git a/src/mcp/types.py b/src/mcp/types.py index 6ab7fba5c..d864b19da 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -29,7 +29,7 @@ not separate types in the schema. """ -LATEST_PROTOCOL_VERSION = "2024-11-05" +LATEST_PROTOCOL_VERSION = "2025-03-26" ProgressToken = str | int Cursor = str @@ -337,6 +337,11 @@ class ProgressNotificationParams(NotificationParams): total is unknown. """ total: float | None = None + """ + Message related to progress. This should provide relevant human readable + progress information. + """ + message: str | None = None """Total number of items to process (or total progress required), if known.""" model_config = ConfigDict(extra="allow") diff --git a/tests/client/test_list_methods_cursor.py b/tests/client/test_list_methods_cursor.py new file mode 100644 index 000000000..b0d6e36b8 --- /dev/null +++ b/tests/client/test_list_methods_cursor.py @@ -0,0 +1,142 @@ +import pytest + +from mcp.server.fastmcp import FastMCP +from mcp.shared.memory import ( + create_connected_server_and_client_session as create_session, +) + +# Mark the whole module for async tests +pytestmark = pytest.mark.anyio + + +async def test_list_tools_cursor_parameter(): + """Test that the cursor parameter is accepted for list_tools. + + Note: FastMCP doesn't currently implement pagination, so this test + only verifies that the cursor parameter is accepted by the client. + """ + server = FastMCP("test") + + # Create a couple of test tools + @server.tool(name="test_tool_1") + async def test_tool_1() -> str: + """First test tool""" + return "Result 1" + + @server.tool(name="test_tool_2") + async def test_tool_2() -> str: + """Second test tool""" + return "Result 2" + + async with create_session(server._mcp_server) as client_session: + # Test without cursor parameter (omitted) + result1 = await client_session.list_tools() + assert len(result1.tools) == 2 + + # Test with cursor=None + result2 = await client_session.list_tools(cursor=None) + assert len(result2.tools) == 2 + + # Test with cursor as string + result3 = await client_session.list_tools(cursor="some_cursor_value") + assert len(result3.tools) == 2 + + # Test with empty string cursor + result4 = await client_session.list_tools(cursor="") + assert len(result4.tools) == 2 + + +async def test_list_resources_cursor_parameter(): + """Test that the cursor parameter is accepted for list_resources. + + Note: FastMCP doesn't currently implement pagination, so this test + only verifies that the cursor parameter is accepted by the client. + """ + server = FastMCP("test") + + # Create a test resource + @server.resource("resource://test/data") + async def test_resource() -> str: + """Test resource""" + return "Test data" + + async with create_session(server._mcp_server) as client_session: + # Test without cursor parameter (omitted) + result1 = await client_session.list_resources() + assert len(result1.resources) >= 1 + + # Test with cursor=None + result2 = await client_session.list_resources(cursor=None) + assert len(result2.resources) >= 1 + + # Test with cursor as string + result3 = await client_session.list_resources(cursor="some_cursor") + assert len(result3.resources) >= 1 + + # Test with empty string cursor + result4 = await client_session.list_resources(cursor="") + assert len(result4.resources) >= 1 + + +async def test_list_prompts_cursor_parameter(): + """Test that the cursor parameter is accepted for list_prompts. + + Note: FastMCP doesn't currently implement pagination, so this test + only verifies that the cursor parameter is accepted by the client. + """ + server = FastMCP("test") + + # Create a test prompt + @server.prompt() + async def test_prompt(name: str) -> str: + """Test prompt""" + return f"Hello, {name}!" + + async with create_session(server._mcp_server) as client_session: + # Test without cursor parameter (omitted) + result1 = await client_session.list_prompts() + assert len(result1.prompts) >= 1 + + # Test with cursor=None + result2 = await client_session.list_prompts(cursor=None) + assert len(result2.prompts) >= 1 + + # Test with cursor as string + result3 = await client_session.list_prompts(cursor="some_cursor") + assert len(result3.prompts) >= 1 + + # Test with empty string cursor + result4 = await client_session.list_prompts(cursor="") + assert len(result4.prompts) >= 1 + + +async def test_list_resource_templates_cursor_parameter(): + """Test that the cursor parameter is accepted for list_resource_templates. + + Note: FastMCP doesn't currently implement pagination, so this test + only verifies that the cursor parameter is accepted by the client. + """ + server = FastMCP("test") + + # Create a test resource template + @server.resource("resource://test/{name}") + async def test_template(name: str) -> str: + """Test resource template""" + return f"Data for {name}" + + async with create_session(server._mcp_server) as client_session: + # Test without cursor parameter (omitted) + result1 = await client_session.list_resource_templates() + assert len(result1.resourceTemplates) >= 1 + + # Test with cursor=None + result2 = await client_session.list_resource_templates(cursor=None) + assert len(result2.resourceTemplates) >= 1 + + # Test with cursor as string + result3 = await client_session.list_resource_templates(cursor="some_cursor") + assert len(result3.resourceTemplates) >= 1 + + # Test with empty string cursor + result4 = await client_session.list_resource_templates(cursor="") + assert len(result4.resourceTemplates) >= 1 diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 6abcf70cb..cad89f217 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -5,6 +5,7 @@ from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder +from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS from mcp.types import ( LATEST_PROTOCOL_VERSION, ClientNotification, @@ -250,3 +251,132 @@ async def mock_server(): # Assert that the default client info was sent assert received_client_info == DEFAULT_CLIENT_INFO + + +@pytest.mark.anyio +async def test_client_session_version_negotiation_success(): + """Test successful version negotiation with supported version""" + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ + SessionMessage + ](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ + SessionMessage + ](1) + + async def mock_server(): + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, InitializeRequest) + + # Verify client sent the latest protocol version + assert request.root.params.protocolVersion == LATEST_PROTOCOL_VERSION + + # Server responds with a supported older version + result = ServerResult( + InitializeResult( + protocolVersion="2024-11-05", + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ) + + async with server_to_client_send: + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump( + by_alias=True, mode="json", exclude_none=True + ), + ) + ) + ) + ) + # Receive initialized notification + await client_to_server_receive.receive() + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + result = await session.initialize() + + # Assert the result with negotiated version + assert isinstance(result, InitializeResult) + assert result.protocolVersion == "2024-11-05" + assert result.protocolVersion in SUPPORTED_PROTOCOL_VERSIONS + + +@pytest.mark.anyio +async def test_client_session_version_negotiation_failure(): + """Test version negotiation failure with unsupported version""" + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ + SessionMessage + ](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ + SessionMessage + ](1) + + async def mock_server(): + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, InitializeRequest) + + # Server responds with an unsupported version + result = ServerResult( + InitializeResult( + protocolVersion="2020-01-01", # Unsupported old version + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ) + + async with server_to_client_send: + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump( + by_alias=True, mode="json", exclude_none=True + ), + ) + ) + ) + ) + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + + # Should raise RuntimeError for unsupported version + with pytest.raises(RuntimeError, match="Unsupported protocol version"): + await session.initialize() diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py new file mode 100644 index 000000000..924ef7a06 --- /dev/null +++ b/tests/client/test_session_group.py @@ -0,0 +1,397 @@ +import contextlib +from unittest import mock + +import pytest + +import mcp +from mcp import types +from mcp.client.session_group import ( + ClientSessionGroup, + SseServerParameters, + StreamableHttpParameters, +) +from mcp.client.stdio import StdioServerParameters +from mcp.shared.exceptions import McpError + + +@pytest.fixture +def mock_exit_stack(): + """Fixture for a mocked AsyncExitStack.""" + # Use unittest.mock.Mock directly if needed, or just a plain object + # if only attribute access/existence is needed. + # For AsyncExitStack, Mock or MagicMock is usually fine. + return mock.MagicMock(spec=contextlib.AsyncExitStack) + + +@pytest.mark.anyio +class TestClientSessionGroup: + def test_init(self): + mcp_session_group = ClientSessionGroup() + assert not mcp_session_group._tools + assert not mcp_session_group._resources + assert not mcp_session_group._prompts + assert not mcp_session_group._tool_to_session + + def test_component_properties(self): + # --- Mock Dependencies --- + mock_prompt = mock.Mock() + mock_resource = mock.Mock() + mock_tool = mock.Mock() + + # --- Prepare Session Group --- + mcp_session_group = ClientSessionGroup() + mcp_session_group._prompts = {"my_prompt": mock_prompt} + mcp_session_group._resources = {"my_resource": mock_resource} + mcp_session_group._tools = {"my_tool": mock_tool} + + # --- Assertions --- + assert mcp_session_group.prompts == {"my_prompt": mock_prompt} + assert mcp_session_group.resources == {"my_resource": mock_resource} + assert mcp_session_group.tools == {"my_tool": mock_tool} + + async def test_call_tool(self): + # --- Mock Dependencies --- + mock_session = mock.AsyncMock() + + # --- Prepare Session Group --- + def hook(name, server_info): + return f"{(server_info.name)}-{name}" + + mcp_session_group = ClientSessionGroup(component_name_hook=hook) + mcp_session_group._tools = { + "server1-my_tool": types.Tool(name="my_tool", inputSchema={}) + } + mcp_session_group._tool_to_session = {"server1-my_tool": mock_session} + text_content = types.TextContent(type="text", text="OK") + mock_session.call_tool.return_value = types.CallToolResult( + content=[text_content] + ) + + # --- Test Execution --- + result = await mcp_session_group.call_tool( + name="server1-my_tool", + args={ + "name": "value1", + "args": {}, + }, + ) + + # --- Assertions --- + assert result.content == [text_content] + mock_session.call_tool.assert_called_once_with( + "my_tool", + {"name": "value1", "args": {}}, + ) + + async def test_connect_to_server(self, mock_exit_stack): + """Test connecting to a server and aggregating components.""" + # --- Mock Dependencies --- + mock_server_info = mock.Mock(spec=types.Implementation) + mock_server_info.name = "TestServer1" + mock_session = mock.AsyncMock(spec=mcp.ClientSession) + mock_tool1 = mock.Mock(spec=types.Tool) + mock_tool1.name = "tool_a" + mock_resource1 = mock.Mock(spec=types.Resource) + mock_resource1.name = "resource_b" + mock_prompt1 = mock.Mock(spec=types.Prompt) + mock_prompt1.name = "prompt_c" + mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool1]) + mock_session.list_resources.return_value = mock.AsyncMock( + resources=[mock_resource1] + ) + mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[mock_prompt1]) + + # --- Test Execution --- + group = ClientSessionGroup(exit_stack=mock_exit_stack) + with mock.patch.object( + group, "_establish_session", return_value=(mock_server_info, mock_session) + ): + await group.connect_to_server(StdioServerParameters(command="test")) + + # --- Assertions --- + assert mock_session in group._sessions + assert len(group.tools) == 1 + assert "tool_a" in group.tools + assert group.tools["tool_a"] == mock_tool1 + assert group._tool_to_session["tool_a"] == mock_session + assert len(group.resources) == 1 + assert "resource_b" in group.resources + assert group.resources["resource_b"] == mock_resource1 + assert len(group.prompts) == 1 + assert "prompt_c" in group.prompts + assert group.prompts["prompt_c"] == mock_prompt1 + mock_session.list_tools.assert_awaited_once() + mock_session.list_resources.assert_awaited_once() + mock_session.list_prompts.assert_awaited_once() + + async def test_connect_to_server_with_name_hook(self, mock_exit_stack): + """Test connecting with a component name hook.""" + # --- Mock Dependencies --- + mock_server_info = mock.Mock(spec=types.Implementation) + mock_server_info.name = "HookServer" + mock_session = mock.AsyncMock(spec=mcp.ClientSession) + mock_tool = mock.Mock(spec=types.Tool) + mock_tool.name = "base_tool" + mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool]) + mock_session.list_resources.return_value = mock.AsyncMock(resources=[]) + mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[]) + + # --- Test Setup --- + def name_hook(name: str, server_info: types.Implementation) -> str: + return f"{server_info.name}.{name}" + + # --- Test Execution --- + group = ClientSessionGroup( + exit_stack=mock_exit_stack, component_name_hook=name_hook + ) + with mock.patch.object( + group, "_establish_session", return_value=(mock_server_info, mock_session) + ): + await group.connect_to_server(StdioServerParameters(command="test")) + + # --- Assertions --- + assert mock_session in group._sessions + assert len(group.tools) == 1 + expected_tool_name = "HookServer.base_tool" + assert expected_tool_name in group.tools + assert group.tools[expected_tool_name] == mock_tool + assert group._tool_to_session[expected_tool_name] == mock_session + + async def test_disconnect_from_server(self): # No mock arguments needed + """Test disconnecting from a server.""" + # --- Test Setup --- + group = ClientSessionGroup() + server_name = "ServerToDisconnect" + + # Manually populate state using standard mocks + mock_session1 = mock.MagicMock(spec=mcp.ClientSession) + mock_session2 = mock.MagicMock(spec=mcp.ClientSession) + mock_tool1 = mock.Mock(spec=types.Tool) + mock_tool1.name = "tool1" + mock_resource1 = mock.Mock(spec=types.Resource) + mock_resource1.name = "res1" + mock_prompt1 = mock.Mock(spec=types.Prompt) + mock_prompt1.name = "prm1" + mock_tool2 = mock.Mock(spec=types.Tool) + mock_tool2.name = "tool2" + mock_component_named_like_server = mock.Mock() + mock_session = mock.Mock(spec=mcp.ClientSession) + + group._tools = { + "tool1": mock_tool1, + "tool2": mock_tool2, + server_name: mock_component_named_like_server, + } + group._tool_to_session = { + "tool1": mock_session1, + "tool2": mock_session2, + server_name: mock_session1, + } + group._resources = { + "res1": mock_resource1, + server_name: mock_component_named_like_server, + } + group._prompts = { + "prm1": mock_prompt1, + server_name: mock_component_named_like_server, + } + group._sessions = { + mock_session: ClientSessionGroup._ComponentNames( + prompts=set({"prm1"}), + resources=set({"res1"}), + tools=set({"tool1", "tool2"}), + ) + } + + # --- Assertions --- + assert mock_session in group._sessions + assert "tool1" in group._tools + assert "tool2" in group._tools + assert "res1" in group._resources + assert "prm1" in group._prompts + + # --- Test Execution --- + await group.disconnect_from_server(mock_session) + + # --- Assertions --- + assert mock_session not in group._sessions + assert "tool1" not in group._tools + assert "tool2" not in group._tools + assert "res1" not in group._resources + assert "prm1" not in group._prompts + + async def test_connect_to_server_duplicate_tool_raises_error(self, mock_exit_stack): + """Test McpError raised when connecting a server with a dup name.""" + # --- Setup Pre-existing State --- + group = ClientSessionGroup(exit_stack=mock_exit_stack) + existing_tool_name = "shared_tool" + # Manually add a tool to simulate a previous connection + group._tools[existing_tool_name] = mock.Mock(spec=types.Tool) + group._tools[existing_tool_name].name = existing_tool_name + # Need a dummy session associated with the existing tool + mock_session = mock.MagicMock(spec=mcp.ClientSession) + group._tool_to_session[existing_tool_name] = mock_session + group._session_exit_stacks[mock_session] = mock.Mock( + spec=contextlib.AsyncExitStack + ) + + # --- Mock New Connection Attempt --- + mock_server_info_new = mock.Mock(spec=types.Implementation) + mock_server_info_new.name = "ServerWithDuplicate" + mock_session_new = mock.AsyncMock(spec=mcp.ClientSession) + + # Configure the new session to return a tool with the *same name* + duplicate_tool = mock.Mock(spec=types.Tool) + duplicate_tool.name = existing_tool_name + mock_session_new.list_tools.return_value = mock.AsyncMock( + tools=[duplicate_tool] + ) + # Keep other lists empty for simplicity + mock_session_new.list_resources.return_value = mock.AsyncMock(resources=[]) + mock_session_new.list_prompts.return_value = mock.AsyncMock(prompts=[]) + + # --- Test Execution and Assertion --- + with pytest.raises(McpError) as excinfo: + with mock.patch.object( + group, + "_establish_session", + return_value=(mock_server_info_new, mock_session_new), + ): + await group.connect_to_server(StdioServerParameters(command="test")) + + # Assert details about the raised error + assert excinfo.value.error.code == types.INVALID_PARAMS + assert existing_tool_name in excinfo.value.error.message + assert "already exist " in excinfo.value.error.message + + # Verify the duplicate tool was *not* added again (state should be unchanged) + assert len(group._tools) == 1 # Should still only have the original + assert ( + group._tools[existing_tool_name] is not duplicate_tool + ) # Ensure it's the original mock + + # No patching needed here + async def test_disconnect_non_existent_server(self): + """Test disconnecting a server that isn't connected.""" + session = mock.Mock(spec=mcp.ClientSession) + group = ClientSessionGroup() + with pytest.raises(McpError): + await group.disconnect_from_server(session) + + @pytest.mark.parametrize( + "server_params_instance, client_type_name, patch_target_for_client_func", + [ + ( + StdioServerParameters(command="test_stdio_cmd"), + "stdio", + "mcp.client.session_group.mcp.stdio_client", + ), + ( + SseServerParameters(url="http://test.com/sse", timeout=10), + "sse", + "mcp.client.session_group.sse_client", + ), # url, headers, timeout, sse_read_timeout + ( + StreamableHttpParameters( + url="http://test.com/stream", terminate_on_close=False + ), + "streamablehttp", + "mcp.client.session_group.streamablehttp_client", + ), # url, headers, timeout, sse_read_timeout, terminate_on_close + ], + ) + async def test_establish_session_parameterized( + self, + server_params_instance, + client_type_name, # Just for clarity or conditional logic if needed + patch_target_for_client_func, + ): + with mock.patch( + "mcp.client.session_group.mcp.ClientSession" + ) as mock_ClientSession_class: + with mock.patch(patch_target_for_client_func) as mock_specific_client_func: + mock_client_cm_instance = mock.AsyncMock( + name=f"{client_type_name}ClientCM" + ) + mock_read_stream = mock.AsyncMock(name=f"{client_type_name}Read") + mock_write_stream = mock.AsyncMock(name=f"{client_type_name}Write") + + # streamablehttp_client's __aenter__ returns three values + if client_type_name == "streamablehttp": + mock_extra_stream_val = mock.AsyncMock(name="StreamableExtra") + mock_client_cm_instance.__aenter__.return_value = ( + mock_read_stream, + mock_write_stream, + mock_extra_stream_val, + ) + else: + mock_client_cm_instance.__aenter__.return_value = ( + mock_read_stream, + mock_write_stream, + ) + + mock_client_cm_instance.__aexit__ = mock.AsyncMock(return_value=None) + mock_specific_client_func.return_value = mock_client_cm_instance + + # --- Mock mcp.ClientSession (class) --- + # mock_ClientSession_class is already provided by the outer patch + mock_raw_session_cm = mock.AsyncMock(name="RawSessionCM") + mock_ClientSession_class.return_value = mock_raw_session_cm + + mock_entered_session = mock.AsyncMock(name="EnteredSessionInstance") + mock_raw_session_cm.__aenter__.return_value = mock_entered_session + mock_raw_session_cm.__aexit__ = mock.AsyncMock(return_value=None) + + # Mock session.initialize() + mock_initialize_result = mock.AsyncMock(name="InitializeResult") + mock_initialize_result.serverInfo = types.Implementation( + name="foo", version="1" + ) + mock_entered_session.initialize.return_value = mock_initialize_result + + # --- Test Execution --- + group = ClientSessionGroup() + returned_server_info = None + returned_session = None + + async with contextlib.AsyncExitStack() as stack: + group._exit_stack = stack + ( + returned_server_info, + returned_session, + ) = await group._establish_session(server_params_instance) + + # --- Assertions --- + # 1. Assert the correct specific client function was called + if client_type_name == "stdio": + mock_specific_client_func.assert_called_once_with( + server_params_instance + ) + elif client_type_name == "sse": + mock_specific_client_func.assert_called_once_with( + url=server_params_instance.url, + headers=server_params_instance.headers, + timeout=server_params_instance.timeout, + sse_read_timeout=server_params_instance.sse_read_timeout, + ) + elif client_type_name == "streamablehttp": + mock_specific_client_func.assert_called_once_with( + url=server_params_instance.url, + headers=server_params_instance.headers, + timeout=server_params_instance.timeout, + sse_read_timeout=server_params_instance.sse_read_timeout, + terminate_on_close=server_params_instance.terminate_on_close, + ) + + mock_client_cm_instance.__aenter__.assert_awaited_once() + + # 2. Assert ClientSession was called correctly + mock_ClientSession_class.assert_called_once_with( + mock_read_stream, mock_write_stream + ) + mock_raw_session_cm.__aenter__.assert_awaited_once() + mock_entered_session.initialize.assert_awaited_once() + + # 3. Assert returned values + assert returned_server_info is mock_initialize_result.serverInfo + assert returned_session is mock_entered_session diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index 523ba199a..33d90e769 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -9,6 +9,13 @@ tee: str = shutil.which("tee") # type: ignore +@pytest.mark.anyio +@pytest.mark.skipif(tee is None, reason="could not find tee command") +async def test_stdio_context_manager_exiting(): + async with stdio_client(StdioServerParameters(command=tee)) as (_, _): + pass + + @pytest.mark.anyio @pytest.mark.skipif(tee is None, reason="could not find tee command") async def test_stdio_client(): diff --git a/tests/issues/test_176_progress_token.py b/tests/issues/test_176_progress_token.py index 7f9131a1e..4ad22f294 100644 --- a/tests/issues/test_176_progress_token.py +++ b/tests/issues/test_176_progress_token.py @@ -39,11 +39,11 @@ async def test_progress_token_zero_first_call(): mock_session.send_progress_notification.call_count == 3 ), "All progress notifications should be sent" mock_session.send_progress_notification.assert_any_call( - progress_token=0, progress=0.0, total=10.0 + progress_token=0, progress=0.0, total=10.0, message=None ) mock_session.send_progress_notification.assert_any_call( - progress_token=0, progress=5.0, total=10.0 + progress_token=0, progress=5.0, total=10.0, message=None ) mock_session.send_progress_notification.assert_any_call( - progress_token=0, progress=10.0, total=10.0 + progress_token=0, progress=10.0, total=10.0, message=None ) diff --git a/tests/server/fastmcp/resources/test_function_resources.py b/tests/server/fastmcp/resources/test_function_resources.py index f0fe22bfb..f59436ae3 100644 --- a/tests/server/fastmcp/resources/test_function_resources.py +++ b/tests/server/fastmcp/resources/test_function_resources.py @@ -136,3 +136,22 @@ async def get_data() -> str: content = await resource.read() assert content == "Hello, world!" assert resource.mime_type == "text/plain" + + @pytest.mark.anyio + async def test_from_function(self): + """Test creating a FunctionResource from a function.""" + + async def get_data() -> str: + """get_data returns a string""" + return "Hello, world!" + + resource = FunctionResource.from_function( + fn=get_data, + uri="function://test", + name="test", + ) + + assert resource.description == "get_data returns a string" + assert resource.mime_type == "text/plain" + assert resource.name == "test" + assert resource.uri == AnyUrl("function://test") diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index 67911e9e7..79285ecb1 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -12,12 +12,25 @@ import pytest import uvicorn +from pydantic import AnyUrl +import mcp.types as types from mcp.client.session import ClientSession from mcp.client.sse import sse_client from mcp.client.streamable_http import streamablehttp_client from mcp.server.fastmcp import FastMCP -from mcp.types import InitializeResult, TextContent +from mcp.server.fastmcp.resources import FunctionResource +from mcp.shared.context import RequestContext +from mcp.types import ( + CreateMessageRequestParams, + CreateMessageResult, + GetPromptResult, + InitializeResult, + ReadResourceResult, + SamplingMessage, + TextContent, + TextResourceContents, +) @pytest.fixture @@ -80,6 +93,119 @@ def echo(message: str) -> str: return mcp, app +def make_everything_fastmcp() -> FastMCP: + """Create a FastMCP server with all features enabled for testing.""" + from mcp.server.fastmcp import Context + + mcp = FastMCP(name="EverythingServer") + + # Tool with context for logging and progress + @mcp.tool(description="A tool that demonstrates logging and progress") + async def tool_with_progress(message: str, ctx: Context, steps: int = 3) -> str: + await ctx.info(f"Starting processing of '{message}' with {steps} steps") + + # Send progress notifications + for i in range(steps): + progress_value = (i + 1) / steps + await ctx.report_progress( + progress=progress_value, + total=1.0, + message=f"Processing step {i + 1} of {steps}", + ) + await ctx.debug(f"Completed step {i + 1}") + + return f"Processed '{message}' in {steps} steps" + + # Simple tool for basic functionality + @mcp.tool(description="A simple echo tool") + def echo(message: str) -> str: + return f"Echo: {message}" + + # Tool with sampling capability + @mcp.tool(description="A tool that uses sampling to generate content") + async def sampling_tool(prompt: str, ctx: Context) -> str: + await ctx.info(f"Requesting sampling for prompt: {prompt}") + + # Request sampling from the client + result = await ctx.session.create_message( + messages=[ + SamplingMessage( + role="user", content=TextContent(type="text", text=prompt) + ) + ], + max_tokens=100, + temperature=0.7, + ) + + await ctx.info(f"Received sampling result from model: {result.model}") + # Handle different content types + if result.content.type == "text": + return f"Sampling result: {result.content.text[:100]}..." + else: + return f"Sampling result: {str(result.content)[:100]}..." + + # Tool that sends notifications and logging + @mcp.tool(description="A tool that demonstrates notifications and logging") + async def notification_tool(message: str, ctx: Context) -> str: + # Send different log levels + await ctx.debug("Debug: Starting notification tool") + await ctx.info(f"Info: Processing message '{message}'") + await ctx.warning("Warning: This is a test warning") + + # Send resource change notifications + await ctx.session.send_resource_list_changed() + await ctx.session.send_tool_list_changed() + + await ctx.info("Completed notification tool successfully") + return f"Sent notifications and logs for: {message}" + + # Resource - static + def get_static_info() -> str: + return "This is static resource content" + + static_resource = FunctionResource( + uri=AnyUrl("resource://static/info"), + name="Static Info", + description="Static information resource", + fn=get_static_info, + ) + mcp.add_resource(static_resource) + + # Resource - dynamic function + @mcp.resource("resource://dynamic/{category}") + def dynamic_resource(category: str) -> str: + return f"Dynamic resource content for category: {category}" + + # Resource template + @mcp.resource("resource://template/{id}/data") + def template_resource(id: str) -> str: + return f"Template resource data for ID: {id}" + + # Prompt - simple + @mcp.prompt(description="A simple prompt") + def simple_prompt(topic: str) -> str: + return f"Tell me about {topic}" + + # Prompt - complex with multiple messages + @mcp.prompt(description="Complex prompt with context") + def complex_prompt(user_query: str, context: str = "general") -> str: + # For simplicity, return a single string that incorporates the context + # Since FastMCP doesn't support system messages in the same way + return f"Context: {context}. Query: {user_query}" + + return mcp + + +def make_everything_fastmcp_app(): + """Create a comprehensive FastMCP server with SSE transport.""" + from starlette.applications import Starlette + + mcp = make_everything_fastmcp() + # Create the SSE app + app: Starlette = mcp.sse_app() + return mcp, app + + def make_fastmcp_streamable_http_app(): """Create a FastMCP server with StreamableHTTP transport.""" from starlette.applications import Starlette @@ -97,6 +223,18 @@ def echo(message: str) -> str: return mcp, app +def make_everything_fastmcp_streamable_http_app(): + """Create a comprehensive FastMCP server with StreamableHTTP transport.""" + from starlette.applications import Starlette + + # Create a new instance with different name for HTTP transport + mcp = make_everything_fastmcp() + # We can't change the name after creation, so we'll use the same name + # Create the StreamableHTTP app + app: Starlette = mcp.streamable_http_app() + return mcp, app + + def make_fastmcp_stateless_http_app(): """Create a FastMCP server with stateless StreamableHTTP transport.""" from starlette.applications import Starlette @@ -126,6 +264,18 @@ def run_server(server_port: int) -> None: server.run() +def run_everything_legacy_sse_http_server(server_port: int) -> None: + """Run the comprehensive server with all features.""" + _, app = make_everything_fastmcp_app() + server = uvicorn.Server( + config=uvicorn.Config( + app=app, host="127.0.0.1", port=server_port, log_level="error" + ) + ) + print(f"Starting comprehensive server on port {server_port}") + server.run() + + def run_streamable_http_server(server_port: int) -> None: """Run the StreamableHTTP server.""" _, app = make_fastmcp_streamable_http_app() @@ -138,6 +288,18 @@ def run_streamable_http_server(server_port: int) -> None: server.run() +def run_everything_server(server_port: int) -> None: + """Run the comprehensive StreamableHTTP server with all features.""" + _, app = make_everything_fastmcp_streamable_http_app() + server = uvicorn.Server( + config=uvicorn.Config( + app=app, host="127.0.0.1", port=server_port, log_level="error" + ) + ) + print(f"Starting comprehensive StreamableHTTP server on port {server_port}") + server.run() + + def run_stateless_http_server(server_port: int) -> None: """Run the stateless StreamableHTTP server.""" _, app = make_fastmcp_stateless_http_app() @@ -323,3 +485,400 @@ async def test_fastmcp_stateless_streamable_http( assert len(tool_result.content) == 1 assert isinstance(tool_result.content[0], TextContent) assert tool_result.content[0].text == f"Echo: test_{i}" + + +@pytest.fixture +def everything_server_port() -> int: + """Get a free port for testing the comprehensive server.""" + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@pytest.fixture +def everything_server_url(everything_server_port: int) -> str: + """Get the comprehensive server URL for testing.""" + return f"http://127.0.0.1:{everything_server_port}" + + +@pytest.fixture +def everything_http_server_port() -> int: + """Get a free port for testing the comprehensive StreamableHTTP server.""" + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@pytest.fixture +def everything_http_server_url(everything_http_server_port: int) -> str: + """Get the comprehensive StreamableHTTP server URL for testing.""" + return f"http://127.0.0.1:{everything_http_server_port}" + + +@pytest.fixture() +def everything_server(everything_server_port: int) -> Generator[None, None, None]: + """Start the comprehensive server in a separate process and clean up after.""" + proc = multiprocessing.Process( + target=run_everything_legacy_sse_http_server, + args=(everything_server_port,), + daemon=True, + ) + print("Starting comprehensive server process") + proc.start() + + # Wait for server to be running + max_attempts = 20 + attempt = 0 + print("Waiting for comprehensive server to start") + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("127.0.0.1", everything_server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError( + f"Comprehensive server failed to start after {max_attempts} attempts" + ) + + yield + + print("Killing comprehensive server") + proc.kill() + proc.join(timeout=2) + if proc.is_alive(): + print("Comprehensive server process failed to terminate") + + +@pytest.fixture() +def everything_streamable_http_server( + everything_http_server_port: int, +) -> Generator[None, None, None]: + """Start the comprehensive StreamableHTTP server in a separate process.""" + proc = multiprocessing.Process( + target=run_everything_server, + args=(everything_http_server_port,), + daemon=True, + ) + print("Starting comprehensive StreamableHTTP server process") + proc.start() + + # Wait for server to be running + max_attempts = 20 + attempt = 0 + print("Waiting for comprehensive StreamableHTTP server to start") + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("127.0.0.1", everything_http_server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError( + f"Comprehensive StreamableHTTP server failed to start after " + f"{max_attempts} attempts" + ) + + yield + + print("Killing comprehensive StreamableHTTP server") + proc.kill() + proc.join(timeout=2) + if proc.is_alive(): + print("Comprehensive StreamableHTTP server process failed to terminate") + + +class NotificationCollector: + def __init__(self): + self.progress_notifications: list = [] + self.log_messages: list = [] + self.resource_notifications: list = [] + self.tool_notifications: list = [] + + async def handle_progress(self, params) -> None: + self.progress_notifications.append(params) + + async def handle_log(self, params) -> None: + self.log_messages.append(params) + + async def handle_resource_list_changed(self, params) -> None: + self.resource_notifications.append(params) + + async def handle_tool_list_changed(self, params) -> None: + self.tool_notifications.append(params) + + async def handle_generic_notification(self, message) -> None: + # Check if this is a ServerNotification + if isinstance(message, types.ServerNotification): + # Check the specific notification type + if isinstance(message.root, types.ProgressNotification): + await self.handle_progress(message.root.params) + elif isinstance(message.root, types.LoggingMessageNotification): + await self.handle_log(message.root.params) + elif isinstance(message.root, types.ResourceListChangedNotification): + await self.handle_resource_list_changed(message.root.params) + elif isinstance(message.root, types.ToolListChangedNotification): + await self.handle_tool_list_changed(message.root.params) + + +async def call_all_mcp_features( + session: ClientSession, collector: NotificationCollector +) -> None: + """ + Test all MCP features using the provided session. + + Args: + session: The MCP client session to test with + collector: Notification collector for capturing server notifications + """ + # Test initialization + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == "EverythingServer" + + # Check server features are reported + assert result.capabilities.prompts is not None + assert result.capabilities.resources is not None + assert result.capabilities.tools is not None + # Note: logging capability may be None if no tools use context logging + + # Test tools + # 1. Simple echo tool + tool_result = await session.call_tool("echo", {"message": "hello"}) + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + assert tool_result.content[0].text == "Echo: hello" + + # 2. Tool with context (logging and progress) + # Test progress callback functionality + progress_updates = [] + + async def progress_callback( + progress: float, total: float | None, message: str | None + ) -> None: + """Collect progress updates for testing (async version).""" + progress_updates.append((progress, total, message)) + print(f"Progress: {progress}/{total} - {message}") + + test_message = "test" + steps = 3 + params = { + "message": test_message, + "steps": steps, + } + tool_result = await session.call_tool( + "tool_with_progress", + params, + progress_callback=progress_callback, + ) + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + assert f"Processed '{test_message}' in {steps} steps" in tool_result.content[0].text + + # Verify progress callback was called + assert len(progress_updates) == steps + for i, (progress, total, message) in enumerate(progress_updates): + expected_progress = (i + 1) / steps + assert abs(progress - expected_progress) < 0.01 + assert total == 1.0 + assert message is not None + assert f"step {i + 1} of {steps}" in message + + # Verify we received log messages from the tool + # Note: Progress notifications require special handling in the MCP client + # that's not implemented by default, so we focus on testing logging + assert len(collector.log_messages) > 0 + + # 3. Test sampling tool + prompt = "What is the meaning of life?" + sampling_result = await session.call_tool("sampling_tool", {"prompt": prompt}) + assert len(sampling_result.content) == 1 + assert isinstance(sampling_result.content[0], TextContent) + assert "Sampling result:" in sampling_result.content[0].text + assert "This is a simulated LLM response" in sampling_result.content[0].text + + # Verify we received log messages from the sampling tool + assert len(collector.log_messages) > 0 + assert any( + "Requesting sampling for prompt" in msg.data for msg in collector.log_messages + ) + assert any( + "Received sampling result from model" in msg.data + for msg in collector.log_messages + ) + + # 4. Test notification tool + notification_message = "test_notifications" + notification_result = await session.call_tool( + "notification_tool", {"message": notification_message} + ) + assert len(notification_result.content) == 1 + assert isinstance(notification_result.content[0], TextContent) + assert "Sent notifications and logs" in notification_result.content[0].text + + # Verify we received various notification types + assert len(collector.log_messages) > 3 # Should have logs from both tools + assert len(collector.resource_notifications) > 0 + assert len(collector.tool_notifications) > 0 + + # Check that we got different log levels + log_levels = [msg.level for msg in collector.log_messages] + assert "debug" in log_levels + assert "info" in log_levels + assert "warning" in log_levels + + # Test resources + # 1. Static resource + resources = await session.list_resources() + # Try using string comparison since AnyUrl might not match directly + static_resource = next( + (r for r in resources.resources if str(r.uri) == "resource://static/info"), + None, + ) + assert static_resource is not None + assert static_resource.name == "Static Info" + + static_content = await session.read_resource(AnyUrl("resource://static/info")) + assert isinstance(static_content, ReadResourceResult) + assert len(static_content.contents) == 1 + assert isinstance(static_content.contents[0], TextResourceContents) + assert static_content.contents[0].text == "This is static resource content" + + # 2. Dynamic resource + resource_category = "test" + dynamic_content = await session.read_resource( + AnyUrl(f"resource://dynamic/{resource_category}") + ) + assert isinstance(dynamic_content, ReadResourceResult) + assert len(dynamic_content.contents) == 1 + assert isinstance(dynamic_content.contents[0], TextResourceContents) + assert ( + f"Dynamic resource content for category: {resource_category}" + in dynamic_content.contents[0].text + ) + + # 3. Template resource + resource_id = "456" + template_content = await session.read_resource( + AnyUrl(f"resource://template/{resource_id}/data") + ) + assert isinstance(template_content, ReadResourceResult) + assert len(template_content.contents) == 1 + assert isinstance(template_content.contents[0], TextResourceContents) + assert ( + f"Template resource data for ID: {resource_id}" + in template_content.contents[0].text + ) + + # Test prompts + # 1. Simple prompt + prompts = await session.list_prompts() + simple_prompt = next( + (p for p in prompts.prompts if p.name == "simple_prompt"), None + ) + assert simple_prompt is not None + + prompt_topic = "AI" + prompt_result = await session.get_prompt("simple_prompt", {"topic": prompt_topic}) + assert isinstance(prompt_result, GetPromptResult) + assert len(prompt_result.messages) >= 1 + # The actual message structure depends on the prompt implementation + + # 2. Complex prompt + complex_prompt = next( + (p for p in prompts.prompts if p.name == "complex_prompt"), None + ) + assert complex_prompt is not None + + query = "What is AI?" + context = "technical" + complex_result = await session.get_prompt( + "complex_prompt", {"user_query": query, "context": context} + ) + assert isinstance(complex_result, GetPromptResult) + assert len(complex_result.messages) >= 1 + + +async def sampling_callback( + context: RequestContext[ClientSession, None], + params: CreateMessageRequestParams, +) -> CreateMessageResult: + # Simulate LLM response based on the input + if params.messages and isinstance(params.messages[0].content, TextContent): + input_text = params.messages[0].content.text + else: + input_text = "No input" + response_text = f"This is a simulated LLM response to: {input_text}" + + model_name = "test-llm-model" + return CreateMessageResult( + role="assistant", + content=TextContent(type="text", text=response_text), + model=model_name, + stopReason="endTurn", + ) + + +@pytest.mark.anyio +async def test_fastmcp_all_features_sse( + everything_server: None, everything_server_url: str +) -> None: + """Test all MCP features work correctly with SSE transport.""" + + # Create notification collector + collector = NotificationCollector() + + # Create a sampling callback that simulates an LLM + + # Connect to the server with callbacks + async with sse_client(everything_server_url + "/sse") as streams: + # Set up message handler to capture notifications + async def message_handler(message): + print(f"Received message: {message}") + await collector.handle_generic_notification(message) + if isinstance(message, Exception): + raise message + + async with ClientSession( + *streams, + sampling_callback=sampling_callback, + message_handler=message_handler, + ) as session: + # Run the common test suite + await call_all_mcp_features(session, collector) + + +@pytest.mark.anyio +async def test_fastmcp_all_features_streamable_http( + everything_streamable_http_server: None, everything_http_server_url: str +) -> None: + """Test all MCP features work correctly with StreamableHTTP transport.""" + + # Create notification collector + collector = NotificationCollector() + + # Connect to the server using StreamableHTTP + async with streamablehttp_client(everything_http_server_url + "/mcp") as ( + read_stream, + write_stream, + _, + ): + # Set up message handler to capture notifications + async def message_handler(message): + print(f"Received message: {message}") + await collector.handle_generic_notification(message) + if isinstance(message, Exception): + raise message + + async with ClientSession( + read_stream, + write_stream, + sampling_callback=sampling_callback, + message_handler=message_handler, + ) as session: + # Run the common test suite with HTTP-specific test suffix + await call_all_mcp_features(session, collector) diff --git a/tests/server/fastmcp/test_server.py b/tests/server/fastmcp/test_server.py index 64700d959..b817761ea 100644 --- a/tests/server/fastmcp/test_server.py +++ b/tests/server/fastmcp/test_server.py @@ -441,6 +441,24 @@ async def test_file_resource_binary(self, tmp_path: Path): == base64.b64encode(b"Binary file data").decode() ) + @pytest.mark.anyio + async def test_function_resource(self): + mcp = FastMCP() + + @mcp.resource("function://test", name="test_get_data") + def get_data() -> str: + """get_data returns a string""" + return "Hello, world!" + + async with client_session(mcp._mcp_server) as client: + resources = await client.list_resources() + assert len(resources.resources) == 1 + resource = resources.resources[0] + assert resource.description == "get_data returns a string" + assert resource.uri == AnyUrl("function://test") + assert resource.name == "test_get_data" + assert resource.mimeType == "text/plain" + class TestServerResourceTemplates: @pytest.mark.anyio diff --git a/tests/server/test_session.py b/tests/server/test_session.py index f2f033588..1375df12f 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -106,3 +106,97 @@ async def list_resources(): caps = server.get_capabilities(notification_options, experimental_capabilities) assert caps.prompts == PromptsCapability(listChanged=False) assert caps.resources == ResourcesCapability(subscribe=False, listChanged=False) + + +@pytest.mark.anyio +async def test_server_session_initialize_with_older_protocol_version(): + """Test that server accepts and responds with older protocol (2024-11-05).""" + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ + SessionMessage + ](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ + SessionMessage | Exception + ](1) + + received_initialized = False + received_protocol_version = None + + async def run_server(): + nonlocal received_initialized + + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="mcp", + server_version="0.1.0", + capabilities=ServerCapabilities(), + ), + ) as server_session: + async for message in server_session.incoming_messages: + if isinstance(message, Exception): + raise message + + if isinstance(message, types.ClientNotification) and isinstance( + message.root, InitializedNotification + ): + received_initialized = True + return + + async def mock_client(): + nonlocal received_protocol_version + + # Send initialization request with older protocol version (2024-11-05) + await client_to_server_send.send( + SessionMessage( + types.JSONRPCMessage( + types.JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params=types.InitializeRequestParams( + protocolVersion="2024-11-05", + capabilities=types.ClientCapabilities(), + clientInfo=types.Implementation( + name="test-client", version="1.0.0" + ), + ).model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + + # Wait for the initialize response + init_response_message = await server_to_client_receive.receive() + assert isinstance(init_response_message.message.root, types.JSONRPCResponse) + result_data = init_response_message.message.root.result + init_result = types.InitializeResult.model_validate(result_data) + + # Check that the server responded with the requested protocol version + received_protocol_version = init_result.protocolVersion + assert received_protocol_version == "2024-11-05" + + # Send initialized notification + await client_to_server_send.send( + SessionMessage( + types.JSONRPCMessage( + types.JSONRPCNotification( + jsonrpc="2.0", + method="notifications/initialized", + ) + ) + ) + ) + + async with ( + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + anyio.create_task_group() as tg, + ): + tg.start_soon(run_server) + tg.start_soon(mock_client) + + assert received_initialized + assert received_protocol_version == "2024-11-05" diff --git a/tests/shared/test_progress_notifications.py b/tests/shared/test_progress_notifications.py new file mode 100644 index 000000000..1e0409e14 --- /dev/null +++ b/tests/shared/test_progress_notifications.py @@ -0,0 +1,349 @@ +from typing import Any, cast + +import anyio +import pytest + +import mcp.types as types +from mcp.client.session import ClientSession +from mcp.server import Server +from mcp.server.lowlevel import NotificationOptions +from mcp.server.models import InitializationOptions +from mcp.server.session import ServerSession +from mcp.shared.context import RequestContext +from mcp.shared.progress import progress +from mcp.shared.session import ( + BaseSession, + RequestResponder, + SessionMessage, +) + + +@pytest.mark.anyio +async def test_bidirectional_progress_notifications(): + """Test that both client and server can send progress notifications.""" + # Create memory streams for client/server + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ + SessionMessage + ](5) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ + SessionMessage + ](5) + + # Run a server session so we can send progress updates in tool + async def run_server(): + # Create a server session + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="ProgressTestServer", + server_version="0.1.0", + capabilities=server.get_capabilities(NotificationOptions(), {}), + ), + ) as server_session: + global serv_sesh + + serv_sesh = server_session + async for message in server_session.incoming_messages: + try: + await server._handle_message(message, server_session, ()) + except Exception as e: + raise e + + # Track progress updates + server_progress_updates = [] + client_progress_updates = [] + + # Progress tokens + server_progress_token = "server_token_123" + client_progress_token = "client_token_456" + + # Create a server with progress capability + server = Server(name="ProgressTestServer") + + # Register progress handler + @server.progress_notification() + async def handle_progress( + progress_token: str | int, + progress: float, + total: float | None, + message: str | None, + ): + server_progress_updates.append( + { + "token": progress_token, + "progress": progress, + "total": total, + "message": message, + } + ) + + # Register list tool handler + @server.list_tools() + async def handle_list_tools() -> list[types.Tool]: + return [ + types.Tool( + name="test_tool", + description="A tool that sends progress notifications list: + # Make sure we received a progress token + if name == "test_tool": + if arguments and "_meta" in arguments: + progressToken = arguments["_meta"]["progressToken"] + + if not progressToken: + raise ValueError("Empty progress token received") + + if progressToken != client_progress_token: + raise ValueError("Server sending back incorrect progressToken") + + # Send progress notifications + await serv_sesh.send_progress_notification( + progress_token=progressToken, + progress=0.25, + total=1.0, + message="Server progress 25%", + ) + + await serv_sesh.send_progress_notification( + progress_token=progressToken, + progress=0.5, + total=1.0, + message="Server progress 50%", + ) + + await serv_sesh.send_progress_notification( + progress_token=progressToken, + progress=1.0, + total=1.0, + message="Server progress 100%", + ) + + else: + raise ValueError("Progress token not sent.") + + return ["Tool executed successfully"] + + raise ValueError(f"Unknown tool: {name}") + + # Client message handler to store progress notifications + async def handle_client_message( + message: RequestResponder[types.ServerRequest, types.ClientResult] + | types.ServerNotification + | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + if isinstance(message, types.ServerNotification): + if isinstance(message.root, types.ProgressNotification): + params = message.root.params + client_progress_updates.append( + { + "token": params.progressToken, + "progress": params.progress, + "total": params.total, + "message": params.message, + } + ) + + # Test using client + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=handle_client_message, + ) as client_session, + anyio.create_task_group() as tg, + ): + # Start the server in a background task + tg.start_soon(run_server) + + # Initialize the client connection + await client_session.initialize() + + # Call list_tools with progress token + await client_session.list_tools() + + # Call test_tool with progress token + await client_session.call_tool( + "test_tool", {"_meta": {"progressToken": client_progress_token}} + ) + + # Send progress notifications from client to server + await client_session.send_progress_notification( + progress_token=server_progress_token, + progress=0.33, + total=1.0, + message="Client progress 33%", + ) + + await client_session.send_progress_notification( + progress_token=server_progress_token, + progress=0.66, + total=1.0, + message="Client progress 66%", + ) + + await client_session.send_progress_notification( + progress_token=server_progress_token, + progress=1.0, + total=1.0, + message="Client progress 100%", + ) + + # Wait and exit + await anyio.sleep(0.5) + tg.cancel_scope.cancel() + + # Verify client received progress updates from server + assert len(client_progress_updates) == 3 + assert client_progress_updates[0]["token"] == client_progress_token + assert client_progress_updates[0]["progress"] == 0.25 + assert client_progress_updates[0]["message"] == "Server progress 25%" + assert client_progress_updates[2]["progress"] == 1.0 + + # Verify server received progress updates from client + assert len(server_progress_updates) == 3 + assert server_progress_updates[0]["token"] == server_progress_token + assert server_progress_updates[0]["progress"] == 0.33 + assert server_progress_updates[0]["message"] == "Client progress 33%" + assert server_progress_updates[2]["progress"] == 1.0 + + +@pytest.mark.anyio +async def test_progress_context_manager(): + """Test client using progress context manager for sending progress notifications.""" + # Create memory streams for client/server + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ + SessionMessage + ](5) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ + SessionMessage + ](5) + + # Track progress updates + server_progress_updates = [] + + server = Server(name="ProgressContextTestServer") + + # Register progress handler + @server.progress_notification() + async def handle_progress( + progress_token: str | int, + progress: float, + total: float | None, + message: str | None, + ): + server_progress_updates.append( + { + "token": progress_token, + "progress": progress, + "total": total, + "message": message, + } + ) + + # Run server session to receive progress updates + async def run_server(): + # Create a server session + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="ProgressContextTestServer", + server_version="0.1.0", + capabilities=server.get_capabilities(NotificationOptions(), {}), + ), + ) as server_session: + async for message in server_session.incoming_messages: + try: + await server._handle_message(message, server_session, ()) + except Exception as e: + raise e + + # Client message handler + async def handle_client_message( + message: RequestResponder[types.ServerRequest, types.ClientResult] + | types.ServerNotification + | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + # run client session + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=handle_client_message, + ) as client_session, + anyio.create_task_group() as tg, + ): + tg.start_soon(run_server) + + await client_session.initialize() + + progress_token = "client_token_456" + + # Create request context + meta = types.RequestParams.Meta(progressToken=progress_token) + request_context = RequestContext( + request_id="test-request", + session=client_session, + meta=meta, + lifespan_context=None, + ) + + # cast for type checker + typed_context = cast( + RequestContext[ + BaseSession[Any, Any, Any, Any, Any], + Any, + ], + request_context, + ) + + # Utilize progress context manager + with progress(typed_context, total=100) as p: + await p.progress(10, message="Loading configuration...") + await p.progress(30, message="Connecting to database...") + await p.progress(40, message="Fetching data...") + await p.progress(20, message="Processing results...") + + # Wait for all messages to be processed + await anyio.sleep(0.5) + tg.cancel_scope.cancel() + + # Verify progress updates were received by server + assert len(server_progress_updates) == 4 + + # first update + assert server_progress_updates[0]["token"] == progress_token + assert server_progress_updates[0]["progress"] == 10 + assert server_progress_updates[0]["total"] == 100 + assert server_progress_updates[0]["message"] == "Loading configuration..." + + # second update + assert server_progress_updates[1]["token"] == progress_token + assert server_progress_updates[1]["progress"] == 40 + assert server_progress_updates[1]["total"] == 100 + assert server_progress_updates[1]["message"] == "Connecting to database..." + + # third update + assert server_progress_updates[2]["token"] == progress_token + assert server_progress_updates[2]["progress"] == 80 + assert server_progress_updates[2]["total"] == 100 + assert server_progress_updates[2]["message"] == "Fetching data..." + + # final update + assert server_progress_updates[3]["token"] == progress_token + assert server_progress_updates[3]["progress"] == 100 + assert server_progress_updates[3]["total"] == 100 + assert server_progress_updates[3]["message"] == "Processing results..." diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 4558bb88c..e55983e01 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -252,3 +252,69 @@ async def test_sse_client_timeout( return pytest.fail("the client should have timed out and returned an error already") + + +def run_mounted_server(server_port: int) -> None: + app = make_server_app() + main_app = Starlette(routes=[Mount("/mounted_app", app=app)]) + server = uvicorn.Server( + config=uvicorn.Config( + app=main_app, host="127.0.0.1", port=server_port, log_level="error" + ) + ) + print(f"starting server on {server_port}") + server.run() + + # Give server time to start + while not server.started: + print("waiting for server to start") + time.sleep(0.5) + + +@pytest.fixture() +def mounted_server(server_port: int) -> Generator[None, None, None]: + proc = multiprocessing.Process( + target=run_mounted_server, kwargs={"server_port": server_port}, daemon=True + ) + print("starting process") + proc.start() + + # Wait for server to be running + max_attempts = 20 + attempt = 0 + print("waiting for server to start") + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("127.0.0.1", server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError(f"Server failed to start after {max_attempts} attempts") + + yield + + print("killing server") + # Signal the server to stop + proc.kill() + proc.join(timeout=2) + if proc.is_alive(): + print("server process failed to terminate") + + +@pytest.mark.anyio +async def test_sse_client_basic_connection_mounted_app( + mounted_server: None, server_url: str +) -> None: + async with sse_client(server_url + "/mounted_app/sse") as streams: + async with ClientSession(*streams) as session: + # Test initialization + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == SERVER_NAME + + # Test ping + ping_result = await session.send_ping() + assert isinstance(ping_result, EmptyResult) diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 9b32254a9..f1c7ef809 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -960,6 +960,72 @@ async def test_streamablehttp_client_session_termination( await session.list_tools() +@pytest.mark.anyio +async def test_streamablehttp_client_session_termination_204( + basic_server, basic_server_url, monkeypatch +): + """Test client session termination functionality with a 204 response. + + This test patches the httpx client to return a 204 response for DELETEs. + """ + + # Save the original delete method to restore later + original_delete = httpx.AsyncClient.delete + + # Mock the client's delete method to return a 204 + async def mock_delete(self, *args, **kwargs): + # Call the original method to get the real response + response = await original_delete(self, *args, **kwargs) + + # Create a new response with 204 status code but same headers + mocked_response = httpx.Response( + 204, + headers=response.headers, + content=response.content, + request=response.request, + ) + return mocked_response + + # Apply the patch to the httpx client + monkeypatch.setattr(httpx.AsyncClient, "delete", mock_delete) + + captured_session_id = None + + # Create the streamablehttp_client with a custom httpx client to capture headers + async with streamablehttp_client(f"{basic_server_url}/mcp") as ( + read_stream, + write_stream, + get_session_id, + ): + async with ClientSession(read_stream, write_stream) as session: + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + captured_session_id = get_session_id() + assert captured_session_id is not None + + # Make a request to confirm session is working + tools = await session.list_tools() + assert len(tools.tools) == 4 + + headers = {} + if captured_session_id: + headers[MCP_SESSION_ID_HEADER] = captured_session_id + + async with streamablehttp_client(f"{basic_server_url}/mcp", headers=headers) as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: + # Attempt to make a request after termination + with pytest.raises( + McpError, + match="Session terminated", + ): + await session.list_tools() + + @pytest.mark.anyio async def test_streamablehttp_client_resumption(event_server): """Test client session to resume a long running tool.""" diff --git a/tests/test_examples.py b/tests/test_examples.py index c5e8ec9d7..b2fff1a91 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -1,5 +1,7 @@ """Tests for example servers""" +import sys + import pytest from pytest_examples import CodeExample, EvalExample, find_examples @@ -69,8 +71,15 @@ async def test_desktop(monkeypatch): content = result.contents[0] assert isinstance(content, TextResourceContents) assert isinstance(content.text, str) - assert "/fake/path/file1.txt" in content.text - assert "/fake/path/file2.txt" in content.text + if sys.platform == "win32": + file_1 = "/fake/path/file1.txt".replace("/", "\\\\") # might be a bug + file_2 = "/fake/path/file2.txt".replace("/", "\\\\") # might be a bug + assert file_1 in content.text + assert file_2 in content.text + # might be a bug, but the test is passing + else: + assert "/fake/path/file1.txt" in content.text + assert "/fake/path/file2.txt" in content.text @pytest.mark.parametrize("example", find_examples("README.md"), ids=str)