8000 Add support for sse and streamable http transports in the ClientSessi… · modelcontextprotocol/python-sdk@14bfcf1 · GitHub
[go: up one dir, main page]

Skip to content

Commit 14bfcf1

Browse files
committed
Add support for sse and streamable http transports in the ClientSessionGroup
This change introduces additional support for the other MCP transports and includes respective server parameter types to align with the stdio client implementation.
1 parent c9acac5 commit 14bfcf1

File tree

2 files changed

+198
-22
lines changed

2 files changed

+198
-22
lines changed

src/mcp/client/session_group.py

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,59 @@
1010

1111
import contextlib
1212
from collections.abc import Callable
13+
from datetime import timedelta
1314
from typing import Any, TypeAlias
1415

1516
from pydantic import BaseModel
1617

1718
import mcp
1819
from mcp import types
20+
from mcp.client.sse import sse_client
1921
from mcp.client.stdio import StdioServerParameters
22+
from mcp.client.streamable_http import streamablehttp_client
2023
from mcp.shared.exceptions import McpError
2124

2225

26+
class SseServerParameters(BaseModel):
27+
"""Parameters for intializing a sse_client."""
28+
29+
# The endpoint URL.
30+
url: str
31+
32+
# Optional headers to include in requests.
33+
headers: dict[str, Any] | None = None
34+
35+
# HTTP timeout for regular operations.
36+
timeout: float = 5
37+
38+
# Timeout for SSE read operations.
39+
sse_read_timeout: float = 60 * 5
40+
41+
42+
class StreamableHttpParameters(BaseModel):
43+
"""Parameters for intializing a streamablehttp_client."""
44+
45+
# The endpoint URL.
46+
url: str
47+
48+
# Optional headers to include in requests.
49+
headers: dict[str, Any] | None = None
50+
51+
# HTTP timeout for regular operations.
52+
timeout: timedelta = timedelta(seconds=30)
53+
54+
# Timeout for SSE read operations.
55+
sse_read_timeout: timedelta = timedelta(seconds=60 * 5)
56+
57+
# Close the client session when the transport closes.
58+
terminate_on_close: bool = True
59+
60+
61+
ServerParameters: TypeAlias = (
62+
StdioServerParameters | SseServerParameters | StreamableHttpParameters
63+
)
64+
65+
2366
class ClientSessionGroup:
2467
"""Client for managing connections to multiple MCP servers.
2568
@@ -118,7 +161,7 @@ def disconnect_from_server(self, session: mcp.ClientSession) -> None:
118161

119162
async def connect_to_server(
120163
self,
121-
server_params: StdioServerParameters,
164+
server_params: ServerParameters,
122165
) -> mcp.ClientSession:
123166
"""Connects to a single MCP server."""
124167

@@ -190,11 +233,32 @@ async def connect_to_server(
190233
return session
191234

192235
async def _establish_session(
193-
self, server_params: StdioServerParameters
236+
self, server_params: ServerParameters
194237
) -> tuple[types.Implementation, mcp.ClientSession]:
195238
"""Establish a client session to an MCP server."""
196-
client = mcp.stdio_client(server_params)
197-
read, write = await self._exit_stack.enter_async_context(client)
239+
240+
# Create read and write streams that facilitate io with the server.
241+
if isinstance(server_params, StdioServerParameters):
242+
client = mcp.stdio_client(server_params)
243+
read, write = await self._exit_stack.enter_async_context(client)
244+
elif isinstance(server_params, SseServerParameters):
245+
client = sse_client(
246+
url=server_params.url,
247+
headers=server_params.headers,
248+
timeout=server_params.timeout,
249+
sse_read_timeout=server_params.sse_read_timeout,
250+
)
251+
read, write = await self._exit_stack.enter_async_context(client)
252+
else:
253+
client = streamablehttp_client(
254+
url=server_params.url,
255+
headers=server_params.headers,
256+
timeout=server_params.timeout,
257+
sse_read_timeout=server_params.sse_read_timeout,
258+
terminate_on_close=server_params.terminate_on_close,
259+
)
260+
read, write, _ = await self._exit_stack.enter_async_context(client)
261+
198262
session = await self._exit_stack.enter_async_context(
199263
mcp.ClientSession(read, write)
200264
)

tests/client/test_session_group.py

Lines changed: 130 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55

66
import mcp
77
from mcp import types
8-
from mcp.client.session_group import ClientSessionGroup
8+
from mcp.client.session_group import (
9+
ClientSessionGroup,
10+
SseServerParameters,
11+
StreamableHttpParameters,
12+
)
913
from mcp.client.stdio import StdioServerParameters
1014
from mcp.shared.exceptions import McpError
1115

@@ -19,12 +23,6 @@ def mock_exit_stack():
1923
return mock.MagicMock(spec=contextlib.AsyncExitStack)
2024

2125

22-
@pytest.fixture
23-
def mock_server_params(): # No mocker needed here
24-
"""Fixture for mocked StdioServerParameters."""
25-
return mock.Mock(spec=StdioServerParameters)
26-
27-
2826
@pytest.mark.anyio
2927
class TestClientSessionGroup:
3028
def test_init(self):
@@ -79,7 +77,7 @@ async def test_call_tool(self):
7977
{"name": "value1", "args": {}},
8078
)
8179

82-
async def test_connect_to_server(self, mock_exit_stack, mock_server_params):
80+
async def test_connect_to_server(self, mock_exit_stack):
8381
"""Test connecting to a server and aggregating components."""
8482
# --- Mock Dependencies ---
8583
mock_server_info = mock.Mock(spec=types.Implementation)
@@ -102,7 +100,7 @@ async def test_connect_to_server(self, mock_exit_stack, mock_server_params):
102100
with mock.patch.object(
103101
group, "_establish_session", return_value=(mock_server_info, mock_session)
104102
):
105-
await group.connect_to_server(mock_server_params)
103+
await group.connect_to_server(StdioServerParameters(command="test"))
106104

107105
# --- Assertions ---
108106
assert mock_session in group._sessions
@@ -120,9 +118,7 @@ async def test_connect_to_server(self, mock_exit_stack, mock_server_params):
120118
mock_session.list_resources.assert_awaited_once()
121119
mock_session.list_prompts.assert_awaited_once()
122120

123-
async def test_connect_to_server_with_name_hook(
124-
self, mock_exit_stack, mock_server_params
125-
):
121+
async def test_connect_to_server_with_name_hook(self, mock_exit_stack):
126122
"""Test connecting with a component name hook."""
127123
# --- Mock Dependencies ---
128124
mock_server_info = mock.Mock(spec=types.Implementation)
@@ -145,7 +141,7 @@ def name_hook(name: str, server_info: types.Implementation) -> str:
145141
with mock.patch.object(
146142
group, "_establish_session", return_value=(mock_server_info, mock_session)
147143
):
148-
await group.connect_to_server(mock_server_params)
144+
await group.connect_to_server(StdioServerParameters(command="test"))
149145

150146
# --- Assertions ---
151147
assert mock_session in group._sessions
@@ -218,9 +214,7 @@ def test_disconnect_from_server(self): # No mock arguments needed
218214
assert "res1" not in group._resources
219215
assert "prm1" not in group._prompts
220216

221-
async def test_connect_to_server_duplicate_tool_raises_error(
222-
self, mock_exit_stack, mock_server_params
223-
):
217+
async def test_connect_to_server_duplicate_tool_raises_error(self, mock_exit_stack):
224218
"""Test McpError raised when connecting a server with a dup name."""
225219
# --- Setup Pre-existing State ---
226220
group = ClientSessionGroup(exit_stack=mock_exit_stack)
@@ -255,7 +249,7 @@ async def test_connect_to_server_duplicate_tool_raises_error(
255249
"_establish_session",
256250
return_value=(mock_server_info_new, mock_session_new),
257251
):
258-
await group.connect_to_server(mock_server_params)
252+
await group.connect_to_server(StdioServerParameters(command="test"))
259253

260254
# Assert details about the raised error
261255
assert excinfo.value.error.code == types.INVALID_PARAMS
@@ -269,9 +263,127 @@ async def test_connect_to_server_duplicate_tool_raises_error(
269263
) # Ensure it's the original mock
270264

271265
# No patching needed here
272-
def test_disconnect_non_existent_server(self): # No mock arguments needed
266+
def test_disconnect_non_existent_server(self):
273267
"""Test disconnecting a server that isn't connected."""
274268
session = mock.Mock(spec=mcp.ClientSession)
275269
group = ClientSessionGroup()
276270
with pytest.raises(McpError):
277271
group.disconnect_from_server(session)
272+
273+
@pytest.mark.parametrize(
274+
"server_params_instance, client_type_name, patch_target_for_client_func",
275+
[
276+
(
277+
StdioServerParameters(command="test_stdio_cmd"),
278+
"stdio",
279+
"mcp.client.session_group.mcp.stdio_client",
280+
),
281+
(
282+
SseServerParameters(url="http://test.com/sse", timeout=10),
283+
"sse",
284+
"mcp.client.session_group.sse_client",
285+
), # url, headers, timeout, sse_read_timeout
286+
(
287+
StreamableHttpParameters(
288+
url="http://test.com/stream", terminate_on_close=False
289+
),
290+
"streamablehttp",
291+
"mcp.client.session_group.streamablehttp_client",
292+
), # url, headers, timeout, sse_read_timeout, terminate_on_close
293+
],
294+
)
295+
async def test_establish_session_parameterized(
296+
self,
297+
server_params_instance,
298+
client_type_name, # Just for clarity or conditional logic if needed
299+
patch_target_for_client_func,
300+
):
301+
with mock.patch(
302+
"mcp.client.session_group.mcp.ClientSession"
303+
) as mock_ClientSession_class:
304+
with mock.patch(patch_target_for_client_func) as mock_specific_client_func:
305+
mock_client_cm_instance = mock.AsyncMock(
306+
name=f"{client_type_name}ClientCM"
307+
)
308+
mock_read_stream = mock.AsyncMock(name=f"{client_type_name}Read")
309+
mock_write_stream = mock.AsyncMock(name=f"{client_type_name}Write")
310+
311+
# streamablehttp_client's __aenter__ returns three values
312+
if client_type_name == "streamablehttp":
313+
mock_extra_stream_val = mock.AsyncMock(name="StreamableExtra")
314+
mock_client_cm_instance.__aenter__.return_value = (
315+
mock_read_stream,
316+
mock_write_stream,
317+
mock_extra_stream_val,
318+
)
319+
else:
320+
mock_client_cm_instance.__aenter__.return_value = (
321+
mock_read_stream,
322+
mock_write_stream,
323+
)
324+
325+
mock_client_cm_instance.__aexit__ = mock.AsyncMock(return_value=None)
326+
mock_specific_client_func.return_value = mock_client_cm_instance
327+
328+
# --- Mock mcp.ClientSession (class) ---
329+
# mock_ClientSession_class is already provided by the outer patch
330+
mock_raw_session_cm = mock.AsyncMock(name="RawSessionCM")
331+
mock_ClientSession_class.return_value = mock_raw_session_cm
332+
333+
mock_entered_session = mock.AsyncMock(name="EnteredSessionInstance")
334+
mock_raw_session_cm.__aenter__.return_value = mock_entered_session
335+
mock_raw_session_cm.__aexit__ = mock.AsyncMock(return_value=None)
336+
337+
# Mock session.initialize()
338+
mock_initialize_result = mock.AsyncMock(name="InitializeResult")
339+
mock_initialize_result.serverInfo = types.Implementation(
340+
name="foo", version="1"
341+
)
342+
mock_entered_session.initialize.return_value = mock_initialize_result
343+
344+
# --- Test Execution ---
345+
group = ClientSessionGroup()
346+
returned_server_info = None
347+
returned_session = None
348+
349+
async with contextlib.AsyncExitStack() as stack:
350+
group._exit_stack = stack
351+
(
352+
returned_server_info,
353+
returned_session,
354+
) = await group._establish_session(server_params_instance)
355+
356+
# --- Assertions ---
357+
# 1. Assert the correct specific client function was called
358+
if client_type_name == "stdio":
359+
mock_specific_client_func.assert_called_once_with(
360+
server_params_instance
361+
)
362+
elif client_type_name == "sse":
363+
mock_specific_client_func.assert_called_once_with(
364+
url=server_params_instance.url,
365+
headers=server_params_instance.headers,
366+
timeout=server_params_instance.timeout,
367+
sse_read_timeout=server_params_instance.sse_read_timeout,
368+
)
369+
elif client_type_name == "streamablehttp":
370+
mock_specific_client_func.assert_called_once_with(
371+
url=server_params_instance.url,
372+
headers=server_pa 94DA rams_instance.headers,
373+
timeout=server_params_instance.timeout,
374+
sse_read_timeout=server_params_instance.sse_read_timeout,
375+
terminate_on_close=server_params_instance.terminate_on_close,
376+
)
377+
378+
mock_client_cm_instance.__aenter__.assert_awaited_once()
379+
380+
# 2. Assert ClientSession was called correctly
381+
mock_ClientSession_class.assert_called_once_with(
382+
mock_read_stream, mock_write_stream
383+
)
384+
mock_raw_session_cm.__aenter__.assert_awaited_once()
385+
mock_entered_session.initialize.assert_awaited_once()
386+
387+
# 3. Assert returned values
388+
assert returned_server_info is mock_initialize_result.serverInfo
389+
assert returned_session is mock_entered_session

0 commit comments

Comments
 (0)
0