diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs index e69de29bb..57cba171f 100644 --- a/.git-blame-ignore-revs +++ b/.git-blame-ignore-revs @@ -0,0 +1,2 @@ +# Applied 120 line-length rule to all files: https://github.com/modelcontextprotocol/python-sdk/pull/856 +543961968c0634e93d919d509cce23a1d6a56c21 diff --git a/.github/ISSUE_TEMPLATE/bug.yaml b/.github/ISSUE_TEMPLATE/bug.yaml new file mode 100644 index 000000000..e52277a2a --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug.yaml @@ -0,0 +1,55 @@ +name: πŸ› MCP Python SDK Bug +description: Report a bug or unexpected behavior in the MCP Python SDK +labels: ["need confirmation"] + +body: + - type: markdown + attributes: + value: Thank you for contributing to the MCP Python SDK! ✊ + + - type: checkboxes + id: checks + attributes: + label: Initial Checks + description: Just making sure you're using the latest version of MCP Python SDK. + options: + - label: I confirm that I'm using the latest version of MCP Python SDK + required: true + - label: I confirm that I searched for my issue in https://github.com/modelcontextprotocol/python-sdk/issues before opening this issue + required: true + + - type: textarea + id: description + attributes: + label: Description + description: | + Please explain what you're seeing and what you would expect to see. + + Please provide as much detail as possible to make understanding and solving your problem as quick as possible. πŸ™ + validations: + required: true + + - type: textarea + id: example + attributes: + label: Example Code + description: > + If applicable, please add a self-contained, + [minimal, reproducible, example](https://stackoverflow.com/help/minimal-reproducible-example) + demonstrating the bug. + + placeholder: | + from mcp.server.fastmcp import FastMCP + + ... + render: Python + + - type: textarea + id: version + attributes: + label: Python & MCP Python SDK + description: | + Which version of Python and MCP Python SDK are you using? + render: Text + validations: + required: true diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md deleted file mode 100644 index dd84ea782..000000000 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ /dev/null @@ -1,38 +0,0 @@ ---- -name: Bug report -about: Create a report to help us improve -title: '' -labels: '' -assignees: '' - ---- - -**Describe the bug** -A clear and concise description of what the bug is. - -**To Reproduce** -Steps to reproduce the behavior: -1. Go to '...' -2. Click on '....' -3. Scroll down to '....' -4. See error - -**Expected behavior** -A clear and concise description of what you expected to happen. - -**Screenshots** -If applicable, add screenshots to help explain your problem. - -**Desktop (please complete the following information):** - - OS: [e.g. iOS] - - Browser [e.g. chrome, safari] - - Version [e.g. 22] - -**Smartphone (please complete the following information):** - - Device: [e.g. iPhone6] - - OS: [e.g. iOS8.1] - - Browser [e.g. stock browser, safari] - - Version [e.g. 22] - -**Additional context** -Add any other context about the problem here. diff --git a/.github/ISSUE_TEMPLATE/config.yaml b/.github/ISSUE_TEMPLATE/config.yaml new file mode 100644 index 000000000..0086358db --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yaml @@ -0,0 +1 @@ +blank_issues_enabled: true diff --git a/.github/ISSUE_TEMPLATE/feature-request.yaml b/.github/ISSUE_TEMPLATE/feature-request.yaml new file mode 100644 index 000000000..bec9b77b1 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature-request.yaml @@ -0,0 +1,29 @@ +name: πŸš€ MCP Python SDK Feature Request +description: "Suggest a new feature for the MCP Python SDK" +labels: ["feature request"] + +body: + - type: markdown + attributes: + value: Thank you for contributing to the MCP Python SDK! ✊ + + - type: textarea + id: description + attributes: + label: Description + description: | + Please give as much detail as possible about the feature you would like to suggest. πŸ™ + + You might like to add: + * A demo of how code might look when using the feature + * Your use case(s) for the feature + * Reference to other projects that have a similar feature + validations: + required: true + + - type: textarea + id: references + attributes: + label: References + description: | + Please add any links or references that might help us understand your feature request better. πŸ“š diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md deleted file mode 100644 index bbcbbe7d6..000000000 --- a/.github/ISSUE_TEMPLATE/feature_request.md +++ /dev/null @@ -1,20 +0,0 @@ ---- -name: Feature request -about: Suggest an idea for this project -title: '' -labels: '' -assignees: '' - ---- - -**Is your feature request related to a problem? Please describe.** -A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] - -**Describe the solution you'd like** -A clear and concise description of what you want to happen. - -**Describe alternatives you've considered** -A clear and concise description of any alternative solutions or features you've considered. - -**Additional context** -Add any other context or screenshots about the feature request here. diff --git a/.github/ISSUE_TEMPLATE/question.yaml b/.github/ISSUE_TEMPLATE/question.yaml new file mode 100644 index 000000000..87a7894f1 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/question.yaml @@ -0,0 +1,33 @@ +name: ❓ MCP Python SDK Question +description: "Ask a question about the MCP Python SDK" +labels: ["question"] + +body: + - type: markdown + attributes: + value: Thank you for reaching out to the MCP Python SDK community! We're here to help! 🀝 + + - type: textarea + id: question + attributes: + label: Question + description: | + Please provide as much detail as possible about your question. πŸ™ + + You might like to include: + * Code snippets showing what you've tried + * Error messages you're encountering (if any) + * Expected vs actual behavior + * Your use case and what you're trying to achieve + validations: + required: true + + - type: textarea + id: context + attributes: + label: Additional Context + description: | + Please provide any additional context that might help us better understand your question, such as: + * Your MCP Python SDK version + * Your Python version + * Relevant configuration or environment details πŸ“ diff --git a/.github/workflows/check-lock.yml b/.github/workflows/check-lock.yml deleted file mode 100644 index 805b0f3cc..000000000 --- a/.github/workflows/check-lock.yml +++ /dev/null @@ -1,25 +0,0 @@ -name: Check uv.lock - -on: - pull_request: - paths: - - "pyproject.toml" - - "uv.lock" - push: - paths: - - "pyproject.toml" - - "uv.lock" - -jobs: - check-lock: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - - name: Install uv - run: | - curl -LsSf https://astral.sh/uv/install.sh | sh - echo "$HOME/.cargo/bin" >> $GITHUB_PATH - - - name: Check uv.lock is up to date - run: uv lock --check diff --git a/.github/workflows/shared.yml b/.github/workflows/shared.yml index e3fbe73bf..499871ca1 100644 --- a/.github/workflows/shared.yml +++ b/.github/workflows/shared.yml @@ -4,42 +4,28 @@ on: workflow_call: jobs: - format: + pre-commit: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - name: Install uv - uses: astral-sh/setup-uv@v3 + - uses: astral-sh/setup-uv@v5 with: enable-cache: true version: 0.7.2 - - name: Install the project - run: uv sync --frozen --all-extras --dev --python 3.12 - - - name: Run ruff format check - run: uv run --no-sync ruff check . - - typecheck: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 + - name: Install dependencies + run: uv sync --frozen --all-extras --python 3.10 - - name: Install uv - uses: astral-sh/setup-uv@v3 + - uses: pre-commit/action@v3.0.0 with: - enable-cache: true - version: 0.7.2 - - - name: Install the project - run: uv sync --frozen --all-extras --dev --python 3.12 - - - name: Run pyright - run: uv run --no-sync pyright + extra_args: --all-files --verbose + env: + SKIP: no-commit-to-branch test: runs-on: ${{ matrix.os }} + timeout-minutes: 10 strategy: matrix: python-version: ["3.10", "3.11", "3.12", "3.13"] @@ -55,7 +41,7 @@ jobs: version: 0.7.2 - name: Install the project - run: uv sync --frozen --all-extras --dev --python ${{ matrix.python-version }} + run: uv sync --frozen --all-extras --python ${{ matrix.python-version }} - name: Run pytest run: uv run --no-sync pytest diff --git a/README.md b/README.md index d76d3d267..d8a2db2b6 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,9 @@ - [Prompts](#prompts) - [Images](#images) - [Context](#context) + - [Completions](#completions) + - [Elicitation](#elicitation) + - [Authentication](#authentication) - [Running Your Server](#running-your-server) - [Development Mode](#development-mode) - [Claude Desktop Integration](#claude-desktop-integration) @@ -73,7 +76,7 @@ The Model Context Protocol allows applications to provide context for LLMs in a ### Adding MCP to your python project -We recommend using [uv](https://docs.astral.sh/uv/) to manage your Python projects. +We recommend using [uv](https://docs.astral.sh/uv/) to manage your Python projects. If you haven't created a uv-managed project yet, create one: @@ -209,13 +212,13 @@ from mcp.server.fastmcp import FastMCP mcp = FastMCP("My App") -@mcp.resource("config://app") +@mcp.resource("config://app", title="Application Configuration") def get_config() -> str: """Static configuration data""" return "App configuration here" -@mcp.resource("users://{user_id}/profile") +@mcp.resource("users://{user_id}/profile", title="User Profile") def get_user_profile(user_id: str) -> str: """Dynamic user data""" return f"Profile data for user {user_id}" @@ -232,13 +235,13 @@ from mcp.server.fastmcp import FastMCP mcp = FastMCP("My App") -@mcp.tool() +@mcp.tool(title="BMI Calculator") def calculate_bmi(weight_kg: float, height_m: float) -> float: """Calculate BMI given weight in kg and height in meters""" return weight_kg / (height_m**2) -@mcp.tool() +@mcp.tool(title="Weather Fetcher") async def fetch_weather(city: str) -> str: """Fetch current weather for a city""" async with httpx.AsyncClient() as client: @@ -257,12 +260,12 @@ from mcp.server.fastmcp.prompts import base mcp = FastMCP("My App") -@mcp.prompt() +@mcp.prompt(title="Code Review") def review_code(code: str) -> str: return f"Please review this code:\n\n{code}" -@mcp.prompt() +@mcp.prompt(title="Debug Assistant") def debug_error(error: str) -> list[base.Message]: return [ base.UserMessage("I'm seeing this error:"), @@ -310,6 +313,112 @@ async def long_task(files: list[str], ctx: Context) -> str: return "Processing complete" ``` +### Completions + +MCP supports providing completion suggestions for prompt arguments and resource template parameters. With the context parameter, servers can provide completions based on previously resolved values: + +Client usage: +```python +from mcp.client.session import ClientSession +from mcp.types import ResourceTemplateReference + + +async def use_completion(session: ClientSession): + # Complete without context + result = await session.complete( + ref=ResourceTemplateReference( + type="ref/resource", uri="github://repos/{owner}/{repo}" + ), + argument={"name": "owner", "value": "model"}, + ) + + # Complete with context - repo suggestions based on owner + result = await session.complete( + ref=ResourceTemplateReference( + type="ref/resource", uri="github://repos/{owner}/{repo}" + ), + argument={"name": "repo", "value": "test"}, + context_arguments={"owner": "modelcontextprotocol"}, + ) +``` + +Server implementation: +```python +from mcp.server import Server +from mcp.types import ( + Completion, + CompletionArgument, + CompletionContext, + PromptReference, + ResourceTemplateReference, +) + +server = Server("example-server") + + +@server.completion() +async def handle_completion( + ref: PromptReference | ResourceTemplateReference, + argument: CompletionArgument, + context: CompletionContext | None, +) -> Completion | None: + if isinstance(ref, ResourceTemplateReference): + if ref.uri == "github://repos/{owner}/{repo}" and argument.name == "repo": + # Use context to provide owner-specific repos + if context and context.arguments: + owner = context.arguments.get("owner") + if owner == "modelcontextprotocol": + repos = ["python-sdk", "typescript-sdk", "specification"] + # Filter based on partial input + filtered = [r for r in repos if r.startswith(argument.value)] + return Completion(values=filtered) + return None +``` +### Elicitation + +Request additional information from users during tool execution: + +```python +from mcp.server.fastmcp import FastMCP, Context +from mcp.server.elicitation import ( + AcceptedElicitation, + DeclinedElicitation, + CancelledElicitation, +) +from pydantic import BaseModel, Field + +mcp = FastMCP("Booking System") + + +@mcp.tool() +async def book_table(date: str, party_size: int, ctx: Context) -> str: + """Book a table with confirmation""" + + # Schema must only contain primitive types (str, int, float, bool) + class ConfirmBooking(BaseModel): + confirm: bool = Field(description="Confirm booking?") + notes: str = Field(default="", description="Special requests") + + result = await ctx.elicit( + message=f"Confirm booking for {party_size} on {date}?", schema=ConfirmBooking + ) + + match result: + case AcceptedElicitation(data=data): + if data.confirm: + return f"Booked! Notes: {data.notes or 'None'}" + return "Booking cancelled" + case DeclinedElicitation(): + return "Booking declined" + case CancelledElicitation(): + return "Booking cancelled" +``` + +The `elicit()` method returns an `ElicitationResult` with: +- `action`: "accept", "decline", or "cancel" +- `data`: The validated response (only when accepted) +- `validation_error`: Any validation error message + ### Authentication Authentication can be used by servers that want to expose tools accessing protected resources. @@ -809,6 +918,42 @@ async def main(): tool_result = await session.call_tool("echo", {"message": "hello"}) ``` +### Client Display Utilities + +When building MCP clients, the SDK provides utilities to help display human-readable names for tools, resources, and prompts: + +```python +from mcp.shared.metadata_utils import get_display_name +from mcp.client.session import ClientSession + + +async def display_tools(session: ClientSession): + """Display available tools with human-readable names""" + tools_response = await session.list_tools() + + for tool in tools_response.tools: + # get_display_name() returns the title if available, otherwise the name + display_name = get_display_name(tool) + print(f"Tool: {display_name}") + if tool.description: + print(f" {tool.description}") + + +async def display_resources(session: ClientSession): + """Display available resources with human-readable names""" + resources_response = await session.list_resources() + + for resource in resources_response.resources: + display_name = get_display_name(resource) + print(f"Resource: {display_name} ({resource.uri})") +``` + +The `get_display_name()` function implements the proper precedence rules for displaying names: +- For tools: `title` > `annotations.title` > `name` +- For other objects: `title` > `name` + +This ensures your client UI shows the most user-friendly names that servers provide. + ### OAuth Authentication for Clients The SDK includes [authorization support](https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization) for connecting to protected MCP servers: diff --git a/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py b/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py index 180d709aa..b97b85080 100644 --- a/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py +++ b/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py @@ -123,7 +123,7 @@ async def list_tools(self) -> list[Any]: for item in tools_response: if isinstance(item, tuple) and item[0] == "tools": tools.extend( - Tool(tool.name, tool.description, tool.inputSchema) + Tool(tool.name, tool.description, tool.inputSchema, tool.title) for tool in item[1] ) @@ -189,9 +189,14 @@ class Tool: """Represents a tool with its properties and formatting.""" def __init__( - self, name: str, description: str, input_schema: dict[str, Any] + self, + name: str, + description: str, + input_schema: dict[str, Any], + title: str | None = None, ) -> None: self.name: str = name + self.title: str | None = title self.description: str = description self.input_schema: dict[str, Any] = input_schema @@ -211,13 +216,20 @@ def format_for_llm(self) -> str: arg_desc += " (required)" args_desc.append(arg_desc) - return f""" -Tool: {self.name} -Description: {self.description} + # Build the formatted output with title as a separate field + output = f"Tool: {self.name}\n" + + # Add human-readable title if available + if self.title: + output += f"User-readable title: {self.title}\n" + + output += f"""Description: {self.description} Arguments: {chr(10).join(args_desc)} """ + return output + class LLMClient: """Manages communication with the LLM provider.""" diff --git a/examples/fastmcp/memory.py b/examples/fastmcp/memory.py index dbc890815..0f97babf1 100644 --- a/examples/fastmcp/memory.py +++ b/examples/fastmcp/memory.py @@ -47,18 +47,14 @@ DB_DSN = "postgresql://postgres:postgres@localhost:54320/memory_db" # reset memory with rm ~/.fastmcp/{USER}/memory/* -PROFILE_DIR = ( - Path.home() / ".fastmcp" / os.environ.get("USER", "anon") / "memory" -).resolve() +PROFILE_DIR = (Path.home() / ".fastmcp" / os.environ.get("USER", "anon") / "memory").resolve() PROFILE_DIR.mkdir(parents=True, exist_ok=True) def cosine_similarity(a: list[float], b: list[float]) -> float: a_array = np.array(a, dtype=np.float64) b_array = np.array(b, dtype=np.float64) - return np.dot(a_array, b_array) / ( - np.linalg.norm(a_array) * np.linalg.norm(b_array) - ) + return np.dot(a_array, b_array) / (np.linalg.norm(a_array) * np.linalg.norm(b_array)) async def do_ai[T]( @@ -97,9 +93,7 @@ class MemoryNode(BaseModel): summary: str = "" importance: float = 1.0 access_count: int = 0 - timestamp: float = Field( - default_factory=lambda: datetime.now(timezone.utc).timestamp() - ) + timestamp: float = Field(default_factory=lambda: datetime.now(timezone.utc).timestamp()) embedding: list[float] @classmethod @@ -152,9 +146,7 @@ async def merge_with(self, other: Self, deps: Deps): self.importance += other.importance self.access_count += other.access_count self.embedding = [(a + b) / 2 for a, b in zip(self.embedding, other.embedding)] - self.summary = await do_ai( - self.content, "Summarize the following text concisely.", str, deps - ) + self.summary = await do_ai(self.content, "Summarize the following text concisely.", str, deps) await self.save(deps) # Delete the merged node from the database if other.id is not None: @@ -221,9 +213,7 @@ async def find_similar_memories(embedding: list[float], deps: Deps) -> list[Memo async def update_importance(user_embedding: list[float], deps: Deps): async with deps.pool.acquire() as conn: - rows = await conn.fetch( - "SELECT id, importance, access_count, embedding FROM memories" - ) + rows = await conn.fetch("SELECT id, importance, access_count, embedding FROM memories") for row in rows: memory_embedding = row["embedding"] similarity = cosine_similarity(user_embedding, memory_embedding) @@ -273,9 +263,7 @@ async def display_memory_tree(deps: Deps) -> str: ) result = "" for row in rows: - effective_importance = row["importance"] * ( - 1 + math.log(row["access_count"] + 1) - ) + effective_importance = row["importance"] * (1 + math.log(row["access_count"] + 1)) summary = row["summary"] or row["content"] result += f"- {summary} (Importance: {effective_importance:.2f})\n" return result @@ -283,15 +271,11 @@ async def display_memory_tree(deps: Deps) -> str: @mcp.tool() async def remember( - contents: list[str] = Field( - description="List of observations or memories to store" - ), + contents: list[str] = Field(description="List of observations or memories to store"), ): deps = Deps(openai=AsyncOpenAI(), pool=await get_db_pool()) try: - return "\n".join( - await asyncio.gather(*[add_memory(content, deps) for content in contents]) - ) + return "\n".join(await asyncio.gather(*[add_memory(content, deps) for content in contents])) finally: await deps.pool.close() @@ -305,9 +289,7 @@ async def read_profile() -> str: async def initialize_database(): - pool = await asyncpg.create_pool( - "postgresql://postgres:postgres@localhost:54320/postgres" - ) + pool = await asyncpg.create_pool("postgresql://postgres:postgres@localhost:54320/postgres") try: async with pool.acquire() as conn: await conn.execute(""" diff --git a/examples/fastmcp/text_me.py b/examples/fastmcp/text_me.py index 8053c6cc5..2434dcddd 100644 --- a/examples/fastmcp/text_me.py +++ b/examples/fastmcp/text_me.py @@ -28,15 +28,11 @@ class SurgeSettings(BaseSettings): - model_config: SettingsConfigDict = SettingsConfigDict( - env_prefix="SURGE_", env_file=".env" - ) + model_config: SettingsConfigDict = SettingsConfigDict(env_prefix="SURGE_", env_file=".env") api_key: str account_id: str - my_phone_number: Annotated[ - str, BeforeValidator(lambda v: "+" + v if not v.startswith("+") else v) - ] + my_phone_number: Annotated[str, BeforeValidator(lambda v: "+" + v if not v.startswith("+") else v)] my_first_name: str my_last_name: str diff --git a/examples/fastmcp/unicode_example.py b/examples/fastmcp/unicode_example.py index a69f586a5..94ef628bb 100644 --- a/examples/fastmcp/unicode_example.py +++ b/examples/fastmcp/unicode_example.py @@ -8,10 +8,7 @@ mcp = FastMCP() -@mcp.tool( - description="🌟 A tool that uses various Unicode characters in its description: " - "Γ‘ Γ© Γ­ Γ³ ΓΊ Γ± ζΌ’ε­— πŸŽ‰" -) +@mcp.tool(description="🌟 A tool that uses various Unicode characters in its description: " "Γ‘ Γ© Γ­ Γ³ ΓΊ Γ± ζΌ’ε­— πŸŽ‰") def hello_unicode(name: str = "δΈ–η•Œ", greeting: str = "Β‘Hola") -> str: """ A simple tool that demonstrates Unicode handling in: diff --git a/examples/servers/simple-auth/mcp_simple_auth/server.py b/examples/servers/simple-auth/mcp_simple_auth/server.py index 51f449113..6e16f8b9d 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/server.py +++ b/examples/servers/simple-auth/mcp_simple_auth/server.py @@ -82,9 +82,7 @@ async def register_client(self, client_info: OAuthClientInformationFull): """Register a new OAuth client.""" self.clients[client_info.client_id] = client_info - async def authorize( - self, client: OAuthClientInformationFull, params: AuthorizationParams - ) -> str: + async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str: """Generate an authorization URL for GitHub OAuth flow.""" state = params.state or secrets.token_hex(16) @@ -92,9 +90,7 @@ async def authorize( self.state_mapping[state] = { "redirect_uri": str(params.redirect_uri), "code_challenge": params.code_challenge, - "redirect_uri_provided_explicitly": str( - params.redirect_uri_provided_explicitly - ), + "redirect_uri_provided_explicitly": str(params.redirect_uri_provided_explicitly), "client_id": client.client_id, } @@ -117,9 +113,7 @@ async def handle_github_callback(self, code: str, state: str) -> str: redirect_uri = state_data["redirect_uri"] code_challenge = state_data["code_challenge"] - redirect_uri_provided_explicitly = ( - state_data["redirect_uri_provided_explicitly"] == "True" - ) + redirect_uri_provided_explicitly = state_data["redirect_uri_provided_explicitly"] == "True" client_id = state_data["client_id"] # Exchange code for token with GitHub @@ -200,8 +194,7 @@ async def exchange_authorization_code( for token, data in self.tokens.items() # see https://github.blog/engineering/platform-security/behind-githubs-new-authentication-token-formats/ # which you get depends on your GH app setup. - if (token.startswith("ghu_") or token.startswith("gho_")) - and data.client_id == client.client_id + if (token.startswith("ghu_") or token.startswith("gho_")) and data.client_id == client.client_id ), None, ) @@ -214,7 +207,7 @@ async def exchange_authorization_code( return OAuthToken( access_token=mcp_token, - token_type="bearer", + token_type="Bearer", expires_in=3600, scope=" ".join(authorization_code.scopes), ) @@ -232,9 +225,7 @@ async def load_access_token(self, token: str) -> AccessToken | None: return access_token - async def load_refresh_token( - self, client: OAuthClientInformationFull, refresh_token: str - ) -> RefreshToken | None: + async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshToken | None: """Load a refresh token - not supported.""" return None @@ -247,9 +238,7 @@ async def exchange_refresh_token( """Exchange refresh token""" raise NotImplementedError("Not supported") - async def revoke_token( - self, token: str, token_type_hint: str | None = None - ) -> None: + async def revoke_token(self, token: str, token_type_hint: str | None = None) -> None: """Revoke a token.""" if token in self.tokens: del self.tokens[token] @@ -335,9 +324,7 @@ async def get_user_profile() -> dict[str, Any]: ) if response.status_code != 200: - raise ValueError( - f"GitHub API error: {response.status_code} - {response.text}" - ) + raise ValueError(f"GitHub API error: {response.status_code} - {response.text}") return response.json() @@ -361,9 +348,7 @@ def main(port: int, host: str, transport: Literal["sse", "streamable-http"]) -> # No hardcoded credentials - all from environment variables settings = ServerSettings(host=host, port=port) except ValueError as e: - logger.error( - "Failed to load settings. Make sure environment variables are set:" - ) + logger.error("Failed to load settings. Make sure environment variables are set:") logger.error(" MCP_GITHUB_GITHUB_CLIENT_ID=") logger.error(" MCP_GITHUB_GITHUB_CLIENT_SECRET=") logger.error(f"Error: {e}") diff --git a/examples/servers/simple-prompt/mcp_simple_prompt/server.py b/examples/servers/simple-prompt/mcp_simple_prompt/server.py index eca0bbcf3..b562cc932 100644 --- a/examples/servers/simple-prompt/mcp_simple_prompt/server.py +++ b/examples/servers/simple-prompt/mcp_simple_prompt/server.py @@ -53,6 +53,7 @@ async def list_prompts() -> list[types.Prompt]: return [ types.Prompt( name="simple", + title="Simple Assistant Prompt", description="A simple prompt that can take optional context and topic " "arguments", arguments=[ diff --git a/examples/servers/simple-resource/mcp_simple_resource/server.py b/examples/servers/simple-resource/mcp_simple_resource/server.py index 85c29cb7d..cef29b851 100644 --- a/examples/servers/simple-resource/mcp_simple_resource/server.py +++ b/examples/servers/simple-resource/mcp_simple_resource/server.py @@ -2,12 +2,21 @@ import click import mcp.types as types from mcp.server.lowlevel import Server -from pydantic import AnyUrl +from pydantic import AnyUrl, FileUrl SAMPLE_RESOURCES = { - "greeting": "Hello! This is a sample text resource.", - "help": "This server provides a few sample text resources for testing.", - "about": "This is the simple-resource MCP server implementation.", + "greeting": { + "content": "Hello! This is a sample text resource.", + "title": "Welcome Message", + }, + "help": { + "content": "This server provides a few sample text resources for testing.", + "title": "Help Documentation", + }, + "about": { + "content": "This is the simple-resource MCP server implementation.", + "title": "About This Server", + }, } @@ -26,8 +35,9 @@ def main(port: int, transport: str) -> int: async def list_resources() -> list[types.Resource]: return [ types.Resource( - uri=AnyUrl(f"file:///{name}.txt"), + uri=FileUrl(f"file:///{name}.txt"), name=name, + title=SAMPLE_RESOURCES[name]["title"], description=f"A sample text resource named {name}", mimeType="text/plain", ) @@ -43,7 +53,7 @@ async def read_resource(uri: AnyUrl) -> str | bytes: if name not in SAMPLE_RESOURCES: raise ValueError(f"Unknown resource: {uri}") - return SAMPLE_RESOURCES[name] + return SAMPLE_RESOURCES[name]["content"] if transport == "sse": from mcp.server.sse import SseServerTransport diff --git a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py index bbf3dc64c..6a9ff9364 100644 --- a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py +++ b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/server.py @@ -41,9 +41,7 @@ def main( app = Server("mcp-streamable-http-stateless-demo") @app.call_tool() - async def call_tool( - name: str, arguments: dict - ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: + async def call_tool(name: str, arguments: dict) -> list[types.ContentBlock]: ctx = app.request_context interval = arguments.get("interval", 1.0) count = arguments.get("count", 5) diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py index bf6f51e5c..85eb1369f 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -45,9 +45,7 @@ def main( app = Server("mcp-streamable-http-demo") @app.call_tool() - async def call_tool( - name: str, arguments: dict - ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: + async def call_tool(name: str, arguments: dict) -> list[types.ContentBlock]: ctx = app.request_context interval = arguments.get("interval", 1.0) count = arguments.get("count", 5) diff --git a/examples/servers/simple-tool/mcp_simple_tool/server.py b/examples/servers/simple-tool/mcp_simple_tool/server.py index cd574ad5e..bf3683c9e 100644 --- a/examples/servers/simple-tool/mcp_simple_tool/server.py +++ b/examples/servers/simple-tool/mcp_simple_tool/server.py @@ -7,7 +7,7 @@ async def fetch_website( url: str, -) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: +) -> list[types.ContentBlock]: headers = { "User-Agent": "MCP Test Server (github.com/modelcontextprotocol/python-sdk)" } @@ -29,9 +29,7 @@ def main(port: int, transport: str) -> int: app = Server("mcp-website-fetcher") @app.call_tool() - async def fetch_tool( - name: str, arguments: dict - ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: + async def fetch_tool(name: str, arguments: dict) -> list[types.ContentBlock]: if name != "fetch": raise ValueError(f"Unknown tool: {name}") if "url" not in arguments: @@ -43,6 +41,7 @@ async def list_tools() -> list[types.Tool]: return [ types.Tool( name="fetch", + title="Website Fetcher", description="Fetches a website and returns its content", inputSchema={ "type": "object", diff --git a/pyproject.toml b/pyproject.toml index 0a11a3b15..9ad50ab58 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,7 +96,7 @@ select = ["C4", "E", "F", "I", "PERF", "UP"] ignore = ["PERF203"] [tool.ruff] -line-length = 88 +line-length = 120 target-version = "py310" [tool.ruff.lint.per-file-ignores] diff --git a/src/mcp/cli/claude.py b/src/mcp/cli/claude.py index 1629f9287..e6eab2851 100644 --- a/src/mcp/cli/claude.py +++ b/src/mcp/cli/claude.py @@ -21,9 +21,7 @@ def get_claude_config_path() -> Path | None: elif sys.platform == "darwin": path = Path(Path.home(), "Library", "Application Support", "Claude") elif sys.platform.startswith("linux"): - path = Path( - os.environ.get("XDG_CONFIG_HOME", Path.home() / ".config"), "Claude" - ) + path = Path(os.environ.get("XDG_CONFIG_HOME", Path.home() / ".config"), "Claude") else: return None @@ -37,8 +35,7 @@ def get_uv_path() -> str: uv_path = shutil.which("uv") if not uv_path: logger.error( - "uv executable not found in PATH, falling back to 'uv'. " - "Please ensure uv is installed and in your PATH" + "uv executable not found in PATH, falling back to 'uv'. " "Please ensure uv is installed and in your PATH" ) return "uv" # Fall back to just "uv" if not found return uv_path @@ -94,10 +91,7 @@ def update_claude_config( config["mcpServers"] = {} # Always preserve existing env vars and merge with new ones - if ( - server_name in config["mcpServers"] - and "env" in config["mcpServers"][server_name] - ): + if server_name in config["mcpServers"] and "env" in config["mcpServers"][server_name]: existing_env = config["mcpServers"][server_name]["env"] if env_vars: # New vars take precedence over existing ones diff --git a/src/mcp/cli/cli.py b/src/mcp/cli/cli.py index b2632f1d9..69e2921f1 100644 --- a/src/mcp/cli/cli.py +++ b/src/mcp/cli/cli.py @@ -45,9 +45,7 @@ def _get_npx_command(): # Try both npx.cmd and npx.exe on Windows for cmd in ["npx.cmd", "npx.exe", "npx"]: try: - subprocess.run( - [cmd, "--version"], check=True, capture_output=True, shell=True - ) + subprocess.run([cmd, "--version"], check=True, capture_output=True, shell=True) return cmd except subprocess.CalledProcessError: continue @@ -58,9 +56,7 @@ def _get_npx_command(): def _parse_env_var(env_var: str) -> tuple[str, str]: """Parse environment variable string in format KEY=VALUE.""" if "=" not in env_var: - logger.error( - f"Invalid environment variable format: {env_var}. Must be KEY=VALUE" - ) + logger.error(f"Invalid environment variable format: {env_var}. Must be KEY=VALUE") sys.exit(1) key, value = env_var.split("=", 1) return key.strip(), value.strip() @@ -154,14 +150,10 @@ def _check_server_object(server_object: Any, object_name: str): True if it's supported. """ if not isinstance(server_object, FastMCP): - logger.error( - f"The server object {object_name} is of type " - f"{type(server_object)} (expecting {FastMCP})." - ) + logger.error(f"The server object {object_name} is of type " f"{type(server_object)} (expecting {FastMCP}).") if isinstance(server_object, LowLevelServer): logger.warning( - "Note that only FastMCP server is supported. Low level " - "Server class is not yet supported." + "Note that only FastMCP server is supported. Low level " "Server class is not yet supported." ) return False return True @@ -172,10 +164,7 @@ def _check_server_object(server_object: Any, object_name: str): for name in ["mcp", "server", "app"]: if hasattr(module, name): if not _check_server_object(getattr(module, name), f"{file}:{name}"): - logger.error( - f"Ignoring object '{file}:{name}' as it's not a valid " - "server object" - ) + logger.error(f"Ignoring object '{file}:{name}' as it's not a valid " "server object") continue return getattr(module, name) @@ -280,8 +269,7 @@ def dev( npx_cmd = _get_npx_command() if not npx_cmd: logger.error( - "npx not found. Please ensure Node.js and npm are properly installed " - "and added to your system PATH." + "npx not found. Please ensure Node.js and npm are properly installed " "and added to your system PATH." ) sys.exit(1) @@ -383,8 +371,7 @@ def install( typer.Option( "--name", "-n", - help="Custom name for the server (defaults to server's name attribute or" - " file name)", + help="Custom name for the server (defaults to server's name attribute or" " file name)", ), ] = None, with_editable: Annotated[ @@ -458,8 +445,7 @@ def install( name = server.name except (ImportError, ModuleNotFoundError) as e: logger.debug( - "Could not import server (likely missing dependencies), using file" - " name", + "Could not import server (likely missing dependencies), using file" " name", extra={"error": str(e)}, ) name = file.stem @@ -477,11 +463,7 @@ def install( if env_file: if dotenv: try: - env_dict |= { - k: v - for k, v in dotenv.dotenv_values(env_file).items() - if v is not None - } + env_dict |= {k: v for k, v in dotenv.dotenv_values(env_file).items() if v is not None} except Exception as e: logger.error(f"Failed to load .env file: {e}") sys.exit(1) diff --git a/src/mcp/client/__main__.py b/src/mcp/client/__main__.py index 2ec68e56c..2efe05d53 100644 --- a/src/mcp/client/__main__.py +++ b/src/mcp/client/__main__.py @@ -24,9 +24,7 @@ async def message_handler( - message: RequestResponder[types.ServerRequest, types.ClientResult] - | types.ServerNotification - | Exception, + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: if isinstance(message, Exception): logger.error("Error: %s", message) @@ -60,9 +58,7 @@ async def main(command_or_url: str, args: list[str], env: list[tuple[str, str]]) await run_session(*streams) else: # Use stdio client for commands - server_parameters = StdioServerParameters( - command=command_or_url, args=args, env=env_dict - ) + server_parameters = StdioServerParameters(command=command_or_url, args=args, env=env_dict) async with stdio_client(server_parameters) as streams: await run_session(*streams) diff --git a/src/mcp/client/auth.py b/src/mcp/client/auth.py index fc6c96a43..4e777d600 100644 --- a/src/mcp/client/auth.py +++ b/src/mcp/client/auth.py @@ -17,6 +17,7 @@ import anyio import httpx +from mcp.client.streamable_http import MCP_PROTOCOL_VERSION from mcp.shared.auth import ( OAuthClientInformationFull, OAuthClientMetadata, @@ -100,10 +101,7 @@ def __init__( def _generate_code_verifier(self) -> str: """Generate a cryptographically random code verifier for PKCE.""" - return "".join( - secrets.choice(string.ascii_letters + string.digits + "-._~") - for _ in range(128) - ) + return "".join(secrets.choice(string.ascii_letters + string.digits + "-._~") for _ in range(128)) def _generate_code_challenge(self, code_verifier: str) -> str: """Generate a code challenge from a code verifier using SHA256.""" @@ -129,7 +127,7 @@ async def _discover_oauth_metadata(self, server_url: str) -> OAuthMetadata | Non # Extract base URL per MCP spec auth_base_url = self._get_authorization_base_url(server_url) url = urljoin(auth_base_url, "/.well-known/oauth-authorization-server") - headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION} + headers = {MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION} async with httpx.AsyncClient() as client: try: @@ -148,9 +146,7 @@ async def _discover_oauth_metadata(self, server_url: str) -> OAuthMetadata | Non return None response.raise_for_status() metadata_json = response.json() - logger.debug( - f"OAuth metadata discovered (no MCP header): {metadata_json}" - ) + logger.debug(f"OAuth metadata discovered (no MCP header): {metadata_json}") return OAuthMetadata.model_validate(metadata_json) except Exception: logger.exception("Failed to discover OAuth metadata") @@ -176,17 +172,11 @@ async def _register_oauth_client( registration_url = urljoin(auth_base_url, "/register") # Handle default scope - if ( - client_metadata.scope is None - and metadata - and metadata.scopes_supported is not None - ): + if client_metadata.scope is None and metadata and metadata.scopes_supported is not None: client_metadata.scope = " ".join(metadata.scopes_supported) # Serialize client metadata - registration_data = client_metadata.model_dump( - by_alias=True, mode="json", exclude_none=True - ) + registration_data = client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True) async with httpx.AsyncClient() as client: try: @@ -213,9 +203,7 @@ async def _register_oauth_client( logger.exception("Registration error") raise - async def async_auth_flow( - self, request: httpx.Request - ) -> AsyncGenerator[httpx.Request, httpx.Response]: + async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]: """ HTTPX auth flow integration. """ @@ -225,9 +213,7 @@ async def async_auth_flow( await self.ensure_token() # Add Bearer token if available if self._current_tokens and self._current_tokens.access_token: - request.headers["Authorization"] = ( - f"Bearer {self._current_tokens.access_token}" - ) + request.headers["Authorization"] = f"Bearer {self._current_tokens.access_token}" response = yield request @@ -305,11 +291,7 @@ async def ensure_token(self) -> None: return # Try refreshing existing token - if ( - self._current_tokens - and self._current_tokens.refresh_token - and await self._refresh_access_token() - ): + if self._current_tokens and self._current_tokens.refresh_token and await self._refresh_access_token(): return # Fall back to full OAuth flow @@ -361,12 +343,8 @@ async def _perform_oauth_flow(self) -> None: auth_code, returned_state = await self.callback_handler() # Validate state parameter for CSRF protection - if returned_state is None or not secrets.compare_digest( - returned_state, self._auth_state - ): - raise Exception( - f"State parameter mismatch: {returned_state} != {self._auth_state}" - ) + if returned_state is None or not secrets.compare_digest(returned_state, self._auth_state): + raise Exception(f"State parameter mismatch: {returned_state} != {self._auth_state}") # Clear state after validation self._auth_state = None @@ -377,9 +355,7 @@ async def _perform_oauth_flow(self) -> None: # Exchange authorization code for tokens await self._exchange_code_for_token(auth_code, client_info) - async def _exchange_code_for_token( - self, auth_code: str, client_info: OAuthClientInformationFull - ) -> None: + async def _exchange_code_for_token(self, auth_code: str, client_info: OAuthClientInformationFull) -> None: """Exchange authorization code for access token.""" # Get token endpoint if self._metadata and self._metadata.token_endpoint: @@ -412,17 +388,10 @@ async def _exchange_code_for_token( # Parse OAuth error response try: error_data = response.json() - error_msg = error_data.get( - "error_description", error_data.get("error", "Unknown error") - ) - raise Exception( - f"Token exchange failed: {error_msg} " - f"(HTTP {response.status_code})" - ) + error_msg = error_data.get("error_description", error_data.get("error", "Unknown error")) + raise Exception(f"Token exchange failed: {error_msg} " f"(HTTP {response.status_code})") except Exception: - raise Exception( - f"Token exchange failed: {response.status_code} {response.text}" - ) + raise Exception(f"Token exchange failed: {response.status_code} {response.text}") # Parse token response token_response = OAuthToken.model_validate(response.json()) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 3b7fc3fae..948817140 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -22,6 +22,14 @@ async def __call__( ) -> types.CreateMessageResult | types.ErrorData: ... +class ElicitationFnT(Protocol): + async def __call__( + self, + context: RequestContext["ClientSession", Any], + params: types.ElicitRequestParams, + ) -> types.ElicitResult | types.ErrorData: ... + + class ListRootsFnT(Protocol): async def __call__( self, context: RequestContext["ClientSession", Any] @@ -38,16 +46,12 @@ async def __call__( class MessageHandlerFnT(Protocol): async def __call__( self, - message: RequestResponder[types.ServerRequest, types.ClientResult] - | types.ServerNotification - | Exception, + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: ... async def _default_message_handler( - message: RequestResponder[types.ServerRequest, types.ClientResult] - | types.ServerNotification - | Exception, + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: await anyio.lowlevel.checkpoint() @@ -62,6 +66,16 @@ async def _default_sampling_callback( ) +async def _default_elicitation_callback( + context: RequestContext["ClientSession", Any], + params: types.ElicitRequestParams, +) -> types.ElicitResult | types.ErrorData: + return types.ErrorData( + code=types.INVALID_REQUEST, + message="Elicitation not supported", + ) + + async def _default_list_roots_callback( context: RequestContext["ClientSession", Any], ) -> types.ListRootsResult | types.ErrorData: @@ -77,9 +91,7 @@ async def _default_logging_callback( pass -ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter( - types.ClientResult | types.ErrorData -) +ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData) class ClientSession( @@ -97,6 +109,7 @@ def __init__( write_stream: MemoryObjectSendStream[SessionMessage], read_timeout_seconds: timedelta | None = None, sampling_callback: SamplingFnT | None = None, + elicitation_callback: ElicitationFnT | None = None, list_roots_callback: ListRootsFnT | None = None, logging_callback: LoggingFnT | None = None, message_handler: MessageHandlerFnT | None = None, @@ -111,15 +124,15 @@ def __init__( ) self._client_info = client_info or DEFAULT_CLIENT_INFO self._sampling_callback = sampling_callback or _default_sampling_callback + self._elicitation_callback = elicitation_callback or _default_elicitation_callback self._list_roots_callback = list_roots_callback or _default_list_roots_callback self._logging_callback = logging_callback or _default_logging_callback self._message_handler = message_handler or _default_message_handler async def initialize(self) -> types.InitializeResult: - sampling = ( - types.SamplingCapability() - if self._sampling_callback is not _default_sampling_callback - else None + sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None + elicitation = ( + types.ElicitationCapability() if self._elicitation_callback is not _default_elicitation_callback else None ) roots = ( # TODO: Should this be based on whether we @@ -138,6 +151,7 @@ async def initialize(self) -> types.InitializeResult: protocolVersion=types.LATEST_PROTOCOL_VERSION, capabilities=types.ClientCapabilities( sampling=sampling, + elicitation=elicitation, experimental=None, roots=roots, ), @@ -149,15 +163,10 @@ async def initialize(self) -> types.InitializeResult: ) if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS: - raise RuntimeError( - "Unsupported protocol version from the server: " - f"{result.protocolVersion}" - ) + raise RuntimeError("Unsupported protocol version from the server: " f"{result.protocolVersion}") await self.send_notification( - types.ClientNotification( - types.InitializedNotification(method="notifications/initialized") - ) + types.ClientNotification(types.InitializedNotification(method="notifications/initialized")) ) return result @@ -207,33 +216,25 @@ async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResul types.EmptyResult, ) - async def list_resources( - self, cursor: str | None = None - ) -> types.ListResourcesResult: + async def list_resources(self, cursor: str | None = None) -> types.ListResourcesResult: """Send a resources/list request.""" return await self.send_request( types.ClientRequest( types.ListResourcesRequest( method="resources/list", - params=types.PaginatedRequestParams(cursor=cursor) - if cursor is not None - else None, + params=types.PaginatedRequestParams(cursor=cursor) if cursor is not None else None, ) ), types.ListResourcesResult, ) - async def list_resource_templates( - self, cursor: str | None = None - ) -> types.ListResourceTemplatesResult: + async def list_resource_templates(self, cursor: str | None = None) -> types.ListResourceTemplatesResult: """Send a resources/templates/list request.""" return await self.send_request( types.ClientRequest( types.ListResourceTemplatesRequest( method="resources/templates/list", - params=types.PaginatedRequestParams(cursor=cursor) - if cursor is not None - else None, + params=types.PaginatedRequestParams(cursor=cursor) if cursor is not None else None, ) ), types.ListResourceTemplatesResult, @@ -305,17 +306,13 @@ async def list_prompts(self, cursor: str | None = None) -> types.ListPromptsResu types.ClientRequest( types.ListPromptsRequest( method="prompts/list", - params=types.PaginatedRequestParams(cursor=cursor) - if cursor is not None - else None, + params=types.PaginatedRequestParams(cursor=cursor) if cursor is not None else None, ) ), types.ListPromptsResult, ) - async def get_prompt( - self, name: str, arguments: dict[str, str] | None = None - ) -> types.GetPromptResult: + async def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult: """Send a prompts/get request.""" return await self.send_request( types.ClientRequest( @@ -329,10 +326,15 @@ async def get_prompt( async def complete( self, - ref: types.ResourceReference | types.PromptReference, + ref: types.ResourceTemplateReference | types.PromptReference, argument: dict[str, str], + context_arguments: dict[str, str] | None = None, ) -> types.CompleteResult: """Send a completion/complete request.""" + context = None + if context_arguments is not None: + context = types.CompletionContext(arguments=context_arguments) + return await self.send_request( types.ClientRequest( types.CompleteRequest( @@ -340,6 +342,7 @@ async def complete( params=types.CompleteRequestParams( ref=ref, argument=types.CompletionArgument(**argument), + context=context, ), ) ), @@ -352,9 +355,7 @@ async def list_tools(self, cursor: str | None = None) -> types.ListToolsResult: types.ClientRequest( types.ListToolsRequest( method="tools/list", - params=types.PaginatedRequestParams(cursor=cursor) - if cursor is not None - else None, + params=types.PaginatedRequestParams(cursor=cursor) if cursor is not None else None, ) ), types.ListToolsResult, @@ -370,9 +371,7 @@ async def send_roots_list_changed(self) -> None: ) ) - async def _received_request( - self, responder: RequestResponder[types.ServerRequest, types.ClientResult] - ) -> None: + async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None: ctx = RequestContext[ClientSession, Any]( request_id=responder.request_id, meta=responder.request_meta, @@ -387,6 +386,12 @@ async def _received_request( client_response = ClientResponse.validate_python(response) await responder.respond(client_response) + case types.ElicitRequest(params=params): + with responder: + response = await self._elicitation_callback(ctx, params) + client_response = ClientResponse.validate_python(response) + await responder.respond(client_response) + case types.ListRootsRequest(): with responder: response = await self._list_roots_callback(ctx) @@ -395,22 +400,16 @@ async def _received_request( case types.PingRequest(): with responder: - return await responder.respond( - types.ClientResult(root=types.EmptyResult()) - ) + return await responder.respond(types.ClientResult(root=types.EmptyResult())) async def _handle_incoming( self, - req: RequestResponder[types.ServerRequest, types.ClientResult] - | types.ServerNotification - | Exception, + req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: """Handle incoming messages by forwarding to the message handler.""" await self._message_handler(req) - async def _received_notification( - self, notification: types.ServerNotification - ) -> None: + async def _received_notification(self, notification: types.ServerNotification) -> None: """Handle notifications from the server.""" # Process specific notification types match notification.root: diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index a77dc7a1e..700b5417f 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -62,9 +62,7 @@ class StreamableHttpParameters(BaseModel): terminate_on_close: bool = True -ServerParameters: TypeAlias = ( - StdioServerParameters | SseServerParameters | StreamableHttpParameters -) +ServerParameters: TypeAlias = StdioServerParameters | SseServerParameters | StreamableHttpParameters class ClientSessionGroup: @@ -261,9 +259,7 @@ async def _establish_session( ) read, write, _ = await session_stack.enter_async_context(client) - session = await session_stack.enter_async_context( - mcp.ClientSession(read, write) - ) + session = await session_stack.enter_async_context(mcp.ClientSession(read, write)) result = await session.initialize() # Session successfully initialized. @@ -280,9 +276,7 @@ async def _establish_session( await session_stack.aclose() raise - async def _aggregate_components( - self, server_info: types.Implementation, session: mcp.ClientSession - ) -> None: + async def _aggregate_components(self, server_info: types.Implementation, session: mcp.ClientSession) -> None: """Aggregates prompts, resources, and tools from a given session.""" # Create a reverse index so we can find all prompts, resources, and diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 2013e4199..0c05c6def 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -54,12 +54,13 @@ async def sse_client( async with anyio.create_task_group() as tg: try: logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}") - async with httpx_client_factory(headers=headers, auth=auth) as client: + async with httpx_client_factory( + headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout) + ) as client: async with aconnect_sse( client, "GET", url, - timeout=httpx.Timeout(timeout, read=sse_read_timeout), ) as event_source: event_source.response.raise_for_status() logger.debug("SSE connection established") @@ -73,20 +74,16 @@ async def sse_reader( match sse.event: case "endpoint": endpoint_url = urljoin(url, sse.data) - logger.debug( - f"Received endpoint URL: {endpoint_url}" - ) + logger.debug(f"Received endpoint URL: {endpoint_url}") url_parsed = urlparse(url) endpoint_parsed = urlparse(endpoint_url) if ( url_parsed.netloc != endpoint_parsed.netloc - or url_parsed.scheme - != endpoint_parsed.scheme + or url_parsed.scheme != endpoint_parsed.scheme ): error_msg = ( - "Endpoint origin does not match " - f"connection origin: {endpoint_url}" + "Endpoint origin does not match " f"connection origin: {endpoint_url}" ) logger.error(error_msg) raise ValueError(error_msg) @@ -98,22 +95,16 @@ async def sse_reader( message = types.JSONRPCMessage.model_validate_json( # noqa: E501 sse.data ) - logger.debug( - f"Received server message: {message}" - ) + logger.debug(f"Received server message: {message}") except Exception as exc: - logger.error( - f"Error parsing server message: {exc}" - ) + logger.error(f"Error parsing server message: {exc}") await read_stream_writer.send(exc) continue session_message = SessionMessage(message) await read_stream_writer.send(session_message) case _: - logger.warning( - f"Unknown SSE event: {sse.event}" - ) + logger.warning(f"Unknown SSE event: {sse.event}") except Exception as exc: logger.error(f"Error in sse_reader: {exc}") await read_stream_writer.send(exc) @@ -124,9 +115,7 @@ async def post_writer(endpoint_url: str): try: async with write_stream_reader: async for session_message in write_stream_reader: - logger.debug( - f"Sending client message: {session_message}" - ) + logger.debug(f"Sending client message: {session_message}") response = await client.post( endpoint_url, json=session_message.message.model_dump( @@ -136,19 +125,14 @@ async def post_writer(endpoint_url: str): ), ) response.raise_for_status() - logger.debug( - "Client message sent successfully: " - f"{response.status_code}" - ) + logger.debug("Client message sent successfully: " f"{response.status_code}") except Exception as exc: logger.error(f"Error in post_writer: {exc}") finally: await write_stream.aclose() endpoint_url = await tg.start(sse_reader) - logger.debug( - f"Starting post writer with endpoint URL: {endpoint_url}" - ) + logger.debug(f"Starting post writer with endpoint URL: {endpoint_url}") tg.start_soon(post_writer, endpoint_url) try: diff --git a/src/mcp/client/stdio/__init__.py b/src/mcp/client/stdio/__init__.py index fce605633..a75cfd764 100644 --- a/src/mcp/client/stdio/__init__.py +++ b/src/mcp/client/stdio/__init__.py @@ -115,11 +115,7 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder process = await _create_platform_compatible_process( command=command, args=server.args, - env=( - {**get_default_environment(), **server.env} - if server.env is not None - else get_default_environment() - ), + env=({**get_default_environment(), **server.env} if server.env is not None else get_default_environment()), errlog=errlog, cwd=server.cwd, ) @@ -163,9 +159,7 @@ async def stdin_writer(): try: async with write_stream_reader: async for session_message in write_stream_reader: - json = session_message.message.model_dump_json( - by_alias=True, exclude_none=True - ) + json = session_message.message.model_dump_json(by_alias=True, exclude_none=True) await process.stdin.send( (json + "\n").encode( encoding=server.encoding, @@ -229,8 +223,6 @@ async def _create_platform_compatible_process( if sys.platform == "win32": process = await create_windows_process(command, args, env, errlog, cwd) else: - process = await anyio.open_process( - [command, *args], env=env, stderr=errlog, cwd=cwd - ) + process = await anyio.open_process([command, *args], env=env, stderr=errlog, cwd=cwd) return process diff --git a/src/mcp/client/stdio/win32.py b/src/mcp/client/stdio/win32.py index 825a0477d..e4f252dc9 100644 --- a/src/mcp/client/stdio/win32.py +++ b/src/mcp/client/stdio/win32.py @@ -82,9 +82,7 @@ async def create_windows_process( return process except Exception: # Don't raise, let's try to create the process without creation flags - process = await anyio.open_process( - [command, *args], env=env, stderr=errlog, cwd=cwd - ) + process = await anyio.open_process([command, *args], env=env, stderr=errlog, cwd=cwd) return process diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 2855f606d..39ac34d8a 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -11,7 +11,6 @@ from contextlib import asynccontextmanager from dataclasses import dataclass from datetime import timedelta -from typing import Any import anyio import httpx @@ -23,6 +22,7 @@ from mcp.shared.message import ClientMessageMetadata, SessionMessage from mcp.types import ( ErrorData, + InitializeResult, JSONRPCError, JSONRPCMessage, JSONRPCNotification, @@ -40,6 +40,7 @@ GetSessionIdCallback = Callable[[], str | None] MCP_SESSION_ID = "mcp-session-id" +MCP_PROTOCOL_VERSION = "mcp-protocol-version" LAST_EVENT_ID = "last-event-id" CONTENT_TYPE = "content-type" ACCEPT = "Accept" @@ -52,14 +53,10 @@ class StreamableHTTPError(Exception): """Base exception for StreamableHTTP transport errors.""" - pass - class ResumptionError(StreamableHTTPError): """Raised when resumption request is invalid.""" - pass - @dataclass class RequestContext: @@ -71,7 +68,7 @@ class RequestContext: session_message: SessionMessage metadata: ClientMessageMetadata | None read_stream_writer: StreamWriter - sse_read_timeout: timedelta + sse_read_timeout: float class StreamableHTTPTransport: @@ -80,9 +77,9 @@ class StreamableHTTPTransport: def __init__( self, url: str, - headers: dict[str, Any] | None = None, - timeout: timedelta = timedelta(seconds=30), - sse_read_timeout: timedelta = timedelta(seconds=60 * 5), + headers: dict[str, str] | None = None, + timeout: float | timedelta = 30, + sse_read_timeout: float | timedelta = 60 * 5, auth: httpx.Auth | None = None, ) -> None: """Initialize the StreamableHTTP transport. @@ -96,38 +93,35 @@ def __init__( """ self.url = url self.headers = headers or {} - self.timeout = timeout - self.sse_read_timeout = sse_read_timeout + self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout + self.sse_read_timeout = ( + sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout + ) self.auth = auth - self.session_id: str | None = None + self.session_id = None + self.protocol_version = None self.request_headers = { ACCEPT: f"{JSON}, {SSE}", CONTENT_TYPE: JSON, **self.headers, } - def _update_headers_with_session( - self, base_headers: dict[str, str] - ) -> dict[str, str]: - """Update headers with session ID if available.""" + def _prepare_request_headers(self, base_headers: dict[str, str]) -> dict[str, str]: + """Update headers with session ID and protocol version if available.""" headers = base_headers.copy() if self.session_id: headers[MCP_SESSION_ID] = self.session_id + if self.protocol_version: + headers[MCP_PROTOCOL_VERSION] = self.protocol_version return headers def _is_initialization_request(self, message: JSONRPCMessage) -> bool: """Check if the message is an initialization request.""" - return ( - isinstance(message.root, JSONRPCRequest) - and message.root.method == "initialize" - ) + return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize" def _is_initialized_notification(self, message: JSONRPCMessage) -> bool: """Check if the message is an initialized notification.""" - return ( - isinstance(message.root, JSONRPCNotification) - and message.root.method == "notifications/initialized" - ) + return isinstance(message.root, JSONRPCNotification) and message.root.method == "notifications/initialized" def _maybe_extract_session_id_from_response( self, @@ -139,12 +133,28 @@ def _maybe_extract_session_id_from_response( self.session_id = new_session_id logger.info(f"Received session ID: {self.session_id}") + def _maybe_extract_protocol_version_from_message( + self, + message: JSONRPCMessage, + ) -> None: + """Extract protocol version from initialization response message.""" + if isinstance(message.root, JSONRPCResponse) and message.root.result: + try: + # Parse the result as InitializeResult for type safety + init_result = InitializeResult.model_validate(message.root.result) + self.protocol_version = str(init_result.protocolVersion) + logger.info(f"Negotiated protocol version: {self.protocol_version}") + except Exception as exc: + logger.warning(f"Failed to parse initialization response as InitializeResult: {exc}") + logger.warning(f"Raw result: {message.root.result}") + async def _handle_sse_event( self, sse: ServerSentEvent, read_stream_writer: StreamWriter, original_request_id: RequestId | None = None, resumption_callback: Callable[[str], Awaitable[None]] | None = None, + is_initialization: bool = False, ) -> bool: """Handle an SSE event, returning True if the response is complete.""" if sse.event == "message": @@ -152,10 +162,12 @@ async def _handle_sse_event( message = JSONRPCMessage.model_validate_json(sse.data) logger.debug(f"SSE message: {message}") + # Extract protocol version from initialization response + if is_initialization: + self._maybe_extract_protocol_version_from_message(message) + # If this is a response and we have original_request_id, replace it - if original_request_id is not None and isinstance( - message.root, JSONRPCResponse | JSONRPCError - ): + if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError): message.root.id = original_request_id session_message = SessionMessage(message) @@ -170,7 +182,7 @@ async def _handle_sse_event( return isinstance(message.root, JSONRPCResponse | JSONRPCError) except Exception as exc: - logger.error(f"Error parsing SSE message: {exc}") + logger.exception("Error parsing SSE message") await read_stream_writer.send(exc) return False else: @@ -187,16 +199,14 @@ async def handle_get_stream( if not self.session_id: return - headers = self._update_headers_with_session(self.request_headers) + headers = self._prepare_request_headers(self.request_headers) async with aconnect_sse( client, "GET", self.url, headers=headers, - timeout=httpx.Timeout( - self.timeout.seconds, read=self.sse_read_timeout.seconds - ), + timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout), ) as event_source: event_source.response.raise_for_status() logger.debug("GET SSE connection established") @@ -209,7 +219,7 @@ async def handle_get_stream( async def _handle_resumption_request(self, ctx: RequestContext) -> None: """Handle a resumption request using GET with SSE.""" - headers = self._update_headers_with_session(ctx.headers) + headers = self._prepare_request_headers(ctx.headers) if ctx.metadata and ctx.metadata.resumption_token: headers[LAST_EVENT_ID] = ctx.metadata.resumption_token else: @@ -225,9 +235,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: "GET", self.url, headers=headers, - timeout=httpx.Timeout( - self.timeout.seconds, read=ctx.sse_read_timeout.seconds - ), + timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout), ) as event_source: event_source.response.raise_for_status() logger.debug("Resumption GET SSE connection established") @@ -244,7 +252,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: async def _handle_post_request(self, ctx: RequestContext) -> None: """Handle a POST request with response processing.""" - headers = self._update_headers_with_session(ctx.headers) + headers = self._prepare_request_headers(ctx.headers) message = ctx.session_message.message is_initialization = self._is_initialization_request(message) @@ -273,9 +281,9 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: content_type = response.headers.get(CONTENT_TYPE, "").lower() if content_type.startswith(JSON): - await self._handle_json_response(response, ctx.read_stream_writer) + await self._handle_json_response(response, ctx.read_stream_writer, is_initialization) elif content_type.startswith(SSE): - await self._handle_sse_response(response, ctx) + await self._handle_sse_response(response, ctx, is_initialization) else: await self._handle_unexpected_content_type( content_type, @@ -286,11 +294,17 @@ async def _handle_json_response( self, response: httpx.Response, read_stream_writer: StreamWriter, + is_initialization: bool = False, ) -> None: """Handle JSON response from the server.""" try: content = await response.aread() message = JSONRPCMessage.model_validate_json(content) + + # Extract protocol version from initialization response + if is_initialization: + self._maybe_extract_protocol_version_from_message(message) + session_message = SessionMessage(message) await read_stream_writer.send(session_message) except Exception as exc: @@ -298,7 +312,10 @@ async def _handle_json_response( await read_stream_writer.send(exc) async def _handle_sse_response( - self, response: httpx.Response, ctx: RequestContext + self, + response: httpx.Response, + ctx: RequestContext, + is_initialization: bool = False, ) -> None: """Handle SSE response from the server.""" try: @@ -307,11 +324,8 @@ async def _handle_sse_response( is_complete = await self._handle_sse_event( sse, ctx.read_stream_writer, - resumption_callback=( - ctx.metadata.on_resumption_token_update - if ctx.metadata - else None - ), + resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None), + is_initialization=is_initialization, ) # If the SSE event indicates completion, like returning respose/error # break the loop @@ -408,7 +422,7 @@ async def terminate_session(self, client: httpx.AsyncClient) -> None: return try: - headers = self._update_headers_with_session(self.request_headers) + headers = self._prepare_request_headers(self.request_headers) response = await client.delete(self.url, headers=headers) if response.status_code == 405: @@ -426,9 +440,9 @@ def get_session_id(self) -> str | None: @asynccontextmanager async def streamablehttp_client( url: str, - headers: dict[str, Any] | None = None, - timeout: timedelta = timedelta(seconds=30), - sse_read_timeout: timedelta = timedelta(seconds=60 * 5), + headers: dict[str, str] | None = None, + timeout: float | timedelta = 30, + sse_read_timeout: float | timedelta = 60 * 5, terminate_on_close: bool = True, httpx_client_factory: McpHttpClientFactory = create_mcp_http_client, auth: httpx.Auth | None = None, @@ -454,12 +468,8 @@ async def streamablehttp_client( """ transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout, auth) - read_stream_writer, read_stream = anyio.create_memory_object_stream[ - SessionMessage | Exception - ](0) - write_stream, write_stream_reader = anyio.create_memory_object_stream[ - SessionMessage - ](0) + read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0) + write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) async with anyio.create_task_group() as tg: try: @@ -467,16 +477,12 @@ async def streamablehttp_client( async with httpx_client_factory( headers=transport.request_headers, - timeout=httpx.Timeout( - transport.timeout.seconds, read=transport.sse_read_timeout.seconds - ), + timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout), auth=transport.auth, ) as client: # Define callbacks that need access to tg def start_get_stream() -> None: - tg.start_soon( - transport.handle_get_stream, client, read_stream_writer - ) + tg.start_soon(transport.handle_get_stream, client, read_stream_writer) tg.start_soon( transport.post_writer, diff --git a/src/mcp/client/websocket.py b/src/mcp/client/websocket.py index ac542fb3f..0a371610b 100644 --- a/src/mcp/client/websocket.py +++ b/src/mcp/client/websocket.py @@ -19,10 +19,7 @@ async def websocket_client( url: str, ) -> AsyncGenerator[ - tuple[ - MemoryObjectReceiveStream[SessionMessage | Exception], - MemoryObjectSendStream[SessionMessage], - ], + tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]], None, ]: """ @@ -74,9 +71,7 @@ async def ws_writer(): async with write_stream_reader: async for session_message in write_stream_reader: # Convert to a dict, then to JSON - msg_dict = session_message.message.model_dump( - by_alias=True, mode="json", exclude_none=True - ) + msg_dict = session_message.message.model_dump(by_alias=True, mode="json", exclude_none=True) await ws.send(json.dumps(msg_dict)) async with anyio.create_task_group() as tg: diff --git a/src/mcp/server/auth/errors.py b/src/mcp/server/auth/errors.py index 053c2fd2e..117deea83 100644 --- a/src/mcp/server/auth/errors.py +++ b/src/mcp/server/auth/errors.py @@ -2,7 +2,4 @@ def stringify_pydantic_error(validation_error: ValidationError) -> str: - return "\n".join( - f"{'.'.join(str(loc) for loc in e['loc'])}: {e['msg']}" - for e in validation_error.errors() - ) + return "\n".join(f"{'.'.join(str(loc) for loc in e['loc'])}: {e['msg']}" for e in validation_error.errors()) diff --git a/src/mcp/server/auth/handlers/authorize.py b/src/mcp/server/auth/handlers/authorize.py index 8f3768908..8d5e2622f 100644 --- a/src/mcp/server/auth/handlers/authorize.py +++ b/src/mcp/server/auth/handlers/authorize.py @@ -2,14 +2,12 @@ from dataclasses import dataclass from typing import Any, Literal -from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, RootModel, ValidationError +from pydantic import AnyUrl, BaseModel, Field, RootModel, ValidationError from starlette.datastructures import FormData, QueryParams from starlette.requests import Request from starlette.responses import RedirectResponse, Response -from mcp.server.auth.errors import ( - stringify_pydantic_error, -) +from mcp.server.auth.errors import stringify_pydantic_error from mcp.server.auth.json_response import PydanticJSONResponse from mcp.server.auth.provider import ( AuthorizationErrorCode, @@ -18,10 +16,7 @@ OAuthAuthorizationServerProvider, construct_redirect_uri, ) -from mcp.shared.auth import ( - InvalidRedirectUriError, - InvalidScopeError, -) +from mcp.shared.auth import InvalidRedirectUriError, InvalidScopeError logger = logging.getLogger(__name__) @@ -29,23 +24,16 @@ class AuthorizationRequest(BaseModel): # See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1 client_id: str = Field(..., description="The client ID") - redirect_uri: AnyHttpUrl | None = Field( - None, description="URL to redirect to after authorization" - ) + redirect_uri: AnyUrl | None = Field(None, description="URL to redirect to after authorization") # see OAuthClientMetadata; we only support `code` - response_type: Literal["code"] = Field( - ..., description="Must be 'code' for authorization code flow" - ) + response_type: Literal["code"] = Field(..., description="Must be 'code' for authorization code flow") code_challenge: str = Field(..., description="PKCE code challenge") - code_challenge_method: Literal["S256"] = Field( - "S256", description="PKCE code challenge method, must be S256" - ) + code_challenge_method: Literal["S256"] = Field("S256", description="PKCE code challenge method, must be S256") state: str | None = Field(None, description="Optional state parameter") scope: str | None = Field( None, - description="Optional scope; if specified, should be " - "a space-separated list of scope strings", + description="Optional scope; if specified, should be " "a space-separated list of scope strings", ) @@ -57,9 +45,7 @@ class AuthorizationErrorResponse(BaseModel): state: str | None = None -def best_effort_extract_string( - key: str, params: None | FormData | QueryParams -) -> str | None: +def best_effort_extract_string(key: str, params: None | FormData | QueryParams) -> str | None: if params is None: return None value = params.get(key) @@ -68,8 +54,8 @@ def best_effort_extract_string( return None -class AnyHttpUrlModel(RootModel[AnyHttpUrl]): - root: AnyHttpUrl +class AnyUrlModel(RootModel[AnyUrl]): + root: AnyUrl @dataclass @@ -116,7 +102,7 @@ async def error_response( if params is not None and "redirect_uri" not in params: raw_redirect_uri = None else: - raw_redirect_uri = AnyHttpUrlModel.model_validate( + raw_redirect_uri = AnyUrlModel.model_validate( best_effort_extract_string("redirect_uri", params) ).root redirect_uri = client.validate_redirect_uri(raw_redirect_uri) @@ -138,9 +124,7 @@ async def error_response( if redirect_uri and client: return RedirectResponse( - url=construct_redirect_uri( - str(redirect_uri), **error_resp.model_dump(exclude_none=True) - ), + url=construct_redirect_uri(str(redirect_uri), **error_resp.model_dump(exclude_none=True)), status_code=302, headers={"Cache-Control": "no-store"}, ) @@ -172,9 +156,7 @@ async def error_response( if e["loc"] == ("response_type",) and e["type"] == "literal_error": error = "unsupported_response_type" break - return await error_response( - error, stringify_pydantic_error(validation_error) - ) + return await error_response(error, stringify_pydantic_error(validation_error)) # Get client information client = await self.provider.get_client( @@ -229,16 +211,9 @@ async def error_response( ) except AuthorizeError as e: # Handle authorization errors as defined in RFC 6749 Section 4.1.2.1 - return await error_response( - error=e.error, - error_description=e.error_description, - ) + return await error_response(error=e.error, error_description=e.error_description) except Exception as validation_error: # Catch-all for unexpected errors - logger.exception( - "Unexpected error in authorization_handler", exc_info=validation_error - ) - return await error_response( - error="server_error", error_description="An unexpected error occurred" - ) + logger.exception("Unexpected error in authorization_handler", exc_info=validation_error) + return await error_response(error="server_error", error_description="An unexpected error occurred") diff --git a/src/mcp/server/auth/handlers/register.py b/src/mcp/server/auth/handlers/register.py index 2e25c779a..61e403aca 100644 --- a/src/mcp/server/auth/handlers/register.py +++ b/src/mcp/server/auth/handlers/register.py @@ -10,11 +10,7 @@ from mcp.server.auth.errors import stringify_pydantic_error from mcp.server.auth.json_response import PydanticJSONResponse -from mcp.server.auth.provider import ( - OAuthAuthorizationServerProvider, - RegistrationError, - RegistrationErrorCode, -) +from mcp.server.auth.provider import OAuthAuthorizationServerProvider, RegistrationError, RegistrationErrorCode from mcp.server.auth.settings import ClientRegistrationOptions from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata @@ -60,9 +56,7 @@ async def handle(self, request: Request) -> Response: if client_metadata.scope is None and self.options.default_scopes is not None: client_metadata.scope = " ".join(self.options.default_scopes) - elif ( - client_metadata.scope is not None and self.options.valid_scopes is not None - ): + elif client_metadata.scope is not None and self.options.valid_scopes is not None: requested_scopes = set(client_metadata.scope.split()) valid_scopes = set(self.options.valid_scopes) if not requested_scopes.issubset(valid_scopes): @@ -78,8 +72,7 @@ async def handle(self, request: Request) -> Response: return PydanticJSONResponse( content=RegistrationErrorResponse( error="invalid_client_metadata", - error_description="grant_types must be authorization_code " - "and refresh_token", + error_description="grant_types must be authorization_code " "and refresh_token", ), status_code=400, ) @@ -122,8 +115,6 @@ async def handle(self, request: Request) -> Response: except RegistrationError as e: # Handle registration errors as defined in RFC 7591 Section 3.2.2 return PydanticJSONResponse( - content=RegistrationErrorResponse( - error=e.error, error_description=e.error_description - ), + content=RegistrationErrorResponse(error=e.error, error_description=e.error_description), status_code=400, ) diff --git a/src/mcp/server/auth/handlers/revoke.py b/src/mcp/server/auth/handlers/revoke.py index 43b4dded9..478ad7a01 100644 --- a/src/mcp/server/auth/handlers/revoke.py +++ b/src/mcp/server/auth/handlers/revoke.py @@ -10,15 +10,8 @@ stringify_pydantic_error, ) from mcp.server.auth.json_response import PydanticJSONResponse -from mcp.server.auth.middleware.client_auth import ( - AuthenticationError, - ClientAuthenticator, -) -from mcp.server.auth.provider import ( - AccessToken, - OAuthAuthorizationServerProvider, - RefreshToken, -) +from mcp.server.auth.middleware.client_auth import AuthenticationError, ClientAuthenticator +from mcp.server.auth.provider import AccessToken, OAuthAuthorizationServerProvider, RefreshToken class RevocationRequest(BaseModel): diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index 94a5c4de3..d73455200 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -4,22 +4,13 @@ from dataclasses import dataclass from typing import Annotated, Any, Literal -from pydantic import AnyHttpUrl, BaseModel, Field, RootModel, ValidationError +from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, RootModel, ValidationError from starlette.requests import Request -from mcp.server.auth.errors import ( - stringify_pydantic_error, -) +from mcp.server.auth.errors import stringify_pydantic_error from mcp.server.auth.json_response import PydanticJSONResponse -from mcp.server.auth.middleware.client_auth import ( - AuthenticationError, - ClientAuthenticator, -) -from mcp.server.auth.provider import ( - OAuthAuthorizationServerProvider, - TokenError, - TokenErrorCode, -) +from mcp.server.auth.middleware.client_auth import AuthenticationError, ClientAuthenticator +from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenError, TokenErrorCode from mcp.shared.auth import OAuthToken @@ -27,9 +18,7 @@ class AuthorizationCodeRequest(BaseModel): # See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3 grant_type: Literal["authorization_code"] code: str = Field(..., description="The authorization code") - redirect_uri: AnyHttpUrl | None = Field( - None, description="Must be the same as redirect URI provided in /authorize" - ) + redirect_uri: AnyUrl | None = Field(None, description="Must be the same as redirect URI provided in /authorize") client_id: str # we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1 client_secret: str | None = None @@ -127,8 +116,7 @@ async def handle(self, request: Request): TokenErrorResponse( error="unsupported_grant_type", error_description=( - f"Unsupported grant type (supported grant types are " - f"{client_info.grant_types})" + f"Unsupported grant type (supported grant types are " f"{client_info.grant_types})" ), ) ) @@ -137,9 +125,7 @@ async def handle(self, request: Request): match token_request: case AuthorizationCodeRequest(): - auth_code = await self.provider.load_authorization_code( - client_info, token_request.code - ) + auth_code = await self.provider.load_authorization_code(client_info, token_request.code) if auth_code is None or auth_code.client_id != token_request.client_id: # if code belongs to different client, pretend it doesn't exist return self.response( @@ -169,18 +155,13 @@ async def handle(self, request: Request): return self.response( TokenErrorResponse( error="invalid_request", - error_description=( - "redirect_uri did not match the one " - "used when creating auth code" - ), + error_description=("redirect_uri did not match the one " "used when creating auth code"), ) ) # Verify PKCE code verifier sha256 = hashlib.sha256(token_request.code_verifier.encode()).digest() - hashed_code_verifier = ( - base64.urlsafe_b64encode(sha256).decode().rstrip("=") - ) + hashed_code_verifier = base64.urlsafe_b64encode(sha256).decode().rstrip("=") if hashed_code_verifier != auth_code.code_challenge: # see https://datatracker.ietf.org/doc/html/rfc7636#section-4.6 @@ -193,9 +174,7 @@ async def handle(self, request: Request): try: # Exchange authorization code for tokens - tokens = await self.provider.exchange_authorization_code( - client_info, auth_code - ) + tokens = await self.provider.exchange_authorization_code(client_info, auth_code) except TokenError as e: return self.response( TokenErrorResponse( @@ -205,13 +184,8 @@ async def handle(self, request: Request): ) case RefreshTokenRequest(): - refresh_token = await self.provider.load_refresh_token( - client_info, token_request.refresh_token - ) - if ( - refresh_token is None - or refresh_token.client_id != token_request.client_id - ): + refresh_token = await self.provider.load_refresh_token(client_info, token_request.refresh_token) + if refresh_token is None or refresh_token.client_id != token_request.client_id: # if token belongs to different client, pretend it doesn't exist return self.response( TokenErrorResponse( @@ -230,29 +204,20 @@ async def handle(self, request: Request): ) # Parse scopes if provided - scopes = ( - token_request.scope.split(" ") - if token_request.scope - else refresh_token.scopes - ) + scopes = token_request.scope.split(" ") if token_request.scope else refresh_token.scopes for scope in scopes: if scope not in refresh_token.scopes: return self.response( TokenErrorResponse( error="invalid_scope", - error_description=( - f"cannot request scope `{scope}` " - "not provided by refresh token" - ), + error_description=(f"cannot request scope `{scope}` " "not provided by refresh token"), ) ) try: # Exchange refresh token for new tokens - tokens = await self.provider.exchange_refresh_token( - client_info, refresh_token, scopes - ) + tokens = await self.provider.exchange_refresh_token(client_info, refresh_token, scopes) except TokenError as e: return self.response( TokenErrorResponse( diff --git a/src/mcp/server/auth/middleware/auth_context.py b/src/mcp/server/auth/middleware/auth_context.py index 1073c07ad..e2116c3bf 100644 --- a/src/mcp/server/auth/middleware/auth_context.py +++ b/src/mcp/server/auth/middleware/auth_context.py @@ -7,9 +7,7 @@ # Create a contextvar to store the authenticated user # The default is None, indicating no authenticated user is present -auth_context_var = contextvars.ContextVar[AuthenticatedUser | None]( - "auth_context", default=None -) +auth_context_var = contextvars.ContextVar[AuthenticatedUser | None]("auth_context", default=None) def get_access_token() -> AccessToken | None: diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py index 30b5e2ba6..2fe1342b7 100644 --- a/src/mcp/server/auth/middleware/bearer_auth.py +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -1,11 +1,7 @@ import time from typing import Any -from starlette.authentication import ( - AuthCredentials, - AuthenticationBackend, - SimpleUser, -) +from starlette.authentication import AuthCredentials, AuthenticationBackend, SimpleUser from starlette.exceptions import HTTPException from starlette.requests import HTTPConnection from starlette.types import Receive, Scope, Send @@ -35,11 +31,7 @@ def __init__( async def authenticate(self, conn: HTTPConnection): auth_header = next( - ( - conn.headers.get(key) - for key in conn.headers - if key.lower() == "authorization" - ), + (conn.headers.get(key) for key in conn.headers if key.lower() == "authorization"), None, ) if not auth_header or not auth_header.lower().startswith("bearer "): @@ -87,10 +79,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: for required_scope in self.required_scopes: # auth_credentials should always be provided; this is just paranoia - if ( - auth_credentials is None - or required_scope not in auth_credentials.scopes - ): + if auth_credentials is None or required_scope not in auth_credentials.scopes: raise HTTPException(status_code=403, detail="Insufficient scope") await self.app(scope, receive, send) diff --git a/src/mcp/server/auth/middleware/client_auth.py b/src/mcp/server/auth/middleware/client_auth.py index 37f7f5066..d5f473b48 100644 --- a/src/mcp/server/auth/middleware/client_auth.py +++ b/src/mcp/server/auth/middleware/client_auth.py @@ -30,9 +30,7 @@ def __init__(self, provider: OAuthAuthorizationServerProvider[Any, Any, Any]): """ self.provider = provider - async def authenticate( - self, client_id: str, client_secret: str | None - ) -> OAuthClientInformationFull: + async def authenticate(self, client_id: str, client_secret: str | None) -> OAuthClientInformationFull: # Look up client information client = await self.provider.get_client(client_id) if not client: @@ -47,10 +45,7 @@ async def authenticate( if client.client_secret != client_secret: raise AuthenticationError("Invalid client_secret") - if ( - client.client_secret_expires_at - and client.client_secret_expires_at < int(time.time()) - ): + if client.client_secret_expires_at and client.client_secret_expires_at < int(time.time()): raise AuthenticationError("Client secret has expired") return client diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index be1ac1dbc..da18d7a71 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -2,19 +2,16 @@ from typing import Generic, Literal, Protocol, TypeVar from urllib.parse import parse_qs, urlencode, urlparse, urlunparse -from pydantic import AnyHttpUrl, BaseModel +from pydantic import AnyUrl, BaseModel -from mcp.shared.auth import ( - OAuthClientInformationFull, - OAuthToken, -) +from mcp.shared.auth import OAuthClientInformationFull, OAuthToken class AuthorizationParams(BaseModel): state: str | None scopes: list[str] | None code_challenge: str - redirect_uri: AnyHttpUrl + redirect_uri: AnyUrl redirect_uri_provided_explicitly: bool @@ -24,7 +21,7 @@ class AuthorizationCode(BaseModel): expires_at: float client_id: str code_challenge: str - redirect_uri: AnyHttpUrl + redirect_uri: AnyUrl redirect_uri_provided_explicitly: bool @@ -96,9 +93,7 @@ class TokenError(Exception): AccessTokenT = TypeVar("AccessTokenT", bound=AccessToken) -class OAuthAuthorizationServerProvider( - Protocol, Generic[AuthorizationCodeT, RefreshTokenT, AccessTokenT] -): +class OAuthAuthorizationServerProvider(Protocol, Generic[AuthorizationCodeT, RefreshTokenT, AccessTokenT]): async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: """ Retrieves client information by client ID. @@ -129,9 +124,7 @@ async def register_client(self, client_info: OAuthClientInformationFull) -> None """ ... - async def authorize( - self, client: OAuthClientInformationFull, params: AuthorizationParams - ) -> str: + async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str: """ Called as part of the /authorize endpoint, and returns a URL that the client will be redirected to. @@ -207,9 +200,7 @@ async def exchange_authorization_code( """ ... - async def load_refresh_token( - self, client: OAuthClientInformationFull, refresh_token: str - ) -> RefreshTokenT | None: + async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshTokenT | None: """ Loads a RefreshToken by its token string. diff --git a/src/mcp/server/auth/routes.py b/src/mcp/server/auth/routes.py index d588d78ee..8647334e0 100644 --- a/src/mcp/server/auth/routes.py +++ b/src/mcp/server/auth/routes.py @@ -16,6 +16,7 @@ from mcp.server.auth.middleware.client_auth import ClientAuthenticator from mcp.server.auth.provider import OAuthAuthorizationServerProvider from mcp.server.auth.settings import ClientRegistrationOptions, RevocationOptions +from mcp.server.streamable_http import MCP_PROTOCOL_VERSION_HEADER from mcp.shared.auth import OAuthMetadata @@ -31,11 +32,7 @@ def validate_issuer_url(url: AnyHttpUrl): """ # RFC 8414 requires HTTPS, but we allow localhost HTTP for testing - if ( - url.scheme != "https" - and url.host != "localhost" - and not url.host.startswith("127.0.0.1") - ): + if url.scheme != "https" and url.host != "localhost" and not url.host.startswith("127.0.0.1"): raise ValueError("Issuer URL must be HTTPS") # No fragments or query parameters allowed @@ -59,7 +56,7 @@ def cors_middleware( app=request_response(handler), allow_origins="*", allow_methods=allow_methods, - allow_headers=["mcp-protocol-version"], + allow_headers=[MCP_PROTOCOL_VERSION_HEADER], ) return cors_app @@ -73,9 +70,7 @@ def create_auth_routes( ) -> list[Route]: validate_issuer_url(issuer_url) - client_registration_options = ( - client_registration_options or ClientRegistrationOptions() - ) + client_registration_options = client_registration_options or ClientRegistrationOptions() revocation_options = revocation_options or RevocationOptions() metadata = build_metadata( issuer_url, @@ -177,15 +172,11 @@ def build_metadata( # Add registration endpoint if supported if client_registration_options.enabled: - metadata.registration_endpoint = AnyHttpUrl( - str(issuer_url).rstrip("/") + REGISTRATION_PATH - ) + metadata.registration_endpoint = AnyHttpUrl(str(issuer_url).rstrip("/") + REGISTRATION_PATH) # Add revocation endpoint if supported if revocation_options.enabled: - metadata.revocation_endpoint = AnyHttpUrl( - str(issuer_url).rstrip("/") + REVOCATION_PATH - ) + metadata.revocation_endpoint = AnyHttpUrl(str(issuer_url).rstrip("/") + REVOCATION_PATH) metadata.revocation_endpoint_auth_methods_supported = ["client_secret_post"] return metadata diff --git a/src/mcp/server/auth/settings.py b/src/mcp/server/auth/settings.py index 1086bb77e..7306d91af 100644 --- a/src/mcp/server/auth/settings.py +++ b/src/mcp/server/auth/settings.py @@ -15,8 +15,7 @@ class RevocationOptions(BaseModel): class AuthSettings(BaseModel): issuer_url: AnyHttpUrl = Field( ..., - description="URL advertised as OAuth issuer; this should be the URL the server " - "is reachable at", + description="URL advertised as OAuth issuer; this should be the URL the server " "is reachable at", ) service_documentation_url: AnyHttpUrl | None = None client_registration_options: ClientRegistrationOptions | None = None diff --git a/src/mcp/server/elicitation.py b/src/mcp/server/elicitation.py new file mode 100644 index 000000000..1e48738c8 --- /dev/null +++ b/src/mcp/server/elicitation.py @@ -0,0 +1,111 @@ +"""Elicitation utilities for MCP servers.""" + +from __future__ import annotations + +import types +from typing import Generic, Literal, TypeVar, Union, get_args, get_origin + +from pydantic import BaseModel +from pydantic.fields import FieldInfo + +from mcp.server.session import ServerSession +from mcp.types import RequestId + +ElicitSchemaModelT = TypeVar("ElicitSchemaModelT", bound=BaseModel) + + +class AcceptedElicitation(BaseModel, Generic[ElicitSchemaModelT]): + """Result when user accepts the elicitation.""" + + action: Literal["accept"] = "accept" + data: ElicitSchemaModelT + + +class DeclinedElicitation(BaseModel): + """Result when user declines the elicitation.""" + + action: Literal["decline"] = "decline" + + +class CancelledElicitation(BaseModel): + """Result when user cancels the elicitation.""" + + action: Literal["cancel"] = "cancel" + + +ElicitationResult = AcceptedElicitation[ElicitSchemaModelT] | DeclinedElicitation | CancelledElicitation + + +# Primitive types allowed in elicitation schemas +_ELICITATION_PRIMITIVE_TYPES = (str, int, float, bool) + + +def _validate_elicitation_schema(schema: type[BaseModel]) -> None: + """Validate that a Pydantic model only contains primitive field types.""" + for field_name, field_info in schema.model_fields.items(): + if not _is_primitive_field(field_info): + raise TypeError( + f"Elicitation schema field '{field_name}' must be a primitive type " + f"{_ELICITATION_PRIMITIVE_TYPES} or Optional of these types. " + f"Complex types like lists, dicts, or nested models are not allowed." + ) + + +def _is_primitive_field(field_info: FieldInfo) -> bool: + """Check if a field is a primitive type allowed in elicitation schemas.""" + annotation = field_info.annotation + + # Handle None type + if annotation is types.NoneType: + return True + + # Handle basic primitive types + if annotation in _ELICITATION_PRIMITIVE_TYPES: + return True + + # Handle Union types + origin = get_origin(annotation) + if origin is Union or origin is types.UnionType: + args = get_args(annotation) + # All args must be primitive types or None + return all(arg is types.NoneType or arg in _ELICITATION_PRIMITIVE_TYPES for arg in args) + + return False + + +async def elicit_with_validation( + session: ServerSession, + message: str, + schema: type[ElicitSchemaModelT], + related_request_id: RequestId | None = None, +) -> ElicitationResult[ElicitSchemaModelT]: + """Elicit information from the client/user with schema validation. + + This method can be used to interactively ask for additional information from the + client within a tool's execution. The client might display the message to the + user and collect a response according to the provided schema. Or in case a + client is an agent, it might decide how to handle the elicitation -- either by asking + the user or automatically generating a response. + """ + # Validate that schema only contains primitive types and fail loudly if not + _validate_elicitation_schema(schema) + + json_schema = schema.model_json_schema() + + result = await session.elicit( + message=message, + requestedSchema=json_schema, + related_request_id=related_request_id, + ) + + if result.action == "accept" and result.content: + # Validate and parse the content using the schema + validated_data = schema.model_validate(result.content) + return AcceptedElicitation(data=validated_data) + elif result.action == "decline": + return DeclinedElicitation() + elif result.action == "cancel": + return CancelledElicitation() + else: + # This should never happen, but handle it just in case + raise ValueError(f"Unexpected elicitation action: {result.action}") diff --git a/src/mcp/server/fastmcp/prompts/base.py b/src/mcp/server/fastmcp/prompts/base.py index aa3d1eac9..b45cfc917 100644 --- a/src/mcp/server/fastmcp/prompts/base.py +++ b/src/mcp/server/fastmcp/prompts/base.py @@ -7,18 +7,16 @@ import pydantic_core from pydantic import BaseModel, Field, TypeAdapter, validate_call -from mcp.types import EmbeddedResource, ImageContent, TextContent - -CONTENT_TYPES = TextContent | ImageContent | EmbeddedResource +from mcp.types import ContentBlock, TextContent class Message(BaseModel): """Base class for all prompt messages.""" role: Literal["user", "assistant"] - content: CONTENT_TYPES + content: ContentBlock - def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any): + def __init__(self, content: str | ContentBlock, **kwargs: Any): if isinstance(content, str): content = TextContent(type="text", text=content) super().__init__(content=content, **kwargs) @@ -29,7 +27,7 @@ class UserMessage(Message): role: Literal["user", "assistant"] = "user" - def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any): + def __init__(self, content: str | ContentBlock, **kwargs: Any): super().__init__(content=content, **kwargs) @@ -38,17 +36,13 @@ class AssistantMessage(Message): role: Literal["user", "assistant"] = "assistant" - def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any): + def __init__(self, content: str | ContentBlock, **kwargs: Any): super().__init__(content=content, **kwargs) -message_validator = TypeAdapter[UserMessage | AssistantMessage]( - UserMessage | AssistantMessage -) +message_validator = TypeAdapter[UserMessage | AssistantMessage](UserMessage | AssistantMessage) -SyncPromptResult = ( - str | Message | dict[str, Any] | Sequence[str | Message | dict[str, Any]] -) +SyncPromptResult = str | Message | dict[str, Any] | Sequence[str | Message | dict[str, Any]] PromptResult = SyncPromptResult | Awaitable[SyncPromptResult] @@ -56,24 +50,17 @@ class PromptArgument(BaseModel): """An argument that can be passed to a prompt.""" name: str = Field(description="Name of the argument") - description: str | None = Field( - None, description="Description of what the argument does" - ) - required: bool = Field( - default=False, description="Whether the argument is required" - ) + description: str | None = Field(None, description="Description of what the argument does") + required: bool = Field(default=False, description="Whether the argument is required") class Prompt(BaseModel): """A prompt template that can be rendered with parameters.""" name: str = Field(description="Name of the prompt") - description: str | None = Field( - None, description="Description of what the prompt does" - ) - arguments: list[PromptArgument] | None = Field( - None, description="Arguments that can be passed to the prompt" - ) + title: str | None = Field(None, description="Human-readable title of the prompt") + description: str | None = Field(None, description="Description of what the prompt does") + arguments: list[PromptArgument] | None = Field(None, description="Arguments that can be passed to the prompt") fn: Callable[..., PromptResult | Awaitable[PromptResult]] = Field(exclude=True) @classmethod @@ -81,6 +68,7 @@ def from_function( cls, fn: Callable[..., PromptResult | Awaitable[PromptResult]], name: str | None = None, + title: str | None = None, description: str | None = None, ) -> "Prompt": """Create a Prompt from a function. @@ -117,6 +105,7 @@ def from_function( return cls( name=func_name, + title=title, description=description or fn.__doc__ or "", arguments=arguments, fn=fn, @@ -154,14 +143,10 @@ async def render(self, arguments: dict[str, Any] | None = None) -> list[Message] content = TextContent(type="text", text=msg) messages.append(UserMessage(content=content)) else: - content = pydantic_core.to_json( - msg, fallback=str, indent=2 - ).decode() + content = pydantic_core.to_json(msg, fallback=str, indent=2).decode() messages.append(Message(role="user", content=content)) except Exception: - raise ValueError( - f"Could not convert prompt result to message: {msg}" - ) + raise ValueError(f"Could not convert prompt result to message: {msg}") return messages except Exception as e: diff --git a/src/mcp/server/fastmcp/prompts/manager.py b/src/mcp/server/fastmcp/prompts/manager.py index 7ccbdef36..6b01d91cd 100644 --- a/src/mcp/server/fastmcp/prompts/manager.py +++ b/src/mcp/server/fastmcp/prompts/manager.py @@ -39,9 +39,7 @@ def add_prompt( self._prompts[prompt.name] = prompt return prompt - async def render_prompt( - self, name: str, arguments: dict[str, Any] | None = None - ) -> list[Message]: + async def render_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> list[Message]: """Render a prompt by name with arguments.""" prompt = self.get_prompt(name) if not prompt: diff --git a/src/mcp/server/fastmcp/resources/base.py b/src/mcp/server/fastmcp/resources/base.py index b2050e7f8..f57631cc1 100644 --- a/src/mcp/server/fastmcp/resources/base.py +++ b/src/mcp/server/fastmcp/resources/base.py @@ -19,13 +19,10 @@ class Resource(BaseModel, abc.ABC): model_config = ConfigDict(validate_default=True) - uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] = Field( - default=..., description="URI of the resource" - ) + uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] = Field(default=..., description="URI of the resource") name: str | None = Field(description="Name of the resource", default=None) - description: str | None = Field( - description="Description of the resource", default=None - ) + title: str | None = Field(description="Human-readable title of the resource", default=None) + description: str | None = Field(description="Description of the resource", default=None) mime_type: str = Field( default="text/plain", description="MIME type of the resource content", diff --git a/src/mcp/server/fastmcp/resources/resource_manager.py b/src/mcp/server/fastmcp/resources/resource_manager.py index d27e6ac12..35e4ec04d 100644 --- a/src/mcp/server/fastmcp/resources/resource_manager.py +++ b/src/mcp/server/fastmcp/resources/resource_manager.py @@ -51,6 +51,7 @@ def add_template( fn: Callable[..., Any], uri_template: str, name: str | None = None, + title: str | None = None, description: str | None = None, mime_type: str | None = None, ) -> ResourceTemplate: @@ -59,6 +60,7 @@ def add_template( fn, uri_template=uri_template, name=name, + title=title, description=description, mime_type=mime_type, ) diff --git a/src/mcp/server/fastmcp/resources/templates.py b/src/mcp/server/fastmcp/resources/templates.py index a30b18253..b1c7b2711 100644 --- a/src/mcp/server/fastmcp/resources/templates.py +++ b/src/mcp/server/fastmcp/resources/templates.py @@ -15,18 +15,13 @@ class ResourceTemplate(BaseModel): """A template for dynamically creating resources.""" - uri_template: str = Field( - description="URI template with parameters (e.g. weather://{city}/current)" - ) + uri_template: str = Field(description="URI template with parameters (e.g. weather://{city}/current)") name: str = Field(description="Name of the resource") + title: str | None = Field(description="Human-readable title of the resource", default=None) description: str | None = Field(description="Description of what the resource does") - mime_type: str = Field( - default="text/plain", description="MIME type of the resource content" - ) + mime_type: str = Field(default="text/plain", description="MIME type of the resource content") fn: Callable[..., Any] = Field(exclude=True) - parameters: dict[str, Any] = Field( - description="JSON schema for function parameters" - ) + parameters: dict[str, Any] = Field(description="JSON schema for function parameters") @classmethod def from_function( @@ -34,6 +29,7 @@ def from_function( fn: Callable[..., Any], uri_template: str, name: str | None = None, + title: str | None = None, description: str | None = None, mime_type: str | None = None, ) -> ResourceTemplate: @@ -51,6 +47,7 @@ def from_function( return cls( uri_template=uri_template, name=func_name, + title=title, description=description or fn.__doc__ or "", mime_type=mime_type or "text/plain", fn=fn, @@ -77,6 +74,7 @@ async def create_resource(self, uri: str, params: dict[str, Any]) -> Resource: return FunctionResource( uri=uri, # type: ignore name=self.name, + title=self.title, description=self.description, mime_type=self.mime_type, fn=lambda: result, # Capture result in closure diff --git a/src/mcp/server/fastmcp/resources/types.py b/src/mcp/server/fastmcp/resources/types.py index d3f10211d..9c980dff1 100644 --- a/src/mcp/server/fastmcp/resources/types.py +++ b/src/mcp/server/fastmcp/resources/types.py @@ -54,9 +54,7 @@ class FunctionResource(Resource): async def read(self) -> str | bytes: """Read the resource by calling the wrapped function.""" try: - result = ( - await self.fn() if inspect.iscoroutinefunction(self.fn) else self.fn() - ) + result = await self.fn() if inspect.iscoroutinefunction(self.fn) else self.fn() if isinstance(result, Resource): return await result.read() elif isinstance(result, bytes): @@ -74,6 +72,7 @@ def from_function( fn: Callable[..., Any], uri: str, name: str | None = None, + title: str | None = None, description: str | None = None, mime_type: str | None = None, ) -> "FunctionResource": @@ -88,6 +87,7 @@ def from_function( return cls( uri=AnyUrl(uri), name=func_name, + title=title, description=description or fn.__doc__ or "", mime_type=mime_type or "text/plain", fn=fn, @@ -141,9 +141,7 @@ class HttpResource(Resource): """A resource that reads from an HTTP endpoint.""" url: str = Field(description="URL to fetch content from") - mime_type: str = Field( - default="application/json", description="MIME type of the resource content" - ) + mime_type: str = Field(default="application/json", description="MIME type of the resource content") async def read(self) -> str | bytes: """Read the HTTP content.""" @@ -157,15 +155,9 @@ class DirectoryResource(Resource): """A resource that lists files in a directory.""" path: Path = Field(description="Path to the directory") - recursive: bool = Field( - default=False, description="Whether to list files recursively" - ) - pattern: str | None = Field( - default=None, description="Optional glob pattern to filter files" - ) - mime_type: str = Field( - default="application/json", description="MIME type of the resource content" - ) + recursive: bool = Field(default=False, description="Whether to list files recursively") + pattern: str | None = Field(default=None, description="Optional glob pattern to filter files") + mime_type: str = Field(default="application/json", description="MIME type of the resource content") @pydantic.field_validator("path") @classmethod @@ -184,16 +176,8 @@ def list_files(self) -> list[Path]: try: if self.pattern: - return ( - list(self.path.glob(self.pattern)) - if not self.recursive - else list(self.path.rglob(self.pattern)) - ) - return ( - list(self.path.glob("*")) - if not self.recursive - else list(self.path.rglob("*")) - ) + return list(self.path.glob(self.pattern)) if not self.recursive else list(self.path.rglob(self.pattern)) + return list(self.path.glob("*")) if not self.recursive else list(self.path.rglob("*")) except Exception as e: raise ValueError(f"Error listing directory {self.path}: {e}") diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index e5b6c3acc..1b761e917 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -34,6 +34,7 @@ from mcp.server.auth.settings import ( AuthSettings, ) +from mcp.server.elicitation import ElicitationResult, ElicitSchemaModelT, elicit_with_validation from mcp.server.fastmcp.exceptions import ResourceError from mcp.server.fastmcp.prompts import Prompt, PromptManager from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager @@ -49,12 +50,12 @@ from mcp.server.stdio import stdio_server from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from mcp.server.transport_security import TransportSecuritySettings from mcp.shared.context import LifespanContextT, RequestContext, RequestT from mcp.types import ( AnyFunction, - EmbeddedResource, + ContentBlock, GetPromptResult, - ImageContent, TextContent, ToolAnnotations, ) @@ -96,9 +97,7 @@ class Settings(BaseSettings, Generic[LifespanResultT]): # StreamableHTTP settings json_response: bool = False - stateless_http: bool = ( - False # If True, uses true stateless mode (new transport per request) - ) + stateless_http: bool = False # If True, uses true stateless mode (new transport per request) # resource settings warn_on_duplicate_resources: bool = True @@ -114,19 +113,20 @@ class Settings(BaseSettings, Generic[LifespanResultT]): description="List of dependencies to install in the server environment", ) - lifespan: ( - Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]] | None - ) = Field(None, description="Lifespan context manager") + lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]] | None = Field( + None, description="Lifespan context manager" + ) auth: AuthSettings | None = None + # Transport security settings (DNS rebinding protection) + transport_security: TransportSecuritySettings | None = None + def lifespan_wrapper( app: FastMCP, lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]], -) -> Callable[ - [MCPServer[LifespanResultT, Request]], AbstractAsyncContextManager[object] -]: +) -> Callable[[MCPServer[LifespanResultT, Request]], AbstractAsyncContextManager[object]]: @asynccontextmanager async def wrap(s: MCPServer[LifespanResultT, Request]) -> AsyncIterator[object]: async with lifespan(app) as context: @@ -140,8 +140,7 @@ def __init__( self, name: str | None = None, instructions: str | None = None, - auth_server_provider: OAuthAuthorizationServerProvider[Any, Any, Any] - | None = None, + auth_server_provider: OAuthAuthorizationServerProvider[Any, Any, Any] | None = None, event_store: EventStore | None = None, *, tools: list[Tool] | None = None, @@ -152,31 +151,18 @@ def __init__( self._mcp_server = MCPServer( name=name or "FastMCP", instructions=instructions, - lifespan=( - lifespan_wrapper(self, self.settings.lifespan) - if self.settings.lifespan - else default_lifespan - ), - ) - self._tool_manager = ToolManager( - tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools - ) - self._resource_manager = ResourceManager( - warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources - ) - self._prompt_manager = PromptManager( - warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts + lifespan=(lifespan_wrapper(self, self.settings.lifespan) if self.settings.lifespan else default_lifespan), ) + self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools) + self._resource_manager = ResourceManager(warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources) + self._prompt_manager = PromptManager(warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts) if (self.settings.auth is not None) != (auth_server_provider is not None): # TODO: after we support separate authorization servers (see # https://github.com/modelcontextprotocol/modelcontextprotocol/pull/284) # we should validate that if auth is enabled, we have either an # auth_server_provider to host our own authorization server, # OR the URL of a 3rd party authorization server. - raise ValueError( - "settings.auth must be specified if and only if auth_server_provider " - "is specified" - ) + raise ValueError("settings.auth must be specified if and only if auth_server_provider " "is specified") self._auth_server_provider = auth_server_provider self._event_store = event_store self._custom_starlette_routes: list[Route] = [] @@ -255,6 +241,7 @@ async def list_tools(self) -> list[MCPTool]: return [ MCPTool( name=info.name, + title=info.title, description=info.description, inputSchema=info.parameters, annotations=info.annotations, @@ -273,9 +260,7 @@ def get_context(self) -> Context[ServerSession, object, Request]: request_context = None return Context(request_context=request_context, fastmcp=self) - async def call_tool( - self, name: str, arguments: dict[str, Any] - ) -> Sequence[TextContent | ImageContent | EmbeddedResource]: + async def call_tool(self, name: str, arguments: dict[str, Any]) -> Sequence[ContentBlock]: """Call a tool by name with arguments.""" context = self.get_context() result = await self._tool_manager.call_tool(name, arguments, context=context) @@ -290,6 +275,7 @@ async def list_resources(self) -> list[MCPResource]: MCPResource( uri=resource.uri, name=resource.name or "", + title=resource.title, description=resource.description, mimeType=resource.mime_type, ) @@ -302,6 +288,7 @@ async def list_resource_templates(self) -> list[MCPResourceTemplate]: MCPResourceTemplate( uriTemplate=template.uri_template, name=template.name, + title=template.title, description=template.description, ) for template in templates @@ -325,6 +312,7 @@ def add_tool( self, fn: AnyFunction, name: str | None = None, + title: str | None = None, description: str | None = None, annotations: ToolAnnotations | None = None, ) -> None: @@ -336,16 +324,16 @@ def add_tool( Args: fn: The function to register as a tool name: Optional name for the tool (defaults to function name) + title: Optional human-readable title for the tool description: Optional description of what the tool does annotations: Optional ToolAnnotations providing additional tool information """ - self._tool_manager.add_tool( - fn, name=name, description=description, annotations=annotations - ) + self._tool_manager.add_tool(fn, name=name, title=title, description=description, annotations=annotations) def tool( self, name: str | None = None, + title: str | None = None, description: str | None = None, annotations: ToolAnnotations | None = None, ) -> Callable[[AnyFunction], AnyFunction]: @@ -357,6 +345,7 @@ def tool( Args: name: Optional name for the tool (defaults to function name) + title: Optional human-readable title for the tool description: Optional description of what the tool does annotations: Optional ToolAnnotations providing additional tool information @@ -378,18 +367,33 @@ async def async_tool(x: int, context: Context) -> str: # Check if user passed function directly instead of calling decorator if callable(name): raise TypeError( - "The @tool decorator was used incorrectly. " - "Did you forget to call it? Use @tool() instead of @tool" + "The @tool decorator was used incorrectly. " "Did you forget to call it? Use @tool() instead of @tool" ) def decorator(fn: AnyFunction) -> AnyFunction: - self.add_tool( - fn, name=name, description=description, annotations=annotations - ) + self.add_tool(fn, name=name, title=title, description=description, annotations=annotations) return fn return decorator + def completion(self): + """Decorator to register a completion handler. + + The completion handler receives: + - ref: PromptReference or ResourceTemplateReference + - argument: CompletionArgument with name and partial value + - context: Optional CompletionContext with previously resolved arguments + + Example: + @mcp.completion() + async def handle_completion(ref, argument, context): + if isinstance(ref, ResourceTemplateReference): + # Return completions based on ref, argument, and context + return Completion(values=["option1", "option2"]) + return None + """ + return self._mcp_server.completion() + def add_resource(self, resource: Resource) -> None: """Add a resource to the server. @@ -403,6 +407,7 @@ def resource( uri: str, *, name: str | None = None, + title: str | None = None, description: str | None = None, mime_type: str | None = None, ) -> Callable[[AnyFunction], AnyFunction]: @@ -420,6 +425,7 @@ def resource( Args: uri: URI for the resource (e.g. "resource://my-resource" or "resource://{param}") name: Optional name for the resource + title: Optional human-readable title for the resource description: Optional description of the resource mime_type: Optional MIME type for the resource @@ -461,8 +467,7 @@ def decorator(fn: AnyFunction) -> AnyFunction: if uri_params != func_params: raise ValueError( - f"Mismatch between URI parameters {uri_params} " - f"and function parameters {func_params}" + f"Mismatch between URI parameters {uri_params} " f"and function parameters {func_params}" ) # Register as template @@ -470,6 +475,7 @@ def decorator(fn: AnyFunction) -> AnyFunction: fn=fn, uri_template=uri, name=name, + title=title, description=description, mime_type=mime_type, ) @@ -479,6 +485,7 @@ def decorator(fn: AnyFunction) -> AnyFunction: fn=fn, uri=uri, name=name, + title=title, description=description, mime_type=mime_type, ) @@ -496,12 +503,13 @@ def add_prompt(self, prompt: Prompt) -> None: self._prompt_manager.add_prompt(prompt) def prompt( - self, name: str | None = None, description: str | None = None + self, name: str | None = None, title: str | None = None, description: str | None = None ) -> Callable[[AnyFunction], AnyFunction]: """Decorator to register a prompt. Args: name: Optional name for the prompt (defaults to function name) + title: Optional human-readable title for the prompt description: Optional description of what the prompt does Example: @@ -539,7 +547,7 @@ async def analyze_file(path: str) -> list[Message]: ) def decorator(func: AnyFunction) -> AnyFunction: - prompt = Prompt.from_function(func, name=name, description=description) + prompt = Prompt.from_function(func, name=name, title=title, description=description) self.add_prompt(prompt) return func @@ -664,14 +672,13 @@ def sse_app(self, mount_path: str | None = None) -> Starlette: self.settings.mount_path = mount_path # Create normalized endpoint considering the mount path - normalized_message_endpoint = self._normalize_path( - self.settings.mount_path, self.settings.message_path - ) + normalized_message_endpoint = self._normalize_path(self.settings.mount_path, self.settings.message_path) # Set up auth context and dependencies sse = SseServerTransport( normalized_message_endpoint, + security_settings=self.settings.transport_security, ) async def handle_sse(scope: Scope, receive: Receive, send: Send): @@ -763,9 +770,7 @@ async def sse_endpoint(request: Request) -> Response: routes.extend(self._custom_starlette_routes) # Create Starlette app with routes and middleware - return Starlette( - debug=self.settings.debug, routes=routes, middleware=middleware - ) + return Starlette(debug=self.settings.debug, routes=routes, middleware=middleware) def streamable_http_app(self) -> Starlette: """Return an instance of the StreamableHTTP server app.""" @@ -779,12 +784,11 @@ def streamable_http_app(self) -> Starlette: event_store=self._event_store, json_response=self.settings.json_response, stateless=self.settings.stateless_http, # Use the stateless setting + security_settings=self.settings.transport_security, ) # Create the ASGI handler - async def handle_streamable_http( - scope: Scope, receive: Receive, send: Send - ) -> None: + async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None: await self.session_manager.handle_request(scope, receive, send) # Create routes @@ -847,6 +851,7 @@ async def list_prompts(self) -> list[MCPPrompt]: return [ MCPPrompt( name=prompt.name, + title=prompt.title, description=prompt.description, arguments=[ MCPPromptArgument( @@ -860,9 +865,7 @@ async def list_prompts(self) -> list[MCPPrompt]: for prompt in prompts ] - async def get_prompt( - self, name: str, arguments: dict[str, Any] | None = None - ) -> GetPromptResult: + async def get_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> GetPromptResult: """Get a prompt by name with arguments.""" try: messages = await self._prompt_manager.render_prompt(name, arguments) @@ -875,12 +878,12 @@ async def get_prompt( def _convert_to_content( result: Any, -) -> Sequence[TextContent | ImageContent | EmbeddedResource]: +) -> Sequence[ContentBlock]: """Convert a result to a sequence of content objects.""" if result is None: return [] - if isinstance(result, TextContent | ImageContent | EmbeddedResource): + if isinstance(result, ContentBlock): return [result] if isinstance(result, Image): @@ -935,9 +938,7 @@ def my_tool(x: int, ctx: Context) -> str: def __init__( self, *, - request_context: ( - RequestContext[ServerSessionT, LifespanContextT, RequestT] | None - ) = None, + request_context: (RequestContext[ServerSessionT, LifespanContextT, RequestT] | None) = None, fastmcp: FastMCP | None = None, **kwargs: Any, ): @@ -961,9 +962,7 @@ def request_context( raise ValueError("Context is not available outside of a request") return self._request_context - async def report_progress( - self, progress: float, total: float | None = None, message: str | None = None - ) -> None: + async def report_progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: """Report progress for the current operation. Args: @@ -971,11 +970,7 @@ async def report_progress( total: Optional total value e.g. 100 message: Optional message e.g. Starting render... """ - progress_token = ( - self.request_context.meta.progressToken - if self.request_context.meta - else None - ) + progress_token = self.request_context.meta.progressToken if self.request_context.meta else None if progress_token is None: return @@ -996,11 +991,40 @@ async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContent Returns: The resource content as either text or bytes """ - assert ( - self._fastmcp is not None - ), "Context is not available outside of a request" + assert self._fastmcp is not None, "Context is not available outside of a request" return await self._fastmcp.read_resource(uri) + async def elicit( + self, + message: str, + schema: type[ElicitSchemaModelT], + ) -> ElicitationResult[ElicitSchemaModelT]: + """Elicit information from the client/user. + + This method can be used to interactively ask for additional information from the + client within a tool's execution. The client might display the message to the + user and collect a response according to the provided schema. Or in case a + client is an agent, it might decide how to handle the elicitation -- either by asking + the user or automatically generating a response. + + Args: + schema: A Pydantic model class defining the expected response structure, according to the specification, + only primive types are allowed. + message: Optional message to present to the user. If not provided, will use + a default message based on the schema + + Returns: + An ElicitationResult containing the action taken and the data if accepted + + Note: + Check the result.action to determine if the user accepted, declined, or cancelled. + The result.data will only be populated if action is "accept" and validation succeeded. + """ + + return await elicit_with_validation( + session=self.request_context.session, message=message, schema=schema, related_request_id=self.request_id + ) + async def log( self, level: Literal["debug", "info", "warning", "error"], @@ -1026,11 +1050,7 @@ async def log( @property def client_id(self) -> str | None: """Get the client ID if available.""" - return ( - getattr(self.request_context.meta, "client_id", None) - if self.request_context.meta - else None - ) + return getattr(self.request_context.meta, "client_id", None) if self.request_context.meta else None @property def request_id(self) -> str: diff --git a/src/mcp/server/fastmcp/tools/base.py b/src/mcp/server/fastmcp/tools/base.py index f32eb15bd..2f7c48e8b 100644 --- a/src/mcp/server/fastmcp/tools/base.py +++ b/src/mcp/server/fastmcp/tools/base.py @@ -22,25 +22,22 @@ class Tool(BaseModel): fn: Callable[..., Any] = Field(exclude=True) name: str = Field(description="Name of the tool") + title: str | None = Field(None, description="Human-readable title of the tool") description: str = Field(description="Description of what the tool does") parameters: dict[str, Any] = Field(description="JSON schema for tool parameters") fn_metadata: FuncMetadata = Field( - description="Metadata about the function including a pydantic model for tool" - " arguments" + description="Metadata about the function including a pydantic model for tool" " arguments" ) is_async: bool = Field(description="Whether the tool is async") - context_kwarg: str | None = Field( - None, description="Name of the kwarg that should receive context" - ) - annotations: ToolAnnotations | None = Field( - None, description="Optional annotations for the tool" - ) + context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context") + annotations: ToolAnnotations | None = Field(None, description="Optional annotations for the tool") @classmethod def from_function( cls, fn: Callable[..., Any], name: str | None = None, + title: str | None = None, description: str | None = None, context_kwarg: str | None = None, annotations: ToolAnnotations | None = None, @@ -74,6 +71,7 @@ def from_function( return cls( fn=fn, name=func_name, + title=title, description=func_doc, parameters=parameters, fn_metadata=func_arg_metadata, @@ -93,9 +91,7 @@ async def run( self.fn, self.is_async, arguments, - {self.context_kwarg: context} - if self.context_kwarg is not None - else None, + {self.context_kwarg: context} if self.context_kwarg is not None else None, ) except Exception as e: raise ToolError(f"Error executing tool {self.name}: {e}") from e diff --git a/src/mcp/server/fastmcp/tools/tool_manager.py b/src/mcp/server/fastmcp/tools/tool_manager.py index 153249379..b9ca1655d 100644 --- a/src/mcp/server/fastmcp/tools/tool_manager.py +++ b/src/mcp/server/fastmcp/tools/tool_manager.py @@ -46,13 +46,12 @@ def add_tool( self, fn: Callable[..., Any], name: str | None = None, + title: str | None = None, description: str | None = None, annotations: ToolAnnotations | None = None, ) -> Tool: """Add a tool to the server.""" - tool = Tool.from_function( - fn, name=name, description=description, annotations=annotations - ) + tool = Tool.from_function(fn, name=name, title=title, description=description, annotations=annotations) existing = self._tools.get(tool.name) if existing: if self.warn_on_duplicate_tools: diff --git a/src/mcp/server/fastmcp/utilities/func_metadata.py b/src/mcp/server/fastmcp/utilities/func_metadata.py index 374391325..9f8d9177a 100644 --- a/src/mcp/server/fastmcp/utilities/func_metadata.py +++ b/src/mcp/server/fastmcp/utilities/func_metadata.py @@ -102,9 +102,7 @@ def pre_parse_json(self, data: dict[str, Any]) -> dict[str, Any]: ) -def func_metadata( - func: Callable[..., Any], skip_names: Sequence[str] = () -) -> FuncMetadata: +def func_metadata(func: Callable[..., Any], skip_names: Sequence[str] = ()) -> FuncMetadata: """Given a function, return metadata including a pydantic model representing its signature. @@ -131,9 +129,7 @@ def func_metadata( globalns = getattr(func, "__globals__", {}) for param in params.values(): if param.name.startswith("_"): - raise InvalidSignature( - f"Parameter {param.name} of {func.__name__} cannot start with '_'" - ) + raise InvalidSignature(f"Parameter {param.name} of {func.__name__} cannot start with '_'") if param.name in skip_names: continue annotation = param.annotation @@ -142,11 +138,7 @@ def func_metadata( if annotation is None: annotation = Annotated[ None, - Field( - default=param.default - if param.default is not inspect.Parameter.empty - else PydanticUndefined - ), + Field(default=param.default if param.default is not inspect.Parameter.empty else PydanticUndefined), ] # Untyped field @@ -160,9 +152,7 @@ def func_metadata( field_info = FieldInfo.from_annotated_attribute( _get_typed_annotation(annotation, globalns), - param.default - if param.default is not inspect.Parameter.empty - else PydanticUndefined, + param.default if param.default is not inspect.Parameter.empty else PydanticUndefined, ) dynamic_pydantic_model_params[param.name] = (field_info.annotation, field_info) continue @@ -177,9 +167,7 @@ def func_metadata( def _get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any: - def try_eval_type( - value: Any, globalns: dict[str, Any], localns: dict[str, Any] - ) -> tuple[Any, bool]: + def try_eval_type(value: Any, globalns: dict[str, Any], localns: dict[str, Any]) -> tuple[Any, bool]: try: return eval_type_backport(value, globalns, localns), True except NameError: diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index b98e3dd1a..0a8ab7f97 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -95,9 +95,7 @@ async def main(): RequestT = TypeVar("RequestT", default=Any) # This will be properly typed in each Server instance's context -request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = ( - contextvars.ContextVar("request_ctx") -) +request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = contextvars.ContextVar("request_ctx") class NotificationOptions: @@ -140,14 +138,12 @@ def __init__( self.version = version self.instructions = instructions self.lifespan = lifespan - self.request_handlers: dict[ - type, Callable[..., Awaitable[types.ServerResult]] - ] = { + self.request_handlers: dict[type, Callable[..., Awaitable[types.ServerResult]]] = { types.PingRequest: _ping_handler, } self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {} self.notification_options = NotificationOptions() - logger.debug(f"Initializing server '{name}'") + logger.debug("Initializing server %r", name) def create_initialization_options( self, @@ -189,9 +185,7 @@ def get_capabilities( # Set prompt capabilities if handler exists if types.ListPromptsRequest in self.request_handlers: - prompts_capability = types.PromptsCapability( - listChanged=notification_options.prompts_changed - ) + prompts_capability = types.PromptsCapability(listChanged=notification_options.prompts_changed) # Set resource capabilities if handler exists if types.ListResourcesRequest in self.request_handlers: @@ -201,9 +195,7 @@ def get_capabilities( # Set tool capabilities if handler exists if types.ListToolsRequest in self.request_handlers: - tools_capability = types.ToolsCapability( - listChanged=notification_options.tools_changed - ) + tools_capability = types.ToolsCapability(listChanged=notification_options.tools_changed) # Set logging capabilities if handler exists if types.SetLevelRequest in self.request_handlers: @@ -239,9 +231,7 @@ async def handler(_: Any): def get_prompt(self): def decorator( - func: Callable[ - [str, dict[str, str] | None], Awaitable[types.GetPromptResult] - ], + func: Callable[[str, dict[str, str] | None], Awaitable[types.GetPromptResult]], ): logger.debug("Registering handler for GetPromptRequest") @@ -260,9 +250,7 @@ def decorator(func: Callable[[], Awaitable[list[types.Resource]]]): async def handler(_: Any): resources = await func() - return types.ServerResult( - types.ListResourcesResult(resources=resources) - ) + return types.ServerResult(types.ListResourcesResult(resources=resources)) self.request_handlers[types.ListResourcesRequest] = handler return func @@ -275,9 +263,7 @@ def decorator(func: Callable[[], Awaitable[list[types.ResourceTemplate]]]): async def handler(_: Any): templates = await func() - return types.ServerResult( - types.ListResourceTemplatesResult(resourceTemplates=templates) - ) + return types.ServerResult(types.ListResourceTemplatesResult(resourceTemplates=templates)) self.request_handlers[types.ListResourceTemplatesRequest] = handler return func @@ -286,9 +272,7 @@ async def handler(_: Any): def read_resource(self): def decorator( - func: Callable[ - [AnyUrl], Awaitable[str | bytes | Iterable[ReadResourceContents]] - ], + func: Callable[[AnyUrl], Awaitable[str | bytes | Iterable[ReadResourceContents]]], ): logger.debug("Registering handler for ReadResourceRequest") @@ -323,8 +307,7 @@ def create_content(data: str | bytes, mime_type: str | None): content = create_content(data, None) case Iterable() as contents: contents_list = [ - create_content(content_item.content, content_item.mime_type) - for content_item in contents + create_content(content_item.content, content_item.mime_type) for content_item in contents ] return types.ServerResult( types.ReadResourceResult( @@ -332,9 +315,7 @@ def create_content(data: str | bytes, mime_type: str | None): ) ) case _: - raise ValueError( - f"Unexpected return type from read_resource: {type(result)}" - ) + raise ValueError(f"Unexpected return type from read_resource: {type(result)}") return types.ServerResult( types.ReadResourceResult( @@ -403,11 +384,7 @@ def call_tool(self): def decorator( func: Callable[ ..., - Awaitable[ - Iterable[ - types.TextContent | types.ImageContent | types.EmbeddedResource - ] - ], + Awaitable[Iterable[types.ContentBlock]], ], ): logger.debug("Registering handler for CallToolRequest") @@ -415,9 +392,7 @@ def decorator( async def handler(req: types.CallToolRequest): try: results = await func(req.params.name, (req.params.arguments or {})) - return types.ServerResult( - types.CallToolResult(content=list(results), isError=False) - ) + return types.ServerResult(types.CallToolResult(content=list(results), isError=False)) except Exception as e: return types.ServerResult( types.CallToolResult( @@ -433,9 +408,7 @@ async def handler(req: types.CallToolRequest): def progress_notification(self): def decorator( - func: Callable[ - [str | int, float, float | None, str | None], Awaitable[None] - ], + func: Callable[[str | int, float, float | None, str | None], Awaitable[None]], ): logger.debug("Registering handler for ProgressNotification") @@ -458,8 +431,9 @@ def completion(self): def decorator( func: Callable[ [ - types.PromptReference | types.ResourceReference, + types.PromptReference | types.ResourceTemplateReference, types.CompletionArgument, + types.CompletionContext | None, ], Awaitable[types.Completion | None], ], @@ -467,7 +441,7 @@ def decorator( logger.debug("Registering handler for CompleteRequest") async def handler(req: types.CompleteRequest): - completion = await func(req.params.ref, req.params.argument) + completion = await func(req.params.ref, req.params.argument, req.params.context) return types.ServerResult( types.CompleteResult( completion=completion @@ -510,7 +484,7 @@ async def run( async with anyio.create_task_group() as tg: async for message in session.incoming_messages: - logger.debug(f"Received message: {message}") + logger.debug("Received message: %s", message) tg.start_soon( self._handle_message, @@ -522,9 +496,7 @@ async def run( async def _handle_message( self, - message: RequestResponder[types.ClientRequest, types.ServerResult] - | types.ClientNotification - | Exception, + message: RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception, session: ServerSession, lifespan_context: LifespanResultT, raise_exceptions: bool = False, @@ -532,18 +504,14 @@ async def _handle_message( with warnings.catch_warnings(record=True) as w: # TODO(Marcelo): We should be checking if message is Exception here. match message: # type: ignore[reportMatchNotExhaustive] - case ( - RequestResponder(request=types.ClientRequest(root=req)) as responder - ): + case RequestResponder(request=types.ClientRequest(root=req)) as responder: with responder: - await self._handle_request( - message, req, session, lifespan_context, raise_exceptions - ) + await self._handle_request(message, req, session, lifespan_context, raise_exceptions) case types.ClientNotification(root=notify): await self._handle_notification(notify) for warning in w: - logger.info(f"Warning: {warning.category.__name__}: {warning.message}") + logger.info("Warning: %s: %s", warning.category.__name__, warning.message) async def _handle_request( self, @@ -553,18 +521,15 @@ async def _handle_request( lifespan_context: LifespanResultT, raise_exceptions: bool, ): - logger.info(f"Processing request of type {type(req).__name__}") - if type(req) in self.request_handlers: - handler = self.request_handlers[type(req)] - logger.debug(f"Dispatching request of type {type(req).__name__}") + logger.info("Processing request of type %s", type(req).__name__) + if handler := self.request_handlers.get(type(req)): # type: ignore + logger.debug("Dispatching request of type %s", type(req).__name__) token = None try: # Extract request context from message metadata request_data = None - if message.message_metadata is not None and isinstance( - message.message_metadata, ServerMessageMetadata - ): + if message.message_metadata is not None and isinstance(message.message_metadata, ServerMessageMetadata): request_data = message.message_metadata.request_context # Set our global state that can be retrieved via @@ -602,16 +567,13 @@ async def _handle_request( logger.debug("Response sent") async def _handle_notification(self, notify: Any): - if type(notify) in self.notification_handlers: - assert type(notify) in self.notification_handlers - - handler = self.notification_handlers[type(notify)] - logger.debug(f"Dispatching notification of type {type(notify).__name__}") + if handler := self.notification_handlers.get(type(notify)): # type: ignore + logger.debug("Dispatching notification of type %s", type(notify).__name__) try: await handler(notify) - except Exception as err: - logger.error(f"Uncaught exception in notification handler: {err}") + except Exception: + logger.exception("Uncaught exception in notification handler") async def _ping_handler(request: types.PingRequest) -> types.ServerResult: diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index ef5c5a3c3..5c696b136 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -64,9 +64,7 @@ class InitializationState(Enum): ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession") ServerRequestResponder = ( - RequestResponder[types.ClientRequest, types.ServerResult] - | types.ClientNotification - | Exception + RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception ) @@ -89,22 +87,16 @@ def __init__( init_options: InitializationOptions, stateless: bool = False, ) -> None: - super().__init__( - read_stream, write_stream, types.ClientRequest, types.ClientNotification - ) + super().__init__(read_stream, write_stream, types.ClientRequest, types.ClientNotification) self._initialization_state = ( - InitializationState.Initialized - if stateless - else InitializationState.NotInitialized + InitializationState.Initialized if stateless else InitializationState.NotInitialized ) self._init_options = init_options - self._incoming_message_stream_writer, self._incoming_message_stream_reader = ( - anyio.create_memory_object_stream[ServerRequestResponder](0) - ) - self._exit_stack.push_async_callback( - lambda: self._incoming_message_stream_reader.aclose() - ) + self._incoming_message_stream_writer, self._incoming_message_stream_reader = anyio.create_memory_object_stream[ + ServerRequestResponder + ](0) + self._exit_stack.push_async_callback(lambda: self._incoming_message_stream_reader.aclose()) @property def client_params(self) -> types.InitializeRequestParams | None: @@ -129,15 +121,16 @@ def check_client_capability(self, capability: types.ClientCapabilities) -> bool: if client_caps.sampling is None: return False + if capability.elicitation is not None: + if client_caps.elicitation is None: + return False + if capability.experimental is not None: if client_caps.experimental is None: return False # Check each experimental capability for exp_key, exp_value in capability.experimental.items(): - if ( - exp_key not in client_caps.experimental - or client_caps.experimental[exp_key] != exp_value - ): + if exp_key not in client_caps.experimental or client_caps.experimental[exp_key] != exp_value: return False return True @@ -146,9 +139,7 @@ async def _receive_loop(self) -> None: async with self._incoming_message_stream_writer: await super()._receive_loop() - async def _received_request( - self, responder: RequestResponder[types.ClientRequest, types.ServerResult] - ): + async def _received_request(self, responder: RequestResponder[types.ClientRequest, types.ServerResult]): match responder.request.root: case types.InitializeRequest(params=params): requested_version = params.protocolVersion @@ -172,13 +163,9 @@ async def _received_request( ) case _: if self._initialization_state != InitializationState.Initialized: - raise RuntimeError( - "Received request before initialization was complete" - ) + raise RuntimeError("Received request before initialization was complete") - async def _received_notification( - self, notification: types.ClientNotification - ) -> None: + async def _received_notification(self, notification: types.ClientNotification) -> None: # Need this to avoid ASYNC910 await anyio.lowlevel.checkpoint() match notification.root: @@ -186,9 +173,7 @@ async def _received_notification( self._initialization_state = InitializationState.Initialized case _: if self._initialization_state != InitializationState.Initialized: - raise RuntimeError( - "Received notification before initialization was complete" - ) + raise RuntimeError("Received notification before initialization was complete") async def send_log_message( self, @@ -270,6 +255,35 @@ async def list_roots(self) -> types.ListRootsResult: types.ListRootsResult, ) + async def elicit( + self, + message: str, + requestedSchema: types.ElicitRequestedSchema, + related_request_id: types.RequestId | None = None, + ) -> types.ElicitResult: + """Send an elicitation/create request. + + Args: + message: The message to present to the user + requestedSchema: Schema defining the expected response structure + + Returns: + The client's response + """ + return await self.send_request( + types.ServerRequest( + types.ElicitRequest( + method="elicitation/create", + params=types.ElicitRequestParams( + message=message, + requestedSchema=requestedSchema, + ), + ) + ), + types.ElicitResult, + metadata=ServerMessageMetadata(related_request_id=related_request_id), + ) + async def send_ping(self) -> types.EmptyResult: """Send a ping request.""" return await self.send_request( diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 192c1290b..41145e49f 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -52,6 +52,10 @@ async def handle_sse(request): from starlette.types import Receive, Scope, Send import mcp.types as types +from mcp.server.transport_security import ( + TransportSecurityMiddleware, + TransportSecuritySettings, +) from mcp.shared.message import ServerMessageMetadata, SessionMessage logger = logging.getLogger(__name__) @@ -71,16 +75,22 @@ class SseServerTransport: _endpoint: str _read_stream_writers: dict[UUID, MemoryObjectSendStream[SessionMessage | Exception]] + _security: TransportSecurityMiddleware - def __init__(self, endpoint: str) -> None: + def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | None = None) -> None: """ Creates a new SSE server transport, which will direct the client to POST messages to the relative or absolute URL given. + + Args: + endpoint: The relative or absolute URL for POST messages. + security_settings: Optional security settings for DNS rebinding protection. """ super().__init__() self._endpoint = endpoint self._read_stream_writers = {} + self._security = TransportSecurityMiddleware(security_settings) logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}") @asynccontextmanager @@ -89,6 +99,13 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): logger.error("connect_sse received non-HTTP request") raise ValueError("connect_sse can only handle HTTP requests") + # Validate request headers for DNS rebinding protection + request = Request(scope, receive) + error_response = await self._security.validate_request(request, is_post=False) + if error_response: + await error_response(scope, receive, send) + raise ValueError("Request validation failed") + logger.debug("Setting up SSE connection") read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] @@ -116,20 +133,14 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): full_message_path_for_client = root_path.rstrip("/") + self._endpoint # This is the URI (path + query) the client will use to POST messages. - client_post_uri_data = ( - f"{quote(full_message_path_for_client)}?session_id={session_id.hex}" - ) + client_post_uri_data = f"{quote(full_message_path_for_client)}?session_id={session_id.hex}" - sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[ - dict[str, Any] - ](0) + sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, Any]](0) async def sse_writer(): logger.debug("Starting SSE writer") async with sse_stream_writer, write_stream_reader: - await sse_stream_writer.send( - {"event": "endpoint", "data": client_post_uri_data} - ) + await sse_stream_writer.send({"event": "endpoint", "data": client_post_uri_data}) logger.debug(f"Sent endpoint event: {client_post_uri_data}") async for session_message in write_stream_reader: @@ -137,9 +148,7 @@ async def sse_writer(): await sse_stream_writer.send( { "event": "message", - "data": session_message.message.model_dump_json( - by_alias=True, exclude_none=True - ), + "data": session_message.message.model_dump_json(by_alias=True, exclude_none=True), } ) @@ -151,9 +160,9 @@ async def response_wrapper(scope: Scope, receive: Receive, send: Send): In this case we close our side of the streams to signal the client that the connection has been closed. """ - await EventSourceResponse( - content=sse_stream_reader, data_sender_callable=sse_writer - )(scope, receive, send) + await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)( + scope, receive, send + ) await read_stream_writer.aclose() await write_stream_reader.aclose() logging.debug(f"Client session disconnected {session_id}") @@ -164,12 +173,15 @@ async def response_wrapper(scope: Scope, receive: Receive, send: Send): logger.debug("Yielding read and write streams") yield (read_stream, write_stream) - async def handle_post_message( - self, scope: Scope, receive: Receive, send: Send - ) -> None: + async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None: logger.debug("Handling POST message") request = Request(scope, receive) + # Validate request headers for DNS rebinding protection + error_response = await self._security.validate_request(request, is_post=True) + if error_response: + return await error_response(scope, receive, send) + session_id_param = request.query_params.get("session_id") if session_id_param is None: logger.warning("Received request without session_id") diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index f0bbe5a31..d1618a371 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -76,9 +76,7 @@ async def stdout_writer(): try: async with write_stream_reader: async for session_message in write_stream_reader: - json = session_message.message.model_dump_json( - by_alias=True, exclude_none=True - ) + json = session_message.message.model_dump_json(by_alias=True, exclude_none=True) await stdout.write(json + "\n") await stdout.flush() except anyio.ClosedResourceError: diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index a94cc2834..d46549929 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -24,8 +24,14 @@ from starlette.responses import Response from starlette.types import Receive, Scope, Send +from mcp.server.transport_security import ( + TransportSecurityMiddleware, + TransportSecuritySettings, +) from mcp.shared.message import ServerMessageMetadata, SessionMessage +from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS from mcp.types import ( + DEFAULT_NEGOTIATED_VERSION, INTERNAL_ERROR, INVALID_PARAMS, INVALID_REQUEST, @@ -45,6 +51,7 @@ # Header names MCP_SESSION_ID_HEADER = "mcp-session-id" +MCP_PROTOCOL_VERSION_HEADER = "mcp-protocol-version" LAST_EVENT_ID_HEADER = "last-event-id" # Content types @@ -82,9 +89,7 @@ class EventStore(ABC): """ @abstractmethod - async def store_event( - self, stream_id: StreamId, message: JSONRPCMessage - ) -> EventId: + async def store_event(self, stream_id: StreamId, message: JSONRPCMessage) -> EventId: """ Stores an event for later retrieval. @@ -125,18 +130,18 @@ class StreamableHTTPServerTransport: """ # Server notification streams for POST requests as well as standalone SSE stream - _read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] | None = ( - None - ) + _read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] | None = None _read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] | None = None _write_stream: MemoryObjectSendStream[SessionMessage] | None = None _write_stream_reader: MemoryObjectReceiveStream[SessionMessage] | None = None + _security: TransportSecurityMiddleware def __init__( self, mcp_session_id: str | None, is_json_response_enabled: bool = False, event_store: EventStore | None = None, + security_settings: TransportSecuritySettings | None = None, ) -> None: """ Initialize a new StreamableHTTP server transport. @@ -149,20 +154,18 @@ def __init__( event_store: Event store for resumability support. If provided, resumability will be enabled, allowing clients to reconnect and resume messages. + security_settings: Optional security settings for DNS rebinding protection. Raises: ValueError: If the session ID contains invalid characters. """ - if mcp_session_id is not None and not SESSION_ID_PATTERN.fullmatch( - mcp_session_id - ): - raise ValueError( - "Session ID must only contain visible ASCII characters (0x21-0x7E)" - ) + if mcp_session_id is not None and not SESSION_ID_PATTERN.fullmatch(mcp_session_id): + raise ValueError("Session ID must only contain visible ASCII characters (0x21-0x7E)") self.mcp_session_id = mcp_session_id self.is_json_response_enabled = is_json_response_enabled self._event_store = event_store + self._security = TransportSecurityMiddleware(security_settings) self._request_streams: dict[ RequestId, tuple[ @@ -218,9 +221,7 @@ def _create_json_response( response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id return Response( - response_message.model_dump_json(by_alias=True, exclude_none=True) - if response_message - else None, + response_message.model_dump_json(by_alias=True, exclude_none=True) if response_message else None, status_code=status_code, headers=response_headers, ) @@ -233,9 +234,7 @@ def _create_event_data(self, event_message: EventMessage) -> dict[str, str]: """Create event data dictionary from an EventMessage.""" event_data = { "event": "message", - "data": event_message.message.model_dump_json( - by_alias=True, exclude_none=True - ), + "data": event_message.message.model_dump_json(by_alias=True, exclude_none=True), } # If an event ID was provided, include it @@ -260,6 +259,14 @@ async def _clean_up_memory_streams(self, request_id: RequestId) -> None: async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None: """Application entry point that handles all HTTP requests""" request = Request(scope, receive) + + # Validate request headers for DNS rebinding protection + is_post = request.method == "POST" + error_response = await self._security.validate_request(request, is_post=is_post) + if error_response: + await error_response(scope, receive, send) + return + if self._terminated: # If the session has been terminated, return 404 Not Found response = self._create_error_response( @@ -283,42 +290,29 @@ def _check_accept_headers(self, request: Request) -> tuple[bool, bool]: accept_header = request.headers.get("accept", "") accept_types = [media_type.strip() for media_type in accept_header.split(",")] - has_json = any( - media_type.startswith(CONTENT_TYPE_JSON) for media_type in accept_types - ) - has_sse = any( - media_type.startswith(CONTENT_TYPE_SSE) for media_type in accept_types - ) + has_json = any(media_type.startswith(CONTENT_TYPE_JSON) for media_type in accept_types) + has_sse = any(media_type.startswith(CONTENT_TYPE_SSE) for media_type in accept_types) return has_json, has_sse def _check_content_type(self, request: Request) -> bool: """Check if the request has the correct Content-Type.""" content_type = request.headers.get("content-type", "") - content_type_parts = [ - part.strip() for part in content_type.split(";")[0].split(",") - ] + content_type_parts = [part.strip() for part in content_type.split(";")[0].split(",")] return any(part == CONTENT_TYPE_JSON for part in content_type_parts) - async def _handle_post_request( - self, scope: Scope, request: Request, receive: Receive, send: Send - ) -> None: + async def _handle_post_request(self, scope: Scope, request: Request, receive: Receive, send: Send) -> None: """Handle POST requests containing JSON-RPC messages.""" writer = self._read_stream_writer if writer is None: - raise ValueError( - "No read stream writer available. Ensure connect() is called first." - ) + raise ValueError("No read stream writer available. Ensure connect() is called first.") try: # Check Accept headers has_json, has_sse = self._check_accept_headers(request) if not (has_json and has_sse): response = self._create_error_response( - ( - "Not Acceptable: Client must accept both application/json and " - "text/event-stream" - ), + ("Not Acceptable: Client must accept both application/json and text/event-stream"), HTTPStatus.NOT_ACCEPTABLE, ) await response(scope, receive, send) @@ -346,9 +340,7 @@ async def _handle_post_request( try: raw_message = json.loads(body) except json.JSONDecodeError as e: - response = self._create_error_response( - f"Parse error: {str(e)}", HTTPStatus.BAD_REQUEST, PARSE_ERROR - ) + response = self._create_error_response(f"Parse error: {str(e)}", HTTPStatus.BAD_REQUEST, PARSE_ERROR) await response(scope, receive, send) return @@ -364,10 +356,7 @@ async def _handle_post_request( return # Check if this is an initialization request - is_initialization_request = ( - isinstance(message.root, JSONRPCRequest) - and message.root.method == "initialize" - ) + is_initialization_request = isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize" if is_initialization_request: # Check if the server already has an established session @@ -383,8 +372,7 @@ async def _handle_post_request( ) await response(scope, receive, send) return - # For non-initialization requests, validate the session - elif not await self._validate_session(request, send): + elif not await self._validate_request_headers(request, send): return # For notifications and responses only, return 202 Accepted @@ -406,9 +394,7 @@ async def _handle_post_request( # Extract the request ID outside the try block for proper scope request_id = str(message.root.id) # Register this stream for the request ID - self._request_streams[request_id] = anyio.create_memory_object_stream[ - EventMessage - ](0) + self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage](0) request_stream_reader = self._request_streams[request_id][1] if self.is_json_response_enabled: @@ -424,16 +410,12 @@ async def _handle_post_request( # Use similar approach to SSE writer for consistency async for event_message in request_stream_reader: # If it's a response, this is what we're waiting for - if isinstance( - event_message.message.root, JSONRPCResponse | JSONRPCError - ): + if isinstance(event_message.message.root, JSONRPCResponse | JSONRPCError): response_message = event_message.message break # For notifications and request, keep waiting else: - logger.debug( - f"received: {event_message.message.root.method}" - ) + logger.debug(f"received: {event_message.message.root.method}") # At this point we should have a response if response_message: @@ -442,9 +424,7 @@ async def _handle_post_request( await response(scope, receive, send) else: # This shouldn't happen in normal operation - logger.error( - "No response message received before stream closed" - ) + logger.error("No response message received before stream closed") response = self._create_error_response( "Error processing request: No response received", HTTPStatus.INTERNAL_SERVER_ERROR, @@ -462,9 +442,7 @@ async def _handle_post_request( await self._clean_up_memory_streams(request_id) else: # Create SSE stream - sse_stream_writer, sse_stream_reader = ( - anyio.create_memory_object_stream[dict[str, str]](0) - ) + sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0) async def sse_writer(): # Get the request ID from the incoming request message @@ -495,11 +473,7 @@ async def sse_writer(): "Cache-Control": "no-cache, no-transform", "Connection": "keep-alive", "Content-Type": CONTENT_TYPE_SSE, - **( - {MCP_SESSION_ID_HEADER: self.mcp_session_id} - if self.mcp_session_id - else {} - ), + **({MCP_SESSION_ID_HEADER: self.mcp_session_id} if self.mcp_session_id else {}), } response = EventSourceResponse( content=sse_stream_reader, @@ -544,9 +518,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: """ writer = self._read_stream_writer if writer is None: - raise ValueError( - "No read stream writer available. Ensure connect() is called first." - ) + raise ValueError("No read stream writer available. Ensure connect() is called first.") # Validate Accept header - must include text/event-stream _, has_sse = self._check_accept_headers(request) @@ -559,8 +531,9 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: await response(request.scope, request.receive, send) return - if not await self._validate_session(request, send): + if not await self._validate_request_headers(request, send): return + # Handle resumability: check for Last-Event-ID header if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER): await self._replay_events(last_event_id, request, send) @@ -585,17 +558,13 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: return # Create SSE stream - sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[ - dict[str, str] - ](0) + sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0) async def standalone_sse_writer(): try: # Create a standalone message stream for server-initiated messages - self._request_streams[GET_STREAM_KEY] = ( - anyio.create_memory_object_stream[EventMessage](0) - ) + self._request_streams[GET_STREAM_KEY] = anyio.create_memory_object_stream[EventMessage](0) standalone_stream_reader = self._request_streams[GET_STREAM_KEY][1] async with sse_stream_writer, standalone_stream_reader: @@ -643,7 +612,7 @@ async def _handle_delete_request(self, request: Request, send: Send) -> None: await response(request.scope, request.receive, send) return - if not await self._validate_session(request, send): + if not await self._validate_request_headers(request, send): return await self._terminate_session() @@ -703,6 +672,13 @@ async def _handle_unsupported_request(self, request: Request, send: Send) -> Non ) await response(request.scope, request.receive, send) + async def _validate_request_headers(self, request: Request, send: Send) -> bool: + if not await self._validate_session(request, send): + return False + if not await self._validate_protocol_version(request, send): + return False + return True + async def _validate_session(self, request: Request, send: Send) -> bool: """Validate the session ID in the request.""" if not self.mcp_session_id: @@ -732,9 +708,29 @@ async def _validate_session(self, request: Request, send: Send) -> bool: return True - async def _replay_events( - self, last_event_id: str, request: Request, send: Send - ) -> None: + async def _validate_protocol_version(self, request: Request, send: Send) -> bool: + """Validate the protocol version header in the request.""" + # Get the protocol version from the request headers + protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER) + + # If no protocol version provided, assume default version + if protocol_version is None: + protocol_version = DEFAULT_NEGOTIATED_VERSION + + # Check if the protocol version is supported + if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS: + supported_versions = ", ".join(SUPPORTED_PROTOCOL_VERSIONS) + response = self._create_error_response( + f"Bad Request: Unsupported protocol version: {protocol_version}. " + + f"Supported versions: {supported_versions}", + HTTPStatus.BAD_REQUEST, + ) + await response(request.scope, request.receive, send) + return False + + return True + + async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None: """ Replays events that would have been sent after the specified event ID. Only used when resumability is enabled. @@ -754,9 +750,7 @@ async def _replay_events( headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id # Create SSE stream for replay - sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[ - dict[str, str] - ](0) + sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0) async def replay_sender(): try: @@ -767,15 +761,11 @@ async def send_event(event_message: EventMessage) -> None: await sse_stream_writer.send(event_data) # Replay past events and get the stream ID - stream_id = await event_store.replay_events_after( - last_event_id, send_event - ) + stream_id = await event_store.replay_events_after(last_event_id, send_event) # If stream ID not in mapping, create it if stream_id and stream_id not in self._request_streams: - self._request_streams[stream_id] = ( - anyio.create_memory_object_stream[EventMessage](0) - ) + self._request_streams[stream_id] = anyio.create_memory_object_stream[EventMessage](0) msg_reader = self._request_streams[stream_id][1] # Forward messages to SSE @@ -829,12 +819,8 @@ async def connect( # Create the memory streams for this connection - read_stream_writer, read_stream = anyio.create_memory_object_stream[ - SessionMessage | Exception - ](0) - write_stream, write_stream_reader = anyio.create_memory_object_stream[ - SessionMessage - ](0) + read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0) + write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) # Store the streams self._read_stream_writer = read_stream_writer @@ -867,35 +853,24 @@ async def message_router(): session_message.metadata, ServerMessageMetadata, ) - and session_message.metadata.related_request_id - is not None + and session_message.metadata.related_request_id is not None ): - target_request_id = str( - session_message.metadata.related_request_id - ) + target_request_id = str(session_message.metadata.related_request_id) - request_stream_id = ( - target_request_id - if target_request_id is not None - else GET_STREAM_KEY - ) + request_stream_id = target_request_id if target_request_id is not None else GET_STREAM_KEY # Store the event if we have an event store, # regardless of whether a client is connected # messages will be replayed on the re-connect event_id = None if self._event_store: - event_id = await self._event_store.store_event( - request_stream_id, message - ) + event_id = await self._event_store.store_event(request_stream_id, message) logger.debug(f"Stored {event_id} from {request_stream_id}") if request_stream_id in self._request_streams: try: # Send both the message and the event ID - await self._request_streams[request_stream_id][0].send( - EventMessage(message, event_id) - ) + await self._request_streams[request_stream_id][0].send(EventMessage(message, event_id)) except ( anyio.BrokenResourceError, anyio.ClosedResourceError, diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index 8188c2f3b..41b807388 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -22,6 +22,7 @@ EventStore, StreamableHTTPServerTransport, ) +from mcp.server.transport_security import TransportSecuritySettings logger = logging.getLogger(__name__) @@ -60,11 +61,13 @@ def __init__( event_store: EventStore | None = None, json_response: bool = False, stateless: bool = False, + security_settings: TransportSecuritySettings | None = None, ): self.app = app self.event_store = event_store self.json_response = json_response self.stateless = stateless + self.security_settings = security_settings # Session tracking (only used if not stateless) self._session_creation_lock = anyio.Lock() @@ -162,12 +165,11 @@ async def _handle_stateless_request( mcp_session_id=None, # No session tracking in stateless mode is_json_response_enabled=self.json_response, event_store=None, # No event store in stateless mode + security_settings=self.security_settings, ) # Start server in a new task - async def run_stateless_server( - *, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED - ): + async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED): async with http_transport.connect() as streams: read_stream, write_stream = streams task_status.started() @@ -204,10 +206,7 @@ async def _handle_stateful_request( request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER) # Existing session case - if ( - request_mcp_session_id is not None - and request_mcp_session_id in self._server_instances - ): + if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances: transport = self._server_instances[request_mcp_session_id] logger.debug("Session already exists, handling request directly") await transport.handle_request(scope, receive, send) @@ -222,6 +221,7 @@ async def _handle_stateful_request( mcp_session_id=new_session_id, is_json_response_enabled=self.json_response, event_store=self.event_store, # May be None (no resumability) + security_settings=self.security_settings, ) assert http_transport.mcp_session_id is not None @@ -229,9 +229,7 @@ async def _handle_stateful_request( logger.info(f"Created new transport with session ID: {new_session_id}") # Define the server runner - async def run_server( - *, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED - ) -> None: + async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED) -> None: async with http_transport.connect() as streams: read_stream, write_stream = streams task_status.started() diff --git a/src/mcp/server/streaming_asgi_transport.py b/src/mcp/server/streaming_asgi_transport.py index 54a2fdb8c..a74751312 100644 --- a/src/mcp/server/streaming_asgi_transport.py +++ b/src/mcp/server/streaming_asgi_transport.py @@ -93,12 +93,8 @@ async def handle_async_request( initial_response_ready = anyio.Event() # Synchronization for streaming response - asgi_send_channel, asgi_receive_channel = anyio.create_memory_object_stream[ - dict[str, Any] - ](100) - content_send_channel, content_receive_channel = ( - anyio.create_memory_object_stream[bytes](100) - ) + asgi_send_channel, asgi_receive_channel = anyio.create_memory_object_stream[dict[str, Any]](100) + content_send_channel, content_receive_channel = anyio.create_memory_object_stream[bytes](100) # ASGI callables. async def receive() -> dict[str, Any]: @@ -124,21 +120,15 @@ async def send(message: dict[str, Any]) -> None: async def run_app() -> None: try: # Cast the receive and send functions to the ASGI types - await self.app( - cast(Scope, scope), cast(Receive, receive), cast(Send, send) - ) + await self.app(cast(Scope, scope), cast(Receive, receive), cast(Send, send)) except Exception: if self.raise_app_exceptions: raise if not response_started: - await asgi_send_channel.send( - {"type": "http.response.start", "status": 500, "headers": []} - ) + await asgi_send_channel.send({"type": "http.response.start", "status": 500, "headers": []}) - await asgi_send_channel.send( - {"type": "http.response.body", "body": b"", "more_body": False} - ) + await asgi_send_channel.send({"type": "http.response.body", "body": b"", "more_body": False}) finally: await asgi_send_channel.aclose() diff --git a/src/mcp/server/transport_security.py b/src/mcp/server/transport_security.py new file mode 100644 index 000000000..3a884ee2b --- /dev/null +++ b/src/mcp/server/transport_security.py @@ -0,0 +1,127 @@ +"""DNS rebinding protection for MCP server transports.""" + +import logging + +from pydantic import BaseModel, Field +from starlette.requests import Request +from starlette.responses import Response + +logger = logging.getLogger(__name__) + + +class TransportSecuritySettings(BaseModel): + """Settings for MCP transport security features. + + These settings help protect against DNS rebinding attacks by validating + incoming request headers. + """ + + enable_dns_rebinding_protection: bool = Field( + default=True, + description="Enable DNS rebinding protection (recommended for production)", + ) + + allowed_hosts: list[str] = Field( + default=[], + description="List of allowed Host header values. Only applies when " + + "enable_dns_rebinding_protection is True.", + ) + + allowed_origins: list[str] = Field( + default=[], + description="List of allowed Origin header values. Only applies when " + + "enable_dns_rebinding_protection is True.", + ) + + +class TransportSecurityMiddleware: + """Middleware to enforce DNS rebinding protection for MCP transport endpoints.""" + + def __init__(self, settings: TransportSecuritySettings | None = None): + # If not specified, disable DNS rebinding protection by default + # for backwards compatibility + self.settings = settings or TransportSecuritySettings(enable_dns_rebinding_protection=False) + + def _validate_host(self, host: str | None) -> bool: + """Validate the Host header against allowed values.""" + if not host: + logger.warning("Missing Host header in request") + return False + + # Check exact match first + if host in self.settings.allowed_hosts: + return True + + # Check wildcard port patterns + for allowed in self.settings.allowed_hosts: + if allowed.endswith(":*"): + # Extract base host from pattern + base_host = allowed[:-2] + # Check if the actual host starts with base host and has a port + if host.startswith(base_host + ":"): + return True + + logger.warning(f"Invalid Host header: {host}") + return False + + def _validate_origin(self, origin: str | None) -> bool: + """Validate the Origin header against allowed values.""" + # Origin can be absent for same-origin requests + if not origin: + return True + + # Check exact match first + if origin in self.settings.allowed_origins: + return True + + # Check wildcard port patterns + for allowed in self.settings.allowed_origins: + if allowed.endswith(":*"): + # Extract base origin from pattern + base_origin = allowed[:-2] + # Check if the actual origin starts with base origin and has a port + if origin.startswith(base_origin + ":"): + return True + + logger.warning(f"Invalid Origin header: {origin}") + return False + + def _validate_content_type(self, content_type: str | None) -> bool: + """Validate the Content-Type header for POST requests.""" + if not content_type: + logger.warning("Missing Content-Type header in POST request") + return False + + # Content-Type must start with application/json + if not content_type.lower().startswith("application/json"): + logger.warning(f"Invalid Content-Type header: {content_type}") + return False + + return True + + async def validate_request(self, request: Request, is_post: bool = False) -> Response | None: + """Validate request headers for DNS rebinding protection. + + Returns None if validation passes, or an error Response if validation fails. + """ + # Always validate Content-Type for POST requests + if is_post: + content_type = request.headers.get("content-type") + if not self._validate_content_type(content_type): + return Response("Invalid Content-Type header", status_code=400) + + # Skip remaining validation if DNS rebinding protection is disabled + if not self.settings.enable_dns_rebinding_protection: + return None + + # Validate Host header + host = request.headers.get("host") + if not self._validate_host(host): + return Response("Invalid Host header", status_code=421) + + # Validate Origin header + origin = request.headers.get("origin") + if not self._validate_origin(origin): + return Response("Invalid Origin header", status_code=400) + + return None diff --git a/src/mcp/server/websocket.py b/src/mcp/server/websocket.py index 9dc3f2a25..7c0d8789c 100644 --- a/src/mcp/server/websocket.py +++ b/src/mcp/server/websocket.py @@ -51,9 +51,7 @@ async def ws_writer(): try: async with write_stream_reader: async for session_message in write_stream_reader: - obj = session_message.message.model_dump_json( - by_alias=True, exclude_none=True - ) + obj = session_message.message.model_dump_json(by_alias=True, exclude_none=True) await websocket.send_text(obj) except anyio.ClosedResourceError: await websocket.close() diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index 22f8a971d..4d2d57221 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -1,6 +1,6 @@ from typing import Any, Literal -from pydantic import AnyHttpUrl, BaseModel, Field +from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, field_validator class OAuthToken(BaseModel): @@ -9,11 +9,20 @@ class OAuthToken(BaseModel): """ access_token: str - token_type: Literal["bearer"] = "bearer" + token_type: Literal["Bearer"] = "Bearer" expires_in: int | None = None scope: str | None = None refresh_token: str | None = None + @field_validator("token_type", mode="before") + @classmethod + def normalize_token_type(cls, v: str | None) -> str | None: + if isinstance(v, str): + # Bearer is title-cased in the spec, so we normalize it + # https://datatracker.ietf.org/doc/html/rfc6750#section-4 + return v.title() + return v + class InvalidScopeError(Exception): def __init__(self, message: str): @@ -32,13 +41,11 @@ class OAuthClientMetadata(BaseModel): for the full specification. """ - redirect_uris: list[AnyHttpUrl] = Field(..., min_length=1) + redirect_uris: list[AnyUrl] = Field(..., min_length=1) # token_endpoint_auth_method: this implementation only supports none & # client_secret_post; # ie: we do not support client_secret_basic - token_endpoint_auth_method: Literal["none", "client_secret_post"] = ( - "client_secret_post" - ) + token_endpoint_auth_method: Literal["none", "client_secret_post"] = "client_secret_post" # grant_types: this implementation only supports authorization_code & refresh_token grant_types: list[Literal["authorization_code", "refresh_token"]] = [ "authorization_code", @@ -71,21 +78,16 @@ def validate_scope(self, requested_scope: str | None) -> list[str] | None: raise InvalidScopeError(f"Client was not registered with scope {scope}") return requested_scopes - def validate_redirect_uri(self, redirect_uri: AnyHttpUrl | None) -> AnyHttpUrl: + def validate_redirect_uri(self, redirect_uri: AnyUrl | None) -> AnyUrl: if redirect_uri is not None: # Validate redirect_uri against client's registered redirect URIs if redirect_uri not in self.redirect_uris: - raise InvalidRedirectUriError( - f"Redirect URI '{redirect_uri}' not registered for client" - ) + raise InvalidRedirectUriError(f"Redirect URI '{redirect_uri}' not registered for client") return redirect_uri elif len(self.redirect_uris) == 1: return self.redirect_uris[0] else: - raise InvalidRedirectUriError( - "redirect_uri must be specified when client " - "has multiple registered URIs" - ) + raise InvalidRedirectUriError("redirect_uri must be specified when client " "has multiple registered URIs") class OAuthClientInformationFull(OAuthClientMetadata): @@ -111,27 +113,19 @@ class OAuthMetadata(BaseModel): token_endpoint: AnyHttpUrl registration_endpoint: AnyHttpUrl | None = None scopes_supported: list[str] | None = None - response_types_supported: list[Literal["code"]] = ["code"] + response_types_supported: list[str] = ["code"] response_modes_supported: list[Literal["query", "fragment"]] | None = None - grant_types_supported: ( - list[Literal["authorization_code", "refresh_token"]] | None - ) = None - token_endpoint_auth_methods_supported: ( - list[Literal["none", "client_secret_post"]] | None - ) = None + grant_types_supported: list[str] | None = None + token_endpoint_auth_methods_supported: list[str] | None = None token_endpoint_auth_signing_alg_values_supported: None = None service_documentation: AnyHttpUrl | None = None ui_locales_supported: list[str] | None = None op_policy_uri: AnyHttpUrl | None = None op_tos_uri: AnyHttpUrl | None = None revocation_endpoint: AnyHttpUrl | None = None - revocation_endpoint_auth_methods_supported: ( - list[Literal["client_secret_post"]] | None - ) = None + revocation_endpoint_auth_methods_supported: list[str] | None = None revocation_endpoint_auth_signing_alg_values_supported: None = None introspection_endpoint: AnyHttpUrl | None = None - introspection_endpoint_auth_methods_supported: ( - list[Literal["client_secret_post"]] | None - ) = None + introspection_endpoint_auth_methods_supported: list[str] | None = None introspection_endpoint_auth_signing_alg_values_supported: None = None - code_challenge_methods_supported: list[Literal["S256"]] | None = None + code_challenge_methods_supported: list[str] | None = None diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index b53f8dd63..c94e5e6ac 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -13,6 +13,7 @@ import mcp.types as types from mcp.client.session import ( ClientSession, + ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, @@ -21,16 +22,11 @@ from mcp.server import Server from mcp.shared.message import SessionMessage -MessageStream = tuple[ - MemoryObjectReceiveStream[SessionMessage | Exception], - MemoryObjectSendStream[SessionMessage], -] +MessageStream = tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]] @asynccontextmanager -async def create_client_server_memory_streams() -> ( - AsyncGenerator[tuple[MessageStream, MessageStream], None] -): +async def create_client_server_memory_streams() -> AsyncGenerator[tuple[MessageStream, MessageStream], None]: """ Creates a pair of bidirectional memory streams for client-server communication. @@ -39,12 +35,8 @@ async def create_client_server_memory_streams() -> ( (read_stream, write_stream) """ # Create streams for both directions - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - SessionMessage | Exception - ](1) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - SessionMessage | Exception - ](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) client_streams = (server_to_client_receive, client_to_server_send) server_streams = (client_to_server_receive, server_to_client_send) @@ -68,6 +60,7 @@ async def create_connected_server_and_client_session( message_handler: MessageHandlerFnT | None = None, client_info: types.Implementation | None = None, raise_exceptions: bool = False, + elicitation_callback: ElicitationFnT | None = None, ) -> AsyncGenerator[ClientSession, None]: """Creates a ClientSession that is connected to a running MCP server.""" async with create_client_server_memory_streams() as ( @@ -98,6 +91,7 @@ async def create_connected_server_and_client_session( logging_callback=logging_callback, message_handler=message_handler, client_info=client_info, + elicitation_callback=elicitation_callback, ) as client_session: await client_session.initialize() yield client_session diff --git a/src/mcp/shared/message.py b/src/mcp/shared/message.py index 6b0233714..4b6df23eb 100644 --- a/src/mcp/shared/message.py +++ b/src/mcp/shared/message.py @@ -20,9 +20,7 @@ class ClientMessageMetadata: """Metadata specific to client messages.""" resumption_token: ResumptionToken | None = None - on_resumption_token_update: Callable[[ResumptionToken], Awaitable[None]] | None = ( - None - ) + on_resumption_token_update: Callable[[ResumptionToken], Awaitable[None]] | None = None @dataclass diff --git a/src/mcp/shared/metadata_utils.py b/src/mcp/shared/metadata_utils.py new file mode 100644 index 000000000..e3f49daf4 --- /dev/null +++ b/src/mcp/shared/metadata_utils.py @@ -0,0 +1,45 @@ +"""Utility functions for working with metadata in MCP types. + +These utilities are primarily intended for client-side usage to properly display +human-readable names in user interfaces in a spec compliant way. +""" + +from mcp.types import Implementation, Prompt, Resource, ResourceTemplate, Tool + + +def get_display_name(obj: Tool | Resource | Prompt | ResourceTemplate | Implementation) -> str: + """ + Get the display name for an MCP object with proper precedence. + + This is a client-side utility function designed to help MCP clients display + human-readable names in their user interfaces. When servers provide a 'title' + field, it should be preferred over the programmatic 'name' field for display. + + For tools: title > annotations.title > name + For other objects: title > name + + Example: + # In a client displaying available tools + tools = await session.list_tools() + for tool in tools.tools: + display_name = get_display_name(tool) + print(f"Available tool: {display_name}") + + Args: + obj: An MCP object with name and optional title fields + + Returns: + The display name to use for UI presentation + """ + if isinstance(obj, Tool): + # Tools have special precedence: title > annotations.title > name + if hasattr(obj, "title") and obj.title is not None: + return obj.title + if obj.annotations and hasattr(obj.annotations, "title") and obj.annotations.title is not None: + return obj.annotations.title + return obj.name + else: + # All other objects: title > name + if hasattr(obj, "title") and obj.title is not None: + return obj.title + return obj.name diff --git a/src/mcp/shared/progress.py b/src/mcp/shared/progress.py index 856a8d3b6..1ad81a779 100644 --- a/src/mcp/shared/progress.py +++ b/src/mcp/shared/progress.py @@ -23,22 +23,8 @@ class Progress(BaseModel): @dataclass -class ProgressContext( - Generic[ - SendRequestT, - SendNotificationT, - SendResultT, - ReceiveRequestT, - ReceiveNotificationT, - ] -): - session: BaseSession[ - SendRequestT, - SendNotificationT, - SendResultT, - ReceiveRequestT, - ReceiveNotificationT, - ] +class ProgressContext(Generic[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]): + session: BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT] progress_token: ProgressToken total: float | None current: float = field(default=0.0, init=False) @@ -54,24 +40,12 @@ async def progress(self, amount: float, message: str | None = None) -> None: @contextmanager def progress( ctx: RequestContext[ - BaseSession[ - SendRequestT, - SendNotificationT, - SendResultT, - ReceiveRequestT, - ReceiveNotificationT, - ], + BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT], LifespanContextT, ], total: float | None = None, ) -> Generator[ - ProgressContext[ - SendRequestT, - SendNotificationT, - SendResultT, - ReceiveRequestT, - ReceiveNotificationT, - ], + ProgressContext[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT], None, ]: if ctx.meta is None or ctx.meta.progressToken is None: diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 4b13709c6..6536272d9 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -15,6 +15,7 @@ from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage from mcp.types import ( CONNECTION_CLOSED, + INVALID_PARAMS, CancelledNotification, ClientNotification, ClientRequest, @@ -37,9 +38,7 @@ SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification) ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest) ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel) -ReceiveNotificationT = TypeVar( - "ReceiveNotificationT", ClientNotification, ServerNotification -) +ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification) RequestId = str | int @@ -47,9 +46,7 @@ class ProgressFnT(Protocol): """Protocol for progress notification callbacks.""" - async def __call__( - self, progress: float, total: float | None, message: str | None - ) -> None: ... + async def __call__(self, progress: float, total: float | None, message: str | None) -> None: ... class RequestResponder(Generic[ReceiveRequestT, SendResultT]): @@ -176,9 +173,7 @@ class BaseSession( messages when entered. """ - _response_streams: dict[ - RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError] - ] + _response_streams: dict[RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]] _request_id: int _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]] _progress_callbacks: dict[RequestId, ProgressFnT] @@ -241,9 +236,7 @@ async def send_request( request_id = self._request_id self._request_id = request_id + 1 - response_stream, response_stream_reader = anyio.create_memory_object_stream[ - JSONRPCResponse | JSONRPCError - ](1) + response_stream, response_stream_reader = anyio.create_memory_object_stream[JSONRPCResponse | JSONRPCError](1) self._response_streams[request_id] = response_stream # Set up progress token if progress callback is provided @@ -265,11 +258,7 @@ async def send_request( **request_data, ) - await self._write_stream.send( - SessionMessage( - message=JSONRPCMessage(jsonrpc_request), metadata=metadata - ) - ) + await self._write_stream.send(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata)) # request read timeout takes precedence over session read timeout timeout = None @@ -321,15 +310,11 @@ async def send_notification( ) session_message = SessionMessage( message=JSONRPCMessage(jsonrpc_notification), - metadata=ServerMessageMetadata(related_request_id=related_request_id) - if related_request_id - else None, + metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None, ) await self._write_stream.send(session_message) - async def _send_response( - self, request_id: RequestId, response: SendResultT | ErrorData - ) -> None: + async def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None: if isinstance(response, ErrorData): jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response) session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error)) @@ -338,9 +323,7 @@ async def _send_response( jsonrpc_response = JSONRPCResponse( jsonrpc="2.0", id=request_id, - result=response.model_dump( - by_alias=True, mode="json", exclude_none=True - ), + result=response.model_dump(by_alias=True, mode="json", exclude_none=True), ) session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response)) await self._write_stream.send(session_message) @@ -350,88 +333,109 @@ async def _receive_loop(self) -> None: self._read_stream, self._write_stream, ): - async for message in self._read_stream: - if isinstance(message, Exception): - await self._handle_incoming(message) - elif isinstance(message.message.root, JSONRPCRequest): - validated_request = self._receive_request_type.model_validate( - message.message.root.model_dump( - by_alias=True, mode="json", exclude_none=True - ) - ) - responder = RequestResponder( - request_id=message.message.root.id, - request_meta=validated_request.root.params.meta - if validated_request.root.params - else None, - request=validated_request, - session=self, - on_complete=lambda r: self._in_flight.pop(r.request_id, None), - message_metadata=message.metadata, - ) - - self._in_flight[responder.request_id] = responder - await self._received_request(responder) - - if not responder._completed: # type: ignore[reportPrivateUsage] - await self._handle_incoming(responder) + try: + async for message in self._read_stream: + if isinstance(message, Exception): + await self._handle_incoming(message) + elif isinstance(message.message.root, JSONRPCRequest): + try: + validated_request = self._receive_request_type.model_validate( + message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + responder = RequestResponder( + request_id=message.message.root.id, + request_meta=validated_request.root.params.meta + if validated_request.root.params + else None, + request=validated_request, + session=self, + on_complete=lambda r: self._in_flight.pop(r.request_id, None), + message_metadata=message.metadata, + ) + self._in_flight[responder.request_id] = responder + await self._received_request(responder) + + if not responder._completed: # type: ignore[reportPrivateUsage] + await self._handle_incoming(responder) + except Exception as e: + # For request validation errors, send a proper JSON-RPC error + # response instead of crashing the server + logging.warning(f"Failed to validate request: {e}") + logging.debug(f"Message that failed validation: {message.message.root}") + error_response = JSONRPCError( + jsonrpc="2.0", + id=message.message.root.id, + error=ErrorData( + code=INVALID_PARAMS, + message="Invalid request parameters", + data="", + ), + ) + session_message = SessionMessage(message=JSONRPCMessage(error_response)) + await self._write_stream.send(session_message) - elif isinstance(message.message.root, JSONRPCNotification): - try: - notification = self._receive_notification_type.model_validate( - message.message.root.model_dump( - by_alias=True, mode="json", exclude_none=True + elif isinstance(message.message.root, JSONRPCNotification): + try: + notification = self._receive_notification_type.model_validate( + message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True) ) - ) - # Handle cancellation notifications - if isinstance(notification.root, CancelledNotification): - cancelled_id = notification.root.params.requestId - if cancelled_id in self._in_flight: - await self._in_flight[cancelled_id].cancel() + # Handle cancellation notifications + if isinstance(notification.root, CancelledNotification): + cancelled_id = notification.root.params.requestId + if cancelled_id in self._in_flight: + await self._in_flight[cancelled_id].cancel() + else: + # Handle progress notifications callback + if isinstance(notification.root, ProgressNotification): + progress_token = notification.root.params.progressToken + # If there is a progress callback for this token, + # call it with the progress information + if progress_token in self._progress_callbacks: + callback = self._progress_callbacks[progress_token] + await callback( + notification.root.params.progress, + notification.root.params.total, + notification.root.params.message, + ) + await self._received_notification(notification) + await self._handle_incoming(notification) + except Exception as e: + # For other validation errors, log and continue + logging.warning( + f"Failed to validate notification: {e}. " f"Message was: {message.message.root}" + ) + else: # Response or error + stream = self._response_streams.pop(message.message.root.id, None) + if stream: + await stream.send(message.message.root) else: - # Handle progress notifications callback - if isinstance(notification.root, ProgressNotification): - progress_token = notification.root.params.progressToken - # If there is a progress callback for this token, - # call it with the progress information - if progress_token in self._progress_callbacks: - callback = self._progress_callbacks[progress_token] - await callback( - notification.root.params.progress, - notification.root.params.total, - notification.root.params.message, - ) - await self._received_notification(notification) - await self._handle_incoming(notification) - except Exception as e: - # For other validation errors, log and continue - logging.warning( - f"Failed to validate notification: {e}. " - f"Message was: {message.message.root}" - ) - else: # Response or error - stream = self._response_streams.pop(message.message.root.id, None) - if stream: - await stream.send(message.message.root) - else: - await self._handle_incoming( - RuntimeError( - "Received response with an unknown " - f"request ID: {message}" + await self._handle_incoming( + RuntimeError("Received response with an unknown " f"request ID: {message}") ) - ) - - # after the read stream is closed, we need to send errors - # to any pending requests - for id, stream in self._response_streams.items(): - error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed") - await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error)) - await stream.aclose() - self._response_streams.clear() - - async def _received_request( - self, responder: RequestResponder[ReceiveRequestT, SendResultT] - ) -> None: + + except anyio.ClosedResourceError: + # This is expected when the client disconnects abruptly. + # Without this handler, the exception would propagate up and + # crash the server's task group. + logging.debug("Read stream closed by client") + except Exception as e: + # Other exceptions are not expected and should be logged. We purposefully + # catch all exceptions here to avoid crashing the server. + logging.exception(f"Unhandled exception in receive loop: {e}") + finally: + # after the read stream is closed, we need to send errors + # to any pending requests + for id, stream in self._response_streams.items(): + error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed") + try: + await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error)) + await stream.aclose() + except Exception: + # Stream might already be closed + pass + self._response_streams.clear() + + async def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None: """ Can be overridden by subclasses to handle a request without needing to listen on the message stream. @@ -460,9 +464,7 @@ async def send_progress_notification( async def _handle_incoming( self, - req: RequestResponder[ReceiveRequestT, SendResultT] - | ReceiveNotificationT - | Exception, + req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception, ) -> None: """A generic handler for incoming messages. Overwritten by subclasses.""" pass diff --git a/src/mcp/types.py b/src/mcp/types.py index 4f5af27b9..d5663dad6 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -1,15 +1,9 @@ from collections.abc import Callable -from typing import ( - Annotated, - Any, - Generic, - Literal, - TypeAlias, - TypeVar, -) +from typing import Annotated, Any, Generic, Literal, TypeAlias, TypeVar from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel from pydantic.networks import AnyUrl, UrlConstraints +from typing_extensions import deprecated """ Model Context Protocol bindings for Python @@ -31,10 +25,18 @@ LATEST_PROTOCOL_VERSION = "2025-03-26" +""" +The default negotiated version of the Model Context Protocol when no version is specified. +We need this to satisfy the MCP specification, which requires the server to assume a +specific version if none is provided by the client. See section "Protocol Version Header" at +https://modelcontextprotocol.io/specification +""" +DEFAULT_NEGOTIATED_VERSION = "2025-03-26" + ProgressToken = str | int Cursor = str Role = Literal["user", "assistant"] -RequestId = str | int +RequestId = Annotated[int | str, Field(union_mode="left_to_right")] AnyFunction: TypeAlias = Callable[..., Any] @@ -67,15 +69,13 @@ class Meta(BaseModel): meta: Meta | None = Field(alias="_meta", default=None) """ - This parameter name is reserved by MCP to allow clients and servers to attach - additional metadata to their notifications. + See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + for notes on _meta usage. """ RequestParamsT = TypeVar("RequestParamsT", bound=RequestParams | dict[str, Any] | None) -NotificationParamsT = TypeVar( - "NotificationParamsT", bound=NotificationParams | dict[str, Any] | None -) +NotificationParamsT = TypeVar("NotificationParamsT", bound=NotificationParams | dict[str, Any] | None) MethodT = TypeVar("MethodT", bound=str) @@ -87,9 +87,7 @@ class Request(BaseModel, Generic[RequestParamsT, MethodT]): model_config = ConfigDict(extra="allow") -class PaginatedRequest( - Request[PaginatedRequestParams | None, MethodT], Generic[MethodT] -): +class PaginatedRequest(Request[PaginatedRequestParams | None, MethodT], Generic[MethodT]): """Base class for paginated requests, matching the schema's PaginatedRequest interface.""" @@ -107,13 +105,12 @@ class Notification(BaseModel, Generic[NotificationParamsT, MethodT]): class Result(BaseModel): """Base class for JSON-RPC results.""" - model_config = ConfigDict(extra="allow") - meta: dict[str, Any] | None = Field(alias="_meta", default=None) """ - This result property is reserved by the protocol to allow clients and servers to - attach additional metadata to their responses. + See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + for notes on _meta usage. """ + model_config = ConfigDict(extra="allow") class PaginatedResult(Result): @@ -191,9 +188,7 @@ class JSONRPCError(BaseModel): model_config = ConfigDict(extra="allow") -class JSONRPCMessage( - RootModel[JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCError] -): +class JSONRPCMessage(RootModel[JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCError]): pass @@ -201,10 +196,26 @@ class EmptyResult(Result): """A response that indicates success but carries no data.""" -class Implementation(BaseModel): - """Describes the name and version of an MCP implementation.""" +class BaseMetadata(BaseModel): + """Base class for entities with name and optional title fields.""" name: str + """The programmatic name of the entity.""" + + title: str | None = None + """ + Intended for UI and end-user contexts β€” optimized to be human-readable and easily understood, + even by those unfamiliar with domain-specific terminology. + + If not provided, the name should be used for display (except for Tool, + where `annotations.title` should be given precedence over using `name`, + if present). + """ + + +class Implementation(BaseMetadata): + """Describes the name and version of an MCP implementation.""" + version: str model_config = ConfigDict(extra="allow") @@ -223,6 +234,12 @@ class SamplingCapability(BaseModel): model_config = ConfigDict(extra="allow") +class ElicitationCapability(BaseModel): + """Capability for elicitation operations.""" + + model_config = ConfigDict(extra="allow") + + class ClientCapabilities(BaseModel): """Capabilities a client may support.""" @@ -230,6 +247,8 @@ class ClientCapabilities(BaseModel): """Experimental, non-standard capabilities that the client supports.""" sampling: SamplingCapability | None = None """Present if the client supports sampling from an LLM.""" + elicitation: ElicitationCapability | None = None + """Present if the client supports elicitation from the user.""" roots: RootsCapability | None = None """Present if the client supports listing roots.""" model_config = ConfigDict(extra="allow") @@ -314,9 +333,7 @@ class InitializeResult(Result): """Instructions describing how to use the server and its features.""" -class InitializedNotification( - Notification[NotificationParams | None, Literal["notifications/initialized"]] -): +class InitializedNotification(Notification[NotificationParams | None, Literal["notifications/initialized"]]): """ This notification is sent from the client to the server after initialization has finished. @@ -350,18 +367,16 @@ class ProgressNotificationParams(NotificationParams): total is unknown. """ total: float | None = None + """Total number of items to process (or total progress required), if known.""" + message: str | None = None """ - Message related to progress. This should provide relevant human readable + Message related to progress. This should provide relevant human readable progress information. """ - message: str | None = None - """Total number of items to process (or total progress required), if known.""" model_config = ConfigDict(extra="allow") -class ProgressNotification( - Notification[ProgressNotificationParams, Literal["notifications/progress"]] -): +class ProgressNotification(Notification[ProgressNotificationParams, Literal["notifications/progress"]]): """ An out-of-band notification used to inform the receiver of a progress update for a long-running request. @@ -383,13 +398,11 @@ class Annotations(BaseModel): model_config = ConfigDict(extra="allow") -class Resource(BaseModel): +class Resource(BaseMetadata): """A known resource that the server is capable of reading.""" uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] """The URI of this resource.""" - name: str - """A human-readable name for this resource.""" description: str | None = None """A description of what this resource represents.""" mimeType: str | None = None @@ -402,10 +415,15 @@ class Resource(BaseModel): This can be used by Hosts to display file sizes and estimate context window usage. """ annotations: Annotations | None = None + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + """ + See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + for notes on _meta usage. + """ model_config = ConfigDict(extra="allow") -class ResourceTemplate(BaseModel): +class ResourceTemplate(BaseMetadata): """A template description for resources available on the server.""" uriTemplate: str @@ -413,8 +431,6 @@ class ResourceTemplate(BaseModel): A URI template (according to RFC 6570) that can be used to construct resource URIs. """ - name: str - """A human-readable name for the type of resource this template refers to.""" description: str | None = None """A human-readable description of what this template is for.""" mimeType: str | None = None @@ -423,6 +439,11 @@ class ResourceTemplate(BaseModel): included if all resources matching this template have the same type. """ annotations: Annotations | None = None + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + """ + See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + for notes on _meta usage. + """ model_config = ConfigDict(extra="allow") @@ -432,9 +453,7 @@ class ListResourcesResult(PaginatedResult): resources: list[Resource] -class ListResourceTemplatesRequest( - PaginatedRequest[Literal["resources/templates/list"]] -): +class ListResourceTemplatesRequest(PaginatedRequest[Literal["resources/templates/list"]]): """Sent from the client to request a list of resource templates the server has.""" method: Literal["resources/templates/list"] @@ -457,9 +476,7 @@ class ReadResourceRequestParams(RequestParams): model_config = ConfigDict(extra="allow") -class ReadResourceRequest( - Request[ReadResourceRequestParams, Literal["resources/read"]] -): +class ReadResourceRequest(Request[ReadResourceRequestParams, Literal["resources/read"]]): """Sent from the client to the server, to read a specific resource URI.""" method: Literal["resources/read"] @@ -473,6 +490,11 @@ class ResourceContents(BaseModel): """The URI of this resource.""" mimeType: str | None = None """The MIME type of this resource, if known.""" + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + """ + See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + for notes on _meta usage. + """ model_config = ConfigDict(extra="allow") @@ -500,9 +522,7 @@ class ReadResourceResult(Result): class ResourceListChangedNotification( - Notification[ - NotificationParams | None, Literal["notifications/resources/list_changed"] - ] + Notification[NotificationParams | None, Literal["notifications/resources/list_changed"]] ): """ An optional notification from the server to the client, informing it that the list @@ -542,9 +562,7 @@ class UnsubscribeRequestParams(RequestParams): model_config = ConfigDict(extra="allow") -class UnsubscribeRequest( - Request[UnsubscribeRequestParams, Literal["resources/unsubscribe"]] -): +class UnsubscribeRequest(Request[UnsubscribeRequestParams, Literal["resources/unsubscribe"]]): """ Sent from the client to request cancellation of resources/updated notifications from the server. @@ -566,9 +584,7 @@ class ResourceUpdatedNotificationParams(NotificationParams): class ResourceUpdatedNotification( - Notification[ - ResourceUpdatedNotificationParams, Literal["notifications/resources/updated"] - ] + Notification[ResourceUpdatedNotificationParams, Literal["notifications/resources/updated"]] ): """ A notification from the server to the client, informing it that a resource has @@ -597,15 +613,18 @@ class PromptArgument(BaseModel): model_config = ConfigDict(extra="allow") -class Prompt(BaseModel): +class Prompt(BaseMetadata): """A prompt or prompt template that the server offers.""" - name: str - """The name of the prompt or prompt template.""" description: str | None = None """An optional description of what this prompt provides.""" arguments: list[PromptArgument] | None = None """A list of arguments to use for templating the prompt.""" + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + """ + See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + for notes on _meta usage. + """ model_config = ConfigDict(extra="allow") @@ -639,6 +658,11 @@ class TextContent(BaseModel): text: str """The text content of the message.""" annotations: Annotations | None = None + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + """ + See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + for notes on _meta usage. + """ model_config = ConfigDict(extra="allow") @@ -654,6 +678,31 @@ class ImageContent(BaseModel): image types. """ annotations: Annotations | None = None + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + """ + See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + for notes on _meta usage. + """ + model_config = ConfigDict(extra="allow") + + +class AudioContent(BaseModel): + """Audio content for a message.""" + + type: Literal["audio"] + data: str + """The base64-encoded audio data.""" + mimeType: str + """ + The MIME type of the audio. Different providers may support different + audio types. + """ + annotations: Annotations | None = None + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + """ + See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + for notes on _meta usage. + """ model_config = ConfigDict(extra="allow") @@ -661,7 +710,7 @@ class SamplingMessage(BaseModel): """Describes a message issued to or received from an LLM API.""" role: Role - content: TextContent | ImageContent + content: TextContent | ImageContent | AudioContent model_config = ConfigDict(extra="allow") @@ -676,14 +725,36 @@ class EmbeddedResource(BaseModel): type: Literal["resource"] resource: TextResourceContents | BlobResourceContents annotations: Annotations | None = None + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + """ + See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + for notes on _meta usage. + """ model_config = ConfigDict(extra="allow") +class ResourceLink(Resource): + """ + A resource that the server is capable of reading, included in a prompt or tool call result. + + Note: resource links returned by tools are not guaranteed to appear in the results of `resources/list` requests. + """ + + type: Literal["resource_link"] + + +ContentBlock = TextContent | ImageContent | AudioContent | ResourceLink | EmbeddedResource +"""A content block that can be used in prompts and tool results.""" + +Content: TypeAlias = ContentBlock +# """DEPRECATED: Content is deprecated, you should use ContentBlock directly.""" + + class PromptMessage(BaseModel): """Describes a message returned as part of a prompt.""" role: Role - content: TextContent | ImageContent | EmbeddedResource + content: ContentBlock model_config = ConfigDict(extra="allow") @@ -696,9 +767,7 @@ class GetPromptResult(Result): class PromptListChangedNotification( - Notification[ - NotificationParams | None, Literal["notifications/prompts/list_changed"] - ] + Notification[NotificationParams | None, Literal["notifications/prompts/list_changed"]] ): """ An optional notification from the server to the client, informing it that the list @@ -763,17 +832,20 @@ class ToolAnnotations(BaseModel): model_config = ConfigDict(extra="allow") -class Tool(BaseModel): +class Tool(BaseMetadata): """Definition for a tool the client can call.""" - name: str - """The name of the tool.""" description: str | None = None """A human-readable description of the tool.""" inputSchema: dict[str, Any] """A JSON Schema object defining the expected parameters for the tool.""" annotations: ToolAnnotations | None = None """Optional additional tool information.""" + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + """ + See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + for notes on _meta usage. + """ model_config = ConfigDict(extra="allow") @@ -801,13 +873,11 @@ class CallToolRequest(Request[CallToolRequestParams, Literal["tools/call"]]): class CallToolResult(Result): """The server's response to a tool call.""" - content: list[TextContent | ImageContent | EmbeddedResource] + content: list[ContentBlock] isError: bool = False -class ToolListChangedNotification( - Notification[NotificationParams | None, Literal["notifications/tools/list_changed"]] -): +class ToolListChangedNotification(Notification[NotificationParams | None, Literal["notifications/tools/list_changed"]]): """ An optional notification from the server to the client, informing it that the list of tools it offers has changed. @@ -817,9 +887,7 @@ class ToolListChangedNotification( params: NotificationParams | None = None -LoggingLevel = Literal[ - "debug", "info", "notice", "warning", "error", "critical", "alert", "emergency" -] +LoggingLevel = Literal["debug", "info", "notice", "warning", "error", "critical", "alert", "emergency"] class SetLevelRequestParams(RequestParams): @@ -852,9 +920,7 @@ class LoggingMessageNotificationParams(NotificationParams): model_config = ConfigDict(extra="allow") -class LoggingMessageNotification( - Notification[LoggingMessageNotificationParams, Literal["notifications/message"]] -): +class LoggingMessageNotification(Notification[LoggingMessageNotificationParams, Literal["notifications/message"]]): """Notification of a log message passed from server to client.""" method: Literal["notifications/message"] @@ -949,9 +1015,7 @@ class CreateMessageRequestParams(RequestParams): model_config = ConfigDict(extra="allow") -class CreateMessageRequest( - Request[CreateMessageRequestParams, Literal["sampling/createMessage"]] -): +class CreateMessageRequest(Request[CreateMessageRequestParams, Literal["sampling/createMessage"]]): """A request from the server to sample an LLM via the client.""" method: Literal["sampling/createMessage"] @@ -965,14 +1029,14 @@ class CreateMessageResult(Result): """The client's response to a sampling/create_message request from the server.""" role: Role - content: TextContent | ImageContent + content: TextContent | ImageContent | AudioContent model: str """The name of the model that generated the message.""" stopReason: StopReason | None = None """The reason why sampling stopped, if known.""" -class ResourceReference(BaseModel): +class ResourceTemplateReference(BaseModel): """A reference to a resource or resource template definition.""" type: Literal["ref/resource"] @@ -981,6 +1045,11 @@ class ResourceReference(BaseModel): model_config = ConfigDict(extra="allow") +@deprecated("`ResourceReference` is deprecated, you should use `ResourceTemplateReference`.") +class ResourceReference(ResourceTemplateReference): + pass + + class PromptReference(BaseModel): """Identifies a prompt.""" @@ -1000,11 +1069,21 @@ class CompletionArgument(BaseModel): model_config = ConfigDict(extra="allow") +class CompletionContext(BaseModel): + """Additional, optional context for completions.""" + + arguments: dict[str, str] | None = None + """Previously-resolved variables in a URI template or prompt.""" + model_config = ConfigDict(extra="allow") + + class CompleteRequestParams(RequestParams): """Parameters for completion requests.""" - ref: ResourceReference | PromptReference + ref: ResourceTemplateReference | PromptReference argument: CompletionArgument + context: CompletionContext | None = None + """Additional, optional context for completions""" model_config = ConfigDict(extra="allow") @@ -1069,6 +1148,11 @@ class Root(BaseModel): identifier for the root, which may be useful for display purposes or for referencing the root in other parts of the application. """ + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + """ + See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + for notes on _meta usage. + """ model_config = ConfigDict(extra="allow") @@ -1108,9 +1192,7 @@ class CancelledNotificationParams(NotificationParams): model_config = ConfigDict(extra="allow") -class CancelledNotification( - Notification[CancelledNotificationParams, Literal["notifications/cancelled"]] -): +class CancelledNotification(Notification[CancelledNotificationParams, Literal["notifications/cancelled"]]): """ This notification can be sent by either side to indicate that it is canceling a previously-issued request. @@ -1141,21 +1223,54 @@ class ClientRequest( class ClientNotification( - RootModel[ - CancelledNotification - | ProgressNotification - | InitializedNotification - | RootsListChangedNotification - ] + RootModel[CancelledNotification | ProgressNotification | InitializedNotification | RootsListChangedNotification] ): pass -class ClientResult(RootModel[EmptyResult | CreateMessageResult | ListRootsResult]): +# Type for elicitation schema - a JSON Schema dict +ElicitRequestedSchema: TypeAlias = dict[str, Any] +"""Schema for elicitation requests.""" + + +class ElicitRequestParams(RequestParams): + """Parameters for elicitation requests.""" + + message: str + requestedSchema: ElicitRequestedSchema + model_config = ConfigDict(extra="allow") + + +class ElicitRequest(Request[ElicitRequestParams, Literal["elicitation/create"]]): + """A request from the server to elicit information from the client.""" + + method: Literal["elicitation/create"] + params: ElicitRequestParams + + +class ElicitResult(Result): + """The client's response to an elicitation request.""" + + action: Literal["accept", "decline", "cancel"] + """ + The user action in response to the elicitation. + - "accept": User submitted the form/confirmed the action + - "decline": User explicitly declined the action + - "cancel": User dismissed without making an explicit choice + """ + + content: dict[str, str | int | float | bool | None] | None = None + """ + The submitted form data, only present when action is "accept". + Contains values matching the requested schema. + """ + + +class ClientResult(RootModel[EmptyResult | CreateMessageResult | ListRootsResult | ElicitResult]): pass -class ServerRequest(RootModel[PingRequest | CreateMessageRequest | ListRootsRequest]): +class ServerRequest(RootModel[PingRequest | CreateMessageRequest | ListRootsRequest | ElicitRequest]): pass diff --git a/tests/client/conftest.py b/tests/client/conftest.py index 60ccac743..0c8283903 100644 --- a/tests/client/conftest.py +++ b/tests/client/conftest.py @@ -49,8 +49,7 @@ def get_client_requests(self, method: str | None = None) -> list[JSONRPCRequest] return [ req.message.root for req in self.client.sent_messages - if isinstance(req.message.root, JSONRPCRequest) - and (method is None or req.message.root.method == method) + if isinstance(req.message.root, JSONRPCRequest) and (method is None or req.message.root.method == method) ] def get_server_requests(self, method: str | None = None) -> list[JSONRPCRequest]: @@ -58,13 +57,10 @@ def get_server_requests(self, method: str | None = None) -> list[JSONRPCRequest] return [ req.message.root for req in self.server.sent_messages - if isinstance(req.message.root, JSONRPCRequest) - and (method is None or req.message.root.method == method) + if isinstance(req.message.root, JSONRPCRequest) and (method is None or req.message.root.method == method) ] - def get_client_notifications( - self, method: str | None = None - ) -> list[JSONRPCNotification]: + def get_client_notifications(self, method: str | None = None) -> list[JSONRPCNotification]: """Get client-sent notifications, optionally filtered by method.""" return [ notif.message.root @@ -73,9 +69,7 @@ def get_client_notifications( and (method is None or notif.message.root.method == method) ] - def get_server_notifications( - self, method: str | None = None - ) -> list[JSONRPCNotification]: + def get_server_notifications(self, method: str | None = None) -> list[JSONRPCNotification]: """Get server-sent notifications, optionally filtered by method.""" return [ notif.message.root @@ -133,9 +127,7 @@ async def patched_create_streams(): yield (client_read, spy_client_write), (server_read, spy_server_write) # Apply the patch for the duration of the test - with patch( - "mcp.shared.memory.create_client_server_memory_streams", patched_create_streams - ): + with patch("mcp.shared.memory.create_client_server_memory_streams", patched_create_streams): # Return a collection with helper methods def get_spy_collection() -> StreamSpyCollection: assert client_spy is not None, "client_spy was not initialized" diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 2edaff946..de4eb70af 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -11,7 +11,7 @@ import httpx import pytest from inline_snapshot import snapshot -from pydantic import AnyHttpUrl +from pydantic import AnyHttpUrl, AnyUrl from mcp.client.auth import OAuthClientProvider from mcp.server.auth.routes import build_metadata @@ -52,7 +52,7 @@ def mock_storage(): @pytest.fixture def client_metadata(): return OAuthClientMetadata( - redirect_uris=[AnyHttpUrl("http://localhost:3000/callback")], + redirect_uris=[AnyUrl("http://localhost:3000/callback")], client_name="Test Client", grant_types=["authorization_code", "refresh_token"], response_types=["code"], @@ -79,7 +79,7 @@ def oauth_client_info(): return OAuthClientInformationFull( client_id="test_client_id", client_secret="test_client_secret", - redirect_uris=[AnyHttpUrl("http://localhost:3000/callback")], + redirect_uris=[AnyUrl("http://localhost:3000/callback")], client_name="Test Client", grant_types=["authorization_code", "refresh_token"], response_types=["code"], @@ -91,7 +91,7 @@ def oauth_client_info(): def oauth_token(): return OAuthToken( access_token="test_access_token", - token_type="bearer", + token_type="Bearer", expires_in=3600, refresh_token="test_refresh_token", scope="read write", @@ -134,25 +134,22 @@ def test_generate_code_verifier(self, oauth_provider): assert len(verifier) == 128 # Check charset (RFC 7636: A-Z, a-z, 0-9, "-", ".", "_", "~") - allowed_chars = set( - "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~" - ) + allowed_chars = set("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~") assert set(verifier) <= allowed_chars # Check uniqueness (generate multiple and ensure they're different) verifiers = {oauth_provider._generate_code_verifier() for _ in range(10)} assert len(verifiers) == 10 - def test_generate_code_challenge(self, oauth_provider): + @pytest.mark.anyio + async def test_generate_code_challenge(self, oauth_provider): """Test PKCE code challenge generation.""" verifier = "test_code_verifier_123" challenge = oauth_provider._generate_code_challenge(verifier) # Manually calculate expected challenge expected_digest = hashlib.sha256(verifier.encode()).digest() - expected_challenge = ( - base64.urlsafe_b64encode(expected_digest).decode().rstrip("=") - ) + expected_challenge = base64.urlsafe_b64encode(expected_digest).decode().rstrip("=") assert challenge == expected_challenge @@ -161,32 +158,23 @@ def test_generate_code_challenge(self, oauth_provider): assert "+" not in challenge assert "/" not in challenge - def test_get_authorization_base_url(self, oauth_provider): + @pytest.mark.anyio + async def test_get_authorization_base_url(self, oauth_provider): """Test authorization base URL extraction.""" # Test with path - assert ( - oauth_provider._get_authorization_base_url("https://api.example.com/v1/mcp") - == "https://api.example.com" - ) + assert oauth_provider._get_authorization_base_url("https://api.example.com/v1/mcp") == "https://api.example.com" # Test with no path - assert ( - oauth_provider._get_authorization_base_url("https://api.example.com") - == "https://api.example.com" - ) + assert oauth_provider._get_authorization_base_url("https://api.example.com") == "https://api.example.com" # Test with port assert ( - oauth_provider._get_authorization_base_url( - "https://api.example.com:8080/path/to/mcp" - ) + oauth_provider._get_authorization_base_url("https://api.example.com:8080/path/to/mcp") == "https://api.example.com:8080" ) @pytest.mark.anyio - async def test_discover_oauth_metadata_success( - self, oauth_provider, oauth_metadata - ): + async def test_discover_oauth_metadata_success(self, oauth_provider, oauth_metadata): """Test successful OAuth metadata discovery.""" metadata_response = oauth_metadata.model_dump(by_alias=True, mode="json") @@ -199,23 +187,16 @@ async def test_discover_oauth_metadata_success( mock_response.json.return_value = metadata_response mock_client.get.return_value = mock_response - result = await oauth_provider._discover_oauth_metadata( - "https://api.example.com/v1/mcp" - ) + result = await oauth_provider._discover_oauth_metadata("https://api.example.com/v1/mcp") assert result is not None - assert ( - result.authorization_endpoint == oauth_metadata.authorization_endpoint - ) + assert result.authorization_endpoint == oauth_metadata.authorization_endpoint assert result.token_endpoint == oauth_metadata.token_endpoint # Verify correct URL was called mock_client.get.assert_called_once() call_args = mock_client.get.call_args[0] - assert ( - call_args[0] - == "https://api.example.com/.well-known/oauth-authorization-server" - ) + assert call_args[0] == "https://api.example.com/.well-known/oauth-authorization-server" @pytest.mark.anyio async def test_discover_oauth_metadata_not_found(self, oauth_provider): @@ -228,16 +209,12 @@ async def test_discover_oauth_metadata_not_found(self, oauth_provider): mock_response.status_code = 404 mock_client.get.return_value = mock_response - result = await oauth_provider._discover_oauth_metadata( - "https://api.example.com/v1/mcp" - ) + result = await oauth_provider._discover_oauth_metadata("https://api.example.com/v1/mcp") assert result is None @pytest.mark.anyio - async def test_discover_oauth_metadata_cors_fallback( - self, oauth_provider, oauth_metadata - ): + async def test_discover_oauth_metadata_cors_fallback(self, oauth_provider, oauth_metadata): """Test OAuth metadata discovery with CORS fallback.""" metadata_response = oauth_metadata.model_dump(by_alias=True, mode="json") @@ -255,17 +232,13 @@ async def test_discover_oauth_metadata_cors_fallback( mock_response_success, # Second call succeeds ] - result = await oauth_provider._discover_oauth_metadata( - "https://api.example.com/v1/mcp" - ) + result = await oauth_provider._discover_oauth_metadata("https://api.example.com/v1/mcp") assert result is not None assert mock_client.get.call_count == 2 @pytest.mark.anyio - async def test_register_oauth_client_success( - self, oauth_provider, oauth_metadata, oauth_client_info - ): + async def test_register_oauth_client_success(self, oauth_provider, oauth_metadata, oauth_client_info): """Test successful OAuth client registration.""" registration_response = oauth_client_info.model_dump(by_alias=True, mode="json") @@ -293,9 +266,7 @@ async def test_register_oauth_client_success( assert call_args[0][0] == str(oauth_metadata.registration_endpoint) @pytest.mark.anyio - async def test_register_oauth_client_fallback_endpoint( - self, oauth_provider, oauth_client_info - ): + async def test_register_oauth_client_fallback_endpoint(self, oauth_provider, oauth_client_info): """Test OAuth client registration with fallback endpoint.""" registration_response = oauth_client_info.model_dump(by_alias=True, mode="json") @@ -309,9 +280,7 @@ async def test_register_oauth_client_fallback_endpoint( mock_client.post.return_value = mock_response # Mock metadata discovery to return None (fallback) - with patch.object( - oauth_provider, "_discover_oauth_metadata", return_value=None - ): + with patch.object(oauth_provider, "_discover_oauth_metadata", return_value=None): result = await oauth_provider._register_oauth_client( "https://api.example.com/v1/mcp", oauth_provider.client_metadata, @@ -338,9 +307,7 @@ async def test_register_oauth_client_failure(self, oauth_provider): mock_client.post.return_value = mock_response # Mock metadata discovery to return None (fallback) - with patch.object( - oauth_provider, "_discover_oauth_metadata", return_value=None - ): + with patch.object(oauth_provider, "_discover_oauth_metadata", return_value=None): with pytest.raises(httpx.HTTPStatusError): await oauth_provider._register_oauth_client( "https://api.example.com/v1/mcp", @@ -348,11 +315,13 @@ async def test_register_oauth_client_failure(self, oauth_provider): None, ) - def test_has_valid_token_no_token(self, oauth_provider): + @pytest.mark.anyio + async def test_has_valid_token_no_token(self, oauth_provider): """Test token validation with no token.""" assert not oauth_provider._has_valid_token() - def test_has_valid_token_valid(self, oauth_provider, oauth_token): + @pytest.mark.anyio + async def test_has_valid_token_valid(self, oauth_provider, oauth_token): """Test token validation with valid token.""" oauth_provider._current_tokens = oauth_token oauth_provider._token_expiry_time = time.time() + 3600 # Future expiry @@ -370,7 +339,7 @@ async def test_has_valid_token_expired(self, oauth_provider, oauth_token): @pytest.mark.anyio async def test_validate_token_scopes_no_scope(self, oauth_provider): """Test scope validation with no scope returned.""" - token = OAuthToken(access_token="test", token_type="bearer") + token = OAuthToken(access_token="test", token_type="Bearer") # Should not raise exception await oauth_provider._validate_token_scopes(token) @@ -381,7 +350,7 @@ async def test_validate_token_scopes_valid(self, oauth_provider, client_metadata oauth_provider.client_metadata = client_metadata token = OAuthToken( access_token="test", - token_type="bearer", + token_type="Bearer", scope="read write", ) @@ -394,7 +363,7 @@ async def test_validate_token_scopes_subset(self, oauth_provider, client_metadat oauth_provider.client_metadata = client_metadata token = OAuthToken( access_token="test", - token_type="bearer", + token_type="Bearer", scope="read", ) @@ -402,14 +371,12 @@ async def test_validate_token_scopes_subset(self, oauth_provider, client_metadat await oauth_provider._validate_token_scopes(token) @pytest.mark.anyio - async def test_validate_token_scopes_unauthorized( - self, oauth_provider, client_metadata - ): + async def test_validate_token_scopes_unauthorized(self, oauth_provider, client_metadata): """Test scope validation with unauthorized scopes.""" oauth_provider.client_metadata = client_metadata token = OAuthToken( access_token="test", - token_type="bearer", + token_type="Bearer", scope="read write admin", # Includes unauthorized "admin" ) @@ -423,7 +390,7 @@ async def test_validate_token_scopes_no_requested(self, oauth_provider): oauth_provider.client_metadata.scope = None token = OAuthToken( access_token="test", - token_type="bearer", + token_type="Bearer", scope="admin super", ) @@ -432,9 +399,7 @@ async def test_validate_token_scopes_no_requested(self, oauth_provider): await oauth_provider._validate_token_scopes(token) @pytest.mark.anyio - async def test_initialize( - self, oauth_provider, mock_storage, oauth_token, oauth_client_info - ): + async def test_initialize(self, oauth_provider, mock_storage, oauth_token, oauth_client_info): """Test initialization loading from storage.""" mock_storage._tokens = oauth_token mock_storage._client_info = oauth_client_info @@ -445,9 +410,7 @@ async def test_initialize( assert oauth_provider._client_info == oauth_client_info @pytest.mark.anyio - async def test_get_or_register_client_existing( - self, oauth_provider, oauth_client_info - ): + async def test_get_or_register_client_existing(self, oauth_provider, oauth_client_info): """Test getting existing client info.""" oauth_provider._client_info = oauth_client_info @@ -456,13 +419,9 @@ async def test_get_or_register_client_existing( assert result == oauth_client_info @pytest.mark.anyio - async def test_get_or_register_client_register_new( - self, oauth_provider, oauth_client_info - ): + async def test_get_or_register_client_register_new(self, oauth_provider, oauth_client_info): """Test registering new client.""" - with patch.object( - oauth_provider, "_register_oauth_client", return_value=oauth_client_info - ) as mock_register: + with patch.object(oauth_provider, "_register_oauth_client", return_value=oauth_client_info) as mock_register: result = await oauth_provider._get_or_register_client() assert result == oauth_client_info @@ -470,9 +429,7 @@ async def test_get_or_register_client_register_new( mock_register.assert_called_once() @pytest.mark.anyio - async def test_exchange_code_for_token_success( - self, oauth_provider, oauth_client_info, oauth_token - ): + async def test_exchange_code_for_token_success(self, oauth_provider, oauth_client_info, oauth_token): """Test successful code exchange for token.""" oauth_provider._code_verifier = "test_verifier" token_response = oauth_token.model_dump(by_alias=True, mode="json") @@ -486,23 +443,14 @@ async def test_exchange_code_for_token_success( mock_response.json.return_value = token_response mock_client.post.return_value = mock_response - with patch.object( - oauth_provider, "_validate_token_scopes" - ) as mock_validate: - await oauth_provider._exchange_code_for_token( - "test_auth_code", oauth_client_info - ) + with patch.object(oauth_provider, "_validate_token_scopes") as mock_validate: + await oauth_provider._exchange_code_for_token("test_auth_code", oauth_client_info) - assert ( - oauth_provider._current_tokens.access_token - == oauth_token.access_token - ) + assert oauth_provider._current_tokens.access_token == oauth_token.access_token mock_validate.assert_called_once() @pytest.mark.anyio - async def test_exchange_code_for_token_failure( - self, oauth_provider, oauth_client_info - ): + async def test_exchange_code_for_token_failure(self, oauth_provider, oauth_client_info): """Test failed code exchange for token.""" oauth_provider._code_verifier = "test_verifier" @@ -516,21 +464,17 @@ async def test_exchange_code_for_token_failure( mock_client.post.return_value = mock_response with pytest.raises(Exception, match="Token exchange failed"): - await oauth_provider._exchange_code_for_token( - "invalid_auth_code", oauth_client_info - ) + await oauth_provider._exchange_code_for_token("invalid_auth_code", oauth_client_info) @pytest.mark.anyio - async def test_refresh_access_token_success( - self, oauth_provider, oauth_client_info, oauth_token - ): + async def test_refresh_access_token_success(self, oauth_provider, oauth_client_info, oauth_token): """Test successful token refresh.""" oauth_provider._current_tokens = oauth_token oauth_provider._client_info = oauth_client_info new_token = OAuthToken( access_token="new_access_token", - token_type="bearer", + token_type="Bearer", expires_in=3600, refresh_token="new_refresh_token", scope="read write", @@ -546,16 +490,11 @@ async def test_refresh_access_token_success( mock_response.json.return_value = token_response mock_client.post.return_value = mock_response - with patch.object( - oauth_provider, "_validate_token_scopes" - ) as mock_validate: + with patch.object(oauth_provider, "_validate_token_scopes") as mock_validate: result = await oauth_provider._refresh_access_token() assert result is True - assert ( - oauth_provider._current_tokens.access_token - == new_token.access_token - ) + assert oauth_provider._current_tokens.access_token == new_token.access_token mock_validate.assert_called_once() @pytest.mark.anyio @@ -563,7 +502,7 @@ async def test_refresh_access_token_no_refresh_token(self, oauth_provider): """Test token refresh with no refresh token.""" oauth_provider._current_tokens = OAuthToken( access_token="test", - token_type="bearer", + token_type="Bearer", # No refresh_token ) @@ -571,9 +510,7 @@ async def test_refresh_access_token_no_refresh_token(self, oauth_provider): assert result is False @pytest.mark.anyio - async def test_refresh_access_token_failure( - self, oauth_provider, oauth_client_info, oauth_token - ): + async def test_refresh_access_token_failure(self, oauth_provider, oauth_client_info, oauth_token): """Test failed token refresh.""" oauth_provider._current_tokens = oauth_token oauth_provider._client_info = oauth_client_info @@ -590,9 +527,7 @@ async def test_refresh_access_token_failure( assert result is False @pytest.mark.anyio - async def test_perform_oauth_flow_success( - self, oauth_provider, oauth_metadata, oauth_client_info - ): + async def test_perform_oauth_flow_success(self, oauth_provider, oauth_metadata, oauth_client_info): """Test successful OAuth flow.""" oauth_provider._metadata = oauth_metadata oauth_provider._client_info = oauth_client_info @@ -636,9 +571,7 @@ async def mock_callback_handler() -> tuple[str, str | None]: mock_exchange.assert_called_once_with("test_auth_code", oauth_client_info) @pytest.mark.anyio - async def test_perform_oauth_flow_state_mismatch( - self, oauth_provider, oauth_metadata, oauth_client_info - ): + async def test_perform_oauth_flow_state_mismatch(self, oauth_provider, oauth_metadata, oauth_client_info): """Test OAuth flow with state parameter mismatch.""" oauth_provider._metadata = oauth_metadata oauth_provider._client_info = oauth_client_info @@ -674,9 +607,7 @@ async def test_ensure_token_refresh(self, oauth_provider, oauth_token): oauth_provider._current_tokens = oauth_token oauth_provider._token_expiry_time = time.time() - 3600 # Expired - with patch.object( - oauth_provider, "_refresh_access_token", return_value=True - ) as mock_refresh: + with patch.object(oauth_provider, "_refresh_access_token", return_value=True) as mock_refresh: await oauth_provider.ensure_token() mock_refresh.assert_called_once() @@ -703,10 +634,7 @@ async def test_async_auth_flow_add_token(self, oauth_provider, oauth_token): auth_flow = oauth_provider.async_auth_flow(request) updated_request = await auth_flow.__anext__() - assert ( - updated_request.headers["Authorization"] - == f"Bearer {oauth_token.access_token}" - ) + assert updated_request.headers["Authorization"] == f"Bearer {oauth_token.access_token}" # Send mock response try: @@ -756,9 +684,8 @@ async def test_async_auth_flow_no_token(self, oauth_provider): # No Authorization header should be added if no token assert "Authorization" not in updated_request.headers - def test_scope_priority_client_metadata_first( - self, oauth_provider, oauth_client_info - ): + @pytest.mark.anyio + async def test_scope_priority_client_metadata_first(self, oauth_provider, oauth_client_info): """Test that client metadata scope takes priority.""" oauth_provider.client_metadata.scope = "read write" oauth_provider._client_info = oauth_client_info @@ -777,17 +704,13 @@ def test_scope_priority_client_metadata_first( # Apply scope logic from _perform_oauth_flow if oauth_provider.client_metadata.scope: auth_params["scope"] = oauth_provider.client_metadata.scope - elif ( - hasattr(oauth_provider._client_info, "scope") - and oauth_provider._client_info.scope - ): + elif hasattr(oauth_provider._client_info, "scope") and oauth_provider._client_info.scope: auth_params["scope"] = oauth_provider._client_info.scope assert auth_params["scope"] == "read write" - def test_scope_priority_no_client_metadata_scope( - self, oauth_provider, oauth_client_info - ): + @pytest.mark.anyio + async def test_scope_priority_no_client_metadata_scope(self, oauth_provider, oauth_client_info): """Test that no scope parameter is set when client metadata has no scope.""" oauth_provider.client_metadata.scope = None oauth_provider._client_info = oauth_client_info @@ -831,10 +754,7 @@ async def test_scope_priority_no_scope(self, oauth_provider, oauth_client_info): # Apply scope logic from _perform_oauth_flow if oauth_provider.client_metadata.scope: auth_params["scope"] = oauth_provider.client_metadata.scope - elif ( - hasattr(oauth_provider._client_info, "scope") - and oauth_provider._client_info.scope - ): + elif hasattr(oauth_provider._client_info, "scope") and oauth_provider._client_info.scope: auth_params["scope"] = oauth_provider._client_info.scope # No scope should be set @@ -860,9 +780,7 @@ async def mock_redirect_handler(url: str) -> None: oauth_provider.redirect_handler = mock_redirect_handler # Patch secrets.compare_digest to verify it's being called - with patch( - "mcp.client.auth.secrets.compare_digest", return_value=False - ) as mock_compare: + with patch("mcp.client.auth.secrets.compare_digest", return_value=False) as mock_compare: with pytest.raises(Exception, match="State parameter mismatch"): await oauth_provider._perform_oauth_flow() @@ -870,9 +788,7 @@ async def mock_redirect_handler(url: str) -> None: mock_compare.assert_called_once() @pytest.mark.anyio - async def test_state_parameter_validation_none_state( - self, oauth_provider, oauth_metadata, oauth_client_info - ): + async def test_state_parameter_validation_none_state(self, oauth_provider, oauth_metadata, oauth_client_info): """Test that None state is handled correctly.""" oauth_provider._metadata = oauth_metadata oauth_provider._client_info = oauth_client_info @@ -907,9 +823,7 @@ async def test_token_exchange_error_basic(self, oauth_provider, oauth_client_inf mock_client.post.return_value = mock_response with pytest.raises(Exception, match="Token exchange failed"): - await oauth_provider._exchange_code_for_token( - "invalid_auth_code", oauth_client_info - ) + await oauth_provider._exchange_code_for_token("invalid_auth_code", oauth_client_info) @pytest.mark.parametrize( @@ -962,9 +876,7 @@ def test_build_metadata( metadata = build_metadata( issuer_url=AnyHttpUrl(issuer_url), service_documentation_url=AnyHttpUrl(service_documentation_url), - client_registration_options=ClientRegistrationOptions( - enabled=True, valid_scopes=["read", "write", "admin"] - ), + client_registration_options=ClientRegistrationOptions(enabled=True, valid_scopes=["read", "write", "admin"]), revocation_options=RevocationOptions(enabled=True), ) diff --git a/tests/client/test_config.py b/tests/client/test_config.py index 69efb4024..f144dcffb 100644 --- a/tests/client/test_config.py +++ b/tests/client/test_config.py @@ -44,9 +44,7 @@ def test_command_execution(mock_config_path: Path): test_args = [command] + args + ["--help"] - result = subprocess.run( - test_args, capture_output=True, text=True, timeout=5, check=False - ) + result = subprocess.run(test_args, capture_output=True, text=True, timeout=5, check=False) assert result.returncode == 0 assert "usage" in result.stdout.lower() diff --git a/tests/client/test_list_methods_cursor.py b/tests/client/test_list_methods_cursor.py index a6df7ec7e..f7b031737 100644 --- a/tests/client/test_list_methods_cursor.py +++ b/tests/client/test_list_methods_cursor.py @@ -182,9 +182,7 @@ async def test_template(name: str) -> str: # Test without cursor parameter (omitted) _ = await client_session.list_resource_templates() - list_templates_requests = spies.get_client_requests( - method="resources/templates/list" - ) + list_templates_requests = spies.get_client_requests(method="resources/templates/list") assert len(list_templates_requests) == 1 assert list_templates_requests[0].params is None @@ -192,9 +190,7 @@ async def test_template(name: str) -> str: # Test with cursor=None _ = await client_session.list_resource_templates(cursor=None) - list_templates_requests = spies.get_client_requests( - method="resources/templates/list" - ) + list_templates_requests = spies.get_client_requests(method="resources/templates/list") assert len(list_templates_requests) == 1 assert list_templates_requests[0].params is None @@ -202,9 +198,7 @@ async def test_template(name: str) -> str: # Test with cursor as string _ = await client_session.list_resource_templates(cursor="some_cursor") - list_templates_requests = spies.get_client_requests( - method="resources/templates/list" - ) + list_templates_requests = spies.get_client_requests(method="resources/templates/list") assert len(list_templates_requests) == 1 assert list_templates_requests[0].params is not None assert list_templates_requests[0].params["cursor"] == "some_cursor" @@ -213,9 +207,7 @@ async def test_template(name: str) -> str: # Test with empty string cursor _ = await client_session.list_resource_templates(cursor="") - list_templates_requests = spies.get_client_requests( - method="resources/templates/list" - ) + list_templates_requests = spies.get_client_requests(method="resources/templates/list") assert len(list_templates_requests) == 1 assert list_templates_requests[0].params is not None assert list_templates_requests[0].params["cursor"] == "" diff --git a/tests/client/test_list_roots_callback.py b/tests/client/test_list_roots_callback.py index f5b598218..f65490421 100644 --- a/tests/client/test_list_roots_callback.py +++ b/tests/client/test_list_roots_callback.py @@ -41,13 +41,9 @@ async def test_list_roots(context: Context, message: str): # type: ignore[repor return True # Test with list_roots callback - async with create_session( - server._mcp_server, list_roots_callback=list_roots_callback - ) as client_session: + async with create_session(server._mcp_server, list_roots_callback=list_roots_callback) as client_session: # Make a request to trigger sampling callback - result = await client_session.call_tool( - "test_list_roots", {"message": "test message"} - ) + result = await client_session.call_tool("test_list_roots", {"message": "test message"}) assert result.isError is False assert isinstance(result.content[0], TextContent) assert result.content[0].text == "true" @@ -55,12 +51,7 @@ async def test_list_roots(context: Context, message: str): # type: ignore[repor # Test without list_roots callback async with create_session(server._mcp_server) as client_session: # Make a request to trigger sampling callback - result = await client_session.call_tool( - "test_list_roots", {"message": "test message"} - ) + result = await client_session.call_tool("test_list_roots", {"message": "test message"}) assert result.isError is True assert isinstance(result.content[0], TextContent) - assert ( - result.content[0].text - == "Error executing tool test_list_roots: List roots not supported" - ) + assert result.content[0].text == "Error executing tool test_list_roots: List roots not supported" diff --git a/tests/client/test_logging_callback.py b/tests/client/test_logging_callback.py index 0c9eeb397..f298ee287 100644 --- a/tests/client/test_logging_callback.py +++ b/tests/client/test_logging_callback.py @@ -49,9 +49,7 @@ async def test_tool_with_log( # Create a message handler to catch exceptions async def message_handler( - message: RequestResponder[types.ServerRequest, types.ClientResult] - | types.ServerNotification - | Exception, + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: if isinstance(message, Exception): raise message diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py index ba586d4a8..a3f6affda 100644 --- a/tests/client/test_sampling_callback.py +++ b/tests/client/test_sampling_callback.py @@ -21,9 +21,7 @@ async def test_sampling_callback(): callback_return = CreateMessageResult( role="assistant", - content=TextContent( - type="text", text="This is a response from the sampling callback" - ), + content=TextContent(type="text", text="This is a response from the sampling callback"), model="test-model", stopReason="endTurn", ) @@ -37,24 +35,16 @@ async def sampling_callback( @server.tool("test_sampling") async def test_sampling_tool(message: str): value = await server.get_context().session.create_message( - messages=[ - SamplingMessage( - role="user", content=TextContent(type="text", text=message) - ) - ], + messages=[SamplingMessage(role="user", content=TextContent(type="text", text=message))], max_tokens=100, ) assert value == callback_return return True # Test with sampling callback - async with create_session( - server._mcp_server, sampling_callback=sampling_callback - ) as client_session: + async with create_session(server._mcp_server, sampling_callback=sampling_callback) as client_session: # Make a request to trigger sampling callback - result = await client_session.call_tool( - "test_sampling", {"message": "Test message for sampling"} - ) + result = await client_session.call_tool("test_sampling", {"message": "Test message for sampling"}) assert result.isError is False assert isinstance(result.content[0], TextContent) assert result.content[0].text == "true" @@ -62,12 +52,7 @@ async def test_sampling_tool(message: str): # Test without sampling callback async with create_session(server._mcp_server) as client_session: # Make a request to trigger sampling callback - result = await client_session.call_tool( - "test_sampling", {"message": "Test message for sampling"} - ) + result = await client_session.call_tool("test_sampling", {"message": "Test message for sampling"}) assert result.isError is True assert isinstance(result.content[0], TextContent) - assert ( - result.content[0].text - == "Error executing tool test_sampling: Sampling not supported" - ) + assert result.content[0].text == "Error executing tool test_sampling: Sampling not supported" diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 72b4413d2..327d1a9e4 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -28,12 +28,8 @@ @pytest.mark.anyio async def test_client_session_initialize(): - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - SessionMessage - ](1) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - SessionMessage - ](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) initialized_notification = None @@ -70,9 +66,7 @@ async def mock_server(): JSONRPCResponse( jsonrpc="2.0", id=jsonrpc_request.root.id, - result=result.model_dump( - by_alias=True, mode="json", exclude_none=True - ), + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) @@ -81,16 +75,12 @@ async def mock_server(): jsonrpc_notification = session_notification.message assert isinstance(jsonrpc_notification.root, JSONRPCNotification) initialized_notification = ClientNotification.model_validate( - jsonrpc_notification.model_dump( - by_alias=True, mode="json", exclude_none=True - ) + jsonrpc_notification.model_dump(by_alias=True, mode="json", exclude_none=True) ) # Create a message handler to catch exceptions async def message_handler( - message: RequestResponder[types.ServerRequest, types.ClientResult] - | types.ServerNotification - | Exception, + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: if isinstance(message, Exception): raise message @@ -124,12 +114,8 @@ async def message_handler( @pytest.mark.anyio async def test_client_session_custom_client_info(): - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - SessionMessage - ](1) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - SessionMessage - ](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) custom_client_info = Implementation(name="test-client", version="1.2.3") received_client_info = None @@ -161,9 +147,7 @@ async def mock_server(): JSONRPCResponse( jsonrpc="2.0", id=jsonrpc_request.root.id, - result=result.model_dump( - by_alias=True, mode="json", exclude_none=True - ), + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) @@ -192,12 +176,8 @@ async def mock_server(): @pytest.mark.anyio async def test_client_session_default_client_info(): - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - SessionMessage - ](1) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - SessionMessage - ](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) received_client_info = None @@ -228,9 +208,7 @@ async def mock_server(): JSONRPCResponse( jsonrpc="2.0", id=jsonrpc_request.root.id, - result=result.model_dump( - by_alias=True, mode="json", exclude_none=True - ), + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) @@ -259,12 +237,8 @@ async def mock_server(): @pytest.mark.anyio async def test_client_session_version_negotiation_success(): """Test successful version negotiation with supported version""" - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - SessionMessage - ](1) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - SessionMessage - ](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) async def mock_server(): session_message = await client_to_server_receive.receive() @@ -294,9 +268,7 @@ async def mock_server(): JSONRPCResponse( jsonrpc="2.0", id=jsonrpc_request.root.id, - result=result.model_dump( - by_alias=True, mode="json", exclude_none=True - ), + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) @@ -327,12 +299,8 @@ async def mock_server(): @pytest.mark.anyio async def test_client_session_version_negotiation_failure(): """Test version negotiation failure with unsupported version""" - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - SessionMessage - ](1) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - SessionMessage - ](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) async def mock_server(): session_message = await client_to_server_receive.receive() @@ -359,9 +327,7 @@ async def mock_server(): JSONRPCResponse( jsonrpc="2.0", id=jsonrpc_request.root.id, - result=result.model_dump( - by_alias=True, mode="json", exclude_none=True - ), + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) @@ -388,12 +354,8 @@ async def mock_server(): @pytest.mark.anyio async def test_client_capabilities_default(): """Test that client capabilities are properly set with default callbacks""" - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - SessionMessage - ](1) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - SessionMessage - ](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) received_capabilities = None @@ -424,9 +386,7 @@ async def mock_server(): JSONRPCResponse( jsonrpc="2.0", id=jsonrpc_request.root.id, - result=result.model_dump( - by_alias=True, mode="json", exclude_none=True - ), + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) @@ -457,12 +417,8 @@ async def mock_server(): @pytest.mark.anyio async def test_client_capabilities_with_custom_callbacks(): """Test that client capabilities are properly set with custom callbacks""" - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - SessionMessage - ](1) - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - SessionMessage - ](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) received_capabilities = None @@ -508,9 +464,7 @@ async def mock_server(): JSONRPCResponse( jsonrpc="2.0", id=jsonrpc_request.root.id, - result=result.model_dump( - by_alias=True, mode="json", exclude_none=True - ), + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) @@ -536,14 +490,8 @@ async def mock_server(): # Assert that capabilities are properly set with custom callbacks assert received_capabilities is not None - assert ( - received_capabilities.sampling is not None - ) # Custom sampling callback provided + assert received_capabilities.sampling is not None # Custom sampling callback provided assert isinstance(received_capabilities.sampling, types.SamplingCapability) - assert ( - received_capabilities.roots is not None - ) # Custom list_roots callback provided + assert received_capabilities.roots is not None # Custom list_roots callback provided assert isinstance(received_capabilities.roots, types.RootsCapability) - assert ( - received_capabilities.roots.listChanged is True - ) # Should be True for custom callback + assert received_capabilities.roots.listChanged is True # Should be True for custom callback diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index 924ef7a06..16a887e00 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -58,14 +58,10 @@ def hook(name, server_info): return f"{(server_info.name)}-{name}" mcp_session_group = ClientSessionGroup(component_name_hook=hook) - mcp_session_group._tools = { - "server1-my_tool": types.Tool(name="my_tool", inputSchema={}) - } + mcp_session_group._tools = {"server1-my_tool": types.Tool(name="my_tool", inputSchema={})} mcp_session_group._tool_to_session = {"server1-my_tool": mock_session} text_content = types.TextContent(type="text", text="OK") - mock_session.call_tool.return_value = types.CallToolResult( - content=[text_content] - ) + mock_session.call_tool.return_value = types.CallToolResult(content=[text_content]) # --- Test Execution --- result = await mcp_session_group.call_tool( @@ -96,16 +92,12 @@ async def test_connect_to_server(self, mock_exit_stack): mock_prompt1 = mock.Mock(spec=types.Prompt) mock_prompt1.name = "prompt_c" mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool1]) - mock_session.list_resources.return_value = mock.AsyncMock( - resources=[mock_resource1] - ) + mock_session.list_resources.return_value = mock.AsyncMock(resources=[mock_resource1]) mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[mock_prompt1]) # --- Test Execution --- group = ClientSessionGroup(exit_stack=mock_exit_stack) - with mock.patch.object( - group, "_establish_session", return_value=(mock_server_info, mock_session) - ): + with mock.patch.object(group, "_establish_session", return_value=(mock_server_info, mock_session)): await group.connect_to_server(StdioServerParameters(command="test")) # --- Assertions --- @@ -141,12 +133,8 @@ def name_hook(name: str, server_info: types.Implementation) -> str: return f"{server_info.name}.{name}" # --- Test Execution --- - group = ClientSessionGroup( - exit_stack=mock_exit_stack, component_name_hook=name_hook - ) - with mock.patch.object( - group, "_establish_session", return_value=(mock_server_info, mock_session) - ): + group = ClientSessionGroup(exit_stack=mock_exit_stack, component_name_hook=name_hook) + with mock.patch.object(group, "_establish_session", return_value=(mock_server_info, mock_session)): await group.connect_to_server(StdioServerParameters(command="test")) # --- Assertions --- @@ -231,9 +219,7 @@ async def test_connect_to_server_duplicate_tool_raises_error(self, mock_exit_sta # Need a dummy session associated with the existing tool mock_session = mock.MagicMock(spec=mcp.ClientSession) group._tool_to_session[existing_tool_name] = mock_session - group._session_exit_stacks[mock_session] = mock.Mock( - spec=contextlib.AsyncExitStack - ) + group._session_exit_stacks[mock_session] = mock.Mock(spec=contextlib.AsyncExitStack) # --- Mock New Connection Attempt --- mock_server_info_new = mock.Mock(spec=types.Implementation) @@ -243,9 +229,7 @@ async def test_connect_to_server_duplicate_tool_raises_error(self, mock_exit_sta # Configure the new session to return a tool with the *same name* duplicate_tool = mock.Mock(spec=types.Tool) duplicate_tool.name = existing_tool_name - mock_session_new.list_tools.return_value = mock.AsyncMock( - tools=[duplicate_tool] - ) + mock_session_new.list_tools.return_value = mock.AsyncMock(tools=[duplicate_tool]) # Keep other lists empty for simplicity mock_session_new.list_resources.return_value = mock.AsyncMock(resources=[]) mock_session_new.list_prompts.return_value = mock.AsyncMock(prompts=[]) @@ -266,9 +250,7 @@ async def test_connect_to_server_duplicate_tool_raises_error(self, mock_exit_sta # Verify the duplicate tool was *not* added again (state should be unchanged) assert len(group._tools) == 1 # Should still only have the original - assert ( - group._tools[existing_tool_name] is not duplicate_tool - ) # Ensure it's the original mock + assert group._tools[existing_tool_name] is not duplicate_tool # Ensure it's the original mock # No patching needed here async def test_disconnect_non_existent_server(self): @@ -292,9 +274,7 @@ async def test_disconnect_non_existent_server(self): "mcp.client.session_group.sse_client", ), # url, headers, timeout, sse_read_timeout ( - StreamableHttpParameters( - url="http://test.com/stream", terminate_on_close=False - ), + StreamableHttpParameters(url="http://test.com/stream", terminate_on_close=False), "streamablehttp", "mcp.client.session_group.streamablehttp_client", ), # url, headers, timeout, sse_read_timeout, terminate_on_close @@ -306,13 +286,9 @@ async def test_establish_session_parameterized( client_type_name, # Just for clarity or conditional logic if needed patch_target_for_client_func, ): - with mock.patch( - "mcp.client.session_group.mcp.ClientSession" - ) as mock_ClientSession_class: + with mock.patch("mcp.client.session_group.mcp.ClientSession") as mock_ClientSession_class: with mock.patch(patch_target_for_client_func) as mock_specific_client_func: - mock_client_cm_instance = mock.AsyncMock( - name=f"{client_type_name}ClientCM" - ) + mock_client_cm_instance = mock.AsyncMock(name=f"{client_type_name}ClientCM") mock_read_stream = mock.AsyncMock(name=f"{client_type_name}Read") mock_write_stream = mock.AsyncMock(name=f"{client_type_name}Write") @@ -344,9 +320,7 @@ async def test_establish_session_parameterized( # Mock session.initialize() mock_initialize_result = mock.AsyncMock(name="InitializeResult") - mock_initialize_result.serverInfo = types.Implementation( - name="foo", version="1" - ) + mock_initialize_result.serverInfo = types.Implementation(name="foo", version="1") mock_entered_session.initialize.return_value = mock_initialize_result # --- Test Execution --- @@ -364,9 +338,7 @@ async def test_establish_session_parameterized( # --- Assertions --- # 1. Assert the correct specific client function was called if client_type_name == "stdio": - mock_specific_client_func.assert_called_once_with( - server_params_instance - ) + mock_specific_client_func.assert_called_once_with(server_params_instance) elif client_type_name == "sse": mock_specific_client_func.assert_called_once_with( url=server_params_instance.url, @@ -386,9 +358,7 @@ async def test_establish_session_parameterized( mock_client_cm_instance.__aenter__.assert_awaited_once() # 2. Assert ClientSession was called correctly - mock_ClientSession_class.assert_called_once_with( - mock_read_stream, mock_write_stream - ) + mock_ClientSession_class.assert_called_once_with(mock_read_stream, mock_write_stream) mock_raw_session_cm.__aenter__.assert_awaited_once() mock_entered_session.initialize.assert_awaited_once() diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index 1c6ffe000..c66a16ab9 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -50,20 +50,14 @@ async def test_stdio_client(): break assert len(read_messages) == 2 - assert read_messages[0] == JSONRPCMessage( - root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") - ) - assert read_messages[1] == JSONRPCMessage( - root=JSONRPCResponse(jsonrpc="2.0", id=2, result={}) - ) + assert read_messages[0] == JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")) + assert read_messages[1] == JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})) @pytest.mark.anyio async def test_stdio_client_bad_path(): """Check that the connection doesn't hang if process errors.""" - server_params = StdioServerParameters( - command="python", args=["-c", "non-existent-file.py"] - ) + server_params = StdioServerParameters(command="python", args=["-c", "non-existent-file.py"]) async with stdio_client(server_params) as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: # The session should raise an error when the connection closes diff --git a/tests/issues/test_100_tool_listing.py b/tests/issues/test_100_tool_listing.py index 2bc386c96..6dccec84d 100644 --- a/tests/issues/test_100_tool_listing.py +++ b/tests/issues/test_100_tool_listing.py @@ -17,9 +17,7 @@ def dummy_tool_func(): f"""Tool number {i}""" return i - globals()[f"dummy_tool_{i}"] = ( - dummy_tool_func # Keep reference to avoid garbage collection - ) + globals()[f"dummy_tool_{i}"] = dummy_tool_func # Keep reference to avoid garbage collection # Get all tools tools = await mcp.list_tools() @@ -30,6 +28,4 @@ def dummy_tool_func(): # Verify each tool is unique and has the correct name tool_names = [tool.name for tool in tools] expected_names = [f"tool_{i}" for i in range(num_tools)] - assert sorted(tool_names) == sorted( - expected_names - ), "Tool names don't match expected names" + assert sorted(tool_names) == sorted(expected_names), "Tool names don't match expected names" diff --git a/tests/issues/test_129_resource_templates.py b/tests/issues/test_129_resource_templates.py index 314952303..4bedb15d5 100644 --- a/tests/issues/test_129_resource_templates.py +++ b/tests/issues/test_129_resource_templates.py @@ -24,9 +24,7 @@ def get_user_profile(user_id: str) -> str: # Note: list_resource_templates() returns a decorator that wraps the handler # The handler returns a ServerResult with a ListResourceTemplatesResult inside result = await mcp._mcp_server.request_handlers[types.ListResourceTemplatesRequest]( - types.ListResourceTemplatesRequest( - method="resources/templates/list", params=None - ) + types.ListResourceTemplatesRequest(method="resources/templates/list", params=None) ) assert isinstance(result.root, types.ListResourceTemplatesResult) templates = result.root.resourceTemplates diff --git a/tests/issues/test_141_resource_templates.py b/tests/issues/test_141_resource_templates.py index 3c17cd559..3145f65e8 100644 --- a/tests/issues/test_141_resource_templates.py +++ b/tests/issues/test_141_resource_templates.py @@ -61,9 +61,7 @@ def get_user_profile_missing(user_id: str) -> str: await mcp.read_resource("resource://users/123/posts") # Missing post_id with pytest.raises(ValueError, match="Unknown resource"): - await mcp.read_resource( - "resource://users/123/posts/456/extra" - ) # Extra path component + await mcp.read_resource("resource://users/123/posts/456/extra") # Extra path component @pytest.mark.anyio @@ -110,11 +108,7 @@ def get_user_profile(user_id: str) -> str: # Verify invalid resource URIs raise appropriate errors with pytest.raises(Exception): # Specific exception type may vary - await session.read_resource( - AnyUrl("resource://users/123/posts") - ) # Missing post_id + await session.read_resource(AnyUrl("resource://users/123/posts")) # Missing post_id with pytest.raises(Exception): # Specific exception type may vary - await session.read_resource( - AnyUrl("resource://users/123/invalid") - ) # Invalid template + await session.read_resource(AnyUrl("resource://users/123/invalid")) # Invalid template diff --git a/tests/issues/test_152_resource_mime_type.py b/tests/issues/test_152_resource_mime_type.py index 1143195e5..a99e5a5c7 100644 --- a/tests/issues/test_152_resource_mime_type.py +++ b/tests/issues/test_152_resource_mime_type.py @@ -45,31 +45,19 @@ def get_image_as_bytes() -> bytes: bytes_resource = mapping["test://image_bytes"] # Verify mime types - assert ( - string_resource.mimeType == "image/png" - ), "String resource mime type not respected" - assert ( - bytes_resource.mimeType == "image/png" - ), "Bytes resource mime type not respected" + assert string_resource.mimeType == "image/png", "String resource mime type not respected" + assert bytes_resource.mimeType == "image/png", "Bytes resource mime type not respected" # Also verify the content can be read correctly string_result = await client.read_resource(AnyUrl("test://image")) assert len(string_result.contents) == 1 - assert ( - getattr(string_result.contents[0], "text") == base64_string - ), "Base64 string mismatch" - assert ( - string_result.contents[0].mimeType == "image/png" - ), "String content mime type not preserved" + assert getattr(string_result.contents[0], "text") == base64_string, "Base64 string mismatch" + assert string_result.contents[0].mimeType == "image/png", "String content mime type not preserved" bytes_result = await client.read_resource(AnyUrl("test://image_bytes")) assert len(bytes_result.contents) == 1 - assert ( - base64.b64decode(getattr(bytes_result.contents[0], "blob")) == image_bytes - ), "Bytes mismatch" - assert ( - bytes_result.contents[0].mimeType == "image/png" - ), "Bytes content mime type not preserved" + assert base64.b64decode(getattr(bytes_result.contents[0], "blob")) == image_bytes, "Bytes mismatch" + assert bytes_result.contents[0].mimeType == "image/png", "Bytes content mime type not preserved" async def test_lowlevel_resource_mime_type(): @@ -82,9 +70,7 @@ async def test_lowlevel_resource_mime_type(): # Create test resources with specific mime types test_resources = [ - types.Resource( - uri=AnyUrl("test://image"), name="test image", mimeType="image/png" - ), + types.Resource(uri=AnyUrl("test://image"), name="test image", mimeType="image/png"), types.Resource( uri=AnyUrl("test://image_bytes"), name="test image bytes", @@ -101,9 +87,7 @@ async def handle_read_resource(uri: AnyUrl): if str(uri) == "test://image": return [ReadResourceContents(content=base64_string, mime_type="image/png")] elif str(uri) == "test://image_bytes": - return [ - ReadResourceContents(content=bytes(image_bytes), mime_type="image/png") - ] + return [ReadResourceContents(content=bytes(image_bytes), mime_type="image/png")] raise Exception(f"Resource not found: {uri}") # Test that resources are listed with correct mime type @@ -119,28 +103,16 @@ async def handle_read_resource(uri: AnyUrl): bytes_resource = mapping["test://image_bytes"] # Verify mime types - assert ( - string_resource.mimeType == "image/png" - ), "String resource mime type not respected" - assert ( - bytes_resource.mimeType == "image/png" - ), "Bytes resource mime type not respected" + assert string_resource.mimeType == "image/png", "String resource mime type not respected" + assert bytes_resource.mimeType == "image/png", "Bytes resource mime type not respected" # Also verify the content can be read correctly string_result = await client.read_resource(AnyUrl("test://image")) assert len(string_result.contents) == 1 - assert ( - getattr(string_result.contents[0], "text") == base64_string - ), "Base64 string mismatch" - assert ( - string_result.contents[0].mimeType == "image/png" - ), "String content mime type not preserved" + assert getattr(string_result.contents[0], "text") == base64_string, "Base64 string mismatch" + assert string_result.contents[0].mimeType == "image/png", "String content mime type not preserved" bytes_result = await client.read_resource(AnyUrl("test://image_bytes")) assert len(bytes_result.contents) == 1 - assert ( - base64.b64decode(getattr(bytes_result.contents[0], "blob")) == image_bytes - ), "Bytes mismatch" - assert ( - bytes_result.contents[0].mimeType == "image/png" - ), "Bytes content mime type not preserved" + assert base64.b64decode(getattr(bytes_result.contents[0], "blob")) == image_bytes, "Bytes mismatch" + assert bytes_result.contents[0].mimeType == "image/png", "Bytes content mime type not preserved" diff --git a/tests/issues/test_176_progress_token.py b/tests/issues/test_176_progress_token.py index 4ad22f294..eb5f19d64 100644 --- a/tests/issues/test_176_progress_token.py +++ b/tests/issues/test_176_progress_token.py @@ -35,15 +35,7 @@ async def test_progress_token_zero_first_call(): await ctx.report_progress(10, 10) # Complete # Verify progress notifications - assert ( - mock_session.send_progress_notification.call_count == 3 - ), "All progress notifications should be sent" - mock_session.send_progress_notification.assert_any_call( - progress_token=0, progress=0.0, total=10.0, message=None - ) - mock_session.send_progress_notification.assert_any_call( - progress_token=0, progress=5.0, total=10.0, message=None - ) - mock_session.send_progress_notification.assert_any_call( - progress_token=0, progress=10.0, total=10.0, message=None - ) + assert mock_session.send_progress_notification.call_count == 3, "All progress notifications should be sent" + mock_session.send_progress_notification.assert_any_call(progress_token=0, progress=0.0, total=10.0, message=None) + mock_session.send_progress_notification.assert_any_call(progress_token=0, progress=5.0, total=10.0, message=None) + mock_session.send_progress_notification.assert_any_call(progress_token=0, progress=10.0, total=10.0, message=None) diff --git a/tests/issues/test_188_concurrency.py b/tests/issues/test_188_concurrency.py index d0a86885f..9ccffefa9 100644 --- a/tests/issues/test_188_concurrency.py +++ b/tests/issues/test_188_concurrency.py @@ -35,7 +35,7 @@ async def slow_resource(): end_time = anyio.current_time() duration = end_time - start_time - assert duration < 6 * _sleep_time_seconds + assert duration < 10 * _sleep_time_seconds print(duration) diff --git a/tests/issues/test_192_request_id.py b/tests/issues/test_192_request_id.py index cf5eb6083..3c63f00b7 100644 --- a/tests/issues/test_192_request_id.py +++ b/tests/issues/test_192_request_id.py @@ -66,9 +66,7 @@ async def run_server(): ) await client_writer.send(SessionMessage(JSONRPCMessage(root=init_req))) - response = ( - await server_reader.receive() - ) # Get init response but don't need to check it + response = await server_reader.receive() # Get init response but don't need to check it # Send initialized notification initialized_notification = JSONRPCNotification( @@ -76,14 +74,10 @@ async def run_server(): params=NotificationParams().model_dump(by_alias=True, exclude_none=True), jsonrpc="2.0", ) - await client_writer.send( - SessionMessage(JSONRPCMessage(root=initialized_notification)) - ) + await client_writer.send(SessionMessage(JSONRPCMessage(root=initialized_notification))) # Send ping request with custom ID - ping_request = JSONRPCRequest( - id=custom_request_id, method="ping", params={}, jsonrpc="2.0" - ) + ping_request = JSONRPCRequest(id=custom_request_id, method="ping", params={}, jsonrpc="2.0") await client_writer.send(SessionMessage(JSONRPCMessage(root=ping_request))) @@ -91,9 +85,7 @@ async def run_server(): response = await server_reader.receive() # Verify response ID matches request ID - assert ( - response.message.root.id == custom_request_id - ), "Response ID should match request ID" + assert response.message.root.id == custom_request_id, "Response ID should match request ID" # Cancel server task tg.cancel_scope.cancel() diff --git a/tests/issues/test_342_base64_encoding.py b/tests/issues/test_342_base64_encoding.py index cff8ec543..6a6e410c7 100644 --- a/tests/issues/test_342_base64_encoding.py +++ b/tests/issues/test_342_base64_encoding.py @@ -47,11 +47,7 @@ async def test_server_base64_encoding_issue(): # Register a resource handler that returns our test data @server.read_resource() async def read_resource(uri: AnyUrl) -> list[ReadResourceContents]: - return [ - ReadResourceContents( - content=binary_data, mime_type="application/octet-stream" - ) - ] + return [ReadResourceContents(content=binary_data, mime_type="application/octet-stream")] # Get the handler directly from the server handler = server.request_handlers[ReadResourceRequest] diff --git a/tests/issues/test_88_random_error.py b/tests/issues/test_88_random_error.py index 88e41d66d..7ba970f0b 100644 --- a/tests/issues/test_88_random_error.py +++ b/tests/issues/test_88_random_error.py @@ -11,11 +11,7 @@ from mcp.client.session import ClientSession from mcp.server.lowlevel import Server from mcp.shared.exceptions import McpError -from mcp.types import ( - EmbeddedResource, - ImageContent, - TextContent, -) +from mcp.types import ContentBlock, TextContent @pytest.mark.anyio @@ -35,9 +31,7 @@ async def test_notification_validation_error(tmp_path: Path): slow_request_complete = anyio.Event() @server.call_tool() - async def slow_tool( - name: str, arg - ) -> Sequence[TextContent | ImageContent | EmbeddedResource]: + async def slow_tool(name: str, arg) -> Sequence[ContentBlock]: nonlocal request_count request_count += 1 @@ -74,9 +68,7 @@ async def client(read_stream, write_stream, scope): # - Long enough for fast operations (>10ms) # - Short enough for slow operations (<200ms) # - Not too short to avoid flakiness - async with ClientSession( - read_stream, write_stream, read_timeout_seconds=timedelta(milliseconds=50) - ) as session: + async with ClientSession(read_stream, write_stream, read_timeout_seconds=timedelta(milliseconds=50)) as session: await session.initialize() # First call should work (fast operation) diff --git a/tests/issues/test_malformed_input.py b/tests/issues/test_malformed_input.py new file mode 100644 index 000000000..97edb651e --- /dev/null +++ b/tests/issues/test_malformed_input.py @@ -0,0 +1,160 @@ +# Claude Debug +"""Test for HackerOne vulnerability report #3156202 - malformed input DOS.""" + +import anyio +import pytest + +from mcp.server.models import InitializationOptions +from mcp.server.session import ServerSession +from mcp.shared.message import SessionMessage +from mcp.types import ( + INVALID_PARAMS, + JSONRPCError, + JSONRPCMessage, + JSONRPCRequest, + ServerCapabilities, +) + + +@pytest.mark.anyio +async def test_malformed_initialize_request_does_not_crash_server(): + """ + Test that malformed initialize requests return proper error responses + instead of crashing the server (HackerOne #3156202). + """ + # Create in-memory streams for testing + read_send_stream, read_receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](10) + write_send_stream, write_receive_stream = anyio.create_memory_object_stream[SessionMessage](10) + + try: + # Create a malformed initialize request (missing required params field) + malformed_request = JSONRPCRequest( + jsonrpc="2.0", + id="f20fe86132ed4cd197f89a7134de5685", + method="initialize", + # params=None # Missing required params field + ) + + # Wrap in session message + request_message = SessionMessage(message=JSONRPCMessage(malformed_request)) + + # Start a server session + async with ServerSession( + read_stream=read_receive_stream, + write_stream=write_send_stream, + init_options=InitializationOptions( + server_name="test_server", + server_version="1.0.0", + capabilities=ServerCapabilities(), + ), + ): + # Send the malformed request + await read_send_stream.send(request_message) + + # Give the session time to process the request + await anyio.sleep(0.1) + + # Check that we received an error response instead of a crash + try: + response_message = write_receive_stream.receive_nowait() + response = response_message.message.root + + # Verify it's a proper JSON-RPC error response + assert isinstance(response, JSONRPCError) + assert response.jsonrpc == "2.0" + assert response.id == "f20fe86132ed4cd197f89a7134de5685" + assert response.error.code == INVALID_PARAMS + assert "Invalid request parameters" in response.error.message + + # Verify the session is still alive and can handle more requests + # Send another malformed request to confirm server stability + another_malformed_request = JSONRPCRequest( + jsonrpc="2.0", + id="test_id_2", + method="tools/call", + # params=None # Missing required params + ) + another_request_message = SessionMessage(message=JSONRPCMessage(another_malformed_request)) + + await read_send_stream.send(another_request_message) + await anyio.sleep(0.1) + + # Should get another error response, not a crash + second_response_message = write_receive_stream.receive_nowait() + second_response = second_response_message.message.root + + assert isinstance(second_response, JSONRPCError) + assert second_response.id == "test_id_2" + assert second_response.error.code == INVALID_PARAMS + + except anyio.WouldBlock: + pytest.fail("No response received - server likely crashed") + finally: + # Close all streams to ensure proper cleanup + await read_send_stream.aclose() + await write_send_stream.aclose() + await read_receive_stream.aclose() + await write_receive_stream.aclose() + + +@pytest.mark.anyio +async def test_multiple_concurrent_malformed_requests(): + """ + Test that multiple concurrent malformed requests don't crash the server. + """ + # Create in-memory streams for testing + read_send_stream, read_receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](100) + write_send_stream, write_receive_stream = anyio.create_memory_object_stream[SessionMessage](100) + + try: + # Start a server session + async with ServerSession( + read_stream=read_receive_stream, + write_stream=write_send_stream, + init_options=InitializationOptions( + server_name="test_server", + server_version="1.0.0", + capabilities=ServerCapabilities(), + ), + ): + # Send multiple malformed requests concurrently + malformed_requests = [] + for i in range(10): + malformed_request = JSONRPCRequest( + jsonrpc="2.0", + id=f"malformed_{i}", + method="initialize", + # params=None # Missing required params + ) + request_message = SessionMessage(message=JSONRPCMessage(malformed_request)) + malformed_requests.append(request_message) + + # Send all requests + for request in malformed_requests: + await read_send_stream.send(request) + + # Give time to process + await anyio.sleep(0.2) + + # Verify we get error responses for all requests + error_responses = [] + try: + while True: + response_message = write_receive_stream.receive_nowait() + error_responses.append(response_message.message.root) + except anyio.WouldBlock: + pass # No more messages + + # Should have received 10 error responses + assert len(error_responses) == 10 + + for i, response in enumerate(error_responses): + assert isinstance(response, JSONRPCError) + assert response.id == f"malformed_{i}" + assert response.error.code == INVALID_PARAMS + finally: + # Close all streams to ensure proper cleanup + await read_send_stream.aclose() + await write_send_stream.aclose() + await read_receive_stream.aclose() + await write_receive_stream.aclose() diff --git a/tests/server/auth/middleware/test_bearer_auth.py b/tests/server/auth/middleware/test_bearer_auth.py index e8c17a4c4..79b813096 100644 --- a/tests/server/auth/middleware/test_bearer_auth.py +++ b/tests/server/auth/middleware/test_bearer_auth.py @@ -116,18 +116,14 @@ def no_expiry_access_token() -> AccessToken: class TestBearerAuthBackend: """Tests for the BearerAuthBackend class.""" - async def test_no_auth_header( - self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any] - ): + async def test_no_auth_header(self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]): """Test authentication with no Authorization header.""" backend = BearerAuthBackend(provider=mock_oauth_provider) request = Request({"type": "http", "headers": []}) result = await backend.authenticate(request) assert result is None - async def test_non_bearer_auth_header( - self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any] - ): + async def test_non_bearer_auth_header(self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]): """Test authentication with non-Bearer Authorization header.""" backend = BearerAuthBackend(provider=mock_oauth_provider) request = Request( @@ -139,9 +135,7 @@ async def test_non_bearer_auth_header( result = await backend.authenticate(request) assert result is None - async def test_invalid_token( - self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any] - ): + async def test_invalid_token(self, mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any]): """Test authentication with invalid token.""" backend = BearerAuthBackend(provider=mock_oauth_provider) request = Request( @@ -160,9 +154,7 @@ async def test_expired_token( ): """Test authentication with expired token.""" backend = BearerAuthBackend(provider=mock_oauth_provider) - add_token_to_provider( - mock_oauth_provider, "expired_token", expired_access_token - ) + add_token_to_provider(mock_oauth_provider, "expired_token", expired_access_token) request = Request( { "type": "http", @@ -203,9 +195,7 @@ async def test_token_without_expiry( ): """Test authentication with token that has no expiry.""" backend = BearerAuthBackend(provider=mock_oauth_provider) - add_token_to_provider( - mock_oauth_provider, "no_expiry_token", no_expiry_access_token - ) + add_token_to_provider(mock_oauth_provider, "no_expiry_token", no_expiry_access_token) request = Request( { "type": "http", diff --git a/tests/server/auth/test_error_handling.py b/tests/server/auth/test_error_handling.py index 18e9933e7..7846c8adb 100644 --- a/tests/server/auth/test_error_handling.py +++ b/tests/server/auth/test_error_handling.py @@ -128,16 +128,12 @@ async def test_registration_error_handling(self, client, oauth_provider): class TestAuthorizeErrorHandling: @pytest.mark.anyio - async def test_authorize_error_handling( - self, client, oauth_provider, registered_client, pkce_challenge - ): + async def test_authorize_error_handling(self, client, oauth_provider, registered_client, pkce_challenge): # Mock the authorize method to raise an authorize error with unittest.mock.patch.object( oauth_provider, "authorize", - side_effect=AuthorizeError( - error="access_denied", error_description="The user denied the request" - ), + side_effect=AuthorizeError(error="access_denied", error_description="The user denied the request"), ): # Register the client client_id = registered_client["client_id"] @@ -169,9 +165,7 @@ async def test_authorize_error_handling( class TestTokenErrorHandling: @pytest.mark.anyio - async def test_token_error_handling_auth_code( - self, client, oauth_provider, registered_client, pkce_challenge - ): + async def test_token_error_handling_auth_code(self, client, oauth_provider, registered_client, pkce_challenge): # Register the client and get an auth code client_id = registered_client["client_id"] client_secret = registered_client["client_secret"] @@ -224,9 +218,7 @@ async def test_token_error_handling_auth_code( assert data["error_description"] == "The authorization code is invalid" @pytest.mark.anyio - async def test_token_error_handling_refresh_token( - self, client, oauth_provider, registered_client, pkce_challenge - ): + async def test_token_error_handling_refresh_token(self, client, oauth_provider, registered_client, pkce_challenge): # Register the client and get tokens client_id = registered_client["client_id"] client_secret = registered_client["client_secret"] diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index d237e860e..5db5d58c2 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -47,9 +47,7 @@ async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: async def register_client(self, client_info: OAuthClientInformationFull): self.clients[client_info.client_id] = client_info - async def authorize( - self, client: OAuthClientInformationFull, params: AuthorizationParams - ) -> str: + async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str: # toy authorize implementation which just immediately generates an authorization # code and completes the redirect code = AuthorizationCode( @@ -63,9 +61,7 @@ async def authorize( ) self.auth_codes[code.code] = code - return construct_redirect_uri( - str(params.redirect_uri), code=code.code, state=params.state - ) + return construct_redirect_uri(str(params.redirect_uri), code=code.code, state=params.state) async def load_authorization_code( self, client: OAuthClientInformationFull, authorization_code: str @@ -96,15 +92,13 @@ async def exchange_authorization_code( return OAuthToken( access_token=access_token, - token_type="bearer", + token_type="Bearer", expires_in=3600, scope="read write", refresh_token=refresh_token, ) - async def load_refresh_token( - self, client: OAuthClientInformationFull, refresh_token: str - ) -> RefreshToken | None: + async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshToken | None: old_access_token = self.refresh_tokens.get(refresh_token) if old_access_token is None: return None @@ -160,7 +154,7 @@ async def exchange_refresh_token( return OAuthToken( access_token=new_access_token, - token_type="bearer", + token_type="Bearer", expires_in=3600, scope=" ".join(scopes) if scopes else " ".join(token_info.scopes), refresh_token=new_refresh_token, @@ -224,9 +218,7 @@ def auth_app(mock_oauth_provider): @pytest.fixture async def test_client(auth_app): - async with httpx.AsyncClient( - transport=httpx.ASGITransport(app=auth_app), base_url="https://mcptest.com" - ) as client: + async with httpx.AsyncClient(transport=httpx.ASGITransport(app=auth_app), base_url="https://mcptest.com") as client: yield client @@ -261,11 +253,7 @@ async def registered_client(test_client: httpx.AsyncClient, request): def pkce_challenge(): """Create a PKCE challenge with code_verifier and code_challenge.""" code_verifier = "some_random_verifier_string" - code_challenge = ( - base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()) - .decode() - .rstrip("=") - ) + code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()).decode().rstrip("=") return {"code_verifier": code_verifier, "code_challenge": code_challenge} @@ -356,17 +344,13 @@ async def test_metadata_endpoint(self, test_client: httpx.AsyncClient): metadata = response.json() assert metadata["issuer"] == "https://auth.example.com/" - assert ( - metadata["authorization_endpoint"] == "https://auth.example.com/authorize" - ) + assert metadata["authorization_endpoint"] == "https://auth.example.com/authorize" assert metadata["token_endpoint"] == "https://auth.example.com/token" assert metadata["registration_endpoint"] == "https://auth.example.com/register" assert metadata["revocation_endpoint"] == "https://auth.example.com/revoke" assert metadata["response_types_supported"] == ["code"] assert metadata["code_challenge_methods_supported"] == ["S256"] - assert metadata["token_endpoint_auth_methods_supported"] == [ - "client_secret_post" - ] + assert metadata["token_endpoint_auth_methods_supported"] == ["client_secret_post"] assert metadata["grant_types_supported"] == [ "authorization_code", "refresh_token", @@ -386,14 +370,10 @@ async def test_token_validation_error(self, test_client: httpx.AsyncClient): ) error_response = response.json() assert error_response["error"] == "invalid_request" - assert ( - "error_description" in error_response - ) # Contains validation error messages + assert "error_description" in error_response # Contains validation error messages @pytest.mark.anyio - async def test_token_invalid_auth_code( - self, test_client, registered_client, pkce_challenge - ): + async def test_token_invalid_auth_code(self, test_client, registered_client, pkce_challenge): """Test token endpoint error - authorization code does not exist.""" # Try to use a non-existent authorization code response = await test_client.post( @@ -413,9 +393,7 @@ async def test_token_invalid_auth_code( assert response.status_code == 400 error_response = response.json() assert error_response["error"] == "invalid_grant" - assert ( - "authorization code does not exist" in error_response["error_description"] - ) + assert "authorization code does not exist" in error_response["error_description"] @pytest.mark.anyio async def test_token_expired_auth_code( @@ -458,9 +436,7 @@ async def test_token_expired_auth_code( assert response.status_code == 400 error_response = response.json() assert error_response["error"] == "invalid_grant" - assert ( - "authorization code has expired" in error_response["error_description"] - ) + assert "authorization code has expired" in error_response["error_description"] @pytest.mark.anyio @pytest.mark.parametrize( @@ -475,9 +451,7 @@ async def test_token_expired_auth_code( ], indirect=True, ) - async def test_token_redirect_uri_mismatch( - self, test_client, registered_client, auth_code, pkce_challenge - ): + async def test_token_redirect_uri_mismatch(self, test_client, registered_client, auth_code, pkce_challenge): """Test token endpoint error - redirect URI mismatch.""" # Try to use the code with a different redirect URI response = await test_client.post( @@ -498,9 +472,7 @@ async def test_token_redirect_uri_mismatch( assert "redirect_uri did not match" in error_response["error_description"] @pytest.mark.anyio - async def test_token_code_verifier_mismatch( - self, test_client, registered_client, auth_code - ): + async def test_token_code_verifier_mismatch(self, test_client, registered_client, auth_code): """Test token endpoint error - PKCE code verifier mismatch.""" # Try to use the code with an incorrect code verifier response = await test_client.post( @@ -569,9 +541,7 @@ async def test_token_expired_refresh_token( # Step 2: Time travel forward 4 hours (tokens expire in 1 hour by default) # Mock the time.time() function to return a value 4 hours in the future - with unittest.mock.patch( - "time.time", return_value=current_time + 14400 - ): # 4 hours = 14400 seconds + with unittest.mock.patch("time.time", return_value=current_time + 14400): # 4 hours = 14400 seconds # Try to use the refresh token which should now be considered expired response = await test_client.post( "/token", @@ -590,9 +560,7 @@ async def test_token_expired_refresh_token( assert "refresh token has expired" in error_response["error_description"] @pytest.mark.anyio - async def test_token_invalid_scope( - self, test_client, registered_client, auth_code, pkce_challenge - ): + async def test_token_invalid_scope(self, test_client, registered_client, auth_code, pkce_challenge): """Test token endpoint error - invalid scope in refresh token request.""" # Exchange authorization code for tokens token_response = await test_client.post( @@ -628,9 +596,7 @@ async def test_token_invalid_scope( assert "cannot request scope" in error_response["error_description"] @pytest.mark.anyio - async def test_client_registration( - self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider - ): + async def test_client_registration(self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider): """Test client registration.""" client_metadata = { "redirect_uris": ["https://client.example.com/callback"], @@ -656,9 +622,7 @@ async def test_client_registration( # ) is not None @pytest.mark.anyio - async def test_client_registration_missing_required_fields( - self, test_client: httpx.AsyncClient - ): + async def test_client_registration_missing_required_fields(self, test_client: httpx.AsyncClient): """Test client registration with missing required fields.""" # Missing redirect_uris which is a required field client_metadata = { @@ -677,9 +641,7 @@ async def test_client_registration_missing_required_fields( assert error_data["error_description"] == "redirect_uris: Field required" @pytest.mark.anyio - async def test_client_registration_invalid_uri( - self, test_client: httpx.AsyncClient - ): + async def test_client_registration_invalid_uri(self, test_client: httpx.AsyncClient): """Test client registration with invalid URIs.""" # Invalid redirect_uri format client_metadata = { @@ -696,14 +658,11 @@ async def test_client_registration_invalid_uri( assert "error" in error_data assert error_data["error"] == "invalid_client_metadata" assert error_data["error_description"] == ( - "redirect_uris.0: Input should be a valid URL, " - "relative URL without a base" + "redirect_uris.0: Input should be a valid URL, " "relative URL without a base" ) @pytest.mark.anyio - async def test_client_registration_empty_redirect_uris( - self, test_client: httpx.AsyncClient - ): + async def test_client_registration_empty_redirect_uris(self, test_client: httpx.AsyncClient): """Test client registration with empty redirect_uris array.""" client_metadata = { "redirect_uris": [], # Empty array @@ -719,8 +678,7 @@ async def test_client_registration_empty_redirect_uris( assert "error" in error_data assert error_data["error"] == "invalid_client_metadata" assert ( - error_data["error_description"] - == "redirect_uris: List should have at least 1 item after validation, not 0" + error_data["error_description"] == "redirect_uris: List should have at least 1 item after validation, not 0" ) @pytest.mark.anyio @@ -831,7 +789,7 @@ async def test_authorization_get( assert "token_type" in token_response assert "refresh_token" in token_response assert "expires_in" in token_response - assert token_response["token_type"] == "bearer" + assert token_response["token_type"] == "Bearer" # 5. Verify the access token access_token = token_response["access_token"] @@ -875,12 +833,7 @@ async def test_authorization_get( assert response.status_code == 200 # Verify that the token was revoked - assert ( - await mock_oauth_provider.load_access_token( - new_token_response["access_token"] - ) - is None - ) + assert await mock_oauth_provider.load_access_token(new_token_response["access_token"]) is None @pytest.mark.anyio async def test_revoke_invalid_token(self, test_client, registered_client): @@ -913,9 +866,7 @@ async def test_revoke_with_malformed_token(self, test_client, registered_client) assert "token_type_hint" in error_response["error_description"] @pytest.mark.anyio - async def test_client_registration_disallowed_scopes( - self, test_client: httpx.AsyncClient - ): + async def test_client_registration_disallowed_scopes(self, test_client: httpx.AsyncClient): """Test client registration with scopes that are not allowed.""" client_metadata = { "redirect_uris": ["https://client.example.com/callback"], @@ -955,18 +906,14 @@ async def test_client_registration_default_scopes( assert client_info["scope"] == "read write" # Retrieve the client from the store to verify default scopes - registered_client = await mock_oauth_provider.get_client( - client_info["client_id"] - ) + registered_client = await mock_oauth_provider.get_client(client_info["client_id"]) assert registered_client is not None # Check that default scopes were applied assert registered_client.scope == "read write" @pytest.mark.anyio - async def test_client_registration_invalid_grant_type( - self, test_client: httpx.AsyncClient - ): + async def test_client_registration_invalid_grant_type(self, test_client: httpx.AsyncClient): client_metadata = { "redirect_uris": ["https://client.example.com/callback"], "client_name": "Test Client", @@ -981,19 +928,14 @@ async def test_client_registration_invalid_grant_type( error_data = response.json() assert "error" in error_data assert error_data["error"] == "invalid_client_metadata" - assert ( - error_data["error_description"] - == "grant_types must be authorization_code and refresh_token" - ) + assert error_data["error_description"] == "grant_types must be authorization_code and refresh_token" class TestAuthorizeEndpointErrors: """Test error handling in the OAuth authorization endpoint.""" @pytest.mark.anyio - async def test_authorize_missing_client_id( - self, test_client: httpx.AsyncClient, pkce_challenge - ): + async def test_authorize_missing_client_id(self, test_client: httpx.AsyncClient, pkce_challenge): """Test authorization endpoint with missing client_id. According to the OAuth2.0 spec, if client_id is missing, the server should @@ -1017,9 +959,7 @@ async def test_authorize_missing_client_id( assert "client_id" in response.text.lower() @pytest.mark.anyio - async def test_authorize_invalid_client_id( - self, test_client: httpx.AsyncClient, pkce_challenge - ): + async def test_authorize_invalid_client_id(self, test_client: httpx.AsyncClient, pkce_challenge): """Test authorization endpoint with invalid client_id. According to the OAuth2.0 spec, if client_id is invalid, the server should @@ -1202,9 +1142,7 @@ async def test_authorize_missing_response_type( assert query_params["state"][0] == "test_state" @pytest.mark.anyio - async def test_authorize_missing_pkce_challenge( - self, test_client: httpx.AsyncClient, registered_client - ): + async def test_authorize_missing_pkce_challenge(self, test_client: httpx.AsyncClient, registered_client): """Test authorization endpoint with missing PKCE code_challenge. Missing PKCE parameters should result in invalid_request error. @@ -1233,9 +1171,7 @@ async def test_authorize_missing_pkce_challenge( assert query_params["state"][0] == "test_state" @pytest.mark.anyio - async def test_authorize_invalid_scope( - self, test_client: httpx.AsyncClient, registered_client, pkce_challenge - ): + async def test_authorize_invalid_scope(self, test_client: httpx.AsyncClient, registered_client, pkce_challenge): """Test authorization endpoint with invalid scope. Invalid scope should redirect with invalid_scope error. diff --git a/tests/server/fastmcp/prompts/test_base.py b/tests/server/fastmcp/prompts/test_base.py index c4af044a6..5b7b50e63 100644 --- a/tests/server/fastmcp/prompts/test_base.py +++ b/tests/server/fastmcp/prompts/test_base.py @@ -18,9 +18,7 @@ def fn() -> str: return "Hello, world!" prompt = Prompt.from_function(fn) - assert await prompt.render() == [ - UserMessage(content=TextContent(type="text", text="Hello, world!")) - ] + assert await prompt.render() == [UserMessage(content=TextContent(type="text", text="Hello, world!"))] @pytest.mark.anyio async def test_async_fn(self): @@ -28,9 +26,7 @@ async def fn() -> str: return "Hello, world!" prompt = Prompt.from_function(fn) - assert await prompt.render() == [ - UserMessage(content=TextContent(type="text", text="Hello, world!")) - ] + assert await prompt.render() == [UserMessage(content=TextContent(type="text", text="Hello, world!"))] @pytest.mark.anyio async def test_fn_with_args(self): @@ -39,11 +35,7 @@ async def fn(name: str, age: int = 30) -> str: prompt = Prompt.from_function(fn) assert await prompt.render(arguments={"name": "World"}) == [ - UserMessage( - content=TextContent( - type="text", text="Hello, World! You're 30 years old." - ) - ) + UserMessage(content=TextContent(type="text", text="Hello, World! You're 30 years old.")) ] @pytest.mark.anyio @@ -61,21 +53,15 @@ async def fn() -> UserMessage: return UserMessage(content="Hello, world!") prompt = Prompt.from_function(fn) - assert await prompt.render() == [ - UserMessage(content=TextContent(type="text", text="Hello, world!")) - ] + assert await prompt.render() == [UserMessage(content=TextContent(type="text", text="Hello, world!"))] @pytest.mark.anyio async def test_fn_returns_assistant_message(self): async def fn() -> AssistantMessage: - return AssistantMessage( - content=TextContent(type="text", text="Hello, world!") - ) + return AssistantMessage(content=TextContent(type="text", text="Hello, world!")) prompt = Prompt.from_function(fn) - assert await prompt.render() == [ - AssistantMessage(content=TextContent(type="text", text="Hello, world!")) - ] + assert await prompt.render() == [AssistantMessage(content=TextContent(type="text", text="Hello, world!"))] @pytest.mark.anyio async def test_fn_returns_multiple_messages(self): @@ -156,9 +142,7 @@ async def fn() -> list[Message]: prompt = Prompt.from_function(fn) assert await prompt.render() == [ - UserMessage( - content=TextContent(type="text", text="Please analyze this file:") - ), + UserMessage(content=TextContent(type="text", text="Please analyze this file:")), UserMessage( content=EmbeddedResource( type="resource", @@ -169,9 +153,7 @@ async def fn() -> list[Message]: ), ) ), - AssistantMessage( - content=TextContent(type="text", text="I'll help analyze that file.") - ), + AssistantMessage(content=TextContent(type="text", text="I'll help analyze that file.")), ] @pytest.mark.anyio diff --git a/tests/server/fastmcp/prompts/test_manager.py b/tests/server/fastmcp/prompts/test_manager.py index c64a4a564..82b234638 100644 --- a/tests/server/fastmcp/prompts/test_manager.py +++ b/tests/server/fastmcp/prompts/test_manager.py @@ -72,9 +72,7 @@ def fn() -> str: prompt = Prompt.from_function(fn) manager.add_prompt(prompt) messages = await manager.render_prompt("fn") - assert messages == [ - UserMessage(content=TextContent(type="text", text="Hello, world!")) - ] + assert messages == [UserMessage(content=TextContent(type="text", text="Hello, world!"))] @pytest.mark.anyio async def test_render_prompt_with_args(self): @@ -87,9 +85,7 @@ def fn(name: str) -> str: prompt = Prompt.from_function(fn) manager.add_prompt(prompt) messages = await manager.render_prompt("fn", arguments={"name": "World"}) - assert messages == [ - UserMessage(content=TextContent(type="text", text="Hello, World!")) - ] + assert messages == [UserMessage(content=TextContent(type="text", text="Hello, World!"))] @pytest.mark.anyio async def test_render_unknown_prompt(self): diff --git a/tests/server/fastmcp/resources/test_file_resources.py b/tests/server/fastmcp/resources/test_file_resources.py index 36cbca32c..ec3c85d8d 100644 --- a/tests/server/fastmcp/resources/test_file_resources.py +++ b/tests/server/fastmcp/resources/test_file_resources.py @@ -100,9 +100,7 @@ async def test_missing_file_error(self, temp_file: Path): with pytest.raises(ValueError, match="Error reading file"): await resource.read() - @pytest.mark.skipif( - os.name == "nt", reason="File permissions behave differently on Windows" - ) + @pytest.mark.skipif(os.name == "nt", reason="File permissions behave differently on Windows") @pytest.mark.anyio async def test_permission_error(self, temp_file: Path): """Test reading a file without permissions.""" diff --git a/tests/server/fastmcp/test_elicitation.py b/tests/server/fastmcp/test_elicitation.py new file mode 100644 index 000000000..20937d91d --- /dev/null +++ b/tests/server/fastmcp/test_elicitation.py @@ -0,0 +1,210 @@ +""" +Test the elicitation feature using stdio transport. +""" + +import pytest +from pydantic import BaseModel, Field + +from mcp.server.fastmcp import Context, FastMCP +from mcp.shared.memory import create_connected_server_and_client_session +from mcp.types import ElicitResult, TextContent + + +# Shared schema for basic tests +class AnswerSchema(BaseModel): + answer: str = Field(description="The user's answer to the question") + + +def create_ask_user_tool(mcp: FastMCP): + """Create a standard ask_user tool that handles all elicitation responses.""" + + @mcp.tool(description="A tool that uses elicitation") + async def ask_user(prompt: str, ctx: Context) -> str: + result = await ctx.elicit( + message=f"Tool wants to ask: {prompt}", + schema=AnswerSchema, + ) + + if result.action == "accept" and result.data: + return f"User answered: {result.data.answer}" + elif result.action == "decline": + return "User declined to answer" + else: + return "User cancelled" + + return ask_user + + +async def call_tool_and_assert( + mcp: FastMCP, + elicitation_callback, + tool_name: str, + args: dict, + expected_text: str | None = None, + text_contains: list[str] | None = None, +): + """Helper to create session, call tool, and assert result.""" + async with create_connected_server_and_client_session( + mcp._mcp_server, elicitation_callback=elicitation_callback + ) as client_session: + await client_session.initialize() + + result = await client_session.call_tool(tool_name, args) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + + if expected_text is not None: + assert result.content[0].text == expected_text + elif text_contains is not None: + for substring in text_contains: + assert substring in result.content[0].text + + return result + + +@pytest.mark.anyio +async def test_stdio_elicitation(): + """Test the elicitation feature using stdio transport.""" + mcp = FastMCP(name="StdioElicitationServer") + create_ask_user_tool(mcp) + + # Create a custom handler for elicitation requests + async def elicitation_callback(context, params): + if params.message == "Tool wants to ask: What is your name?": + return ElicitResult(action="accept", content={"answer": "Test User"}) + else: + raise ValueError(f"Unexpected elicitation message: {params.message}") + + await call_tool_and_assert( + mcp, elicitation_callback, "ask_user", {"prompt": "What is your name?"}, "User answered: Test User" + ) + + +@pytest.mark.anyio +async def test_stdio_elicitation_decline(): + """Test elicitation with user declining.""" + mcp = FastMCP(name="StdioElicitationDeclineServer") + create_ask_user_tool(mcp) + + async def elicitation_callback(context, params): + return ElicitResult(action="decline") + + await call_tool_and_assert( + mcp, elicitation_callback, "ask_user", {"prompt": "What is your name?"}, "User declined to answer" + ) + + +@pytest.mark.anyio +async def test_elicitation_schema_validation(): + """Test that elicitation schemas must only contain primitive types.""" + mcp = FastMCP(name="ValidationTestServer") + + def create_validation_tool(name: str, schema_class: type[BaseModel]): + @mcp.tool(name=name, description=f"Tool testing {name}") + async def tool(ctx: Context) -> str: + try: + await ctx.elicit(message="This should fail validation", schema=schema_class) + return "Should not reach here" + except TypeError as e: + return f"Validation failed as expected: {str(e)}" + + return tool + + # Test cases for invalid schemas + class InvalidListSchema(BaseModel): + names: list[str] = Field(description="List of names") + + class NestedModel(BaseModel): + value: str + + class InvalidNestedSchema(BaseModel): + nested: NestedModel = Field(description="Nested model") + + create_validation_tool("invalid_list", InvalidListSchema) + create_validation_tool("nested_model", InvalidNestedSchema) + + # Dummy callback (won't be called due to validation failure) + async def elicitation_callback(context, params): + return ElicitResult(action="accept", content={}) + + async with create_connected_server_and_client_session( + mcp._mcp_server, elicitation_callback=elicitation_callback + ) as client_session: + await client_session.initialize() + + # Test both invalid schemas + for tool_name, field_name in [("invalid_list", "names"), ("nested_model", "nested")]: + result = await client_session.call_tool(tool_name, {}) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert "Validation failed as expected" in result.content[0].text + assert field_name in result.content[0].text + + +@pytest.mark.anyio +async def test_elicitation_with_optional_fields(): + """Test that Optional fields work correctly in elicitation schemas.""" + mcp = FastMCP(name="OptionalFieldServer") + + class OptionalSchema(BaseModel): + required_name: str = Field(description="Your name (required)") + optional_age: int | None = Field(default=None, description="Your age (optional)") + optional_email: str | None = Field(default=None, description="Your email (optional)") + subscribe: bool | None = Field(default=False, description="Subscribe to newsletter?") + + @mcp.tool(description="Tool with optional fields") + async def optional_tool(ctx: Context) -> str: + result = await ctx.elicit(message="Please provide your information", schema=OptionalSchema) + + if result.action == "accept" and result.data: + info = [f"Name: {result.data.required_name}"] + if result.data.optional_age is not None: + info.append(f"Age: {result.data.optional_age}") + if result.data.optional_email is not None: + info.append(f"Email: {result.data.optional_email}") + info.append(f"Subscribe: {result.data.subscribe}") + return ", ".join(info) + else: + return f"User {result.action}" + + # Test cases with different field combinations + test_cases = [ + ( + # All fields provided + {"required_name": "John Doe", "optional_age": 30, "optional_email": "john@example.com", "subscribe": True}, + "Name: John Doe, Age: 30, Email: john@example.com, Subscribe: True", + ), + ( + # Only required fields + {"required_name": "Jane Smith"}, + "Name: Jane Smith, Subscribe: False", + ), + ] + + for content, expected in test_cases: + + async def callback(context, params): + return ElicitResult(action="accept", content=content) + + await call_tool_and_assert(mcp, callback, "optional_tool", {}, expected) + + # Test invalid optional field + class InvalidOptionalSchema(BaseModel): + name: str = Field(description="Name") + optional_list: list[str] | None = Field(default=None, description="Invalid optional list") + + @mcp.tool(description="Tool with invalid optional field") + async def invalid_optional_tool(ctx: Context) -> str: + try: + await ctx.elicit(message="This should fail", schema=InvalidOptionalSchema) + return "Should not reach here" + except TypeError as e: + return f"Validation failed: {str(e)}" + + await call_tool_and_assert( + mcp, + lambda c, p: ElicitResult(action="accept", content={}), + "invalid_optional_tool", + {}, + text_contains=["Validation failed:", "optional_list"], + ) diff --git a/tests/server/fastmcp/test_func_metadata.py b/tests/server/fastmcp/test_func_metadata.py index b1828ffe9..b13685e88 100644 --- a/tests/server/fastmcp/test_func_metadata.py +++ b/tests/server/fastmcp/test_func_metadata.py @@ -28,9 +28,7 @@ def complex_arguments_fn( # list[str] | str is an interesting case because if it comes in as JSON like # "[\"a\", \"b\"]" then it will be naively parsed as a string. list_str_or_str: list[str] | str, - an_int_annotated_with_field: Annotated[ - int, Field(description="An int with a field") - ], + an_int_annotated_with_field: Annotated[int, Field(description="An int with a field")], an_int_annotated_with_field_and_others: Annotated[ int, str, # Should be ignored, really @@ -42,9 +40,7 @@ def complex_arguments_fn( "123", 456, ], - field_with_default_via_field_annotation_before_nondefault_arg: Annotated[ - int, Field(1) - ], + field_with_default_via_field_annotation_before_nondefault_arg: Annotated[int, Field(1)], unannotated, my_model_a: SomeInputModelA, my_model_a_forward_ref: "SomeInputModelA", @@ -179,9 +175,7 @@ def func_with_str_types(str_or_list: str | list[str]): def test_skip_names(): """Test that skipped parameters are not included in the model""" - def func_with_many_params( - keep_this: int, skip_this: str, also_keep: float, also_skip: bool - ): + def func_with_many_params(keep_this: int, skip_this: str, also_keep: float, also_skip: bool): return keep_this, skip_this, also_keep, also_skip # Skip some parameters diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index 121492bc6..526201f9a 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -14,26 +14,38 @@ import pytest import uvicorn -from pydantic import AnyUrl +from pydantic import AnyUrl, BaseModel, Field from starlette.applications import Starlette from starlette.requests import Request -import mcp.types as types from mcp.client.session import ClientSession from mcp.client.sse import sse_client from mcp.client.streamable_http import streamablehttp_client -from mcp.server.fastmcp import FastMCP +from mcp.server.fastmcp import Context, FastMCP from mcp.server.fastmcp.resources import FunctionResource +from mcp.server.transport_security import TransportSecuritySettings from mcp.shared.context import RequestContext from mcp.types import ( + Completion, + CompletionArgument, + CompletionContext, CreateMessageRequestParams, CreateMessageResult, + ElicitResult, GetPromptResult, InitializeResult, + LoggingMessageNotification, + ProgressNotification, + PromptReference, ReadResourceResult, + ResourceLink, + ResourceListChangedNotification, + ResourceTemplateReference, SamplingMessage, + ServerNotification, TextContent, TextResourceContents, + ToolListChangedNotification, ) @@ -82,13 +94,30 @@ def stateless_http_server_url(stateless_http_server_port: int) -> str: # Create a function to make the FastMCP server app def make_fastmcp_app(): """Create a FastMCP server without auth settings.""" - mcp = FastMCP(name="NoAuthServer") + transport_security = TransportSecuritySettings( + allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] + ) + mcp = FastMCP(name="NoAuthServer", transport_security=transport_security) # Add a simple tool @mcp.tool(description="A simple echo tool") def echo(message: str) -> str: return f"Echo: {message}" + # Add a tool that uses elicitation + @mcp.tool(description="A tool that uses elicitation") + async def ask_user(prompt: str, ctx: Context) -> str: + class AnswerSchema(BaseModel): + answer: str = Field(description="The user's answer to the question") + + result = await ctx.elicit(message=f"Tool wants to ask: {prompt}", schema=AnswerSchema) + + if result.action == "accept" and result.data: + return f"User answered: {result.data.answer}" + else: + # Handle cancellation or decline + return f"User cancelled or declined: {result.action}" + # Create the SSE app app = mcp.sse_app() @@ -97,12 +126,13 @@ def echo(message: str) -> str: def make_everything_fastmcp() -> FastMCP: """Create a FastMCP server with all features enabled for testing.""" - from mcp.server.fastmcp import Context - - mcp = FastMCP(name="EverythingServer") + transport_security = TransportSecuritySettings( + allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] + ) + mcp = FastMCP(name="EverythingServer", transport_security=transport_security) # Tool with context for logging and progress - @mcp.tool(description="A tool that demonstrates logging and progress") + @mcp.tool(description="A tool that demonstrates logging and progress", title="Progress Tool") async def tool_with_progress(message: str, ctx: Context, steps: int = 3) -> str: await ctx.info(f"Starting processing of '{message}' with {steps} steps") @@ -119,22 +149,37 @@ async def tool_with_progress(message: str, ctx: Context, steps: int = 3) -> str: return f"Processed '{message}' in {steps} steps" # Simple tool for basic functionality - @mcp.tool(description="A simple echo tool") + @mcp.tool(description="A simple echo tool", title="Echo Tool") def echo(message: str) -> str: return f"Echo: {message}" + # Tool that returns ResourceLinks + @mcp.tool(description="Lists files and returns resource links", title="List Files Tool") + def list_files() -> list[ResourceLink]: + """Returns a list of resource links for files matching the pattern.""" + + # Mock some file resources for testing + file_resources = [ + { + "type": "resource_link", + "uri": "file:///project/README.md", + "name": "README.md", + "mimeType": "text/markdown", + } + ] + + result: list[ResourceLink] = [ResourceLink.model_validate(file_json) for file_json in file_resources] + + return result + # Tool with sampling capability - @mcp.tool(description="A tool that uses sampling to generate content") + @mcp.tool(description="A tool that uses sampling to generate content", title="Sampling Tool") async def sampling_tool(prompt: str, ctx: Context) -> str: await ctx.info(f"Requesting sampling for prompt: {prompt}") # Request sampling from the client result = await ctx.session.create_message( - messages=[ - SamplingMessage( - role="user", content=TextContent(type="text", text=prompt) - ) - ], + messages=[SamplingMessage(role="user", content=TextContent(type="text", text=prompt))], max_tokens=100, temperature=0.7, ) @@ -147,7 +192,7 @@ async def sampling_tool(prompt: str, ctx: Context) -> str: return f"Sampling result: {str(result.content)[:100]}..." # Tool that sends notifications and logging - @mcp.tool(description="A tool that demonstrates notifications and logging") + @mcp.tool(description="A tool that demonstrates notifications and logging", title="Notification Tool") async def notification_tool(message: str, ctx: Context) -> str: # Send different log levels await ctx.debug("Debug: Starting notification tool") @@ -168,35 +213,70 @@ def get_static_info() -> str: static_resource = FunctionResource( uri=AnyUrl("resource://static/info"), name="Static Info", + title="Static Information", description="Static information resource", fn=get_static_info, ) mcp.add_resource(static_resource) # Resource - dynamic function - @mcp.resource("resource://dynamic/{category}") + @mcp.resource("resource://dynamic/{category}", title="Dynamic Resource") def dynamic_resource(category: str) -> str: return f"Dynamic resource content for category: {category}" # Resource template - @mcp.resource("resource://template/{id}/data") + @mcp.resource("resource://template/{id}/data", title="Template Resource") def template_resource(id: str) -> str: return f"Template resource data for ID: {id}" # Prompt - simple - @mcp.prompt(description="A simple prompt") + @mcp.prompt(description="A simple prompt", title="Simple Prompt") def simple_prompt(topic: str) -> str: return f"Tell me about {topic}" # Prompt - complex with multiple messages - @mcp.prompt(description="Complex prompt with context") + @mcp.prompt(description="Complex prompt with context", title="Complex Prompt") def complex_prompt(user_query: str, context: str = "general") -> str: # For simplicity, return a single string that incorporates the context # Since FastMCP doesn't support system messages in the same way return f"Context: {context}. Query: {user_query}" + # Resource template with completion support + @mcp.resource("github://repos/{owner}/{repo}", title="GitHub Repository") + def github_repo_resource(owner: str, repo: str) -> str: + return f"Repository: {owner}/{repo}" + + # Add completion handler for the server + @mcp.completion() + async def handle_completion( + ref: PromptReference | ResourceTemplateReference, + argument: CompletionArgument, + context: CompletionContext | None, + ) -> Completion | None: + # Handle GitHub repository completion + if isinstance(ref, ResourceTemplateReference): + if ref.uri == "github://repos/{owner}/{repo}" and argument.name == "repo": + if context and context.arguments and context.arguments.get("owner") == "modelcontextprotocol": + # Return repos for modelcontextprotocol org + return Completion(values=["python-sdk", "typescript-sdk", "specification"], total=3, hasMore=False) + elif context and context.arguments and context.arguments.get("owner") == "test-org": + # Return repos for test-org + return Completion(values=["test-repo1", "test-repo2"], total=2, hasMore=False) + + # Handle prompt completions + if isinstance(ref, PromptReference): + if ref.name == "complex_prompt" and argument.name == "context": + # Complete context values + contexts = ["general", "technical", "business", "academic"] + return Completion( + values=[c for c in contexts if c.startswith(argument.value)], total=None, hasMore=False + ) + + # Default: no completion available + return Completion(values=[], total=0, hasMore=False) + # Tool that echoes request headers from context - @mcp.tool(description="Echo request headers from context") + @mcp.tool(description="Echo request headers from context", title="Echo Headers") def echo_headers(ctx: Context[Any, Any, Request]) -> str: """Returns the request headers as JSON.""" headers_info = {} @@ -206,7 +286,7 @@ def echo_headers(ctx: Context[Any, Any, Request]) -> str: return json.dumps(headers_info) # Tool that returns full request context - @mcp.tool(description="Echo request context with custom data") + @mcp.tool(description="Echo request context with custom data", title="Echo Context") def echo_context(custom_request_id: str, ctx: Context[Any, Any, Request]) -> str: """Returns request context including headers and custom data.""" context_data = { @@ -222,6 +302,49 @@ def echo_context(custom_request_id: str, ctx: Context[Any, Any, Request]) -> str context_data["path"] = request.url.path return json.dumps(context_data) + # Restaurant booking tool with elicitation + @mcp.tool(description="Book a table at a restaurant with elicitation", title="Restaurant Booking") + async def book_restaurant( + date: str, + time: str, + party_size: int, + ctx: Context, + ) -> str: + """Book a table - uses elicitation if requested date is unavailable.""" + + class AlternativeDateSchema(BaseModel): + checkAlternative: bool = Field(description="Would you like to try another date?") + alternativeDate: str = Field( + default="2024-12-26", + description="What date would you prefer? (YYYY-MM-DD)", + ) + + # For testing: assume dates starting with "2024-12-25" are unavailable + if date.startswith("2024-12-25"): + # Use elicitation to ask about alternatives + result = await ctx.elicit( + message=( + f"No tables available for {party_size} people on {date} " + f"at {time}. Would you like to check another date?" + ), + schema=AlternativeDateSchema, + ) + + if result.action == "accept" and result.data: + if result.data.checkAlternative: + alt_date = result.data.alternativeDate + return f"βœ… Booked table for {party_size} on {alt_date} at {time}" + else: + return "❌ No booking made" + elif result.action in ("decline", "cancel"): + return "❌ Booking cancelled" + else: + # Handle case where action is "accept" but data is None + return "❌ No booking data received" + else: + # Available - book directly + return f"βœ… Booked table for {party_size} on {date} at {time}" + return mcp @@ -235,8 +358,10 @@ def make_everything_fastmcp_app(): def make_fastmcp_streamable_http_app(): """Create a FastMCP server with StreamableHTTP transport.""" - - mcp = FastMCP(name="NoAuthServer") + transport_security = TransportSecuritySettings( + allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] + ) + mcp = FastMCP(name="NoAuthServer", transport_security=transport_security) # Add a simple tool @mcp.tool(description="A simple echo tool") @@ -261,8 +386,10 @@ def make_everything_fastmcp_streamable_http_app(): def make_fastmcp_stateless_http_app(): """Create a FastMCP server with stateless StreamableHTTP transport.""" - - mcp = FastMCP(name="StatelessServer", stateless_http=True) + transport_security = TransportSecuritySettings( + allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] + ) + mcp = FastMCP(name="StatelessServer", stateless_http=True, transport_security=transport_security) # Add a simple tool @mcp.tool(description="A simple echo tool") @@ -278,11 +405,7 @@ def echo(message: str) -> str: def run_server(server_port: int) -> None: """Run the server.""" _, app = make_fastmcp_app() - server = uvicorn.Server( - config=uvicorn.Config( - app=app, host="127.0.0.1", port=server_port, log_level="error" - ) - ) + server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) print(f"Starting server on port {server_port}") server.run() @@ -290,11 +413,7 @@ def run_server(server_port: int) -> None: def run_everything_legacy_sse_http_server(server_port: int) -> None: """Run the comprehensive server with all features.""" _, app = make_everything_fastmcp_app() - server = uvicorn.Server( - config=uvicorn.Config( - app=app, host="127.0.0.1", port=server_port, log_level="error" - ) - ) + server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) print(f"Starting comprehensive server on port {server_port}") server.run() @@ -302,11 +421,7 @@ def run_everything_legacy_sse_http_server(server_port: int) -> None: def run_streamable_http_server(server_port: int) -> None: """Run the StreamableHTTP server.""" _, app = make_fastmcp_streamable_http_app() - server = uvicorn.Server( - config=uvicorn.Config( - app=app, host="127.0.0.1", port=server_port, log_level="error" - ) - ) + server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) print(f"Starting StreamableHTTP server on port {server_port}") server.run() @@ -314,11 +429,7 @@ def run_streamable_http_server(server_port: int) -> None: def run_everything_server(server_port: int) -> None: """Run the comprehensive StreamableHTTP server with all features.""" _, app = make_everything_fastmcp_streamable_http_app() - server = uvicorn.Server( - config=uvicorn.Config( - app=app, host="127.0.0.1", port=server_port, log_level="error" - ) - ) + server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) print(f"Starting comprehensive StreamableHTTP server on port {server_port}") server.run() @@ -326,11 +437,7 @@ def run_everything_server(server_port: int) -> None: def run_stateless_http_server(server_port: int) -> None: """Run the stateless StreamableHTTP server.""" _, app = make_fastmcp_stateless_http_app() - server = uvicorn.Server( - config=uvicorn.Config( - app=app, host="127.0.0.1", port=server_port, log_level="error" - ) - ) + server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) print(f"Starting stateless StreamableHTTP server on port {server_port}") server.run() @@ -369,9 +476,7 @@ def server(server_port: int) -> Generator[None, None, None]: @pytest.fixture() def streamable_http_server(http_server_port: int) -> Generator[None, None, None]: """Start the StreamableHTTP server in a separate process.""" - proc = multiprocessing.Process( - target=run_streamable_http_server, args=(http_server_port,), daemon=True - ) + proc = multiprocessing.Process(target=run_streamable_http_server, args=(http_server_port,), daemon=True) print("Starting StreamableHTTP server process") proc.start() @@ -388,9 +493,7 @@ def streamable_http_server(http_server_port: int) -> Generator[None, None, None] time.sleep(0.1) attempt += 1 else: - raise RuntimeError( - f"StreamableHTTP server failed to start after {max_attempts} attempts" - ) + raise RuntimeError(f"StreamableHTTP server failed to start after {max_attempts} attempts") yield @@ -427,9 +530,7 @@ def stateless_http_server( time.sleep(0.1) attempt += 1 else: - raise RuntimeError( - f"Stateless server failed to start after {max_attempts} attempts" - ) + raise RuntimeError(f"Stateless server failed to start after {max_attempts} attempts") yield @@ -459,9 +560,7 @@ async def test_fastmcp_without_auth(server: None, server_url: str) -> None: @pytest.mark.anyio -async def test_fastmcp_streamable_http( - streamable_http_server: None, http_server_url: str -) -> None: +async def test_fastmcp_streamable_http(streamable_http_server: None, http_server_url: str) -> None: """Test that FastMCP works with StreamableHTTP transport.""" # Connect to the server using StreamableHTTP async with streamablehttp_client(http_server_url + "/mcp") as ( @@ -484,9 +583,7 @@ async def test_fastmcp_streamable_http( @pytest.mark.anyio -async def test_fastmcp_stateless_streamable_http( - stateless_http_server: None, stateless_http_server_url: str -) -> None: +async def test_fastmcp_stateless_streamable_http(stateless_http_server: None, stateless_http_server_url: str) -> None: """Test that FastMCP works with stateless StreamableHTTP transport.""" # Connect to the server using StreamableHTTP async with streamablehttp_client(stateless_http_server_url + "/mcp") as ( @@ -562,9 +659,7 @@ def everything_server(everything_server_port: int) -> Generator[None, None, None time.sleep(0.1) attempt += 1 else: - raise RuntimeError( - f"Comprehensive server failed to start after {max_attempts} attempts" - ) + raise RuntimeError(f"Comprehensive server failed to start after {max_attempts} attempts") yield @@ -601,10 +696,7 @@ def everything_streamable_http_server( time.sleep(0.1) attempt += 1 else: - raise RuntimeError( - f"Comprehensive StreamableHTTP server failed to start after " - f"{max_attempts} attempts" - ) + raise RuntimeError(f"Comprehensive StreamableHTTP server failed to start after " f"{max_attempts} attempts") yield @@ -636,21 +728,35 @@ async def handle_tool_list_changed(self, params) -> None: async def handle_generic_notification(self, message) -> None: # Check if this is a ServerNotification - if isinstance(message, types.ServerNotification): + if isinstance(message, ServerNotification): # Check the specific notification type - if isinstance(message.root, types.ProgressNotification): + if isinstance(message.root, ProgressNotification): await self.handle_progress(message.root.params) - elif isinstance(message.root, types.LoggingMessageNotification): + elif isinstance(message.root, LoggingMessageNotification): await self.handle_log(message.root.params) - elif isinstance(message.root, types.ResourceListChangedNotification): + elif isinstance(message.root, ResourceListChangedNotification): await self.handle_resource_list_changed(message.root.params) - elif isinstance(message.root, types.ToolListChangedNotification): + elif isinstance(message.root, ToolListChangedNotification): await self.handle_tool_list_changed(message.root.params) -async def call_all_mcp_features( - session: ClientSession, collector: NotificationCollector -) -> None: +async def create_test_elicitation_callback(context, params): + """Shared elicitation callback for tests. + + Handles elicitation requests for restaurant booking tests. + """ + # For restaurant booking test + if "No tables available" in params.message: + return ElicitResult( + action="accept", + content={"checkAlternative": True, "alternativeDate": "2024-12-26"}, + ) + else: + # Default response + return ElicitResult(action="decline") + + +async def call_all_mcp_features(session: ClientSession, collector: NotificationCollector) -> None: """ Test all MCP features using the provided session. @@ -676,13 +782,21 @@ async def call_all_mcp_features( assert isinstance(tool_result.content[0], TextContent) assert tool_result.content[0].text == "Echo: hello" - # 2. Tool with context (logging and progress) + # 2. Test tool that returns ResourceLinks + list_files_result = await session.call_tool("list_files") + assert len(list_files_result.content) == 1 + + # Rest should be ResourceLinks + content = list_files_result.content[0] + assert isinstance(content, ResourceLink) + assert str(content.uri).startswith("file:///") + assert content.name is not None + assert content.mimeType is not None + # Test progress callback functionality progress_updates = [] - async def progress_callback( - progress: float, total: float | None, message: str | None - ) -> None: + async def progress_callback(progress: float, total: float | None, message: str | None) -> None: """Collect progress updates for testing (async version).""" progress_updates.append((progress, total, message)) print(f"Progress: {progress}/{total} - {message}") @@ -726,19 +840,12 @@ async def progress_callback( # Verify we received log messages from the sampling tool assert len(collector.log_messages) > 0 - assert any( - "Requesting sampling for prompt" in msg.data for msg in collector.log_messages - ) - assert any( - "Received sampling result from model" in msg.data - for msg in collector.log_messages - ) + assert any("Requesting sampling for prompt" in msg.data for msg in collector.log_messages) + assert any("Received sampling result from model" in msg.data for msg in collector.log_messages) # 4. Test notification tool notification_message = "test_notifications" - notification_result = await session.call_tool( - "notification_tool", {"message": notification_message} - ) + notification_result = await session.call_tool("notification_tool", {"message": notification_message}) assert len(notification_result.content) == 1 assert isinstance(notification_result.content[0], TextContent) assert "Sent notifications and logs" in notification_result.content[0].text @@ -754,6 +861,21 @@ async def progress_callback( assert "info" in log_levels assert "warning" in log_levels + # 5. Test elicitation tool + # Test restaurant booking with unavailable date (triggers elicitation) + booking_result = await session.call_tool( + "book_restaurant", + { + "date": "2024-12-25", # Unavailable date to trigger elicitation + "time": "19:00", + "party_size": 4, + }, + ) + assert len(booking_result.content) == 1 + assert isinstance(booking_result.content[0], TextContent) + # Should have booked the alternative date from elicitation callback + assert "βœ… Booked table for 4 on 2024-12-26" in booking_result.content[0].text + # Test resources # 1. Static resource resources = await session.list_resources() @@ -773,36 +895,24 @@ async def progress_callback( # 2. Dynamic resource resource_category = "test" - dynamic_content = await session.read_resource( - AnyUrl(f"resource://dynamic/{resource_category}") - ) + dynamic_content = await session.read_resource(AnyUrl(f"resource://dynamic/{resource_category}")) assert isinstance(dynamic_content, ReadResourceResult) assert len(dynamic_content.contents) == 1 assert isinstance(dynamic_content.contents[0], TextResourceContents) - assert ( - f"Dynamic resource content for category: {resource_category}" - in dynamic_content.contents[0].text - ) + assert f"Dynamic resource content for category: {resource_category}" in dynamic_content.contents[0].text # 3. Template resource resource_id = "456" - template_content = await session.read_resource( - AnyUrl(f"resource://template/{resource_id}/data") - ) + template_content = await session.read_resource(AnyUrl(f"resource://template/{resource_id}/data")) assert isinstance(template_content, ReadResourceResult) assert len(template_content.contents) == 1 assert isinstance(template_content.contents[0], TextResourceContents) - assert ( - f"Template resource data for ID: {resource_id}" - in template_content.contents[0].text - ) + assert f"Template resource data for ID: {resource_id}" in template_content.contents[0].text # Test prompts # 1. Simple prompt prompts = await session.list_prompts() - simple_prompt = next( - (p for p in prompts.prompts if p.name == "simple_prompt"), None - ) + simple_prompt = next((p for p in prompts.prompts if p.name == "simple_prompt"), None) assert simple_prompt is not None prompt_topic = "AI" @@ -812,16 +922,12 @@ async def progress_callback( # The actual message structure depends on the prompt implementation # 2. Complex prompt - complex_prompt = next( - (p for p in prompts.prompts if p.name == "complex_prompt"), None - ) + complex_prompt = next((p for p in prompts.prompts if p.name == "complex_prompt"), None) assert complex_prompt is not None query = "What is AI?" context = "technical" - complex_result = await session.get_prompt( - "complex_prompt", {"user_query": query, "context": context} - ) + complex_result = await session.get_prompt("complex_prompt", {"user_query": query, "context": context}) assert isinstance(complex_result, GetPromptResult) assert len(complex_result.messages) >= 1 @@ -837,9 +943,7 @@ async def progress_callback( print(f"Received headers: {headers_data}") # Test 6: Call tool that returns full context - context_result = await session.call_tool( - "echo_context", {"custom_request_id": "test-123"} - ) + context_result = await session.call_tool("echo_context", {"custom_request_id": "test-123"}) assert len(context_result.content) == 1 assert isinstance(context_result.content[0], TextContent) @@ -849,6 +953,41 @@ async def progress_callback( if context_data["method"]: assert context_data["method"] == "POST" + # Test completion functionality + # 1. Test resource template completion with context + repo_result = await session.complete( + ref=ResourceTemplateReference(type="ref/resource", uri="github://repos/{owner}/{repo}"), + argument={"name": "repo", "value": ""}, + context_arguments={"owner": "modelcontextprotocol"}, + ) + assert repo_result.completion.values == ["python-sdk", "typescript-sdk", "specification"] + assert repo_result.completion.total == 3 + assert repo_result.completion.hasMore is False + + # 2. Test with different context + repo_result2 = await session.complete( + ref=ResourceTemplateReference(type="ref/resource", uri="github://repos/{owner}/{repo}"), + argument={"name": "repo", "value": ""}, + context_arguments={"owner": "test-org"}, + ) + assert repo_result2.completion.values == ["test-repo1", "test-repo2"] + assert repo_result2.completion.total == 2 + + # 3. Test prompt argument completion + context_result = await session.complete( + ref=PromptReference(type="ref/prompt", name="complex_prompt"), + argument={"name": "context", "value": "tech"}, + ) + assert "technical" in context_result.completion.values + + # 4. Test completion without context (should return empty) + no_context_result = await session.complete( + ref=ResourceTemplateReference(type="ref/resource", uri="github://repos/{owner}/{repo}"), + argument={"name": "repo", "value": "test"}, + ) + assert no_context_result.completion.values == [] + assert no_context_result.completion.total == 0 + async def sampling_callback( context: RequestContext[ClientSession, None], @@ -871,16 +1010,12 @@ async def sampling_callback( @pytest.mark.anyio -async def test_fastmcp_all_features_sse( - everything_server: None, everything_server_url: str -) -> None: +async def test_fastmcp_all_features_sse(everything_server: None, everything_server_url: str) -> None: """Test all MCP features work correctly with SSE transport.""" # Create notification collector collector = NotificationCollector() - # Create a sampling callback that simulates an LLM - # Connect to the server with callbacks async with sse_client(everything_server_url + "/sse") as streams: # Set up message handler to capture notifications @@ -893,6 +1028,7 @@ async def message_handler(message): async with ClientSession( *streams, sampling_callback=sampling_callback, + elicitation_callback=create_test_elicitation_callback, message_handler=message_handler, ) as session: # Run the common test suite @@ -925,7 +1061,111 @@ async def message_handler(message): read_stream, write_stream, sampling_callback=sampling_callback, + elicitation_callback=create_test_elicitation_callback, message_handler=message_handler, ) as session: # Run the common test suite with HTTP-specific test suffix await call_all_mcp_features(session, collector) + + +@pytest.mark.anyio +async def test_elicitation_feature(server: None, server_url: str) -> None: + """Test the elicitation feature.""" + + # Create a custom handler for elicitation requests + async def elicitation_callback(context, params): + # Verify the elicitation parameters + if params.message == "Tool wants to ask: What is your name?": + return ElicitResult(content={"answer": "Test User"}, action="accept") + else: + raise ValueError("Unexpected elicitation message") + + # Connect to the server with our custom elicitation handler + async with sse_client(server_url + "/sse") as streams: + async with ClientSession(*streams, elicitation_callback=elicitation_callback) as session: + # First initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == "NoAuthServer" + + # Call the tool that uses elicitation + tool_result = await session.call_tool("ask_user", {"prompt": "What is your name?"}) + # Verify the result + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + # # The test should only succeed with the successful elicitation response + assert tool_result.content[0].text == "User answered: Test User" + + +@pytest.mark.anyio +async def test_title_precedence(everything_server: None, everything_server_url: str) -> None: + """Test that titles are properly returned for tools, resources, and prompts.""" + from mcp.shared.metadata_utils import get_display_name + + async with sse_client(everything_server_url + "/sse") as streams: + async with ClientSession(*streams) as session: + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + + # Test tools have titles + tools_result = await session.list_tools() + assert tools_result.tools + + # Check specific tools have titles + tool_names_to_titles = { + "tool_with_progress": "Progress Tool", + "echo": "Echo Tool", + "sampling_tool": "Sampling Tool", + "notification_tool": "Notification Tool", + "echo_headers": "Echo Headers", + "echo_context": "Echo Context", + "book_restaurant": "Restaurant Booking", + } + + for tool in tools_result.tools: + if tool.name in tool_names_to_titles: + assert tool.title == tool_names_to_titles[tool.name] + # Test get_display_name utility + assert get_display_name(tool) == tool_names_to_titles[tool.name] + + # Test resources have titles + resources_result = await session.list_resources() + assert resources_result.resources + + # Check specific resources have titles + static_resource = next((r for r in resources_result.resources if r.name == "Static Info"), None) + assert static_resource is not None + assert static_resource.title == "Static Information" + assert get_display_name(static_resource) == "Static Information" + + # Test resource templates have titles + resource_templates = await session.list_resource_templates() + assert resource_templates.resourceTemplates + + # Check specific resource templates have titles + template_uris_to_titles = { + "resource://dynamic/{category}": "Dynamic Resource", + "resource://template/{id}/data": "Template Resource", + "github://repos/{owner}/{repo}": "GitHub Repository", + } + + for template in resource_templates.resourceTemplates: + if template.uriTemplate in template_uris_to_titles: + assert template.title == template_uris_to_titles[template.uriTemplate] + assert get_display_name(template) == template_uris_to_titles[template.uriTemplate] + + # Test prompts have titles + prompts_result = await session.list_prompts() + assert prompts_result.prompts + + # Check specific prompts have titles + prompt_names_to_titles = { + "simple_prompt": "Simple Prompt", + "complex_prompt": "Complex Prompt", + } + + for prompt in prompts_result.prompts: + if prompt.name in prompt_names_to_titles: + assert prompt.title == prompt_names_to_titles[prompt.name] + assert get_display_name(prompt) == prompt_names_to_titles[prompt.name] diff --git a/tests/server/fastmcp/test_server.py b/tests/server/fastmcp/test_server.py index b817761ea..8719b78d5 100644 --- a/tests/server/fastmcp/test_server.py +++ b/tests/server/fastmcp/test_server.py @@ -8,7 +8,7 @@ from starlette.routing import Mount, Route from mcp.server.fastmcp import Context, FastMCP -from mcp.server.fastmcp.prompts.base import EmbeddedResource, Message, UserMessage +from mcp.server.fastmcp.prompts.base import Message, UserMessage from mcp.server.fastmcp.resources import FileResource, FunctionResource from mcp.server.fastmcp.utilities.types import Image from mcp.shared.exceptions import McpError @@ -16,7 +16,10 @@ create_connected_server_and_client_session as client_session, ) from mcp.types import ( + AudioContent, BlobResourceContents, + ContentBlock, + EmbeddedResource, ImageContent, TextContent, TextResourceContents, @@ -58,9 +61,7 @@ async def test_sse_app_with_mount_path(self): """Test SSE app creation with different mount paths.""" # Test with default mount path mcp = FastMCP() - with patch.object( - mcp, "_normalize_path", return_value="/messages/" - ) as mock_normalize: + with patch.object(mcp, "_normalize_path", return_value="/messages/") as mock_normalize: mcp.sse_app() # Verify _normalize_path was called with correct args mock_normalize.assert_called_once_with("/", "/messages/") @@ -68,18 +69,14 @@ async def test_sse_app_with_mount_path(self): # Test with custom mount path in settings mcp = FastMCP() mcp.settings.mount_path = "/custom" - with patch.object( - mcp, "_normalize_path", return_value="/custom/messages/" - ) as mock_normalize: + with patch.object(mcp, "_normalize_path", return_value="/custom/messages/") as mock_normalize: mcp.sse_app() # Verify _normalize_path was called with correct args mock_normalize.assert_called_once_with("/custom", "/messages/") # Test with mount_path parameter mcp = FastMCP() - with patch.object( - mcp, "_normalize_path", return_value="/param/messages/" - ) as mock_normalize: + with patch.object(mcp, "_normalize_path", return_value="/param/messages/") as mock_normalize: mcp.sse_app(mount_path="/param") # Verify _normalize_path was called with correct args mock_normalize.assert_called_once_with("/param", "/messages/") @@ -102,9 +99,7 @@ async def test_starlette_routes_with_mount_path(self): # Verify path values assert sse_routes[0].path == "/sse", "SSE route path should be /sse" - assert ( - mount_routes[0].path == "/messages" - ), "Mount route path should be /messages" + assert mount_routes[0].path == "/messages", "Mount route path should be /messages" # Test with mount path as parameter mcp = FastMCP() @@ -120,20 +115,14 @@ async def test_starlette_routes_with_mount_path(self): # Verify path values assert sse_routes[0].path == "/sse", "SSE route path should be /sse" - assert ( - mount_routes[0].path == "/messages" - ), "Mount route path should be /messages" + assert mount_routes[0].path == "/messages", "Mount route path should be /messages" @pytest.mark.anyio async def test_non_ascii_description(self): """Test that FastMCP handles non-ASCII characters in descriptions correctly""" mcp = FastMCP() - @mcp.tool( - description=( - "🌟 This tool uses emojis and UTF-8 characters: Γ‘ Γ© Γ­ Γ³ ΓΊ Γ± ζΌ’ε­— πŸŽ‰" - ) - ) + @mcp.tool(description=("🌟 This tool uses emojis and UTF-8 characters: Γ‘ Γ© Γ­ Γ³ ΓΊ Γ± ζΌ’ε­— πŸŽ‰")) def hello_world(name: str = "δΈ–η•Œ") -> str: return f"Β‘Hola, {name}! πŸ‘‹" @@ -186,9 +175,7 @@ def get_data(x: str) -> str: async def test_add_resource_decorator_incorrect_usage(self): mcp = FastMCP() - with pytest.raises( - TypeError, match="The @resource decorator was used incorrectly" - ): + with pytest.raises(TypeError, match="The @resource decorator was used incorrectly"): @mcp.resource # Missing parentheses #type: ignore def get_data(x: str) -> str: @@ -207,10 +194,11 @@ def image_tool_fn(path: str) -> Image: return Image(path) -def mixed_content_tool_fn() -> list[TextContent | ImageContent]: +def mixed_content_tool_fn() -> list[ContentBlock]: return [ TextContent(type="text", text="Hello"), ImageContent(type="image", data="abc", mimeType="image/png"), + AudioContent(type="audio", data="def", mimeType="audio/wav"), ] @@ -312,14 +300,16 @@ async def test_tool_mixed_content(self): mcp.add_tool(mixed_content_tool_fn) async with client_session(mcp._mcp_server) as client: result = await client.call_tool("mixed_content_tool_fn", {}) - assert len(result.content) == 2 - content1 = result.content[0] - content2 = result.content[1] + assert len(result.content) == 3 + content1, content2, content3 = result.content assert isinstance(content1, TextContent) assert content1.text == "Hello" assert isinstance(content2, ImageContent) assert content2.mimeType == "image/png" assert content2.data == "abc" + assert isinstance(content3, AudioContent) + assert content3.mimeType == "audio/wav" + assert content3.data == "def" @pytest.mark.anyio async def test_tool_mixed_list_with_image(self, tmp_path: Path): @@ -369,9 +359,7 @@ async def test_text_resource(self): def get_text(): return "Hello, world!" - resource = FunctionResource( - uri=AnyUrl("resource://test"), name="test", fn=get_text - ) + resource = FunctionResource(uri=AnyUrl("resource://test"), name="test", fn=get_text) mcp.add_resource(resource) async with client_session(mcp._mcp_server) as client: @@ -407,9 +395,7 @@ async def test_file_resource_text(self, tmp_path: Path): text_file = tmp_path / "test.txt" text_file.write_text("Hello from file!") - resource = FileResource( - uri=AnyUrl("file://test.txt"), name="test.txt", path=text_file - ) + resource = FileResource(uri=AnyUrl("file://test.txt"), name="test.txt", path=text_file) mcp.add_resource(resource) async with client_session(mcp._mcp_server) as client: @@ -436,10 +422,7 @@ async def test_file_resource_binary(self, tmp_path: Path): async with client_session(mcp._mcp_server) as client: result = await client.read_resource(AnyUrl("file://test.bin")) assert isinstance(result.contents[0], BlobResourceContents) - assert ( - result.contents[0].blob - == base64.b64encode(b"Binary file data").decode() - ) + assert result.contents[0].blob == base64.b64encode(b"Binary file data").decode() @pytest.mark.anyio async def test_function_resource(self): @@ -528,9 +511,7 @@ def get_data(org: str, repo: str) -> str: return f"Data for {org}/{repo}" async with client_session(mcp._mcp_server) as client: - result = await client.read_resource( - AnyUrl("resource://cursor/fastmcp/data") - ) + result = await client.read_resource(AnyUrl("resource://cursor/fastmcp/data")) assert isinstance(result.contents[0], TextResourceContents) assert result.contents[0].text == "Data for cursor/fastmcp" diff --git a/tests/server/fastmcp/test_title.py b/tests/server/fastmcp/test_title.py new file mode 100644 index 000000000..a94f6671d --- /dev/null +++ b/tests/server/fastmcp/test_title.py @@ -0,0 +1,215 @@ +"""Integration tests for title field functionality.""" + +import pytest +from pydantic import AnyUrl + +from mcp.server.fastmcp import FastMCP +from mcp.server.fastmcp.resources import FunctionResource +from mcp.shared.memory import create_connected_server_and_client_session +from mcp.shared.metadata_utils import get_display_name +from mcp.types import Prompt, Resource, ResourceTemplate, Tool, ToolAnnotations + + +@pytest.mark.anyio +async def test_tool_title_precedence(): + """Test that tool title precedence works correctly: title > annotations.title > name.""" + # Create server with various tool configurations + mcp = FastMCP(name="TitleTestServer") + + # Tool with only name + @mcp.tool(description="Basic tool") + def basic_tool(message: str) -> str: + return message + + # Tool with title + @mcp.tool(description="Tool with title", title="User-Friendly Tool") + def tool_with_title(message: str) -> str: + return message + + # Tool with annotations.title (when title is not supported on decorator) + # We'll need to add this manually after registration + @mcp.tool(description="Tool with annotations") + def tool_with_annotations(message: str) -> str: + return message + + # Tool with both title and annotations.title + @mcp.tool(description="Tool with both", title="Primary Title") + def tool_with_both(message: str) -> str: + return message + + # Start server and connect client + async with create_connected_server_and_client_session(mcp._mcp_server) as client: + await client.initialize() + + # List tools + tools_result = await client.list_tools() + tools = {tool.name: tool for tool in tools_result.tools} + + # Verify basic tool uses name + assert "basic_tool" in tools + basic = tools["basic_tool"] + # Since we haven't implemented get_display_name yet, we'll check the raw fields + assert basic.title is None + assert basic.name == "basic_tool" + + # Verify tool with title + assert "tool_with_title" in tools + titled = tools["tool_with_title"] + assert titled.title == "User-Friendly Tool" + + # For now, we'll skip the annotations.title test as it requires modifying + # the tool after registration, which we'll implement later + + # Verify tool with both uses title over annotations.title + assert "tool_with_both" in tools + both = tools["tool_with_both"] + assert both.title == "Primary Title" + + +@pytest.mark.anyio +async def test_prompt_title(): + """Test that prompt titles work correctly.""" + mcp = FastMCP(name="PromptTitleServer") + + # Prompt with only name + @mcp.prompt(description="Basic prompt") + def basic_prompt(topic: str) -> str: + return f"Tell me about {topic}" + + # Prompt with title + @mcp.prompt(description="Titled prompt", title="Ask About Topic") + def titled_prompt(topic: str) -> str: + return f"Tell me about {topic}" + + # Start server and connect client + async with create_connected_server_and_client_session(mcp._mcp_server) as client: + await client.initialize() + + # List prompts + prompts_result = await client.list_prompts() + prompts = {prompt.name: prompt for prompt in prompts_result.prompts} + + # Verify basic prompt uses name + assert "basic_prompt" in prompts + basic = prompts["basic_prompt"] + assert basic.title is None + assert basic.name == "basic_prompt" + + # Verify prompt with title + assert "titled_prompt" in prompts + titled = prompts["titled_prompt"] + assert titled.title == "Ask About Topic" + + +@pytest.mark.anyio +async def test_resource_title(): + """Test that resource titles work correctly.""" + mcp = FastMCP(name="ResourceTitleServer") + + # Static resource without title + def get_basic_data() -> str: + return "Basic data" + + basic_resource = FunctionResource( + uri=AnyUrl("resource://basic"), + name="basic_resource", + description="Basic resource", + fn=get_basic_data, + ) + mcp.add_resource(basic_resource) + + # Static resource with title + def get_titled_data() -> str: + return "Titled data" + + titled_resource = FunctionResource( + uri=AnyUrl("resource://titled"), + name="titled_resource", + title="User-Friendly Resource", + description="Resource with title", + fn=get_titled_data, + ) + mcp.add_resource(titled_resource) + + # Dynamic resource without title + @mcp.resource("resource://dynamic/{id}") + def dynamic_resource(id: str) -> str: + return f"Data for {id}" + + # Dynamic resource with title (when supported) + @mcp.resource("resource://titled-dynamic/{id}", title="Dynamic Data") + def titled_dynamic_resource(id: str) -> str: + return f"Data for {id}" + + # Start server and connect client + async with create_connected_server_and_client_session(mcp._mcp_server) as client: + await client.initialize() + + # List resources + resources_result = await client.list_resources() + resources = {str(res.uri): res for res in resources_result.resources} + + # Verify basic resource uses name + assert "resource://basic" in resources + basic = resources["resource://basic"] + assert basic.title is None + assert basic.name == "basic_resource" + + # Verify resource with title + assert "resource://titled" in resources + titled = resources["resource://titled"] + assert titled.title == "User-Friendly Resource" + + # List resource templates + templates_result = await client.list_resource_templates() + templates = {tpl.uriTemplate: tpl for tpl in templates_result.resourceTemplates} + + # Verify dynamic resource template + assert "resource://dynamic/{id}" in templates + dynamic = templates["resource://dynamic/{id}"] + assert dynamic.title is None + assert dynamic.name == "dynamic_resource" + + # Verify titled dynamic resource template (when supported) + if "resource://titled-dynamic/{id}" in templates: + titled_dynamic = templates["resource://titled-dynamic/{id}"] + assert titled_dynamic.title == "Dynamic Data" + + +@pytest.mark.anyio +async def test_get_display_name_utility(): + """Test the get_display_name utility function.""" + + # Test tool precedence: title > annotations.title > name + tool_name_only = Tool(name="test_tool", inputSchema={}) + assert get_display_name(tool_name_only) == "test_tool" + + tool_with_title = Tool(name="test_tool", title="Test Tool", inputSchema={}) + assert get_display_name(tool_with_title) == "Test Tool" + + tool_with_annotations = Tool(name="test_tool", inputSchema={}, annotations=ToolAnnotations(title="Annotated Tool")) + assert get_display_name(tool_with_annotations) == "Annotated Tool" + + tool_with_both = Tool( + name="test_tool", title="Primary Title", inputSchema={}, annotations=ToolAnnotations(title="Secondary Title") + ) + assert get_display_name(tool_with_both) == "Primary Title" + + # Test other types: title > name + resource = Resource(uri=AnyUrl("file://test"), name="test_res") + assert get_display_name(resource) == "test_res" + + resource_with_title = Resource(uri=AnyUrl("file://test"), name="test_res", title="Test Resource") + assert get_display_name(resource_with_title) == "Test Resource" + + prompt = Prompt(name="test_prompt") + assert get_display_name(prompt) == "test_prompt" + + prompt_with_title = Prompt(name="test_prompt", title="Test Prompt") + assert get_display_name(prompt_with_title) == "Test Prompt" + + template = ResourceTemplate(uriTemplate="file://{id}", name="test_template") + assert get_display_name(template) == "test_template" + + template_with_title = ResourceTemplate(uriTemplate="file://{id}", name="test_template", title="Test Template") + assert get_display_name(template_with_title) == "Test Template" diff --git a/tests/server/fastmcp/test_tool_manager.py b/tests/server/fastmcp/test_tool_manager.py index b45c7ac38..206df42d7 100644 --- a/tests/server/fastmcp/test_tool_manager.py +++ b/tests/server/fastmcp/test_tool_manager.py @@ -44,6 +44,7 @@ class AddArguments(ArgModelBase): original_tool = Tool( name="add", + title="Add Tool", description="Add two numbers.", fn=add, fn_metadata=fn_metadata, @@ -147,9 +148,7 @@ def test_add_lambda(self): def test_add_lambda_with_no_name(self): manager = ToolManager() - with pytest.raises( - ValueError, match="You must provide a name for lambda functions" - ): + with pytest.raises(ValueError, match="You must provide a name for lambda functions"): manager.add_tool(lambda x: x) def test_warn_on_duplicate_tools(self, caplog): @@ -346,9 +345,7 @@ def tool_without_context(x: int) -> str: tool = manager.add_tool(tool_without_context) assert tool.context_kwarg is None - def tool_with_parametrized_context( - x: int, ctx: Context[ServerSessionT, LifespanContextT, RequestT] - ) -> str: + def tool_with_parametrized_context(x: int, ctx: Context[ServerSessionT, LifespanContextT, RequestT]) -> str: return str(x) tool = manager.add_tool(tool_with_parametrized_context) diff --git a/tests/server/test_completion_with_context.py b/tests/server/test_completion_with_context.py new file mode 100644 index 000000000..f0d154587 --- /dev/null +++ b/tests/server/test_completion_with_context.py @@ -0,0 +1,180 @@ +""" +Tests for completion handler with context functionality. +""" + +import pytest + +from mcp.server.lowlevel import Server +from mcp.shared.memory import create_connected_server_and_client_session +from mcp.types import ( + Completion, + CompletionArgument, + CompletionContext, + PromptReference, + ResourceTemplateReference, +) + + +@pytest.mark.anyio +async def test_completion_handler_receives_context(): + """Test that the completion handler receives context correctly.""" + server = Server("test-server") + + # Track what the handler receives + received_args = {} + + @server.completion() + async def handle_completion( + ref: PromptReference | ResourceTemplateReference, + argument: CompletionArgument, + context: CompletionContext | None, + ) -> Completion | None: + received_args["ref"] = ref + received_args["argument"] = argument + received_args["context"] = context + + # Return test completion + return Completion(values=["test-completion"], total=1, hasMore=False) + + async with create_connected_server_and_client_session(server) as client: + # Test with context + result = await client.complete( + ref=ResourceTemplateReference(type="ref/resource", uri="test://resource/{param}"), + argument={"name": "param", "value": "test"}, + context_arguments={"previous": "value"}, + ) + + # Verify handler received the context + assert received_args["context"] is not None + assert received_args["context"].arguments == {"previous": "value"} + assert result.completion.values == ["test-completion"] + + +@pytest.mark.anyio +async def test_completion_backward_compatibility(): + """Test that completion works without context (backward compatibility).""" + server = Server("test-server") + + context_was_none = False + + @server.completion() + async def handle_completion( + ref: PromptReference | ResourceTemplateReference, + argument: CompletionArgument, + context: CompletionContext | None, + ) -> Completion | None: + nonlocal context_was_none + context_was_none = context is None + + return Completion(values=["no-context-completion"], total=1, hasMore=False) + + async with create_connected_server_and_client_session(server) as client: + # Test without context + result = await client.complete( + ref=PromptReference(type="ref/prompt", name="test-prompt"), argument={"name": "arg", "value": "val"} + ) + + # Verify context was None + assert context_was_none + assert result.completion.values == ["no-context-completion"] + + +@pytest.mark.anyio +async def test_dependent_completion_scenario(): + """Test a real-world scenario with dependent completions.""" + server = Server("test-server") + + @server.completion() + async def handle_completion( + ref: PromptReference | ResourceTemplateReference, + argument: CompletionArgument, + context: CompletionContext | None, + ) -> Completion | None: + # Simulate database/table completion scenario + if isinstance(ref, ResourceTemplateReference): + if ref.uri == "db://{database}/{table}": + if argument.name == "database": + # Complete database names + return Completion(values=["users_db", "products_db", "analytics_db"], total=3, hasMore=False) + elif argument.name == "table": + # Complete table names based on selected database + if context and context.arguments: + db = context.arguments.get("database") + if db == "users_db": + return Completion(values=["users", "sessions", "permissions"], total=3, hasMore=False) + elif db == "products_db": + return Completion(values=["products", "categories", "inventory"], total=3, hasMore=False) + + return Completion(values=[], total=0, hasMore=False) + + async with create_connected_server_and_client_session(server) as client: + # First, complete database + db_result = await client.complete( + ref=ResourceTemplateReference(type="ref/resource", uri="db://{database}/{table}"), + argument={"name": "database", "value": ""}, + ) + assert "users_db" in db_result.completion.values + assert "products_db" in db_result.completion.values + + # Then complete table with database context + table_result = await client.complete( + ref=ResourceTemplateReference(type="ref/resource", uri="db://{database}/{table}"), + argument={"name": "table", "value": ""}, + context_arguments={"database": "users_db"}, + ) + assert table_result.completion.values == ["users", "sessions", "permissions"] + + # Different database gives different tables + table_result2 = await client.complete( + ref=ResourceTemplateReference(type="ref/resource", uri="db://{database}/{table}"), + argument={"name": "table", "value": ""}, + context_arguments={"database": "products_db"}, + ) + assert table_result2.completion.values == ["products", "categories", "inventory"] + + +@pytest.mark.anyio +async def test_completion_error_on_missing_context(): + """Test that server can raise error when required context is missing.""" + server = Server("test-server") + + @server.completion() + async def handle_completion( + ref: PromptReference | ResourceTemplateReference, + argument: CompletionArgument, + context: CompletionContext | None, + ) -> Completion | None: + if isinstance(ref, ResourceTemplateReference): + if ref.uri == "db://{database}/{table}": + if argument.name == "table": + # Check if database context is provided + if not context or not context.arguments or "database" not in context.arguments: + # Raise an error instead of returning error as completion + raise ValueError("Please select a database first to see available tables") + # Normal completion if context is provided + db = context.arguments.get("database") + if db == "test_db": + return Completion(values=["users", "orders", "products"], total=3, hasMore=False) + + return Completion(values=[], total=0, hasMore=False) + + async with create_connected_server_and_client_session(server) as client: + # Try to complete table without database context - should raise error + with pytest.raises(Exception) as exc_info: + await client.complete( + ref=ResourceTemplateReference(type="ref/resource", uri="db://{database}/{table}"), + argument={"name": "table", "value": ""}, + ) + + # Verify error message + assert "Please select a database first" in str(exc_info.value) + + # Now complete with proper context - should work normally + result_with_context = await client.complete( + ref=ResourceTemplateReference(type="ref/resource", uri="db://{database}/{table}"), + argument={"name": "table", "value": ""}, + context_arguments={"database": "test_db"}, + ) + + # Should get normal completions + assert result_with_context.completion.values == ["users", "orders", "products"] diff --git a/tests/server/test_lowlevel_tool_annotations.py b/tests/server/test_lowlevel_tool_annotations.py index e9eff9ed0..2eb3b7ddb 100644 --- a/tests/server/test_lowlevel_tool_annotations.py +++ b/tests/server/test_lowlevel_tool_annotations.py @@ -10,13 +10,7 @@ from mcp.server.session import ServerSession from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder -from mcp.types import ( - ClientResult, - ServerNotification, - ServerRequest, - Tool, - ToolAnnotations, -) +from mcp.types import ClientResult, ServerNotification, ServerRequest, Tool, ToolAnnotations @pytest.mark.anyio @@ -45,18 +39,12 @@ async def list_tools(): ) ] - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - SessionMessage - ](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - SessionMessage - ](10) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) # Message handler for client async def message_handler( - message: RequestResponder[ServerRequest, ClientResult] - | ServerNotification - | Exception, + message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, ) -> None: if isinstance(message, Exception): raise message diff --git a/tests/server/test_read_resource.py b/tests/server/test_read_resource.py index 469eef857..91f6ef8c8 100644 --- a/tests/server/test_read_resource.py +++ b/tests/server/test_read_resource.py @@ -56,11 +56,7 @@ async def test_read_resource_binary(temp_file: Path): @server.read_resource() async def read_resource(uri: AnyUrl) -> Iterable[ReadResourceContents]: - return [ - ReadResourceContents( - content=b"Hello World", mime_type="application/octet-stream" - ) - ] + return [ReadResourceContents(content=b"Hello World", mime_type="application/octet-stream")] # Get the handler directly from the server handler = server.request_handlers[types.ReadResourceRequest] diff --git a/tests/server/test_session.py b/tests/server/test_session.py index 1375df12f..69321f87c 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -20,18 +20,12 @@ @pytest.mark.anyio async def test_server_session_initialize(): - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - SessionMessage - ](1) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - SessionMessage - ](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) # Create a message handler to catch exceptions async def message_handler( - message: RequestResponder[types.ServerRequest, types.ClientResult] - | types.ServerNotification - | Exception, + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: if isinstance(message, Exception): raise message @@ -54,9 +48,7 @@ async def run_server(): if isinstance(message, Exception): raise message - if isinstance(message, ClientNotification) and isinstance( - message.root, InitializedNotification - ): + if isinstance(message, ClientNotification) and isinstance(message.root, InitializedNotification): received_initialized = True return @@ -111,12 +103,8 @@ async def list_resources(): @pytest.mark.anyio async def test_server_session_initialize_with_older_protocol_version(): """Test that server accepts and responds with older protocol (2024-11-05).""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - SessionMessage - ](1) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - SessionMessage | Exception - ](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) received_initialized = False received_protocol_version = None @@ -137,9 +125,7 @@ async def run_server(): if isinstance(message, Exception): raise message - if isinstance(message, types.ClientNotification) and isinstance( - message.root, InitializedNotification - ): + if isinstance(message, types.ClientNotification) and isinstance(message.root, InitializedNotification): received_initialized = True return @@ -157,9 +143,7 @@ async def mock_client(): params=types.InitializeRequestParams( protocolVersion="2024-11-05", capabilities=types.ClientCapabilities(), - clientInfo=types.Implementation( - name="test-client", version="1.0.0" - ), + clientInfo=types.Implementation(name="test-client", version="1.0.0"), ).model_dump(by_alias=True, mode="json", exclude_none=True), ) ) diff --git a/tests/server/test_sse_security.py b/tests/server/test_sse_security.py new file mode 100644 index 000000000..43af35061 --- /dev/null +++ b/tests/server/test_sse_security.py @@ -0,0 +1,293 @@ +"""Tests for SSE server DNS rebinding protection.""" + +import logging +import multiprocessing +import socket +import time + +import httpx +import pytest +import uvicorn +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import Response +from starlette.routing import Mount, Route + +from mcp.server import Server +from mcp.server.sse import SseServerTransport +from mcp.server.transport_security import TransportSecuritySettings +from mcp.types import Tool + +logger = logging.getLogger(__name__) +SERVER_NAME = "test_sse_security_server" + + +@pytest.fixture +def server_port() -> int: + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@pytest.fixture +def server_url(server_port: int) -> str: + return f"http://127.0.0.1:{server_port}" + + +class SecurityTestServer(Server): + def __init__(self): + super().__init__(SERVER_NAME) + + async def on_list_tools(self) -> list[Tool]: + return [] + + +def run_server_with_settings(port: int, security_settings: TransportSecuritySettings | None = None): + """Run the SSE server with specified security settings.""" + app = SecurityTestServer() + sse_transport = SseServerTransport("/messages/", security_settings) + + async def handle_sse(request: Request): + try: + async with sse_transport.connect_sse(request.scope, request.receive, request._send) as streams: + if streams: + await app.run(streams[0], streams[1], app.create_initialization_options()) + except ValueError as e: + # Validation error was already handled inside connect_sse + logger.debug(f"SSE connection failed validation: {e}") + return Response() + + routes = [ + Route("/sse", endpoint=handle_sse), + Mount("/messages/", app=sse_transport.handle_post_message), + ] + + starlette_app = Starlette(routes=routes) + uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="error") + + +def start_server_process(port: int, security_settings: TransportSecuritySettings | None = None): + """Start server in a separate process.""" + process = multiprocessing.Process(target=run_server_with_settings, args=(port, security_settings)) + process.start() + # Give server time to start + time.sleep(1) + return process + + +@pytest.mark.anyio +async def test_sse_security_default_settings(server_port: int): + """Test SSE with default security settings (protection disabled).""" + process = start_server_process(server_port) + + try: + headers = {"Host": "evil.com", "Origin": "http://evil.com"} + + async with httpx.AsyncClient(timeout=5.0) as client: + async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: + assert response.status_code == 200 + finally: + process.terminate() + process.join() + + +@pytest.mark.anyio +async def test_sse_security_invalid_host_header(server_port: int): + """Test SSE with invalid Host header.""" + # Enable security by providing settings with an empty allowed_hosts list + security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["example.com"]) + process = start_server_process(server_port, security_settings) + + try: + # Test with invalid host header + headers = {"Host": "evil.com"} + + async with httpx.AsyncClient() as client: + response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) + assert response.status_code == 421 + assert response.text == "Invalid Host header" + + finally: + process.terminate() + process.join() + + +@pytest.mark.anyio +async def test_sse_security_invalid_origin_header(server_port: int): + """Test SSE with invalid Origin header.""" + # Configure security to allow the host but restrict origins + security_settings = TransportSecuritySettings( + enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://localhost:*"] + ) + process = start_server_process(server_port, security_settings) + + try: + # Test with invalid origin header + headers = {"Origin": "http://evil.com"} + + async with httpx.AsyncClient() as client: + response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) + assert response.status_code == 400 + assert response.text == "Invalid Origin header" + + finally: + process.terminate() + process.join() + + +@pytest.mark.anyio +async def test_sse_security_post_invalid_content_type(server_port: int): + """Test POST endpoint with invalid Content-Type header.""" + # Configure security to allow the host + security_settings = TransportSecuritySettings( + enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://127.0.0.1:*"] + ) + process = start_server_process(server_port, security_settings) + + try: + async with httpx.AsyncClient(timeout=5.0) as client: + # Test POST with invalid content type + fake_session_id = "12345678123456781234567812345678" + response = await client.post( + f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", + headers={"Content-Type": "text/plain"}, + content="test", + ) + assert response.status_code == 400 + assert response.text == "Invalid Content-Type header" + + # Test POST with missing content type + response = await client.post( + f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", content="test" + ) + assert response.status_code == 400 + assert response.text == "Invalid Content-Type header" + + finally: + process.terminate() + process.join() + + +@pytest.mark.anyio +async def test_sse_security_disabled(server_port: int): + """Test SSE with security disabled.""" + settings = TransportSecuritySettings(enable_dns_rebinding_protection=False) + process = start_server_process(server_port, settings) + + try: + # Test with invalid host header - should still work + headers = {"Host": "evil.com"} + + async with httpx.AsyncClient(timeout=5.0) as client: + # For SSE endpoints, we need to use stream to avoid timeout + async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: + # Should connect successfully even with invalid host + assert response.status_code == 200 + + finally: + process.terminate() + process.join() + + +@pytest.mark.anyio +async def test_sse_security_custom_allowed_hosts(server_port: int): + """Test SSE with custom allowed hosts.""" + settings = TransportSecuritySettings( + enable_dns_rebinding_protection=True, + allowed_hosts=["localhost", "127.0.0.1", "custom.host"], + allowed_origins=["http://localhost", "http://127.0.0.1", "http://custom.host"], + ) + process = start_server_process(server_port, settings) + + try: + # Test with custom allowed host + headers = {"Host": "custom.host"} + + async with httpx.AsyncClient(timeout=5.0) as client: + # For SSE endpoints, we need to use stream to avoid timeout + async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: + # Should connect successfully with custom host + assert response.status_code == 200 + + # Test with non-allowed host + headers = {"Host": "evil.com"} + + async with httpx.AsyncClient() as client: + response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) + assert response.status_code == 421 + assert response.text == "Invalid Host header" + + finally: + process.terminate() + process.join() + + +@pytest.mark.anyio +async def test_sse_security_wildcard_ports(server_port: int): + """Test SSE with wildcard port patterns.""" + settings = TransportSecuritySettings( + enable_dns_rebinding_protection=True, + allowed_hosts=["localhost:*", "127.0.0.1:*"], + allowed_origins=["http://localhost:*", "http://127.0.0.1:*"], + ) + process = start_server_process(server_port, settings) + + try: + # Test with various port numbers + for test_port in [8080, 3000, 9999]: + headers = {"Host": f"localhost:{test_port}"} + + async with httpx.AsyncClient(timeout=5.0) as client: + # For SSE endpoints, we need to use stream to avoid timeout + async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: + # Should connect successfully with any port + assert response.status_code == 200 + + headers = {"Origin": f"http://localhost:{test_port}"} + + async with httpx.AsyncClient(timeout=5.0) as client: + # For SSE endpoints, we need to use stream to avoid timeout + async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response: + # Should connect successfully with any port + assert response.status_code == 200 + + finally: + process.terminate() + process.join() + + +@pytest.mark.anyio +async def test_sse_security_post_valid_content_type(server_port: int): + """Test POST endpoint with valid Content-Type headers.""" + # Configure security to allow the host + security_settings = TransportSecuritySettings( + enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://127.0.0.1:*"] + ) + process = start_server_process(server_port, security_settings) + + try: + async with httpx.AsyncClient() as client: + # Test with various valid content types + valid_content_types = [ + "application/json", + "application/json; charset=utf-8", + "application/json;charset=utf-8", + "APPLICATION/JSON", # Case insensitive + ] + + for content_type in valid_content_types: + # Use a valid UUID format (even though session won't exist) + fake_session_id = "12345678123456781234567812345678" + response = await client.post( + f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", + headers={"Content-Type": content_type}, + json={"test": "data"}, + ) + # Will get 404 because session doesn't exist, but that's OK + # We're testing that it passes the content-type check + assert response.status_code == 404 + assert response.text == "Could not find session" + + finally: + process.terminate() + process.join() diff --git a/tests/server/test_stdio.py b/tests/server/test_stdio.py index c546a7167..2d1850b73 100644 --- a/tests/server/test_stdio.py +++ b/tests/server/test_stdio.py @@ -22,9 +22,10 @@ async def test_stdio_server(): stdin.write(message.model_dump_json(by_alias=True, exclude_none=True) + "\n") stdin.seek(0) - async with stdio_server( - stdin=anyio.AsyncFile(stdin), stdout=anyio.AsyncFile(stdout) - ) as (read_stream, write_stream): + async with stdio_server(stdin=anyio.AsyncFile(stdin), stdout=anyio.AsyncFile(stdout)) as ( + read_stream, + write_stream, + ): received_messages = [] async with read_stream: async for message in read_stream: @@ -36,12 +37,8 @@ async def test_stdio_server(): # Verify received messages assert len(received_messages) == 2 - assert received_messages[0] == JSONRPCMessage( - root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") - ) - assert received_messages[1] == JSONRPCMessage( - root=JSONRPCResponse(jsonrpc="2.0", id=2, result={}) - ) + assert received_messages[0] == JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")) + assert received_messages[1] == JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})) # Test sending responses from the server responses = [ @@ -58,13 +55,7 @@ async def test_stdio_server(): output_lines = stdout.readlines() assert len(output_lines) == 2 - received_responses = [ - JSONRPCMessage.model_validate_json(line.strip()) for line in output_lines - ] + received_responses = [JSONRPCMessage.model_validate_json(line.strip()) for line in output_lines] assert len(received_responses) == 2 - assert received_responses[0] == JSONRPCMessage( - root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping") - ) - assert received_responses[1] == JSONRPCMessage( - root=JSONRPCResponse(jsonrpc="2.0", id=4, result={}) - ) + assert received_responses[0] == JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping")) + assert received_responses[1] == JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=4, result={})) diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index 32782e458..65828b63b 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -22,10 +22,7 @@ async def test_run_can_only_be_called_once(): async with manager.run(): pass - assert ( - "StreamableHTTPSessionManager .run() can only be called once per instance" - in str(excinfo.value) - ) + assert "StreamableHTTPSessionManager .run() can only be called once per instance" in str(excinfo.value) @pytest.mark.anyio @@ -51,10 +48,7 @@ async def try_run(): # One should succeed, one should fail assert len(errors) == 1 - assert ( - "StreamableHTTPSessionManager .run() can only be called once per instance" - in str(errors[0]) - ) + assert "StreamableHTTPSessionManager .run() can only be called once per instance" in str(errors[0]) @pytest.mark.anyio @@ -76,6 +70,4 @@ async def send(message): with pytest.raises(RuntimeError) as excinfo: await manager.handle_request(scope, receive, send) - assert "Task group is not initialized. Make sure to use run()." in str( - excinfo.value - ) + assert "Task group is not initialized. Make sure to use run()." in str(excinfo.value) diff --git a/tests/server/test_streamable_http_security.py b/tests/server/test_streamable_http_security.py new file mode 100644 index 000000000..eed791924 --- /dev/null +++ b/tests/server/test_streamable_http_security.py @@ -0,0 +1,293 @@ +"""Tests for StreamableHTTP server DNS rebinding protection.""" + +import logging +import multiprocessing +import socket +import time +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager + +import httpx +import pytest +import uvicorn +from starlette.applications import Starlette +from starlette.routing import Mount +from starlette.types import Receive, Scope, Send + +from mcp.server import Server +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from mcp.server.transport_security import TransportSecuritySettings +from mcp.types import Tool + +logger = logging.getLogger(__name__) +SERVER_NAME = "test_streamable_http_security_server" + + +@pytest.fixture +def server_port() -> int: + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@pytest.fixture +def server_url(server_port: int) -> str: + return f"http://127.0.0.1:{server_port}" + + +class SecurityTestServer(Server): + def __init__(self): + super().__init__(SERVER_NAME) + + async def on_list_tools(self) -> list[Tool]: + return [] + + +def run_server_with_settings(port: int, security_settings: TransportSecuritySettings | None = None): + """Run the StreamableHTTP server with specified security settings.""" + app = SecurityTestServer() + + # Create session manager with security settings + session_manager = StreamableHTTPSessionManager( + app=app, + json_response=False, + stateless=False, + security_settings=security_settings, + ) + + # Create the ASGI handler + async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None: + await session_manager.handle_request(scope, receive, send) + + # Create Starlette app with lifespan + @asynccontextmanager + async def lifespan(app: Starlette) -> AsyncGenerator[None, None]: + async with session_manager.run(): + yield + + routes = [ + Mount("/", app=handle_streamable_http), + ] + + starlette_app = Starlette(routes=routes, lifespan=lifespan) + uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="error") + + +def start_server_process(port: int, security_settings: TransportSecuritySettings | None = None): + """Start server in a separate process.""" + process = multiprocessing.Process(target=run_server_with_settings, args=(port, security_settings)) + process.start() + # Give server time to start + time.sleep(1) + return process + + +@pytest.mark.anyio +async def test_streamable_http_security_default_settings(server_port: int): + """Test StreamableHTTP with default security settings (protection enabled).""" + process = start_server_process(server_port) + + try: + # Test with valid localhost headers + async with httpx.AsyncClient(timeout=5.0) as client: + # POST request to initialize session + response = await client.post( + f"http://127.0.0.1:{server_port}/", + json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + ) + assert response.status_code == 200 + assert "mcp-session-id" in response.headers + + finally: + process.terminate() + process.join() + + +@pytest.mark.anyio +async def test_streamable_http_security_invalid_host_header(server_port: int): + """Test StreamableHTTP with invalid Host header.""" + security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True) + process = start_server_process(server_port, security_settings) + + try: + # Test with invalid host header + headers = { + "Host": "evil.com", + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + } + + async with httpx.AsyncClient(timeout=5.0) as client: + response = await client.post( + f"http://127.0.0.1:{server_port}/", + json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, + headers=headers, + ) + assert response.status_code == 421 + assert response.text == "Invalid Host header" + + finally: + process.terminate() + process.join() + + +@pytest.mark.anyio +async def test_streamable_http_security_invalid_origin_header(server_port: int): + """Test StreamableHTTP with invalid Origin header.""" + security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"]) + process = start_server_process(server_port, security_settings) + + try: + # Test with invalid origin header + headers = { + "Origin": "http://evil.com", + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + } + + async with httpx.AsyncClient(timeout=5.0) as client: + response = await client.post( + f"http://127.0.0.1:{server_port}/", + json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, + headers=headers, + ) + assert response.status_code == 400 + assert response.text == "Invalid Origin header" + + finally: + process.terminate() + process.join() + + +@pytest.mark.anyio +async def test_streamable_http_security_invalid_content_type(server_port: int): + """Test StreamableHTTP POST with invalid Content-Type header.""" + process = start_server_process(server_port) + + try: + async with httpx.AsyncClient(timeout=5.0) as client: + # Test POST with invalid content type + response = await client.post( + f"http://127.0.0.1:{server_port}/", + headers={ + "Content-Type": "text/plain", + "Accept": "application/json, text/event-stream", + }, + content="test", + ) + assert response.status_code == 400 + assert response.text == "Invalid Content-Type header" + + # Test POST with missing content type + response = await client.post( + f"http://127.0.0.1:{server_port}/", + headers={"Accept": "application/json, text/event-stream"}, + content="test", + ) + assert response.status_code == 400 + assert response.text == "Invalid Content-Type header" + + finally: + process.terminate() + process.join() + + +@pytest.mark.anyio +async def test_streamable_http_security_disabled(server_port: int): + """Test StreamableHTTP with security disabled.""" + settings = TransportSecuritySettings(enable_dns_rebinding_protection=False) + process = start_server_process(server_port, settings) + + try: + # Test with invalid host header - should still work + headers = { + "Host": "evil.com", + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + } + + async with httpx.AsyncClient(timeout=5.0) as client: + response = await client.post( + f"http://127.0.0.1:{server_port}/", + json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, + headers=headers, + ) + # Should connect successfully even with invalid host + assert response.status_code == 200 + + finally: + process.terminate() + process.join() + + +@pytest.mark.anyio +async def test_streamable_http_security_custom_allowed_hosts(server_port: int): + """Test StreamableHTTP with custom allowed hosts.""" + settings = TransportSecuritySettings( + enable_dns_rebinding_protection=True, + allowed_hosts=["localhost", "127.0.0.1", "custom.host"], + allowed_origins=["http://localhost", "http://127.0.0.1", "http://custom.host"], + ) + process = start_server_process(server_port, settings) + + try: + # Test with custom allowed host + headers = { + "Host": "custom.host", + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + } + + async with httpx.AsyncClient(timeout=5.0) as client: + response = await client.post( + f"http://127.0.0.1:{server_port}/", + json={"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}}, + headers=headers, + ) + # Should connect successfully with custom host + assert response.status_code == 200 + finally: + process.terminate() + process.join() + + +@pytest.mark.anyio +async def test_streamable_http_security_get_request(server_port: int): + """Test StreamableHTTP GET request with security.""" + security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1"]) + process = start_server_process(server_port, security_settings) + + try: + # Test GET request with invalid host header + headers = { + "Host": "evil.com", + "Accept": "text/event-stream", + } + + async with httpx.AsyncClient(timeout=5.0) as client: + response = await client.get(f"http://127.0.0.1:{server_port}/", headers=headers) + assert response.status_code == 421 + assert response.text == "Invalid Host header" + + # Test GET request with valid host header + headers = { + "Host": "127.0.0.1", + "Accept": "text/event-stream", + } + + async with httpx.AsyncClient(timeout=5.0) as client: + # GET requests need a session ID in StreamableHTTP + # So it will fail with "Missing session ID" not security error + response = await client.get(f"http://127.0.0.1:{server_port}/", headers=headers) + # This should pass security but fail on session validation + assert response.status_code == 400 + body = response.json() + assert "Missing session ID" in body["error"]["message"] + + finally: + process.terminate() + process.join() diff --git a/tests/shared/test_progress_notifications.py b/tests/shared/test_progress_notifications.py index 1e0409e14..08bcb2662 100644 --- a/tests/shared/test_progress_notifications.py +++ b/tests/shared/test_progress_notifications.py @@ -22,12 +22,8 @@ async def test_bidirectional_progress_notifications(): """Test that both client and server can send progress notifications.""" # Create memory streams for client/server - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - SessionMessage - ](5) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - SessionMessage - ](5) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](5) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](5) # Run a server session so we can send progress updates in tool async def run_server(): @@ -134,9 +130,7 @@ async def handle_call_tool(name: str, arguments: dict | None) -> list: # Client message handler to store progress notifications async def handle_client_message( - message: RequestResponder[types.ServerRequest, types.ClientResult] - | types.ServerNotification - | Exception, + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: if isinstance(message, Exception): raise message @@ -172,9 +166,7 @@ async def handle_client_message( await client_session.list_tools() # Call test_tool with progress token - await client_session.call_tool( - "test_tool", {"_meta": {"progressToken": client_progress_token}} - ) + await client_session.call_tool("test_tool", {"_meta": {"progressToken": client_progress_token}}) # Send progress notifications from client to server await client_session.send_progress_notification( @@ -221,12 +213,8 @@ async def handle_client_message( async def test_progress_context_manager(): """Test client using progress context manager for sending progress notifications.""" # Create memory streams for client/server - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - SessionMessage - ](5) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - SessionMessage - ](5) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](5) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](5) # Track progress updates server_progress_updates = [] @@ -270,9 +258,7 @@ async def run_server(): # Client message handler async def handle_client_message( - message: RequestResponder[types.ServerRequest, types.ClientResult] - | types.ServerNotification - | Exception, + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: if isinstance(message, Exception): raise message diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index eb4e004ae..864e0d1b4 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -90,9 +90,7 @@ async def make_request(client_session): ClientRequest( types.CallToolRequest( method="tools/call", - params=types.CallToolRequestParams( - name="slow_tool", arguments={} - ), + params=types.CallToolRequestParams(name="slow_tool", arguments={}), ) ), types.CallToolResult, @@ -103,9 +101,7 @@ async def make_request(client_session): assert "Request cancelled" in str(e) ev_cancelled.set() - async with create_connected_server_and_client_session( - make_server() - ) as client_session: + async with create_connected_server_and_client_session(make_server()) as client_session: async with anyio.create_task_group() as tg: tg.start_soon(make_request, client_session) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 78bbbb235..8e1912e9b 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -8,16 +8,19 @@ import httpx import pytest import uvicorn +from inline_snapshot import snapshot from pydantic import AnyUrl from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import Response from starlette.routing import Mount, Route +import mcp.types as types from mcp.client.session import ClientSession from mcp.client.sse import sse_client from mcp.server import Server from mcp.server.sse import SseServerTransport +from mcp.server.transport_security import TransportSecuritySettings from mcp.shared.exceptions import McpError from mcp.types import ( EmptyResult, @@ -58,11 +61,7 @@ async def handle_read_resource(uri: AnyUrl) -> str | bytes: await anyio.sleep(2.0) return f"Slow response from {uri.host}" - raise McpError( - error=ErrorData( - code=404, message="OOPS! no resource with that URI was found" - ) - ) + raise McpError(error=ErrorData(code=404, message="OOPS! no resource with that URI was found")) @self.list_tools() async def handle_list_tools() -> list[Tool]: @@ -82,16 +81,16 @@ async def handle_call_tool(name: str, args: dict) -> list[TextContent]: # Test fixtures def make_server_app() -> Starlette: """Create test Starlette app with SSE transport""" - sse = SseServerTransport("/messages/") + # Configure security with allowed hosts/origins for testing + security_settings = TransportSecuritySettings( + allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] + ) + sse = SseServerTransport("/messages/", security_settings=security_settings) server = ServerTest() async def handle_sse(request: Request) -> Response: - async with sse.connect_sse( - request.scope, request.receive, request._send - ) as streams: - await server.run( - streams[0], streams[1], server.create_initialization_options() - ) + async with sse.connect_sse(request.scope, request.receive, request._send) as streams: + await server.run(streams[0], streams[1], server.create_initialization_options()) return Response() app = Starlette( @@ -106,11 +105,7 @@ async def handle_sse(request: Request) -> Response: def run_server(server_port: int) -> None: app = make_server_app() - server = uvicorn.Server( - config=uvicorn.Config( - app=app, host="127.0.0.1", port=server_port, log_level="error" - ) - ) + server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) print(f"starting server on {server_port}") server.run() @@ -122,9 +117,7 @@ def run_server(server_port: int) -> None: @pytest.fixture() def server(server_port: int) -> Generator[None, None, None]: - proc = multiprocessing.Process( - target=run_server, kwargs={"server_port": server_port}, daemon=True - ) + proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True) print("starting process") proc.start() @@ -169,10 +162,7 @@ async def test_raw_sse_connection(http_client: httpx.AsyncClient) -> None: async def connection_test() -> None: async with http_client.stream("GET", "/sse") as response: assert response.status_code == 200 - assert ( - response.headers["content-type"] - == "text/event-stream; charset=utf-8" - ) + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" line_number = 0 async for line in response.aiter_lines(): @@ -204,9 +194,7 @@ async def test_sse_client_basic_connection(server: None, server_url: str) -> Non @pytest.fixture -async def initialized_sse_client_session( - server, server_url: str -) -> AsyncGenerator[ClientSession, None]: +async def initialized_sse_client_session(server, server_url: str) -> AsyncGenerator[ClientSession, None]: async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams: async with ClientSession(*streams) as session: await session.initialize() @@ -234,9 +222,7 @@ async def test_sse_client_exception_handling( @pytest.mark.anyio -@pytest.mark.skip( - "this test highlights a possible bug in SSE read timeout exception handling" -) +@pytest.mark.skip("this test highlights a possible bug in SSE read timeout exception handling") async def test_sse_client_timeout( initialized_sse_client_session: ClientSession, ) -> None: @@ -258,11 +244,7 @@ async def test_sse_client_timeout( def run_mounted_server(server_port: int) -> None: app = make_server_app() main_app = Starlette(routes=[Mount("/mounted_app", app=app)]) - server = uvicorn.Server( - config=uvicorn.Config( - app=main_app, host="127.0.0.1", port=server_port, log_level="error" - ) - ) + server = uvicorn.Server(config=uvicorn.Config(app=main_app, host="127.0.0.1", port=server_port, log_level="error")) print(f"starting server on {server_port}") server.run() @@ -274,9 +256,7 @@ def run_mounted_server(server_port: int) -> None: @pytest.fixture() def mounted_server(server_port: int) -> Generator[None, None, None]: - proc = multiprocessing.Process( - target=run_mounted_server, kwargs={"server_port": server_port}, daemon=True - ) + proc = multiprocessing.Process(target=run_mounted_server, kwargs={"server_port": server_port}, daemon=True) print("starting process") proc.start() @@ -306,9 +286,7 @@ def mounted_server(server_port: int) -> Generator[None, None, None]: @pytest.mark.anyio -async def test_sse_client_basic_connection_mounted_app( - mounted_server: None, server_url: str -) -> None: +async def test_sse_client_basic_connection_mounted_app(mounted_server: None, server_url: str) -> None: async with sse_client(server_url + "/mounted_app/sse") as streams: async with ClientSession(*streams) as session: # Test initialization @@ -366,16 +344,16 @@ async def handle_list_tools() -> list[Tool]: def run_context_server(server_port: int) -> None: """Run a server that captures request context""" - sse = SseServerTransport("/messages/") + # Configure security with allowed hosts/origins for testing + security_settings = TransportSecuritySettings( + allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] + ) + sse = SseServerTransport("/messages/", security_settings=security_settings) context_server = RequestContextServer() async def handle_sse(request: Request) -> Response: - async with sse.connect_sse( - request.scope, request.receive, request._send - ) as streams: - await context_server.run( - streams[0], streams[1], context_server.create_initialization_options() - ) + async with sse.connect_sse(request.scope, request.receive, request._send) as streams: + await context_server.run(streams[0], streams[1], context_server.create_initialization_options()) return Response() app = Starlette( @@ -385,11 +363,7 @@ async def handle_sse(request: Request) -> Response: ] ) - server = uvicorn.Server( - config=uvicorn.Config( - app=app, host="127.0.0.1", port=server_port, log_level="error" - ) - ) + server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) print(f"starting context server on {server_port}") server.run() @@ -397,9 +371,7 @@ async def handle_sse(request: Request) -> Response: @pytest.fixture() def context_server(server_port: int) -> Generator[None, None, None]: """Fixture that provides a server with request context capture""" - proc = multiprocessing.Process( - target=run_context_server, kwargs={"server_port": server_port}, daemon=True - ) + proc = multiprocessing.Process(target=run_context_server, kwargs={"server_port": server_port}, daemon=True) print("starting context server process") proc.start() @@ -416,9 +388,7 @@ def context_server(server_port: int) -> Generator[None, None, None]: time.sleep(0.1) attempt += 1 else: - raise RuntimeError( - f"Context server failed to start after {max_attempts} attempts" - ) + raise RuntimeError(f"Context server failed to start after {max_attempts} attempts") yield @@ -430,9 +400,7 @@ def context_server(server_port: int) -> Generator[None, None, None]: @pytest.mark.anyio -async def test_request_context_propagation( - context_server: None, server_url: str -) -> None: +async def test_request_context_propagation(context_server: None, server_url: str) -> None: """Test that request context is properly propagated through SSE transport.""" # Test with custom headers custom_headers = { @@ -456,11 +424,7 @@ async def test_request_context_propagation( # Parse the JSON response assert len(tool_result.content) == 1 - headers_data = json.loads( - tool_result.content[0].text - if tool_result.content[0].type == "text" - else "{}" - ) + headers_data = json.loads(tool_result.content[0].text if tool_result.content[0].type == "text" else "{}") # Verify headers were propagated assert headers_data.get("authorization") == "Bearer test-token" @@ -485,15 +449,11 @@ async def test_request_context_isolation(context_server: None, server_url: str) await session.initialize() # Call the tool that echoes context - tool_result = await session.call_tool( - "echo_context", {"request_id": f"request-{i}"} - ) + tool_result = await session.call_tool("echo_context", {"request_id": f"request-{i}"}) assert len(tool_result.content) == 1 context_data = json.loads( - tool_result.content[0].text - if tool_result.content[0].type == "text" - else "{}" + tool_result.content[0].text if tool_result.content[0].type == "text" else "{}" ) contexts.append(context_data) @@ -503,3 +463,13 @@ async def test_request_context_isolation(context_server: None, server_url: str) assert ctx["request_id"] == f"request-{i}" assert ctx["headers"].get("x-request-id") == f"request-{i}" assert ctx["headers"].get("x-custom-value") == f"value-{i}" + + +def test_sse_message_id_coercion(): + """Test that string message IDs that look like integers are parsed as integers. + + See for more details. + """ + json_message = '{"jsonrpc": "2.0", "id": "123", "method": "ping", "params": null}' + msg = types.JSONRPCMessage.model_validate_json(json_message) + assert msg == snapshot(types.JSONRPCMessage(root=types.JSONRPCRequest(method="ping", jsonrpc="2.0", id=123))) diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 5cf346e1a..1ffcc13b0 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -26,6 +26,7 @@ from mcp.client.streamable_http import streamablehttp_client from mcp.server import Server from mcp.server.streamable_http import ( + MCP_PROTOCOL_VERSION_HEADER, MCP_SESSION_ID_HEADER, SESSION_ID_PATTERN, EventCallback, @@ -36,6 +37,7 @@ StreamId, ) from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from mcp.server.transport_security import TransportSecuritySettings from mcp.shared.context import RequestContext from mcp.shared.exceptions import McpError from mcp.shared.message import ( @@ -64,6 +66,17 @@ } +# Helper functions +def extract_protocol_version_from_sse(response: requests.Response) -> str: + """Extract the negotiated protocol version from an SSE initialization response.""" + assert response.headers.get("Content-Type") == "text/event-stream" + for line in response.text.splitlines(): + if line.startswith("data: "): + init_data = json.loads(line[6:]) + return init_data["result"]["protocolVersion"] + raise ValueError("Could not extract protocol version from SSE response") + + # Simple in-memory event store for testing class SimpleEventStore(EventStore): """Simple in-memory event store for testing.""" @@ -72,9 +85,7 @@ def __init__(self): self._events: list[tuple[StreamId, EventId, types.JSONRPCMessage]] = [] self._event_id_counter = 0 - async def store_event( - self, stream_id: StreamId, message: types.JSONRPCMessage - ) -> EventId: + async def store_event(self, stream_id: StreamId, message: types.JSONRPCMessage) -> EventId: """Store an event and return its ID.""" self._event_id_counter += 1 event_id = str(self._event_id_counter) @@ -156,9 +167,7 @@ async def handle_call_tool(name: str, args: dict) -> list[TextContent]: # When the tool is called, send a notification to test GET stream if name == "test_tool_with_standalone_notification": - await ctx.session.send_resource_updated( - uri=AnyUrl("http://test_resource") - ) + await ctx.session.send_resource_updated(uri=AnyUrl("http://test_resource")) return [TextContent(type="text", text=f"Called {name}")] elif name == "long_running_with_checkpoints": @@ -189,9 +198,7 @@ async def handle_call_tool(name: str, args: dict) -> list[TextContent]: messages=[ types.SamplingMessage( role="user", - content=types.TextContent( - type="text", text="Server needs client sampling" - ), + content=types.TextContent(type="text", text="Server needs client sampling"), ) ], max_tokens=100, @@ -199,11 +206,7 @@ async def handle_call_tool(name: str, args: dict) -> list[TextContent]: ) # Return the sampling result in the tool response - response = ( - sampling_result.content.text - if sampling_result.content.type == "text" - else None - ) + response = sampling_result.content.text if sampling_result.content.type == "text" else None return [ TextContent( type="text", @@ -214,9 +217,7 @@ async def handle_call_tool(name: str, args: dict) -> list[TextContent]: return [TextContent(type="text", text=f"Called {name}")] -def create_app( - is_json_response_enabled=False, event_store: EventStore | None = None -) -> Starlette: +def create_app(is_json_response_enabled=False, event_store: EventStore | None = None) -> Starlette: """Create a Starlette application for testing using the session manager. Args: @@ -227,10 +228,14 @@ def create_app( server = ServerTest() # Create the session manager + security_settings = TransportSecuritySettings( + allowed_hosts=["127.0.0.1:*", "localhost:*"], allowed_origins=["http://127.0.0.1:*", "http://localhost:*"] + ) session_manager = StreamableHTTPSessionManager( app=server, event_store=event_store, json_response=is_json_response_enabled, + security_settings=security_settings, ) # Create an ASGI application that uses the session manager @@ -245,9 +250,7 @@ def create_app( return app -def run_server( - port: int, is_json_response_enabled=False, event_store: EventStore | None = None -) -> None: +def run_server(port: int, is_json_response_enabled=False, event_store: EventStore | None = None) -> None: """Run the test server. Args: @@ -300,9 +303,7 @@ def json_server_port() -> int: @pytest.fixture def basic_server(basic_server_port: int) -> Generator[None, None, None]: """Start a basic server.""" - proc = multiprocessing.Process( - target=run_server, kwargs={"port": basic_server_port}, daemon=True - ) + proc = multiprocessing.Process(target=run_server, kwargs={"port": basic_server_port}, daemon=True) proc.start() # Wait for server to be running @@ -440,8 +441,9 @@ def test_content_type_validation(basic_server, basic_server_url): }, data="This is not JSON", ) - assert response.status_code == 415 - assert "Unsupported Media Type" in response.text + + assert response.status_code == 400 + assert "Invalid Content-Type" in response.text def test_json_validation(basic_server, basic_server_url): @@ -576,11 +578,17 @@ def test_session_termination(basic_server, basic_server_url): ) assert response.status_code == 200 + # Extract negotiated protocol version from SSE response + negotiated_version = extract_protocol_version_from_sse(response) + # Now terminate the session session_id = response.headers.get(MCP_SESSION_ID_HEADER) response = requests.delete( f"{basic_server_url}/mcp", - headers={MCP_SESSION_ID_HEADER: session_id}, + headers={ + MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: negotiated_version, + }, ) assert response.status_code == 200 @@ -611,16 +619,20 @@ def test_response(basic_server, basic_server_url): ) assert response.status_code == 200 - # Now terminate the session + # Extract negotiated protocol version from SSE response + negotiated_version = extract_protocol_version_from_sse(response) + + # Now get the session ID session_id = response.headers.get(MCP_SESSION_ID_HEADER) - # Try to use the terminated session + # Try to use the session with proper headers tools_response = requests.post( mcp_url, headers={ "Accept": "application/json, text/event-stream", "Content-Type": "application/json", MCP_SESSION_ID_HEADER: session_id, # Use the session ID we got earlier + MCP_PROTOCOL_VERSION_HEADER: negotiated_version, }, json={"jsonrpc": "2.0", "method": "tools/list", "id": "tools-1"}, stream=True, @@ -662,12 +674,23 @@ def test_get_sse_stream(basic_server, basic_server_url): session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) assert session_id is not None + # Extract negotiated protocol version from SSE response + init_data = None + assert init_response.headers.get("Content-Type") == "text/event-stream" + for line in init_response.text.splitlines(): + if line.startswith("data: "): + init_data = json.loads(line[6:]) + break + assert init_data is not None + negotiated_version = init_data["result"]["protocolVersion"] + # Now attempt to establish an SSE stream via GET get_response = requests.get( mcp_url, headers={ "Accept": "text/event-stream", MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: negotiated_version, }, stream=True, ) @@ -682,6 +705,7 @@ def test_get_sse_stream(basic_server, basic_server_url): headers={ "Accept": "text/event-stream", MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: negotiated_version, }, stream=True, ) @@ -710,11 +734,22 @@ def test_get_validation(basic_server, basic_server_url): session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) assert session_id is not None + # Extract negotiated protocol version from SSE response + init_data = None + assert init_response.headers.get("Content-Type") == "text/event-stream" + for line in init_response.text.splitlines(): + if line.startswith("data: "): + init_data = json.loads(line[6:]) + break + assert init_data is not None + negotiated_version = init_data["result"]["protocolVersion"] + # Test without Accept header response = requests.get( mcp_url, headers={ MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: negotiated_version, }, stream=True, ) @@ -727,6 +762,7 @@ def test_get_validation(basic_server, basic_server_url): headers={ "Accept": "application/json", MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: negotiated_version, }, ) assert response.status_code == 406 @@ -778,9 +814,7 @@ async def test_streamablehttp_client_basic_connection(basic_server, basic_server @pytest.mark.anyio async def test_streamablehttp_client_resource_read(initialized_client_session): """Test client resource read functionality.""" - response = await initialized_client_session.read_resource( - uri=AnyUrl("foobar://test-resource") - ) + response = await initialized_client_session.read_resource(uri=AnyUrl("foobar://test-resource")) assert len(response.contents) == 1 assert response.contents[0].uri == AnyUrl("foobar://test-resource") assert response.contents[0].text == "Read test-resource" @@ -805,17 +839,13 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session) async def test_streamablehttp_client_error_handling(initialized_client_session): """Test error handling in client.""" with pytest.raises(McpError) as exc_info: - await initialized_client_session.read_resource( - uri=AnyUrl("unknown://test-error") - ) + await initialized_client_session.read_resource(uri=AnyUrl("unknown://test-error")) assert exc_info.value.error.code == 0 assert "Unknown resource: unknown://test-error" in exc_info.value.error.message @pytest.mark.anyio -async def test_streamablehttp_client_session_persistence( - basic_server, basic_server_url -): +async def test_streamablehttp_client_session_persistence(basic_server, basic_server_url): """Test that session ID persists across requests.""" async with streamablehttp_client(f"{basic_server_url}/mcp") as ( read_stream, @@ -843,9 +873,7 @@ async def test_streamablehttp_client_session_persistence( @pytest.mark.anyio -async def test_streamablehttp_client_json_response( - json_response_server, json_server_url -): +async def test_streamablehttp_client_json_response(json_response_server, json_server_url): """Test client with JSON response mode.""" async with streamablehttp_client(f"{json_server_url}/mcp") as ( read_stream, @@ -882,9 +910,7 @@ async def test_streamablehttp_client_get_stream(basic_server, basic_server_url): # Define message handler to capture notifications async def message_handler( - message: RequestResponder[types.ServerRequest, types.ClientResult] - | types.ServerNotification - | Exception, + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: if isinstance(message, types.ServerNotification): notifications_received.append(message) @@ -894,9 +920,7 @@ async def message_handler( write_stream, _, ): - async with ClientSession( - read_stream, write_stream, message_handler=message_handler - ) as session: + async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: # Initialize the session - this triggers the GET stream setup result = await session.initialize() assert isinstance(result, InitializeResult) @@ -914,15 +938,11 @@ async def message_handler( assert str(notif.root.params.uri) == "http://test_resource/" resource_update_found = True - assert ( - resource_update_found - ), "ResourceUpdatedNotification not received via GET stream" + assert resource_update_found, "ResourceUpdatedNotification not received via GET stream" @pytest.mark.anyio -async def test_streamablehttp_client_session_termination( - basic_server, basic_server_url -): +async def test_streamablehttp_client_session_termination(basic_server, basic_server_url): """Test client session termination functionality.""" captured_session_id = None @@ -963,9 +983,7 @@ async def test_streamablehttp_client_session_termination( @pytest.mark.anyio -async def test_streamablehttp_client_session_termination_204( - basic_server, basic_server_url, monkeypatch -): +async def test_streamablehttp_client_session_termination_204(basic_server, basic_server_url, monkeypatch): """Test client session termination functionality with a 204 response. This test patches the httpx client to return a 204 response for DELETEs. @@ -1038,11 +1056,10 @@ async def test_streamablehttp_client_resumption(event_server): captured_resumption_token = None captured_notifications = [] tool_started = False + captured_protocol_version = None async def message_handler( - message: RequestResponder[types.ServerRequest, types.ClientResult] - | types.ServerNotification - | Exception, + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: if isinstance(message, types.ServerNotification): captured_notifications.append(message) @@ -1062,14 +1079,14 @@ async def on_resumption_token_update(token: str) -> None: write_stream, get_session_id, ): - async with ClientSession( - read_stream, write_stream, message_handler=message_handler - ) as session: + async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: # Initialize the session result = await session.initialize() assert isinstance(result, InitializeResult) captured_session_id = get_session_id() assert captured_session_id is not None + # Capture the negotiated protocol version + captured_protocol_version = result.protocolVersion # Start a long-running tool in a task async with anyio.create_task_group() as tg: @@ -1082,9 +1099,7 @@ async def run_tool(): types.ClientRequest( types.CallToolRequest( method="tools/call", - params=types.CallToolRequestParams( - name="long_running_with_checkpoints", arguments={} - ), + params=types.CallToolRequestParams(name="long_running_with_checkpoints", arguments={}), ) ), types.CallToolResult, @@ -1104,19 +1119,19 @@ async def run_tool(): captured_notifications_pre = captured_notifications.copy() captured_notifications = [] - # Now resume the session with the same mcp-session-id + # Now resume the session with the same mcp-session-id and protocol version headers = {} if captured_session_id: headers[MCP_SESSION_ID_HEADER] = captured_session_id + if captured_protocol_version: + headers[MCP_PROTOCOL_VERSION_HEADER] = captured_protocol_version async with streamablehttp_client(f"{server_url}/mcp", headers=headers) as ( read_stream, write_stream, _, ): - async with ClientSession( - read_stream, write_stream, message_handler=message_handler - ) as session: + async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: # Don't initialize - just use the existing session # Resume the tool with the resumption token @@ -1129,9 +1144,7 @@ async def run_tool(): types.ClientRequest( types.CallToolRequest( method="tools/call", - params=types.CallToolRequestParams( - name="long_running_with_checkpoints", arguments={} - ), + params=types.CallToolRequestParams(name="long_running_with_checkpoints", arguments={}), ) ), types.CallToolResult, @@ -1149,14 +1162,11 @@ async def run_tool(): # Should not have the first notification # Check that "Tool started" notification isn't repeated when resuming assert not any( - isinstance(n.root, types.LoggingMessageNotification) - and n.root.params.data == "Tool started" + isinstance(n.root, types.LoggingMessageNotification) and n.root.params.data == "Tool started" for n in captured_notifications ) # there is no intersection between pre and post notifications - assert not any( - n in captured_notifications_pre for n in captured_notifications - ) + assert not any(n in captured_notifications_pre for n in captured_notifications) @pytest.mark.anyio @@ -1175,11 +1185,7 @@ async def sampling_callback( nonlocal sampling_callback_invoked, captured_message_params sampling_callback_invoked = True captured_message_params = params - message_received = ( - params.messages[0].content.text - if params.messages[0].content.type == "text" - else None - ) + message_received = params.messages[0].content.text if params.messages[0].content.type == "text" else None return types.CreateMessageResult( role="assistant", @@ -1212,19 +1218,13 @@ async def sampling_callback( # Verify the tool result contains the expected content assert len(tool_result.content) == 1 assert tool_result.content[0].type == "text" - assert ( - "Response from sampling: Received message from server" - in tool_result.content[0].text - ) + assert "Response from sampling: Received message from server" in tool_result.content[0].text # Verify sampling callback was invoked assert sampling_callback_invoked assert captured_message_params is not None assert len(captured_message_params.messages) == 1 - assert ( - captured_message_params.messages[0].content.text - == "Server needs client sampling" - ) + assert captured_message_params.messages[0].content.text == "Server needs client sampling" # Context-aware server implementation for testing request context propagation @@ -1325,9 +1325,7 @@ def run_context_aware_server(port: int): @pytest.fixture def context_aware_server(basic_server_port: int) -> Generator[None, None, None]: """Start the context-aware server in a separate process.""" - proc = multiprocessing.Process( - target=run_context_aware_server, args=(basic_server_port,), daemon=True - ) + proc = multiprocessing.Process(target=run_context_aware_server, args=(basic_server_port,), daemon=True) proc.start() # Wait for server to be running @@ -1342,9 +1340,7 @@ def context_aware_server(basic_server_port: int) -> Generator[None, None, None]: time.sleep(0.1) attempt += 1 else: - raise RuntimeError( - f"Context-aware server failed to start after {max_attempts} attempts" - ) + raise RuntimeError(f"Context-aware server failed to start after {max_attempts} attempts") yield @@ -1355,9 +1351,7 @@ def context_aware_server(basic_server_port: int) -> Generator[None, None, None]: @pytest.mark.anyio -async def test_streamablehttp_request_context_propagation( - context_aware_server: None, basic_server_url: str -) -> None: +async def test_streamablehttp_request_context_propagation(context_aware_server: None, basic_server_url: str) -> None: """Test that request context is properly propagated through StreamableHTTP.""" custom_headers = { "Authorization": "Bearer test-token", @@ -1365,9 +1359,11 @@ async def test_streamablehttp_request_context_propagation( "X-Trace-Id": "trace-123", } - async with streamablehttp_client( - f"{basic_server_url}/mcp", headers=custom_headers - ) as (read_stream, write_stream, _): + async with streamablehttp_client(f"{basic_server_url}/mcp", headers=custom_headers) as ( + read_stream, + write_stream, + _, + ): async with ClientSession(read_stream, write_stream) as session: result = await session.initialize() assert isinstance(result, InitializeResult) @@ -1388,9 +1384,7 @@ async def test_streamablehttp_request_context_propagation( @pytest.mark.anyio -async def test_streamablehttp_request_context_isolation( - context_aware_server: None, basic_server_url: str -) -> None: +async def test_streamablehttp_request_context_isolation(context_aware_server: None, basic_server_url: str) -> None: """Test that request contexts are isolated between StreamableHTTP clients.""" contexts = [] @@ -1402,16 +1396,12 @@ async def test_streamablehttp_request_context_isolation( "Authorization": f"Bearer token-{i}", } - async with streamablehttp_client( - f"{basic_server_url}/mcp", headers=headers - ) as (read_stream, write_stream, _): + async with streamablehttp_client(f"{basic_server_url}/mcp", headers=headers) as (read_stream, write_stream, _): async with ClientSession(read_stream, write_stream) as session: await session.initialize() # Call the tool that echoes context - tool_result = await session.call_tool( - "echo_context", {"request_id": f"request-{i}"} - ) + tool_result = await session.call_tool("echo_context", {"request_id": f"request-{i}"}) assert len(tool_result.content) == 1 assert isinstance(tool_result.content[0], TextContent) @@ -1425,3 +1415,152 @@ async def test_streamablehttp_request_context_isolation( assert ctx["headers"].get("x-request-id") == f"request-{i}" assert ctx["headers"].get("x-custom-value") == f"value-{i}" assert ctx["headers"].get("authorization") == f"Bearer token-{i}" + + +@pytest.mark.anyio +async def test_client_includes_protocol_version_header_after_init(context_aware_server, basic_server_url): + """Test that client includes mcp-protocol-version header after initialization.""" + async with streamablehttp_client(f"{basic_server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: + # Initialize and get the negotiated version + init_result = await session.initialize() + negotiated_version = init_result.protocolVersion + + # Call a tool that echoes headers to verify the header is present + tool_result = await session.call_tool("echo_headers", {}) + + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + headers_data = json.loads(tool_result.content[0].text) + + # Verify protocol version header is present + assert "mcp-protocol-version" in headers_data + assert headers_data[MCP_PROTOCOL_VERSION_HEADER] == negotiated_version + + +def test_server_validates_protocol_version_header(basic_server, basic_server_url): + """Test that server returns 400 Bad Request version if header unsupported or invalid.""" + # First initialize a session to get a valid session ID + init_response = requests.post( + f"{basic_server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert init_response.status_code == 200 + session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) + + # Test request with invalid protocol version (should fail) + response = requests.post( + f"{basic_server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: "invalid-version", + }, + json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-2"}, + ) + assert response.status_code == 400 + assert MCP_PROTOCOL_VERSION_HEADER in response.text or "protocol version" in response.text.lower() + + # Test request with unsupported protocol version (should fail) + response = requests.post( + f"{basic_server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: "1999-01-01", # Very old unsupported version + }, + json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-3"}, + ) + assert response.status_code == 400 + assert MCP_PROTOCOL_VERSION_HEADER in response.text or "protocol version" in response.text.lower() + + # Test request with valid protocol version (should succeed) + negotiated_version = extract_protocol_version_from_sse(init_response) + + response = requests.post( + f"{basic_server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, + MCP_PROTOCOL_VERSION_HEADER: negotiated_version, + }, + json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-4"}, + ) + assert response.status_code == 200 + + +def test_server_backwards_compatibility_no_protocol_version(basic_server, basic_server_url): + """Test server accepts requests without protocol version header.""" + # First initialize a session to get a valid session ID + init_response = requests.post( + f"{basic_server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json=INIT_REQUEST, + ) + assert init_response.status_code == 200 + session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) + + # Test request without mcp-protocol-version header (backwards compatibility) + response = requests.post( + f"{basic_server_url}/mcp", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + MCP_SESSION_ID_HEADER: session_id, + }, + json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-backwards-compat"}, + stream=True, + ) + assert response.status_code == 200 # Should succeed for backwards compatibility + assert response.headers.get("Content-Type") == "text/event-stream" + + +@pytest.mark.anyio +async def test_client_crash_handled(basic_server, basic_server_url): + """Test that cases where the client crashes are handled gracefully.""" + + # Simulate bad client that crashes after init + async def bad_client(): + """Client that triggers ClosedResourceError""" + async with streamablehttp_client(f"{basic_server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + raise Exception("client crash") + + # Run bad client a few times to trigger the crash + for _ in range(3): + try: + await bad_client() + except Exception: + pass + await anyio.sleep(0.1) + + # Try a good client, it should still be able to connect and list tools + async with streamablehttp_client(f"{basic_server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: + result = await session.initialize() + assert isinstance(result, InitializeResult) + tools = await session.list_tools() + assert tools.tools diff --git a/tests/shared/test_ws.py b/tests/shared/test_ws.py index 1381c8153..5081f1d53 100644 --- a/tests/shared/test_ws.py +++ b/tests/shared/test_ws.py @@ -54,11 +54,7 @@ async def handle_read_resource(uri: AnyUrl) -> str | bytes: await anyio.sleep(2.0) return f"Slow response from {uri.host}" - raise McpError( - error=ErrorData( - code=404, message="OOPS! no resource with that URI was found" - ) - ) + raise McpError(error=ErrorData(code=404, message="OOPS! no resource with that URI was found")) @self.list_tools() async def handle_list_tools() -> list[Tool]: @@ -81,12 +77,8 @@ def make_server_app() -> Starlette: server = ServerTest() async def handle_ws(websocket): - async with websocket_server( - websocket.scope, websocket.receive, websocket.send - ) as streams: - await server.run( - streams[0], streams[1], server.create_initialization_options() - ) + async with websocket_server(websocket.scope, websocket.receive, websocket.send) as streams: + await server.run(streams[0], streams[1], server.create_initialization_options()) app = Starlette( routes=[ @@ -99,11 +91,7 @@ async def handle_ws(websocket): def run_server(server_port: int) -> None: app = make_server_app() - server = uvicorn.Server( - config=uvicorn.Config( - app=app, host="127.0.0.1", port=server_port, log_level="error" - ) - ) + server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) print(f"starting server on {server_port}") server.run() @@ -115,9 +103,7 @@ def run_server(server_port: int) -> None: @pytest.fixture() def server(server_port: int) -> Generator[None, None, None]: - proc = multiprocessing.Process( - target=run_server, kwargs={"server_port": server_port}, daemon=True - ) + proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True) print("starting process") proc.start() @@ -147,9 +133,7 @@ def server(server_port: int) -> Generator[None, None, None]: @pytest.fixture() -async def initialized_ws_client_session( - server, server_url: str -) -> AsyncGenerator[ClientSession, None]: +async def initialized_ws_client_session(server, server_url: str) -> AsyncGenerator[ClientSession, None]: """Create and initialize a WebSocket client session""" async with websocket_client(server_url + "/ws") as streams: async with ClientSession(*streams) as session: @@ -186,9 +170,7 @@ async def test_ws_client_happy_request_and_response( initialized_ws_client_session: ClientSession, ) -> None: """Test a successful request and response via WebSocket""" - result = await initialized_ws_client_session.read_resource( - AnyUrl("foobar://example") - ) + result = await initialized_ws_client_session.read_resource(AnyUrl("foobar://example")) assert isinstance(result, ReadResourceResult) assert isinstance(result.contents, list) assert len(result.contents) > 0 @@ -218,9 +200,7 @@ async def test_ws_client_timeout( # Now test that we can still use the session after a timeout with anyio.fail_after(5): # Longer timeout to allow completion - result = await initialized_ws_client_session.read_resource( - AnyUrl("foobar://example") - ) + result = await initialized_ws_client_session.read_resource(AnyUrl("foobar://example")) assert isinstance(result, ReadResourceResult) assert isinstance(result.contents, list) assert len(result.contents) > 0 diff --git a/tests/test_examples.py b/tests/test_examples.py index b2fff1a91..230e7d394 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -31,9 +31,7 @@ async def test_complex_inputs(): async with client_session(mcp._mcp_server) as client: tank = {"shrimp": [{"name": "bob"}, {"name": "alice"}]} - result = await client.call_tool( - "name_shrimp", {"tank": tank, "extra_names": ["charlie"]} - ) + result = await client.call_tool("name_shrimp", {"tank": tank, "extra_names": ["charlie"]}) assert len(result.content) == 3 assert isinstance(result.content[0], TextContent) assert isinstance(result.content[1], TextContent) @@ -86,9 +84,7 @@ async def test_desktop(monkeypatch): def test_docs_examples(example: CodeExample, eval_example: EvalExample): ruff_ignore: list[str] = ["F841", "I001"] - eval_example.set_config( - ruff_ignore=ruff_ignore, target_version="py310", line_length=88 - ) + eval_example.set_config(ruff_ignore=ruff_ignore, target_version="py310", line_length=88) if eval_example.update_examples: # pragma: no cover eval_example.format(example)