8000 fix: address @ochafik comments · modelcontextprotocol/python-sdk@8db5cb3 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8db5cb3

Browse files
fix: address @ochafik comments
- update test comment - log warning if no protocol version is set - nits
1 parent ad5f887 commit 8db5cb3

File tree

4 files changed

+16
-30
lines changed

4 files changed

+16
-30
lines changed

src/mcp/client/streamable_http.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,8 @@ def _maybe_extract_protocol_version_from_message(
146146
if "protocolVersion" in result:
147147
self.protocol_version = result["protocolVersion"]
148148
logger.info(f"Negotiated protocol version: {self.protocol_version}")
149+
else:
150+
logger.warning(f"Initialization response does not contain protocolVersion: {result}")
149151

150152
async def _handle_sse_event(
151153
self,
@@ -286,9 +288,7 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
286288
content_type = response.headers.get(CONTENT_TYPE, "").lower()
287289

288290
if content_type.startswith(JSON):
289-
await self._handle_json_response(
290-
response, ctx.read_stream_writer, is_initialization
291-
)
291+
await self._handle_json_response(response, ctx.read_stream_writer, is_initialization)
292292
elif content_type.startswith(SSE):
293293
await self._handle_sse_response(response, ctx, is_initialization)
294294
else:
@@ -331,11 +331,7 @@ async def _handle_sse_response(
331331
is_complete = await self._handle_sse_event(
332332
sse,
333333
ctx.read_stream_writer,
334-
resumption_callback=(
335-
ctx.metadata.on_resumption_token_update
336-
if ctx.metadata
337-
else None
338-
),
334+
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
339335
is_initialization=is_initialization,
340336
)
341337
# If the SSE event indicates completion, like returning respose/error

src/mcp/server/streamable_http.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from mcp.shared.message import ServerMessageMetadata, SessionMessage
2828
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
2929
from mcp.types import (
30+
DEFAULT_PROTOCOL_VERSION,
3031
INTERNAL_ERROR,
3132
INVALID_PARAMS,
3233
INVALID_REQUEST,
@@ -295,7 +296,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
295296
has_json, has_sse = self._check_accept_headers(request)
296297
if not (has_json and has_sse):
297298
response = self._create_error_response(
298-
("Not Acceptable: Client must accept both application/json and " "text/event-stream"),
2 8000 99+
("Not Acceptable: Client must accept both application/json and text/event-stream"),
299300
HTTPStatus.NOT_ACCEPTABLE,
300301
)
301302
await response(scope, receive, send)
@@ -696,9 +697,9 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool
696697
# Get the protocol version from the request headers
697698
protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER)
698699

699-
# If no protocol version provided, assume version 2025-03-26
700-
if not protocol_version:
701-
protocol_version = "2025-03-26"
700+
# If no protocol version provided, assume default version
701+
if protocol_version is None:
702+
protocol_version = DEFAULT_PROTOCOL_VERSION
702703

703704
# Check if the protocol version is supported
704705
if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS:
@@ -711,9 +712,7 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool
711712

712713
return True
713714

714-
async def _replay_events(
715-
self, last_event_id: str, request: Request, send: Send
716-
) -> None:
715+
async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None:
717716
"""
718717
Replays events that would have been sent after the specified event ID.
719718
Only used when resumability is enabled.

src/mcp/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
"""
2424

2525
LATEST_PROTOCOL_VERSION = "2025-03-26"
26+
DEFAULT_PROTOCOL_VERSION = "2025-03-26"
2627

2728
ProgressToken = str | int
2829
Cursor = str

tests/shared/test_streamable_http.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1412,9 +1412,7 @@ async def test_streamablehttp_request_context_isolation(context_aware_server: No
14121412

14131413

14141414
@pytest.mark.anyio
1415-
async def test_client_includes_protocol_version_header_after_init(
1416-
context_aware_server, basic_server_url
1417-
):
1415+
async def test_client_includes_protocol_version_header_after_init(context_aware_server, basic_server_url):
14181416
"""Test that client includes mcp-protocol-version header after initialization."""
14191417
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
14201418
read_stream,
@@ -1439,7 +1437,7 @@ async def test_client_includes_protocol_version_header_after_init(
14391437

14401438

14411439
def test_server_validates_protocol_version_header(basic_server, basic_server_url):
1442-
"""Test that server returns 400 Bad Request version if header missing or invalid."""
1440+
"""Test that server returns 400 Bad Request version if header unsupported or invalid."""
14431441
# First initialize a session to get a valid session ID
14441442
init_response = requests.post(
14451443
f"{basic_server_url}/mcp",
@@ -1464,10 +1462,7 @@ def test_server_validates_protocol_version_header(basic_server, basic_server_url
14641462
json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-2"},
14651463
)
14661464
assert response.status_code == 400
1467-
assert (
1468-
MCP_PROTOCOL_VERSION_HEADER in response.text
1469-
or "protocol version" in response.text.lower()
1470-
)
1465+
assert MCP_PROTOCOL_VERSION_HEADER in response.text or "protocol version" in response.text.lower()
14711466

14721467
# Test request with unsupported protocol version (should fail)
14731468
response = requests.post(
@@ -1481,10 +1476,7 @@ def test_server_validates_protocol_version_header(basic_server, basic_server_url
14811476
json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-3"},
14821477
)
14831478
assert response.status_code == 400
1484-
assert (
1485-
MCP_PROTOCOL_VERSION_HEADER in response.text
1486-
or "protocol version" in response.text.lower()
1487-
)
1479+
assert MCP_PROTOCOL_VERSION_HEADER in response.text or "protocol version" in response.text.lower()
14881480

14891481
# Test request with valid protocol version (should succeed)
14901482
negotiated_version = extract_protocol_version_from_sse(init_response)
@@ -1502,9 +1494,7 @@ def test_server_validates_protocol_version_header(basic_server, basic_server_url
15021494
assert response.status_code == 200
1503 65BB 1495

15041496

1505-
def test_server_backwards_compatibility_no_protocol_version(
1506-
basic_server, basic_server_url
1507-
):
1497+
def test_server_backwards_compatibility_no_protocol_version(basic_server, basic_server_url):
15081498
"""Test server accepts requests without protocol version header."""
15091499
# First initialize a session to get a valid session ID
15101500
init_response = requests.post(

0 commit comments

Comments
 (0)
0