5
5
6
6
import mcp
7
7
from mcp import types
8
- from mcp .client .session_group import ClientSessionGroup
8
+ from mcp .client .session_group import (
9
+ ClientSessionGroup ,
10
+ SseServerParameters ,
11
+ StreamableHttpParameters ,
12
+ )
9
13
from mcp .client .stdio import StdioServerParameters
10
14
from mcp .shared .exceptions import McpError
11
15
@@ -19,12 +23,6 @@ def mock_exit_stack():
19
23
return mock .MagicMock (spec = contextlib .AsyncExitStack )
20
24
21
25
22
- @pytest .fixture
23
- def mock_server_params (): # No mocker needed here
24
- """Fixture for mocked StdioServerParameters."""
25
- return mock .Mock (spec = StdioServerParameters )
26
-
27
-
28
26
@pytest .mark .anyio
29
27
class TestClientSessionGroup :
30
28
def test_init (self ):
@@ -79,7 +77,7 @@ async def test_call_tool(self):
79
77
{"name" : "value1" , "args" : {}},
80
78
)
81
79
82
- async def test_connect_to_server (self , mock_exit_stack , mock_server_params ):
80
+ async def test_connect_to_server (self , mock_exit_stack ):
83
81
"""Test connecting to a server and aggregating components."""
84
82
# --- Mock Dependencies ---
85
83
mock_server_info = mock .Mock (spec = types .Implementation )
@@ -102,7 +100,7 @@ async def test_connect_to_server(self, mock_exit_stack, mock_server_params):
102
100
with mock .patch .object (
103
101
group , "_establish_session" , return_value = (mock_server_info , mock_session )
104
102
):
105
- await group .connect_to_server (mock_server_params )
103
+ await group .connect_to_server (StdioServerParameters ( command = "test" ) )
106
104
107
105
# --- Assertions ---
108
106
assert mock_session in group ._sessions
@@ -120,9 +118,7 @@ async def test_connect_to_server(self, mock_exit_stack, mock_server_params):
120
118
mock_session .list_resources .assert_awaited_once ()
121
119
mock_session .list_prompts .assert_awaited_once ()
122
120
123
- async def test_connect_to_server_with_name_hook (
124
- self , mock_exit_stack , mock_server_params
125
- ):
121
+ async def test_connect_to_server_with_name_hook (self , mock_exit_stack ):
126
122
"""Test connecting with a component name hook."""
127
123
# --- Mock Dependencies ---
128
124
mock_server_info = mock .Mock (spec = types .Implementation )
@@ -145,7 +141,7 @@ def name_hook(name: str, server_info: types.Implementation) -> str:
145
141
with mock .patch .object (
146
142
group , "_establish_session" , return_value = (mock_server_info , mock_session )
147
143
):
148
- await group .connect_to_server (mock_server_params )
144
+ await group .connect_to_server (StdioServerParameters ( command = "test" ) )
149
145
150
146
# --- Assertions ---
151
147
assert mock_session in group ._sessions
@@ -218,9 +214,7 @@ def test_disconnect_from_server(self): # No mock arguments needed
218
214
assert "res1" not in group ._resources
219
215
assert "prm1" not in group ._prompts
220
216
221
- async def test_connect_to_server_duplicate_tool_raises_error (
222
- self , mock_exit_stack , mock_server_params
223
- ):
217
+ async def test_connect_to_server_duplicate_tool_raises_error (self , mock_exit_stack ):
224
218
"""Test McpError raised when connecting a server with a dup name."""
225
219
# --- Setup Pre-existing State ---
226
220
group = ClientSessionGroup (exit_stack = mock_exit_stack )
@@ -255,7 +249,7 @@ async def test_connect_to_server_duplicate_tool_raises_error(
255
249
"_establish_session" ,
256
250
return_value = (mock_server_info_new , mock_session_new ),
257
251
):
258
- await group .connect_to_server (mock_server_params )
252
+ await group .connect_to_server (StdioServerParameters ( command = "test" ) )
259
253
260
254
# Assert details about the raised error
261
255
assert excinfo .value .error .code == types .INVALID_PARAMS
@@ -269,9 +263,127 @@ async def test_connect_to_server_duplicate_tool_raises_error(
269
263
) # Ensure it's the original mock
270
264
271
265
# No patching needed here
272
- def test_disconnect_non_existent_server (self ): # No mock arguments needed
266
+ def test_disconnect_non_existent_server (self ):
273
267
"""Test disconnecting a server that isn't connected."""
274
268
session = mock .Mock (spec = mcp .ClientSession )
275
269
group = ClientSessionGroup ()
276
270
with pytest .raises (McpError ):
277
271
group .disconnect_from_server (session )
272
+
273
+ @pytest .mark .parametrize (
274
+ "server_params_instance, client_type_name, patch_target_for_client_func" ,
275
+ [
276
+ (
277
+ StdioServerParameters (command = "test_stdio_cmd" ),
278
+ "stdio" ,
279
+ "mcp.client.session_group.mcp.stdio_client" ,
280
+ ),
281
+ (
282
+ SseServerParameters (url = "http://test.com/sse" , timeout = 10 ),
283
+ "sse" ,
284
+ "mcp.client.session_group.sse_client" ,
285
+ ), # url, headers, timeout, sse_read_timeout
286
+ (
287
+ StreamableHttpParameters (
288
+ url = "http://test.com/stream" , terminate_on_close = False
289
+ ),
290
+ "streamablehttp" ,
291
+ "mcp.client.session_group.streamablehttp_client" ,
292
+ ), # url, headers, timeout, sse_read_timeout, terminate_on_close
293
+ ],
294
+ )
295
+ async def test_establish_session_parameterized (
296
+ self ,
297
+ server_params_instance ,
298
+ client_type_name , # Just for clarity or conditional logic if needed
299
+ patch_target_for_client_func ,
300
+ ):
301
+ with mock .patch (
302
+ "mcp.client.session_group.mcp.ClientSession"
303
+ ) as mock_ClientSession_class :
304
+ with mock .patch (patch_target_for_client_func ) as mock_specific_client_func :
305
+ mock_client_cm_instance = mock .AsyncMock (
306
+ name = f"{ client_type_name } ClientCM"
307
+ )
308
+ mock_read_stream = mock .AsyncMock (name = f"{ client_type_name } Read" )
309
+ mock_write_stream = mock .AsyncMock (name = f"{ client_type_name } Write" )
310
+
311
+ # streamablehttp_client's __aenter__ returns three values
312
+ if client_type_name == "streamablehttp" :
313
+ mock_extra_stream_val = mock .AsyncMock (name = "StreamableExtra" )
314
+ mock_client_cm_instance .__aenter__ .return_value = (
315
+ mock_read_stream ,
316
+ mock_write_stream ,
317
+ mock_extra_stream_val ,
318
+ )
319
+ else :
320
+ mock_client_cm_instance .__aenter__ .return_value = (
321
+ mock_read_stream ,
322
+ mock_write_stream ,
323
+ )
324
+
325
+ mock_client_cm_instance .__aexit__ = mock .AsyncMock (return_value = None )
326
+ mock_specific_client_func .return_value = mock_client_cm_instance
327
+
328
+ # --- Mock mcp.ClientSession (class) ---
329
+ # mock_ClientSession_class is already provided by the outer patch
330
+ mock_raw_session_cm = mock .AsyncMock (name = "RawSessionCM" )
331
+ mock_ClientSession_class .return_value = mock_raw_session_cm
332
+
333
+ mock_entered_session = mock .AsyncMock (name = "EnteredSessionInstance" )
334
+ mock_raw_session_cm .__aenter__ .return_value = mock_entered_session
335
+ mock_raw_session_cm .__aexit__ = mock .AsyncMock (return_value = None )
336
+
337
+ # Mock session.initialize()
338
+ mock_initialize_result = mock .AsyncMock (name = "InitializeResult" )
339
+ mock_initialize_result .serverInfo = types .Implementation (
340
+ name = "foo" , version = "1"
341
+ )
342
+ mock_entered_session .initialize .return_value = mock_initialize_result
343
+
344
+ # --- Test Execution ---
345
+ group = ClientSessionGroup ()
346
+ returned_server_info = None
347
+ returned_session = None
348
+
349
+ async with contextlib .AsyncExitStack () as stack :
350
+ group ._exit_stack = stack
351
+ (
352
+ returned_server_info ,
353
+ returned_session ,
354
+ ) = await group ._establish_session (server_params_instance )
355
+
356
+ # --- Assertions ---
357
+ # 1. Assert the correct specific client function was called
358
+ if client_type_name == "stdio" :
359
+ mock_specific_client_func .assert_called_once_with (
360
+ server_params_instance
361
+ )
362
+ elif client_type_name == "sse" :
363
+ mock_specific_client_func .assert_called_once_with (
364
+ url = server_params_instance .url ,
365
+ headers = server_params_instance .headers ,
366
+ timeout = server_params_instance .timeout ,
367
+ sse_read_timeout = server_params_instance .sse_read_timeout ,
368
+ )
369
+ elif client_type_name == "streamablehttp" :
370
+ mock_specific_client_func .assert_called_once_with (
371
+ url = server_params_instance .url ,
372
+ headers = server_pa
94DA
rams_instance .headers ,
373
+ timeout = server_params_instance .timeout ,
374
+ sse_read_timeout = server_params_instance .sse_read_timeout ,
375
+ terminate_on_close = server_params_instance .terminate_on_close ,
376
+ )
377
+
378
+ mock_client_cm_instance .__aenter__ .assert_awaited_once ()
379
+
380
+ # 2. Assert ClientSession was called correctly
381
+ mock_ClientSession_class .assert_called_once_with (
382
+ mock_read_stream , mock_write_stream
383
+ )
384
+ mock_raw_session_cm .__aenter__ .assert_awaited_once ()
385
+ mock_entered_session .initialize .assert_awaited_once ()
386
+
387
+ # 3. Assert returned values
388
+ assert returned_server_info is mock_initialize_result .serverInfo
389
+ assert returned_session is mock_entered_session
0 commit comments