diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 29195cbd9..357029002 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -26,12 +26,17 @@ async def sse_client( headers: dict[str, Any] | None = None, timeout: float = 5, sse_read_timeout: float = 60 * 5, + **client_kwargs: Any, ): """ Client transport for SSE. `sse_read_timeout` determines how long (in seconds) the client will wait for a new event before disconnecting. All other HTTP operations are controlled by `timeout`. + + `**client_kwargs` : dict, optional - Additional http client configurations used + to configure the AsyncClient. + """ read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] @@ -45,7 +50,9 @@ async def sse_client( async with anyio.create_task_group() as tg: try: logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}") - async with create_mcp_http_client(headers=headers) as client: + async with create_mcp_http_client( + headers=headers, client_kwargs=client_kwargs + ) as client: async with aconnect_sse( client, "GET", diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 3324dab5a..756813f8a 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -427,6 +427,7 @@ async def streamablehttp_client( timeout: timedelta = timedelta(seconds=30), sse_read_timeout: timedelta = timedelta(seconds=60 * 5), terminate_on_close: bool = True, + **client_kwargs: Any, ) -> AsyncGenerator[ tuple[ MemoryObjectReceiveStream[SessionMessage | Exception], @@ -441,6 +442,9 @@ async def streamablehttp_client( `sse_read_timeout` determines how long (in seconds) the client will wait for a new event before disconnecting. All other HTTP operations are controlled by `timeout`. + `**client_kwargs` : dict, optional - Additional http client configurations used + to configure the AsyncClient. + Yields: Tuple containing: - read_stream: Stream for reading messages from the server @@ -465,6 +469,7 @@ async def streamablehttp_client( timeout=httpx.Timeout( transport.timeout.seconds, read=transport.sse_read_timeout.seconds ), + client_kwargs=client_kwargs, ) as client: # Define callbacks that need access to tg def start_get_stream() -> None: diff --git a/src/mcp/shared/_httpx_utils.py b/src/mcp/shared/_httpx_utils.py index 95080bde1..b66605106 100644 --- a/src/mcp/shared/_httpx_utils.py +++ b/src/mcp/shared/_httpx_utils.py @@ -10,6 +10,7 @@ def create_mcp_http_client( headers: dict[str, str] | None = None, timeout: httpx.Timeout | None = None, + client_kwargs: dict[str, Any] | None = None, ) -> httpx.AsyncClient: """Create a standardized httpx AsyncClient with MCP defaults. @@ -21,6 +22,7 @@ def create_mcp_http_client( headers: Optional headers to include with all requests. timeout: Request timeout as httpx.Timeout object. Defaults to 30 seconds if not specified. + client_kwargs : dict[str, Any]. Optional. To configure the AsyncClient. Returns: Configured httpx.AsyncClient instance with MCP defaults. @@ -45,18 +47,18 @@ def create_mcp_http_client( response = await client.get("/long-request") """ # Set MCP defaults - kwargs: dict[str, Any] = { - "follow_redirects": True, - } + if not client_kwargs: + client_kwargs = {} + client_kwargs["follow_redirects"] = True # Handle timeout if timeout is None: - kwargs["timeout"] = httpx.Timeout(30.0) + client_kwargs["timeout"] = httpx.Timeout(30.0) else: - kwargs["timeout"] = timeout + client_kwargs["timeout"] = timeout # Handle headers if headers is not None: - kwargs["headers"] = headers + client_kwargs["headers"] = headers - return httpx.AsyncClient(**kwargs) + return httpx.AsyncClient(**client_kwargs) diff --git a/tests/shared/test_httpx_utils.py b/tests/shared/test_httpx_utils.py index dcc6fd003..284cda0a1 100644 --- a/tests/shared/test_httpx_utils.py +++ b/tests/shared/test_httpx_utils.py @@ -22,3 +22,11 @@ def test_custom_parameters(): assert client.headers["Authorization"] == "Bearer token" assert client.timeout.connect == 60.0 + + +def test_client_kwargs_parameters(): + """Test if additional kwargs are set correctly.""" + client_kwargs = {"max_redirects": 999} + + client = create_mcp_http_client(client_kwargs=client_kwargs) + assert client.max_redirects == 999