diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 4b13709c6..c0345d6ab 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -15,6 +15,7 @@ from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage from mcp.types import ( CONNECTION_CLOSED, + INVALID_PARAMS, CancelledNotification, ClientNotification, ClientRequest, @@ -354,27 +355,47 @@ async def _receive_loop(self) -> None: if isinstance(message, Exception): await self._handle_incoming(message) elif isinstance(message.message.root, JSONRPCRequest): - validated_request = self._receive_request_type.model_validate( - message.message.root.model_dump( - by_alias=True, mode="json", exclude_none=True + try: + validated_request = self._receive_request_type.model_validate( + message.message.root.model_dump( + by_alias=True, mode="json", exclude_none=True + ) ) - ) - responder = RequestResponder( - request_id=message.message.root.id, - request_meta=validated_request.root.params.meta - if validated_request.root.params - else None, - request=validated_request, - session=self, - on_complete=lambda r: self._in_flight.pop(r.request_id, None), - message_metadata=message.metadata, - ) - - self._in_flight[responder.request_id] = responder - await self._received_request(responder) + responder = RequestResponder( + request_id=message.message.root.id, + request_meta=validated_request.root.params.meta + if validated_request.root.params + else None, + request=validated_request, + session=self, + on_complete=lambda r: self._in_flight.pop( + r.request_id, None), + message_metadata=message.metadata, + ) + self._in_flight[responder.request_id] = responder + await self._received_request(responder) - if not responder._completed: # type: ignore[reportPrivateUsage] - await self._handle_incoming(responder) + if not responder._completed: # type: ignore[reportPrivateUsage] + await self._handle_incoming(responder) + except Exception as e: + # For request validation errors, send a proper JSON-RPC error + # response instead of crashing the server + logging.warning(f"Failed to validate request: {e}") + logging.debug( + f"Message that failed validation: {message.message.root}" + ) + error_response = JSONRPCError( + jsonrpc="2.0", + id=message.message.root.id, + error=ErrorData( + code=INVALID_PARAMS, + message="Invalid request parameters", + data="", + ), + ) + session_message = SessionMessage( + message=JSONRPCMessage(error_response)) + await self._write_stream.send(session_message) elif isinstance(message.message.root, JSONRPCNotification): try: diff --git a/tests/issues/test_malformed_input.py b/tests/issues/test_malformed_input.py new file mode 100644 index 000000000..e4fda9e13 --- /dev/null +++ b/tests/issues/test_malformed_input.py @@ -0,0 +1,172 @@ +# Claude Debug +"""Test for HackerOne vulnerability report #3156202 - malformed input DOS.""" + +import anyio +import pytest + +from mcp.server.models import InitializationOptions +from mcp.server.session import ServerSession +from mcp.shared.message import SessionMessage +from mcp.types import ( + INVALID_PARAMS, + JSONRPCError, + JSONRPCMessage, + JSONRPCRequest, + ServerCapabilities, +) + + +@pytest.mark.anyio +async def test_malformed_initialize_request_does_not_crash_server(): + """ + Test that malformed initialize requests return proper error responses + instead of crashing the server (HackerOne #3156202). + """ + # Create in-memory streams for testing + read_send_stream, read_receive_stream = anyio.create_memory_object_stream[ + SessionMessage | Exception + ](10) + write_send_stream, write_receive_stream = anyio.create_memory_object_stream[ + SessionMessage + ](10) + + try: + # Create a malformed initialize request (missing required params field) + malformed_request = JSONRPCRequest( + jsonrpc="2.0", + id="f20fe86132ed4cd197f89a7134de5685", + method="initialize", + # params=None # Missing required params field + ) + + # Wrap in session message + request_message = SessionMessage(message=JSONRPCMessage(malformed_request)) + + # Start a server session + async with ServerSession( + read_stream=read_receive_stream, + write_stream=write_send_stream, + init_options=InitializationOptions( + server_name="test_server", + server_version="1.0.0", + capabilities=ServerCapabilities(), + ), + ): + # Send the malformed request + await read_send_stream.send(request_message) + + # Give the session time to process the request + await anyio.sleep(0.1) + + # Check that we received an error response instead of a crash + try: + response_message = write_receive_stream.receive_nowait() + response = response_message.message.root + + # Verify it's a proper JSON-RPC error response + assert isinstance(response, JSONRPCError) + assert response.jsonrpc == "2.0" + assert response.id == "f20fe86132ed4cd197f89a7134de5685" + assert response.error.code == INVALID_PARAMS + assert "Invalid request parameters" in response.error.message + + # Verify the session is still alive and can handle more requests + # Send another malformed request to confirm server stability + another_malformed_request = JSONRPCRequest( + jsonrpc="2.0", + id="test_id_2", + method="tools/call", + # params=None # Missing required params + ) + another_request_message = SessionMessage( + message=JSONRPCMessage(another_malformed_request) + ) + + await read_send_stream.send(another_request_message) + await anyio.sleep(0.1) + + # Should get another error response, not a crash + second_response_message = write_receive_stream.receive_nowait() + second_response = second_response_message.message.root + + assert isinstance(second_response, JSONRPCError) + assert second_response.id == "test_id_2" + assert second_response.error.code == INVALID_PARAMS + + except anyio.WouldBlock: + pytest.fail("No response received - server likely crashed") + finally: + # Close all streams to ensure proper cleanup + await read_send_stream.aclose() + await write_send_stream.aclose() + await read_receive_stream.aclose() + await write_receive_stream.aclose() + + +@pytest.mark.anyio +async def test_multiple_concurrent_malformed_requests(): + """ + Test that multiple concurrent malformed requests don't crash the server. + """ + # Create in-memory streams for testing + read_send_stream, read_receive_stream = anyio.create_memory_object_stream[ + SessionMessage | Exception + ](100) + write_send_stream, write_receive_stream = anyio.create_memory_object_stream[ + SessionMessage + ](100) + + try: + # Start a server session + async with ServerSession( + read_stream=read_receive_stream, + write_stream=write_send_stream, + init_options=InitializationOptions( + server_name="test_server", + server_version="1.0.0", + capabilities=ServerCapabilities(), + ), + ): + # Send multiple malformed requests concurrently + malformed_requests = [] + for i in range(10): + malformed_request = JSONRPCRequest( + jsonrpc="2.0", + id=f"malformed_{i}", + method="initialize", + # params=None # Missing required params + ) + request_message = SessionMessage( + message=JSONRPCMessage(malformed_request) + ) + malformed_requests.append(request_message) + + # Send all requests + for request in malformed_requests: + await read_send_stream.send(request) + + # Give time to process + await anyio.sleep(0.2) + + # Verify we get error responses for all requests + error_responses = [] + try: + while True: + response_message = write_receive_stream.receive_nowait() + error_responses.append(response_message.message.root) + except anyio.WouldBlock: + pass # No more messages + + # Should have received 10 error responses + assert len(error_responses) == 10 + + for i, response in enumerate(error_responses): + assert isinstance(response, JSONRPCError) + assert response.id == f"malformed_{i}" + assert response.error.code == INVALID_PARAMS + finally: + # Close all streams to ensure proper cleanup + await read_send_stream.aclose() + await write_send_stream.aclose() + await read_receive_stream.aclose() + await write_receive_stream.aclose() \ No newline at end of file