8000 [pull] main from modelcontextprotocol:main by pull[bot] · Pull Request #39 · rpatil524/python-sdk · GitHub
[go: up one dir, main page]

Skip to content

[pull] main from modelcontextprotocol:main #39

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 100 additions & 83 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,90 +333,107 @@ async def _receive_loop(self) -> None:
self._read_stream,
self._write_stream,
):
async for message in self._read_stream:
if isinstance(message, Exception):
await self._handle_incoming(message)
elif isinstance(message.message.root, JSONRPCRequest):
try:
validated_request = self._receive_request_type.model_validate(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
responder = RequestResponder(
request_id=message.message.root.id,
request_meta=validated_request.root.params.meta if validated_request.root.params else None,
request=validated_request,
session=self,
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
message_metadata=message.metadata,
)
self._in_flight[responder.request_id] = responder
await self._received_request(responder)

if not responder._completed: # type: ignore[reportPrivateUsage]
await self._handle_incoming(responder)
except Exception as e:
# For request validation errors, send a proper JSON-RPC error
# response instead of crashing the server
logging.warning(f"Failed to validate request: {e}")
logging.debug(f"Message that failed validation: {message.message.root}")
error_response = JSONRPCError(
jsonrpc="2.0",
id=message.message.root.id,
error=ErrorData(
code=INVALID_PARAMS,
message="Invalid request parameters",
data="",
),
)
session_message = SessionMessage(message=JSONRPCMessage(error_response))
await self._write_stream.send(session_message)

elif isinstance(message.message.root, JSONRPCNotification):
try:
10000 notification = self._receive_notification_type.model_validate(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
# Handle cancellation notifications
if isinstance(notification.root, CancelledNotification):
cancelled_id = notification.root.params.requestId
if cancelled_id in self._in_flight:
await self._in_flight[cancelled_id].cancel()
try:
async for message in self._read_stream:
if isinstance(message, Exception):
await self._handle_incoming(message)
elif isinstance(message.message.root, JSONRPCRequest):
try:
validated_request = self._receive_request_type.model_validate(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
responder = RequestResponder(
request_id=message.message.root.id,
request_meta=validated_request.root.params.meta
if validated_request.root.params
else None,
request=validated_request,
session=self,
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
message_metadata=message.metadata,
)
self._in_flight[responder.request_id] = responder
await self._received_request(responder)

if not responder._completed: # type: ignore[reportPrivateUsage]
await self._handle_incoming(responder)
except Exception as e:
# For request validation errors, send a proper JSON-RPC error
# response instead of crashing the server
logging.warning(f"Failed to validate request: {e}")
logging.debug(f"Message that failed validation: {message.message.root}")
error_response = JSONRPCError(
jsonrpc="2.0",
id=message.message.root.id,
error=ErrorData(
code=INVALID_PARAMS,
message="Invalid request parameters",
data="",
),
)
session_message = SessionMessage(message=JSONRPCMessage(error_response))
await self._write_stream.send(session_message)

elif isinstance(message.message.root, JSONRPCNotification):
try:
notification = self._receive_notification_type.model_validate(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
# Handle cancellation notifications
if isinstance(notification.root, CancelledNotification):
cancelled_id = notification.root.params.requestId
if cancelled_id in self._in_flight:
await self._in_flight[cancelled_id].cancel()
else:
# Handle progress notifications callback
if isinstance(notification.root, ProgressNotification):
progress_token = notification.root.params.progressToken
# If there is a progress callback for this token,
# call it with the progress information
if progress_token in self._progress_callbacks:
callback = self._progress_callbacks[progress_token]
await callback(
notification.root.params.progress,
notification.root.params.total,
notification.root.params.message,
)
await self._received_notification(notification)
await self._handle_incoming(notification)
except Exception as e:
# For other validation errors, log and continue
logging.warning(
f"Failed to validate notification: {e}. " f"Message was: {message.message.root}"
)
else: # Response or error
stream = self._response_streams.pop(message.message.root.id, None)
if stream:
await stream.send(message.message.root)
else:
# Handle progress notifications callback
if isinstance(notification.root, ProgressNotification):
progress_token = notification.root.params.progressToken
# If there is a progress callback for this token,
# call it with the progress information
if progress_token in self._progress_callbacks:
callback = self._progress_callbacks[progress_token]
await callback(
notification.root.params.progress,
notification.root.params.total,
notification.root.params.message,
)
await self._received_notification(notification)
await self._handle_incoming(notification)
except Exception as e:
# For other validation errors, log and continue
logging.warning(
f"Failed to validate notification: {e}. " f"Message was: {message.message.root}"
)
else: # Response or error
stream = self._response_streams.pop(message.message.root.id, None)
if stream:
await stream.send(message.message.root)
else:
await self._handle_incoming(
RuntimeError("Received response with an unknown " f"request ID: {message}")
)

# after the read stream is closed, we need to send errors
# to any pending requests
for id, stream in self._response_streams.items():
error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed")
await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error))
await stream.aclose()
self._response_streams.clear()
await self._handle_incoming(
RuntimeError("Received response with an unknown " f"request ID: {message}")
)

except anyio.ClosedResourceError:
# This is expected when the client disconnects abruptly.
# Without this handler, the exception would propagate up and
# crash the server's task group.
logging.debug("Read stream closed by client")
except Exception as e:
# Other exceptions are not expected and should be logged. We purposefully
# catch all exceptions here to avoid crashing the server.
logging.exception(f"Unhandled exception in receive loop: {e}")
finally:
# after the read stream is closed, we need to send errors
# to any pending requests
for id, stream in self._response_streams.items():
error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed")
try:
await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error))
await stream.aclose()
except Exception:
# Stream might already be closed
pass
self._response_streams.clear()

async def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None:
"""
Expand Down
37 changes: 37 additions & 0 deletions tests/shared/test_streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -1521,3 +1521,40 @@ def test_server_backwards_compatibility_no_protocol_version(basic_server, basic_
)
assert response.status_code == 200 # Should succeed for backwards compatibility
assert response.headers.get("Content-Type") == "text/event-stream"


@pytest.mark.anyio
async def test_client_crash_handled(basic_server, basic_server_url):
"""Test that cases where the client crashes are handled gracefully."""

# Simulate bad client that crashes after init
async def bad_client():
"""Client that triggers ClosedResourceError"""
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
read_stream,
write_stream,
_,
):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
raise Exception("client crash")

# Run bad client a few times to trigger the crash
for _ in range(3):
try:
await bad_client()
except Exception:
pass
await anyio.sleep(0.1)

# Try a good client, it should still be able to connect and list tools
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
read_stream,
write_stream,
_,
):
async with ClientSession(read_stream, write_stream) as session:
result = await session.initialize()
assert isinstance(result, InitializeResult)
tools = await session.list_tools()
assert tools.tools
0