From 26bac1787ae6cb093c63072bacee91dd4984f998 Mon Sep 17 00:00:00 2001 From: Mo Date: Tue, 6 May 2025 16:05:54 +0000 Subject: [PATCH 1/3] Create `ClientSessionGroup` for managing multiple session connections. This abstraction concurrently manages multiple MCP session connections. Tools, resources, and prompts are aggregated across servers. Servers may be connected to or disconnected from at any point after initialization. This abstractions can handle naming collisions using a custom user-provided hook. --- src/mcp/__init__.py | 2 + src/mcp/client/session_group.py | 207 +++++++++++++++++++++ tests/client/test_session_group.py | 277 +++++++++++++++++++++++++++++ 3 files changed, 486 insertions(+) create mode 100644 src/mcp/client/session_group.py create mode 100644 tests/client/test_session_group.py diff --git a/src/mcp/__init__.py b/src/mcp/__init__.py index 0d3c372ce..e93b95c90 100644 --- a/src/mcp/__init__.py +++ b/src/mcp/__init__.py @@ -1,4 +1,5 @@ from .client.session import ClientSession +from .client.session_group import ClientSessionGroup from .client.stdio import StdioServerParameters, stdio_client from .server.session import ServerSession from .server.stdio import stdio_server @@ -63,6 +64,7 @@ "ClientRequest", "ClientResult", "ClientSession", + "ClientSessionGroup", "CreateMessageRequest", "CreateMessageResult", "ErrorData", diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py new file mode 100644 index 000000000..724d6a439 --- /dev/null +++ b/src/mcp/client/session_group.py @@ -0,0 +1,207 @@ +""" +SessionGroup concurrently manages multiple MCP session connections. + +Tools, resources, and prompts are aggregated across servers. Servers may +be connected to or disconnected from at any point after initialization. + +This abstractions can handle naming collisions using a custom user-provided +hook. +""" + +import contextlib +from collections.abc import Callable +from typing import Any, TypeAlias + +from pydantic import BaseModel + +import mcp +from mcp import types +from mcp.client.stdio import StdioServerParameters +from mcp.shared.exceptions import McpError + + +class ClientSessionGroup: + """Client for managing connections to multiple MCP servers. + + This class is responsible for encapsulating management of server connections. + It it aggregates tools, resources, and prompts from all connected servers. + + For auxiliary handlers, such as resource subscription, this is delegated to + the client and can be accessed via the session. For example: + mcp_session_group.get_session("server_name").subscribe_to_resource(...) + """ + + class _ComponentNames(BaseModel): + """Used for reverse index to find components.""" + + prompts: set[str] = set() + resources: set[str] = set() + tools: set[str] = set() + + # Standard MCP components. + _prompts: dict[str, types.Prompt] + _resources: dict[str, types.Resource] + _tools: dict[str, types.Tool] + + # Client-server connection management. + _sessions: dict[mcp.ClientSession, _ComponentNames] + _tool_to_session: dict[str, mcp.ClientSession] + _exit_stack: contextlib.AsyncExitStack + + # Optional fn consuming (component_name, serverInfo) for custom names. + # This is provide a means to mitigate naming conflicts across servers. + # Example: (tool_name, serverInfo) => "{result.serverInfo.name}.{tool_name}" + _ComponentNameHook: TypeAlias = Callable[[str, types.Implementation], str] + _component_name_hook: _ComponentNameHook | None + + def __init__( + self, + exit_stack: contextlib.AsyncExitStack = contextlib.AsyncExitStack(), + component_name_hook: _ComponentNameHook | None = None, + ) -> None: + """Initializes the MCP client.""" + + self._tools = {} + self._resources = {} + self._prompts = {} + + self._sessions = {} + self._tool_to_session = {} + self._exit_stack = exit_stack + self._component_name_hook = component_name_hook + + @property + def prompts(self) -> dict[str, types.Prompt]: + """Returns the prompts as a dictionary of names to prompts.""" + return self._prompts + + @property + def resources(self) -> dict[str, types.Resource]: + """Returns the resources as a dictionary of names to resources.""" + return self._resources + + @property + def tools(self) -> dict[str, types.Tool]: + """Returns the tools as a dictionary of names to tools.""" + return self._tools + + async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResult: + """Executes a tool given its name and arguments.""" + session = self._tool_to_session[name] + return await session.call_tool(name, args) + + def disconnect_from_server(self, session: mcp.ClientSession) -> None: + """Disconnects from a single MCP server.""" + + if session not in self._sessions: + raise McpError( + types.ErrorData( + code=types.INVALID_PARAMS, + message="Provided session is not being managed.", + ) + ) + component_names = self._sessions[session] + + # Remove prompts associated with the session. + for name in component_names.prompts: + del self._prompts[name] + + # Remove resources associated with the session. + for name in component_names.resources: + del self._resources[name] + + # Remove tools associated with the session. + for name in component_names.tools: + del self._tools[name] + + del self._sessions[session] + + async def connect_to_server( + self, + server_params: StdioServerParameters, + ) -> mcp.ClientSession: + """Connects to a single MCP server.""" + + # Establish server connection and create session. + server_info, session = await self._establish_session(server_params) + + # Create a reverse index so we can find all prompts, resources, and + # tools belonging to this session. Used for removing components from + # the session group via self.disconnect_from_server. + component_names = self._ComponentNames() + + # Temporary components dicts. We do not want to modify the aggregate + # lists in case of an intermediate failure. + prompts_temp: dict[str, types.Prompt] = {} + resources_temp: dict[str, types.Resource] = {} + tools_temp: dict[str, types.Tool] = {} + tool_to_session_temp: dict[str, mcp.ClientSession] = {} + + # Query the server for its prompts and aggregate to list. + prompts = (await session.list_prompts()).prompts + for prompt in prompts: + name = self._component_name(prompt.name, server_info) + if name in self._prompts: + raise McpError( + types.ErrorData( + code=types.INVALID_PARAMS, + message=f"{name} already exists in group prompts.", + ) + ) + prompts_temp[name] = prompt + component_names.prompts.add(name) + + # Query the server for its resources and aggregate to list. + resources = (await session.list_resources()).resources + for resource in resources: + name = self._component_name(resource.name, server_info) + if name in self._resources: + raise McpError( + types.ErrorData( + code=types.INVALID_PARAMS, + message=f"{name} already exists in group resources.", + ) + ) + resources_temp[name] = resource + component_names.resources.add(name) + + # Query the server for its tools and aggregate to list. + tools = (await session.list_tools()).tools + for tool in tools: + name = self._component_name(tool.name, server_info) + if name in self._tools: + raise McpError( + types.ErrorData( + code=types.INVALID_PARAMS, + message=f"{name} already exists in group tools.", + ) + ) + tools_temp[name] = tool + tool_to_session_temp[name] = session + component_names.tools.add(name) + + # Aggregate components. + self._sessions[session] = component_names + self._prompts.update(prompts_temp) + self._resources.update(resources_temp) + self._tools.update(tools_temp) + self._tool_to_session.update(tool_to_session_temp) + + return session + + async def _establish_session( + self, server_params: StdioServerParameters + ) -> tuple[types.Implementation, mcp.ClientSession]: + """Establish a client session to an MCP server.""" + client = mcp.stdio_client(server_params) + read, write = await self._exit_stack.enter_async_context(client) + session = await self._exit_stack.enter_async_context( + mcp.ClientSession(read, write) + ) + result = await session.initialize() + return result.serverInfo, session + + def _component_name(self, name: str, server_info: types.Implementation) -> str: + if self._component_name_hook: + return self._component_name_hook(name, server_info) + return name diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py new file mode 100644 index 000000000..5219009be --- /dev/null +++ b/tests/client/test_session_group.py @@ -0,0 +1,277 @@ +import contextlib +from unittest import mock + +import pytest + +import mcp +from mcp import types +from mcp.client.session_group import ClientSessionGroup +from mcp.client.stdio import StdioServerParameters +from mcp.shared.exceptions import McpError + + +@pytest.fixture +def mock_exit_stack(): + """Fixture for a mocked AsyncExitStack.""" + # Use unittest.mock.Mock directly if needed, or just a plain object + # if only attribute access/existence is needed. + # For AsyncExitStack, Mock or MagicMock is usually fine. + return mock.MagicMock(spec=contextlib.AsyncExitStack) + + +@pytest.fixture +def mock_server_params(): # No mocker needed here + """Fixture for mocked StdioServerParameters.""" + return mock.Mock(spec=StdioServerParameters) + + +@pytest.mark.anyio +class TestClientSessionGroup: + def test_init(self): + mcp_session_group = ClientSessionGroup() + assert not mcp_session_group._tools + assert not mcp_session_group._resources + assert not mcp_session_group._prompts + assert not mcp_session_group._tool_to_session + + def test_component_properties(self): + # --- Mock Dependencies --- + mock_prompt = mock.Mock() + mock_resource = mock.Mock() + mock_tool = mock.Mock() + + # --- Prepare Session Group --- + mcp_session_group = ClientSessionGroup() + mcp_session_group._prompts = {"my_prompt": mock_prompt} + mcp_session_group._resources = {"my_resource": mock_resource} + mcp_session_group._tools = {"my_tool": mock_tool} + + # --- Assertions --- + assert mcp_session_group.prompts == {"my_prompt": mock_prompt} + assert mcp_session_group.resources == {"my_resource": mock_resource} + assert mcp_session_group.tools == {"my_tool": mock_tool} + + async def test_call_tool(self): + # --- Mock Dependencies --- + mock_session = mock.AsyncMock() + + # --- Prepare Session Group --- + mcp_session_group = ClientSessionGroup() + mcp_session_group._tool_to_session = {"my_tool": mock_session} + text_content = types.TextContent(type="text", text="OK") + mock_session.call_tool.return_value = types.CallToolResult( + content=[text_content] + ) + + # --- Test Execution --- + result = await mcp_session_group.call_tool( + name="my_tool", + args={ + "name": "value1", + "args": {}, + }, + ) + + # --- Assertions --- + assert result.content == [text_content] + mock_session.call_tool.assert_called_once_with( + "my_tool", + {"name": "value1", "args": {}}, + ) + + async def test_connect_to_server(self, mock_exit_stack, mock_server_params): + """Test connecting to a server and aggregating components.""" + # --- Mock Dependencies --- + mock_server_info = mock.Mock(spec=types.Implementation) + mock_server_info.name = "TestServer1" + mock_session = mock.AsyncMock(spec=mcp.ClientSession) + mock_tool1 = mock.Mock(spec=types.Tool) + mock_tool1.name = "tool_a" + mock_resource1 = mock.Mock(spec=types.Resource) + mock_resource1.name = "resource_b" + mock_prompt1 = mock.Mock(spec=types.Prompt) + mock_prompt1.name = "prompt_c" + mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool1]) + mock_session.list_resources.return_value = mock.AsyncMock( + resources=[mock_resource1] + ) + mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[mock_prompt1]) + + # --- Test Execution --- + group = ClientSessionGroup(exit_stack=mock_exit_stack) + with mock.patch.object( + group, "_establish_session", return_value=(mock_server_info, mock_session) + ): + await group.connect_to_server(mock_server_params) + + # --- Assertions --- + assert mock_session in group._sessions + assert len(group.tools) == 1 + assert "tool_a" in group.tools + assert group.tools["tool_a"] == mock_tool1 + assert group._tool_to_session["tool_a"] == mock_session + assert len(group.resources) == 1 + assert "resource_b" in group.resources + assert group.resources["resource_b"] == mock_resource1 + assert len(group.prompts) == 1 + assert "prompt_c" in group.prompts + assert group.prompts["prompt_c"] == mock_prompt1 + mock_session.list_tools.assert_awaited_once() + mock_session.list_resources.assert_awaited_once() + mock_session.list_prompts.assert_awaited_once() + + async def test_connect_to_server_with_name_hook( + self, mock_exit_stack, mock_server_params + ): + """Test connecting with a component name hook.""" + # --- Mock Dependencies --- + mock_server_info = mock.Mock(spec=types.Implementation) + mock_server_info.name = "HookServer" + mock_session = mock.AsyncMock(spec=mcp.ClientSession) + mock_tool = mock.Mock(spec=types.Tool) + mock_tool.name = "base_tool" + mock_session.list_tools.return_value = mock.AsyncMock(tools=[mock_tool]) + mock_session.list_resources.return_value = mock.AsyncMock(resources=[]) + mock_session.list_prompts.return_value = mock.AsyncMock(prompts=[]) + + # --- Test Setup --- + def name_hook(name: str, server_info: types.Implementation) -> str: + return f"{server_info.name}.{name}" + + # --- Test Execution --- + group = ClientSessionGroup( + exit_stack=mock_exit_stack, component_name_hook=name_hook + ) + with mock.patch.object( + group, "_establish_session", return_value=(mock_server_info, mock_session) + ): + await group.connect_to_server(mock_server_params) + + # --- Assertions --- + assert mock_session in group._sessions + assert len(group.tools) == 1 + expected_tool_name = "HookServer.base_tool" + assert expected_tool_name in group.tools + assert group.tools[expected_tool_name] == mock_tool + assert group._tool_to_session[expected_tool_name] == mock_session + + def test_disconnect_from_server(self): # No mock arguments needed + """Test disconnecting from a server.""" + # --- Test Setup --- + group = ClientSessionGroup() + server_name = "ServerToDisconnect" + + # Manually populate state using standard mocks + mock_session1 = mock.MagicMock(spec=mcp.ClientSession) + mock_session2 = mock.MagicMock(spec=mcp.ClientSession) + mock_tool1 = mock.Mock(spec=types.Tool) + mock_tool1.name = "tool1" + mock_resource1 = mock.Mock(spec=types.Resource) + mock_resource1.name = "res1" + mock_prompt1 = mock.Mock(spec=types.Prompt) + mock_prompt1.name = "prm1" + mock_tool2 = mock.Mock(spec=types.Tool) + mock_tool2.name = "tool2" + mock_component_named_like_server = mock.Mock() + mock_session = mock.Mock(spec=mcp.ClientSession) + + group._tools = { + "tool1": mock_tool1, + "tool2": mock_tool2, + server_name: mock_component_named_like_server, + } + group._tool_to_session = { + "tool1": mock_session1, + "tool2": mock_session2, + server_name: mock_session1, + } + group._resources = { + "res1": mock_resource1, + server_name: mock_component_named_like_server, + } + group._prompts = { + "prm1": mock_prompt1, + server_name: mock_component_named_like_server, + } + group._sessions = { + mock_session: ClientSessionGroup._ComponentNames( + prompts=set({"prm1"}), + resources=set({"res1"}), + tools=set({"tool1", "tool2"}), + ) + } + + # --- Assertions --- + assert mock_session in group._sessions + assert "tool1" in group._tools + assert "tool2" in group._tools + assert "res1" in group._resources + assert "prm1" in group._prompts + + # --- Test Execution --- + group.disconnect_from_server(mock_session) + + # --- Assertions --- + assert mock_session not in group._sessions + assert "tool1" not in group._tools + assert "tool2" not in group._tools + assert "res1" not in group._resources + assert "prm1" not in group._prompts + + async def test_connect_to_server_duplicate_tool_raises_error( + self, mock_exit_stack, mock_server_params + ): + """Test McpError raised when connecting a server with a dup name.""" + # --- Setup Pre-existing State --- + group = ClientSessionGroup(exit_stack=mock_exit_stack) + existing_tool_name = "shared_tool" + # Manually add a tool to simulate a previous connection + group._tools[existing_tool_name] = mock.Mock(spec=types.Tool) + group._tools[existing_tool_name].name = existing_tool_name + # Need a dummy session associated with the existing tool + group._tool_to_session[existing_tool_name] = mock.MagicMock( + spec=mcp.ClientSession + ) + + # --- Mock New Connection Attempt --- + mock_server_info_new = mock.Mock(spec=types.Implementation) + mock_server_info_new.name = "ServerWithDuplicate" + mock_session_new = mock.AsyncMock(spec=mcp.ClientSession) + + # Configure the new session to return a tool with the *same name* + duplicate_tool = mock.Mock(spec=types.Tool) + duplicate_tool.name = existing_tool_name + mock_session_new.list_tools.return_value = mock.AsyncMock( + tools=[duplicate_tool] + ) + # Keep other lists empty for simplicity + mock_session_new.list_resources.return_value = mock.AsyncMock(resources=[]) + mock_session_new.list_prompts.return_value = mock.AsyncMock(prompts=[]) + + # --- Test Execution and Assertion --- + with pytest.raises(McpError) as excinfo: + with mock.patch.object( + group, + "_establish_session", + return_value=(mock_server_info_new, mock_session_new), + ): + await group.connect_to_server(mock_server_params) + + # Assert details about the raised error + assert excinfo.value.error.code == types.INVALID_PARAMS + assert existing_tool_name in excinfo.value.error.message + assert "already exists in group tools" in excinfo.value.error.message + + # Verify the duplicate tool was *not* added again (state should be unchanged) + assert len(group._tools) == 1 # Should still only have the original + assert ( + group._tools[existing_tool_name] is not duplicate_tool + ) # Ensure it's the original mock + + # No patching needed here + def test_disconnect_non_existent_server(self): # No mock arguments needed + """Test disconnecting a server that isn't connected.""" + session = mock.Mock(spec=mcp.ClientSession) + group = ClientSessionGroup() + with pytest.raises(McpError): + group.disconnect_from_server(session) From 14bfcf1dad882f5aa1ae364eca9df2c0b1a13f66 Mon Sep 17 00:00:00 2001 From: Mo Date: Fri, 9 May 2025 01:23:33 +0000 Subject: [PATCH 2/3] Add support for sse and streamable http transports in the ClientSessionGroup This change introduces additional support for the other MCP transports and includes respective server parameter types to align with the stdio client implementation. --- src/mcp/client/session_group.py | 72 +++++++++++++- tests/client/test_session_group.py | 148 +++++++++++++++++++++++++---- 2 files changed, 198 insertions(+), 22 deletions(-) diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 724d6a439..0c518de80 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -10,16 +10,59 @@ import contextlib from collections.abc import Callable +from datetime import timedelta from typing import Any, TypeAlias from pydantic import BaseModel import mcp from mcp import types +from mcp.client.sse import sse_client from mcp.client.stdio import StdioServerParameters +from mcp.client.streamable_http import streamablehttp_client from mcp.shared.exceptions import McpError +class SseServerParameters(BaseModel): + """Parameters for intializing a sse_client.""" + + # The endpoint URL. + url: str + + # Optional headers to include in requests. + headers: dict[str, Any] | None = None + + # HTTP timeout for regular operations. + timeout: float = 5 + + # Timeout for SSE read operations. + sse_read_timeout: float = 60 * 5 + + +class StreamableHttpParameters(BaseModel): + """Parameters for intializing a streamablehttp_client.""" + + # The endpoint URL. + url: str + + # Optional headers to include in requests. + headers: dict[str, Any] | None = None + + # HTTP timeout for regular operations. + timeout: timedelta = timedelta(seconds=30) + + # Timeout for SSE read operations. + sse_read_timeout: timedelta = timedelta(seconds=60 * 5) + + # Close the client session when the transport closes. + terminate_on_close: bool = True + + +ServerParameters: TypeAlias = ( + StdioServerParameters | SseServerParameters | StreamableHttpParameters +) + + class ClientSessionGroup: """Client for managing connections to multiple MCP servers. @@ -118,7 +161,7 @@ def disconnect_from_server(self, session: mcp.ClientSession) -> None: async def connect_to_server( self, - server_params: StdioServerParameters, + server_params: ServerParameters, ) -> mcp.ClientSession: """Connects to a single MCP server.""" @@ -190,11 +233,32 @@ async def connect_to_server( return session async def _establish_session( - self, server_params: StdioServerParameters + self, server_params: ServerParameters ) -> tuple[types.Implementation, mcp.ClientSession]: """Establish a client session to an MCP server.""" - client = mcp.stdio_client(server_params) - read, write = await self._exit_stack.enter_async_context(client) + + # Create read and write streams that facilitate io with the server. + if isinstance(server_params, StdioServerParameters): + client = mcp.stdio_client(server_params) + read, write = await self._exit_stack.enter_async_context(client) + elif isinstance(server_params, SseServerParameters): + client = sse_client( + url=server_params.url, + headers=server_params.headers, + timeout=server_params.timeout, + sse_read_timeout=server_params.sse_read_timeout, + ) + read, write = await self._exit_stack.enter_async_context(client) + else: + client = streamablehttp_client( + url=server_params.url, + headers=server_params.headers, + timeout=server_params.timeout, + sse_read_timeout=server_params.sse_read_timeout, + terminate_on_close=server_params.terminate_on_close, + ) + read, write, _ = await self._exit_stack.enter_async_context(client) + session = await self._exit_stack.enter_async_context( mcp.ClientSession(read, write) ) diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index 5219009be..22687d21d 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -5,7 +5,11 @@ import mcp from mcp import types -from mcp.client.session_group import ClientSessionGroup +from mcp.client.session_group import ( + ClientSessionGroup, + SseServerParameters, + StreamableHttpParameters, +) from mcp.client.stdio import StdioServerParameters from mcp.shared.exceptions import McpError @@ -19,12 +23,6 @@ def mock_exit_stack(): return mock.MagicMock(spec=contextlib.AsyncExitStack) -@pytest.fixture -def mock_server_params(): # No mocker needed here - """Fixture for mocked StdioServerParameters.""" - return mock.Mock(spec=StdioServerParameters) - - @pytest.mark.anyio class TestClientSessionGroup: def test_init(self): @@ -79,7 +77,7 @@ async def test_call_tool(self): {"name": "value1", "args": {}}, ) - async def test_connect_to_server(self, mock_exit_stack, mock_server_params): + async def test_connect_to_server(self, mock_exit_stack): """Test connecting to a server and aggregating components.""" # --- Mock Dependencies --- mock_server_info = mock.Mock(spec=types.Implementation) @@ -102,7 +100,7 @@ async def test_connect_to_server(self, mock_exit_stack, mock_server_params): with mock.patch.object( group, "_establish_session", return_value=(mock_server_info, mock_session) ): - await group.connect_to_server(mock_server_params) + await group.connect_to_server(StdioServerParameters(command="test")) # --- Assertions --- assert mock_session in group._sessions @@ -120,9 +118,7 @@ async def test_connect_to_server(self, mock_exit_stack, mock_server_params): mock_session.list_resources.assert_awaited_once() mock_session.list_prompts.assert_awaited_once() - async def test_connect_to_server_with_name_hook( - self, mock_exit_stack, mock_server_params - ): + async def test_connect_to_server_with_name_hook(self, mock_exit_stack): """Test connecting with a component name hook.""" # --- Mock Dependencies --- mock_server_info = mock.Mock(spec=types.Implementation) @@ -145,7 +141,7 @@ def name_hook(name: str, server_info: types.Implementation) -> str: with mock.patch.object( group, "_establish_session", return_value=(mock_server_info, mock_session) ): - await group.connect_to_server(mock_server_params) + await group.connect_to_server(StdioServerParameters(command="test")) # --- Assertions --- assert mock_session in group._sessions @@ -218,9 +214,7 @@ def test_disconnect_from_server(self): # No mock arguments needed assert "res1" not in group._resources assert "prm1" not in group._prompts - async def test_connect_to_server_duplicate_tool_raises_error( - self, mock_exit_stack, mock_server_params - ): + async def test_connect_to_server_duplicate_tool_raises_error(self, mock_exit_stack): """Test McpError raised when connecting a server with a dup name.""" # --- Setup Pre-existing State --- group = ClientSessionGroup(exit_stack=mock_exit_stack) @@ -255,7 +249,7 @@ async def test_connect_to_server_duplicate_tool_raises_error( "_establish_session", return_value=(mock_server_info_new, mock_session_new), ): - await group.connect_to_server(mock_server_params) + await group.connect_to_server(StdioServerParameters(command="test")) # Assert details about the raised error assert excinfo.value.error.code == types.INVALID_PARAMS @@ -269,9 +263,127 @@ async def test_connect_to_server_duplicate_tool_raises_error( ) # Ensure it's the original mock # No patching needed here - def test_disconnect_non_existent_server(self): # No mock arguments needed + def test_disconnect_non_existent_server(self): """Test disconnecting a server that isn't connected.""" session = mock.Mock(spec=mcp.ClientSession) group = ClientSessionGroup() with pytest.raises(McpError): group.disconnect_from_server(session) + + @pytest.mark.parametrize( + "server_params_instance, client_type_name, patch_target_for_client_func", + [ + ( + StdioServerParameters(command="test_stdio_cmd"), + "stdio", + "mcp.client.session_group.mcp.stdio_client", + ), + ( + SseServerParameters(url="http://test.com/sse", timeout=10), + "sse", + "mcp.client.session_group.sse_client", + ), # url, headers, timeout, sse_read_timeout + ( + StreamableHttpParameters( + url="http://test.com/stream", terminate_on_close=False + ), + "streamablehttp", + "mcp.client.session_group.streamablehttp_client", + ), # url, headers, timeout, sse_read_timeout, terminate_on_close + ], + ) + async def test_establish_session_parameterized( + self, + server_params_instance, + client_type_name, # Just for clarity or conditional logic if needed + patch_target_for_client_func, + ): + with mock.patch( + "mcp.client.session_group.mcp.ClientSession" + ) as mock_ClientSession_class: + with mock.patch(patch_target_for_client_func) as mock_specific_client_func: + mock_client_cm_instance = mock.AsyncMock( + name=f"{client_type_name}ClientCM" + ) + mock_read_stream = mock.AsyncMock(name=f"{client_type_name}Read") + mock_write_stream = mock.AsyncMock(name=f"{client_type_name}Write") + + # streamablehttp_client's __aenter__ returns three values + if client_type_name == "streamablehttp": + mock_extra_stream_val = mock.AsyncMock(name="StreamableExtra") + mock_client_cm_instance.__aenter__.return_value = ( + mock_read_stream, + mock_write_stream, + mock_extra_stream_val, + ) + else: + mock_client_cm_instance.__aenter__.return_value = ( + mock_read_stream, + mock_write_stream, + ) + + mock_client_cm_instance.__aexit__ = mock.AsyncMock(return_value=None) + mock_specific_client_func.return_value = mock_client_cm_instance + + # --- Mock mcp.ClientSession (class) --- + # mock_ClientSession_class is already provided by the outer patch + mock_raw_session_cm = mock.AsyncMock(name="RawSessionCM") + mock_ClientSession_class.return_value = mock_raw_session_cm + + mock_entered_session = mock.AsyncMock(name="EnteredSessionInstance") + mock_raw_session_cm.__aenter__.return_value = mock_entered_session + mock_raw_session_cm.__aexit__ = mock.AsyncMock(return_value=None) + + # Mock session.initialize() + mock_initialize_result = mock.AsyncMock(name="InitializeResult") + mock_initialize_result.serverInfo = types.Implementation( + name="foo", version="1" + ) + mock_entered_session.initialize.return_value = mock_initialize_result + + # --- Test Execution --- + group = ClientSessionGroup() + returned_server_info = None + returned_session = None + + async with contextlib.AsyncExitStack() as stack: + group._exit_stack = stack + ( + returned_server_info, + returned_session, + ) = await group._establish_session(server_params_instance) + + # --- Assertions --- + # 1. Assert the correct specific client function was called + if client_type_name == "stdio": + mock_specific_client_func.assert_called_once_with( + server_params_instance + ) + elif client_type_name == "sse": + mock_specific_client_func.assert_called_once_with( + url=server_params_instance.url, + headers=server_params_instance.headers, + timeout=server_params_instance.timeout, + sse_read_timeout=server_params_instance.sse_read_timeout, + ) + elif client_type_name == "streamablehttp": + mock_specific_client_func.assert_called_once_with( + url=server_params_instance.url, + headers=server_params_instance.headers, + timeout=server_params_instance.timeout, + sse_read_timeout=server_params_instance.sse_read_timeout, + terminate_on_close=server_params_instance.terminate_on_close, + ) + + mock_client_cm_instance.__aenter__.assert_awaited_once() + + # 2. Assert ClientSession was called correctly + mock_ClientSession_class.assert_called_once_with( + mock_read_stream, mock_write_stream + ) + mock_raw_session_cm.__aenter__.assert_awaited_once() + mock_entered_session.initialize.assert_awaited_once() + + # 3. Assert returned values + assert returned_server_info is mock_initialize_result.serverInfo + assert returned_session is mock_entered_session From e798eeb12a2d44d8705d67587f783cd2bc906bdd Mon Sep 17 00:00:00 2001 From: Mo Date: Sat, 10 May 2025 17:26:57 +0000 Subject: [PATCH 3/3] Add support for async context management to `ClientSessionGroup` This changes enables context management for setting up and tearing down async exit stacks durring server connection and disconnection respectively. Documentation has been added to show an example use case that demonstrates how `ClientSessionGroup` can be used with `async with`. --- src/mcp/client/session_group.py | 293 +++++++++++++++++++---------- tests/client/test_session_group.py | 28 ++- 2 files changed, 215 insertions(+), 106 deletions(-) diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 0c518de80..c23f2523e 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -9,11 +9,15 @@ """ import contextlib +import logging from collections.abc import Callable from datetime import timedelta +from types import TracebackType from typing import Any, TypeAlias +import anyio from pydantic import BaseModel +from typing_extensions import Self import mcp from mcp import types @@ -67,11 +71,18 @@ class ClientSessionGroup: """Client for managing connections to multiple MCP servers. This class is responsible for encapsulating management of server connections. - It it aggregates tools, resources, and prompts from all connected servers. + It aggregates tools, resources, and prompts from all connected servers. For auxiliary handlers, such as resource subscription, this is delegated to - the client and can be accessed via the session. For example: - mcp_session_group.get_session("server_name").subscribe_to_resource(...) + the client and can be accessed via the session. + + Example Usage: + name_fn = lambda name, server_info: f"{(server_info.name)}-{name}" + async with ClientSessionGroup(component_name_hook=name_fn) as group: + for server_params in server_params: + group.connect_to_server(server_param) + ... + """ class _ComponentNames(BaseModel): @@ -90,6 +101,7 @@ class _ComponentNames(BaseModel): _sessions: dict[mcp.ClientSession, _ComponentNames] _tool_to_session: dict[str, mcp.ClientSession] _exit_stack: contextlib.AsyncExitStack + _session_exit_stacks: dict[mcp.ClientSession, contextlib.AsyncExitStack] # Optional fn consuming (component_name, serverInfo) for custom names. # This is provide a means to mitigate naming conflicts across servers. @@ -99,7 +111,7 @@ class _ComponentNames(BaseModel): def __init__( self, - exit_stack: contextlib.AsyncExitStack = contextlib.AsyncExitStack(), + exit_stack: contextlib.AsyncExitStack | None = None, component_name_hook: _ComponentNameHook | None = None, ) -> None: """Initializes the MCP client.""" @@ -110,9 +122,43 @@ def __init__( self._sessions = {} self._tool_to_session = {} - self._exit_stack = exit_stack + if exit_stack is None: + self._exit_stack = contextlib.AsyncExitStack() + self._owns_exit_stack = True + else: + self._exit_stack = exit_stack + self._owns_exit_stack = False + self._session_exit_stacks = {} self._component_name_hook = component_name_hook + async def __aenter__(self) -> Self: + # Enter the exit stack only if we created it ourselves + if self._owns_exit_stack: + await self._exit_stack.__aenter__() + return self + + async def __aexit__( + self, + _exc_type: type[BaseException] | None, + _exc_val: BaseException | None, + _exc_tb: TracebackType | None, + ) -> bool | None: + """Closes session exit stacks and main exit stack upon completion.""" + + # Concurrently close session stacks. + async with anyio.create_task_group() as tg: + for exit_stack in self._session_exit_stacks.values(): + tg.start_soon(exit_stack.aclose) + + # Only close the main exit stack if we created it + if self._owns_exit_stack: + await self._exit_stack.aclose() + + @property + def sessions(self) -> list[mcp.ClientSession]: + """Returns the list of sessions being managed.""" + return list(self._sessions.keys()) + @property def prompts(self) -> dict[str, types.Prompt]: """Returns the prompts as a dictionary of names to prompts.""" @@ -131,42 +177,113 @@ def tools(self) -> dict[str, types.Tool]: async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResult: """Executes a tool given its name and arguments.""" session = self._tool_to_session[name] - return await session.call_tool(name, args) + session_tool_name = self.tools[name].name + return await session.call_tool(session_tool_name, args) - def disconnect_from_server(self, session: mcp.ClientSession) -> None: + async def disconnect_from_server(self, session: mcp.ClientSession) -> None: """Disconnects from a single MCP server.""" - if session not in self._sessions: + session_known_for_components = session in self._sessions + session_known_for_stack = session in self._session_exit_stacks + + if not session_known_for_components and not session_known_for_stack: raise McpError( types.ErrorData( code=types.INVALID_PARAMS, - message="Provided session is not being managed.", + message="Provided session is not managed or already disconnected.", ) ) - component_names = self._sessions[session] - - # Remove prompts associated with the session. - for name in component_names.prompts: - del self._prompts[name] - # Remove resources associated with the session. - for name in component_names.resources: - del self._resources[name] - - # Remove tools associated with the session. - for name in component_names.tools: - del self._tools[name] - - del self._sessions[session] + if session_known_for_components: + component_names = self._sessions.pop(session) # Pop from _sessions tracking + + # Remove prompts associated with the session. + for name in component_names.prompts: + if name in self._prompts: + del self._prompts[name] + # Remove resources associated with the session. + for name in component_names.resources: + if name in self._resources: + del self._resources[name] + # Remove tools associated with the session. + for name in component_names.tools: + if name in self._tools: + del self._tools[name] + if name in self._tool_to_session: + del self._tool_to_session[name] + + # Clean up the session's resources via its dedicated exit stack + if session_known_for_stack: + session_stack_to_close = self._session_exit_stacks.pop(session) + await session_stack_to_close.aclose() + + async def connect_with_session( + self, server_info: types.Implementation, session: mcp.ClientSession + ) -> mcp.ClientSession: + """Connects to a single MCP server.""" + await self._aggregate_components(server_info, session) + return session async def connect_to_server( self, server_params: ServerParameters, ) -> mcp.ClientSession: """Connects to a single MCP server.""" - - # Establish server connection and create session. server_info, session = await self._establish_session(server_params) + return await self.connect_with_session(server_info, session) + + async def _establish_session( + self, server_params: ServerParameters + ) -> tuple[types.Implementation, mcp.ClientSession]: + """Establish a client session to an MCP server.""" + + session_stack = contextlib.AsyncExitStack() + try: + # Create read and write streams that facilitate io with the server. + if isinstance(server_params, StdioServerParameters): + client = mcp.stdio_client(server_params) + read, write = await session_stack.enter_async_context(client) + elif isinstance(server_params, SseServerParameters): + client = sse_client( + url=server_params.url, + headers=server_params.headers, + timeout=server_params.timeout, + sse_read_timeout=server_params.sse_read_timeout, + ) + read, write = await session_stack.enter_async_context(client) + else: + client = streamablehttp_client( + url=server_params.url, + headers=server_params.headers, + timeout=server_params.timeout, + sse_read_timeout=server_params.sse_read_timeout, + terminate_on_close=server_params.terminate_on_close, + ) + read, write, _ = await session_stack.enter_async_context(client) + + session = await session_stack.enter_async_context( + mcp.ClientSession(read, write) + ) + result = await session.initialize() + + # Session successfully initialized. + # Store its stack and register the stack with the main group stack. + self._session_exit_stacks[session] = session_stack + # session_stack itself becomes a resource managed by the + # main _exit_stack. + await self._exit_stack.enter_async_context(session_stack) + + return result.serverInfo, session + except Exception: + # If anything during this setup fails, ensure the session-specific + # stack is closed. + await session_stack.aclose() + raise + + async def _aggregate_components( + self, server_info: types.Implementation, session: mcp.ClientSession + ) -> None: + """Aggregates prompts, resources, and tools from a given session.""" # Create a reverse index so we can find all prompts, resources, and # tools belonging to this session. Used for removing components from @@ -181,47 +298,66 @@ async def connect_to_server( tool_to_session_temp: dict[str, mcp.ClientSession] = {} # Query the server for its prompts and aggregate to list. - prompts = (await session.list_prompts()).prompts - for prompt in prompts: - name = self._component_name(prompt.name, server_info) - if name in self._prompts: - raise McpError( - types.ErrorData( - code=types.INVALID_PARAMS, - message=f"{name} already exists in group prompts.", - ) - ) - prompts_temp[name] = prompt - component_names.prompts.add(name) + try: + prompts = (await session.list_prompts()).prompts + for prompt in prompts: + name = self._component_name(prompt.name, server_info) + prompts_temp[name] = prompt + component_names.prompts.add(name) + except McpError as err: + logging.warning(f"Could not fetch prompts: {err}") # Query the server for its resources and aggregate to list. - resources = (await session.list_resources()).resources - for resource in resources: - name = self._component_name(resource.name, server_info) - if name in self._resources: - raise McpError( - types.ErrorData( - code=types.INVALID_PARAMS, - message=f"{name} already exists in group resources.", - ) - ) - resources_temp[name] = resource - component_names.resources.add(name) + try: + resources = (await session.list_resources()).resources + for resource in resources: + name = self._component_name(resource.name, server_info) + resources_temp[name] = resource + component_names.resources.add(name) + except McpError as err: + logging.warning(f"Could not fetch resources: {err}") # Query the server for its tools and aggregate to list. - tools = (await session.list_tools()).tools - for tool in tools: - name = self._component_name(tool.name, server_info) - if name in self._tools: - raise McpError( - types.ErrorData( - code=types.INVALID_PARAMS, - message=f"{name} already exists in group tools.", - ) + try: + tools = (await session.list_tools()).tools + for tool in tools: + name = self._component_name(tool.name, server_info) + tools_temp[name] = tool + tool_to_session_temp[name] = session + component_names.tools.add(name) + except McpError as err: + logging.warning(f"Could not fetch tools: {err}") + + # Clean up exit stack for session if we couldn't retrieve anything + # from the server. + if not any((prompts_temp, resources_temp, tools_temp)): + del self._session_exit_stacks[session] + + # Check for duplicates. + matching_prompts = prompts_temp.keys() & self._prompts.keys() + if matching_prompts: + raise McpError( + types.ErrorData( + code=types.INVALID_PARAMS, + message=f"{matching_prompts} already exist in group prompts.", + ) + ) + matching_resources = resources_temp.keys() & self._resources.keys() + if matching_resources: + raise McpError( + types.ErrorData( + code=types.INVALID_PARAMS, + message=f"{matching_resources} already exist in group resources.", + ) + ) + matching_tools = tools_temp.keys() & self._tools.keys() + if matching_tools: + raise McpError( + types.ErrorData( + code=types.INVALID_PARAMS, + message=f"{matching_tools} already exist in group tools.", ) - tools_temp[name] = tool - tool_to_session_temp[name] = session - component_names.tools.add(name) + ) # Aggregate components. self._sessions[session] = component_names @@ -230,41 +366,6 @@ async def connect_to_server( self._tools.update(tools_temp) self._tool_to_session.update(tool_to_session_temp) - return session - - async def _establish_session( - self, server_params: ServerParameters - ) -> tuple[types.Implementation, mcp.ClientSession]: - """Establish a client session to an MCP server.""" - - # Create read and write streams that facilitate io with the server. - if isinstance(server_params, StdioServerParameters): - client = mcp.stdio_client(server_params) - read, write = await self._exit_stack.enter_async_context(client) - elif isinstance(server_params, SseServerParameters): - client = sse_client( - url=server_params.url, - headers=server_params.headers, - timeout=server_params.timeout, - sse_read_timeout=server_params.sse_read_timeout, - ) - read, write = await self._exit_stack.enter_async_context(client) - else: - client = streamablehttp_client( - url=server_params.url, - headers=server_params.headers, - timeout=server_params.timeout, - sse_read_timeout=server_params.sse_read_timeout, - terminate_on_close=server_params.terminate_on_close, - ) - read, write, _ = await self._exit_stack.enter_async_context(client) - - session = await self._exit_stack.enter_async_context( - mcp.ClientSession(read, write) - ) - result = await session.initialize() - return result.serverInfo, session - def _component_name(self, name: str, server_info: types.Implementation) -> str: if self._component_name_hook: return self._component_name_hook(name, server_info) diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index 22687d21d..924ef7a06 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -54,8 +54,14 @@ async def test_call_tool(self): mock_session = mock.AsyncMock() # --- Prepare Session Group --- - mcp_session_group = ClientSessionGroup() - mcp_session_group._tool_to_session = {"my_tool": mock_session} + def hook(name, server_info): + return f"{(server_info.name)}-{name}" + + mcp_session_group = ClientSessionGroup(component_name_hook=hook) + mcp_session_group._tools = { + "server1-my_tool": types.Tool(name="my_tool", inputSchema={}) + } + mcp_session_group._tool_to_session = {"server1-my_tool": mock_session} text_content = types.TextContent(type="text", text="OK") mock_session.call_tool.return_value = types.CallToolResult( content=[text_content] @@ -63,7 +69,7 @@ async def test_call_tool(self): # --- Test Execution --- result = await mcp_session_group.call_tool( - name="my_tool", + name="server1-my_tool", args={ "name": "value1", "args": {}, @@ -151,7 +157,7 @@ def name_hook(name: str, server_info: types.Implementation) -> str: assert group.tools[expected_tool_name] == mock_tool assert group._tool_to_session[expected_tool_name] == mock_session - def test_disconnect_from_server(self): # No mock arguments needed + async def test_disconnect_from_server(self): # No mock arguments needed """Test disconnecting from a server.""" # --- Test Setup --- group = ClientSessionGroup() @@ -205,7 +211,7 @@ def test_disconnect_from_server(self): # No mock arguments needed assert "prm1" in group._prompts # --- Test Execution --- - group.disconnect_from_server(mock_session) + await group.disconnect_from_server(mock_session) # --- Assertions --- assert mock_session not in group._sessions @@ -223,8 +229,10 @@ async def test_connect_to_server_duplicate_tool_raises_error(self, mock_exit_sta group._tools[existing_tool_name] = mock.Mock(spec=types.Tool) group._tools[existing_tool_name].name = existing_tool_name # Need a dummy session associated with the existing tool - group._tool_to_session[existing_tool_name] = mock.MagicMock( - spec=mcp.ClientSession + mock_session = mock.MagicMock(spec=mcp.ClientSession) + group._tool_to_session[existing_tool_name] = mock_session + group._session_exit_stacks[mock_session] = mock.Mock( + spec=contextlib.AsyncExitStack ) # --- Mock New Connection Attempt --- @@ -254,7 +262,7 @@ async def test_connect_to_server_duplicate_tool_raises_error(self, mock_exit_sta # Assert details about the raised error assert excinfo.value.error.code == types.INVALID_PARAMS assert existing_tool_name in excinfo.value.error.message - assert "already exists in group tools" in excinfo.value.error.message + assert "already exist " in excinfo.value.error.message # Verify the duplicate tool was *not* added again (state should be unchanged) assert len(group._tools) == 1 # Should still only have the original @@ -263,12 +271,12 @@ async def test_connect_to_server_duplicate_tool_raises_error(self, mock_exit_sta ) # Ensure it's the original mock # No patching needed here - def test_disconnect_non_existent_server(self): + async def test_disconnect_non_existent_server(self): """Test disconnecting a server that isn't connected.""" session = mock.Mock(spec=mcp.ClientSession) group = ClientSessionGroup() with pytest.raises(McpError): - group.disconnect_from_server(session) + await group.disconnect_from_server(session) @pytest.mark.parametrize( "server_params_instance, client_type_name, patch_target_for_client_func",