8000 Fix uncaught exception in MCP server by ddworken · Pull Request #822 · modelcontextprotocol/python-sdk · GitHub
[go: up one dir, main page]

Skip to content

Fix uncaught exception in MCP server #822

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jun 10, 2025
Merged
59 changes: 40 additions & 19 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage
from mcp.types import (
CONNECTION_CLOSED,
INVALID_PARAMS,
CancelledNotification,
ClientNotification,
ClientRequest,
Expand Down Expand Up @@ -354,27 +355,47 @@ async def _receive_loop(self) -> None:
if isinstance(message, Exception):
await self._handle_incoming(message)
elif isinstance(message.message.root, JSONRPCRequest):
validated_request = self._receive_request_type.model_validate(
message.message.root.model_dump(
by_alias=True, mode="json", exclude_none=True
try:
validated_request = self._receive_request_type.model_validate(
message.message.root.model_dump(
by_alias=True, mode="json", exclude_none=True
)
)
)
responder = RequestResponder(
request_id=message.message.root.id,
request_meta=validated_request.root.params.meta
if validated_request.root.params
else None,
request=validated_request,
session=self,
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
message_metadata=message.metadata,
)

self._in_flight[responder.request_id] = responder
await self._received_request(responder)
responder = RequestResponder(
request_id=message.message.root.id,
request_meta=validated_request.root.params.meta
if validated_request.root.params
else None,
request=validated_request,
session=self,
on_complete=lambda r: self._in_flight.pop(
r.request_id, None),
message_metadata=message.metadata,
)
self._in_flight[responder.request_id] = responder
await self._received_request(responder)

if not responder._completed: # type: ignore[reportPrivateUsage]
await self._handle_incoming(responder)
if not responder._completed: # type: ignore[reportPrivateUsage]
await self._handle_incoming(responder)
except Exception as e:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think is better to handle a specific exceptions before the general one, such as RuntimeError (Which mentioned in this issue)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that the risk is the server becoming unresponsive, I believe catching all exceptions to isolate errors to a single request is correct.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is a valid answer from the package's point of view.

It makes it considerably hard to maintain the package if everything is except Exception.

# For request validation errors, send a proper JSON-RPC error
# response instead of crashing the server
logging.warning(f"Failed to validate request: {e}")
logging.debug(
f"Message that failed validation: {message.message.root}"
)
error_response = JSONRPCError(
jsonrpc="2.0",
id=message.message.root.id,
error=ErrorData(
code=INVALID_PARAMS,
message="Invalid request parameters",
data="",
),
)
session_message = SessionMessage(
message=JSONRPCMessage(error_response))
Comment on lines +396 to +397
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How this got merged?

await self._write_stream.send(session_message)

elif isinstance(message.message.root, JSONRPCNotification):
try:
Expand Down
172 changes: 172 additions & 0 deletions tests/issues/test_malformed_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Claude Debug
"""Test for HackerOne vulnerability report #3156202 - malformed input DOS."""

import anyio
import pytest

from mcp.server.models import InitializationOptions
from mcp.server.session import ServerSession
from mcp.shared.message import SessionMessage
from mcp.types import (
INVALID_PARAMS,
JSONRPCError,
JSONRPCMessage,
JSONRPCRequest,
ServerCapabilities,
)


@pytest.mark.anyio
async def test_malformed_initialize_request_does_not_crash_server():
"""
Test that malformed initialize requests return proper error responses
instead of crashing the server (HackerOne #3156202).
"""
# Create in-memory streams for testing
read_send_stream, read_receive_stream = anyio.create_memory_object_stream[
SessionMessage | Exception
](10)
write_send_stream, write_receive_stream = anyio.create_memory_object_stream[
SessionMessage
](10)

try:
# Create a malformed initialize request (missing required params field)
malformed_request = JSONRPCRequest(
jsonrpc="2.0",
id="f20fe86132ed4cd197f89a7134de5685",
method="initialize",
# params=None # Missing required params field
)

# Wrap in session message
request_message = SessionMessage(message=JSONRPCMessage(malformed_request))

# Start a server session
async with ServerSession(
read_stream=read_receive_stream,
write_stream=write_send_stream,
init_options=InitializationOptions(
server_name="test_server",
server_version="1.0.0",
capabilities=ServerCapabilities(),
),
):
# Send the malformed request
await read_send_stream.send(request_message)

# Give the session time to process the request
await anyio.sleep(0.1)

# Check that we received an error response instead of a crash
try:
response_message = write_receive_stream.receive_nowait()
response = response_message.message.root

# Verify it's a proper JSON-RPC error response
assert isinstance(response, JSONRPCError)
assert response.jsonrpc == "2.0"
assert response.id == "f20fe86132ed4cd197f89a7134de5685"
assert response.error.code == INVALID_PARAMS
assert "Invalid request parameters" in response.error.message

# Verify the session is still alive and can handle more requests
# Send another malformed request to confirm server stability
another_malformed_request = JSONRPCRequest(
jsonrpc="2.0",
id="test_id_2",
method="tools/call",
# params=None # Missing required params
)
another_request_message = SessionMessage(
message=JSONRPCMessage(another_malformed_request)
)

await read_send_stream.send(another_request_message)
await anyio.sleep(0.1)

# Should get another error response, not a crash
second_response_message = write_receive_stream.receive_nowait()
second_response = second_response_message.message.root

assert isinstance(second_response, JSONRPCError)
assert second_response.id == "test_id_2"
assert second_response.error.code == INVALID_PARAMS

except anyio.WouldBlock:
pytest.fail("No response received - server likely crashed")
finally:
# Close all streams to ensure proper cleanup
await read_send_stream.aclose()
await write_send_stream.aclose()
await read_receive_stream.aclose()
await write_receive_stream.aclose()


@pytest.mark.anyio
async def test_multiple_concurrent_malformed_requests():
"""
Test that multiple concurrent malformed requests don't crash the server.
"""
# Create in-memory streams for testing
read_send_stream, read_receive_stream = anyio.create_memory_object_stream[
SessionMessage | Exception
](100)
write_send_stream, write_receive_stream = anyio.create_memory_object_stream[
SessionMessage
](100)

try:
# Start a server session
async with ServerSession(
read_stream=read_receive_stream,
write_stream=write_send_stream,
init_options=InitializationOptions(
server_name="test_server",
server_version="1.0.0",
capabilities=ServerCapabilities(),
),
):
# Send multiple malformed requests concurrently
malformed_requests = []
for i in range(10):
malformed_request = JSONRPCRequest(
jsonrpc="2.0",
id=f"malformed_{i}",
method="initialize",
# params=None # Missing required params
)
request_message = SessionMessage(
message=JSONRPCMessage(malformed_request)
)
malformed_requests.append(request_message)

# Send all requests
for request in malformed_requests:
await read_send_stream.send(request)

# Give time to process
await anyio.sleep(0.2)

# Verify we get error responses for all requests
error_responses = []
try:
while True:
response_message = write_receive_stream.receive_nowait()
error_responses.append(response_message.message.root)
except anyio.WouldBlock:
pass # No more messages

# Should have received 10 error responses
assert len(error_responses) == 10

for i, response in enumerate(error_responses):
assert isinstance(response, JSONRPCError)
assert response.id == f"malformed_{i}"
assert response.err 50F2 or.code == INVALID_PARAMS
finally:
# Close all streams to ensure proper cleanup
await read_send_stream.aclose()
await write_send_stream.aclose()
await read_receive_stream.aclose()
await write_receive_stream.aclose()
Loading
0