diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 294986acb..6536272d9 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -333,90 +333,107 @@ async def _receive_loop(self) -> None: self._read_stream, self._write_stream, ): - async for message in self._read_stream: - if isinstance(message, Exception): - await self._handle_incoming(message) - elif isinstance(message.message.root, JSONRPCRequest): - try: - validated_request = self._receive_request_type.model_validate( - message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True) - ) - responder = RequestResponder( - request_id=message.message.root.id, - request_meta=validated_request.root.params.meta if validated_request.root.params else None, - request=validated_request, - session=self, - on_complete=lambda r: self._in_flight.pop(r.request_id, None), - message_metadata=message.metadata, - ) - self._in_flight[responder.request_id] = responder - await self._received_request(responder) - - if not responder._completed: # type: ignore[reportPrivateUsage] - await self._handle_incoming(responder) - except Exception as e: - # For request validation errors, send a proper JSON-RPC error - # response instead of crashing the server - logging.warning(f"Failed to validate request: {e}") - logging.debug(f"Message that failed validation: {message.message.root}") - error_response = JSONRPCError( - jsonrpc="2.0", - id=message.message.root.id, - error=ErrorData( - code=INVALID_PARAMS, - message="Invalid request parameters", - data="", - ), - ) - session_message = SessionMessage(message=JSONRPCMessage(error_response)) - await self._write_stream.send(session_message) - - elif isinstance(message.message.root, JSONRPCNotification): - try: - notification = self._receive_notification_type.model_validate( - message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True) - ) - # Handle cancellation notifications - if isinstance(notification.root, CancelledNotification): - cancelled_id = notification.root.params.requestId - if cancelled_id in self._in_flight: - await self._in_flight[cancelled_id].cancel() + try: + async for message in self._read_stream: + if isinstance(message, Exception): + await self._handle_incoming(message) + elif isinstance(message.message.root, JSONRPCRequest): + try: + validated_request = self._receive_request_type.model_validate( + message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + responder = RequestResponder( + request_id=message.message.root.id, + request_meta=validated_request.root.params.meta + if validated_request.root.params + else None, + request=validated_request, + session=self, + on_complete=lambda r: self._in_flight.pop(r.request_id, None), + message_metadata=message.metadata, + ) + self._in_flight[responder.request_id] = responder + await self._received_request(responder) + + if not responder._completed: # type: ignore[reportPrivateUsage] + await self._handle_incoming(responder) + except Exception as e: + # For request validation errors, send a proper JSON-RPC error + # response instead of crashing the server + logging.warning(f"Failed to validate request: {e}") + logging.debug(f"Message that failed validation: {message.message.root}") + error_response = JSONRPCError( + jsonrpc="2.0", + id=message.message.root.id, + error=ErrorData( + code=INVALID_PARAMS, + message="Invalid request parameters", + data="", + ), + ) + session_message = SessionMessage(message=JSONRPCMessage(error_response)) + await self._write_stream.send(session_message) + + elif isinstance(message.message.root, JSONRPCNotification): + try: + notification = self._receive_notification_type.model_validate( + message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + # Handle cancellation notifications + if isinstance(notification.root, CancelledNotification): + cancelled_id = notification.root.params.requestId + if cancelled_id in self._in_flight: + await self._in_flight[cancelled_id].cancel() + else: + # Handle progress notifications callback + if isinstance(notification.root, ProgressNotification): + progress_token = notification.root.params.progressToken + # If there is a progress callback for this token, + # call it with the progress information + if progress_token in self._progress_callbacks: + callback = self._progress_callbacks[progress_token] + await callback( + notification.root.params.progress, + notification.root.params.total, + notification.root.params.message, + ) + await self._received_notification(notification) + await self._handle_incoming(notification) + except Exception as e: + # For other validation errors, log and continue + logging.warning( + f"Failed to validate notification: {e}. " f"Message was: {message.message.root}" + ) + else: # Response or error + stream = self._response_streams.pop(message.message.root.id, None) + if stream: + await stream.send(message.message.root) else: - # Handle progress notifications callback - if isinstance(notification.root, ProgressNotification): - progress_token = notification.root.params.progressToken - # If there is a progress callback for this token, - # call it with the progress information - if progress_token in self._progress_callbacks: - callback = self._progress_callbacks[progress_token] - await callback( - notification.root.params.progress, - notification.root.params.total, - notification.root.params.message, - ) - await self._received_notification(notification) - await self._handle_incoming(notification) - except Exception as e: - # For other validation errors, log and continue - logging.warning( - f"Failed to validate notification: {e}. " f"Message was: {message.message.root}" - ) - else: # Response or error - stream = self._response_streams.pop(message.message.root.id, None) - if stream: - await stream.send(message.message.root) - else: - await self._handle_incoming( - RuntimeError("Received response with an unknown " f"request ID: {message}") - ) - - # after the read stream is closed, we need to send errors - # to any pending requests - for id, stream in self._response_streams.items(): - error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed") - await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error)) - await stream.aclose() - self._response_streams.clear() + await self._handle_incoming( + RuntimeError("Received response with an unknown " f"request ID: {message}") + ) + + except anyio.ClosedResourceError: + # This is expected when the client disconnects abruptly. + # Without this handler, the exception would propagate up and + # crash the server's task group. + logging.debug("Read stream closed by client") + except Exception as e: + # Other exceptions are not expected and should be logged. We purposefully + # catch all exceptions here to avoid crashing the server. + logging.exception(f"Unhandled exception in receive loop: {e}") + finally: + # after the read stream is closed, we need to send errors + # to any pending requests + for id, stream in self._response_streams.items(): + error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed") + try: + await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error)) + await stream.aclose() + except Exception: + # Stream might already be closed + pass + self._response_streams.clear() async def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None: """ diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index c4604e3fa..68d58fc91 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -1521,3 +1521,40 @@ def test_server_backwards_compatibility_no_protocol_version(basic_server, basic_ ) assert response.status_code == 200 # Should succeed for backwards compatibility assert response.headers.get("Content-Type") == "text/event-stream" + + +@pytest.mark.anyio +async def test_client_crash_handled(basic_server, basic_server_url): + """Test that cases where the client crashes are handled gracefully.""" + + # Simulate bad client that crashes after init + async def bad_client(): + """Client that triggers ClosedResourceError""" + async with streamablehttp_client(f"{basic_server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + raise Exception("client crash") + + # Run bad client a few times to trigger the crash + for _ in range(3): + try: + await bad_client() + except Exception: + pass + await anyio.sleep(0.1) + + # Try a good client, it should still be able to connect and list tools + async with streamablehttp_client(f"{basic_server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: + result = await session.initialize() + assert isinstance(result, InitializeResult) + tools = await session.list_tools() + assert tools.tools