8000 refactor: enhance mcp tool session management · samsepassi/adk-python@40b15ad · GitHub
[go: up one dir, main page]

Skip to content

Commit 40b15ad

Browse files
seanzhougooglecopybara-github
authored andcommitted
refactor: enhance mcp tool session management
1. remove unnecessary cached session instance in mcp toolset 2. move session reinitialization logic from mcp tool and mcp toolset to mcp session manager 3. add lock for the code block of session creation to avoid race conditions PiperOrigin-RevId: 770949529
1 parent dbdeb49 commit 40b15ad

File tree

3 files changed

+120
-117
lines changed

3 files changed

+120
-117
lines changed

src/google/adk/tools/mcp_tool/mcp_session_manager.py

Lines changed: 116 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
import asyncio
1718
from contextlib import AsyncExitStack
1819
from datetime import timedelta
1920
import functools
@@ -34,7 +35,6 @@
3435
from mcp.client.stdio import stdio_client
3536
from mcp.client.streamable_http import streamablehttp_client
3637
except ImportError as e:
37-
import sys
3838

3939
if sys.version_info < (3, 10):
4040
raise ImportError(
@@ -105,30 +105,29 @@ class StreamableHTTPConnectionParams(BaseModel):
105105
terminate_on_close: bool = True
106106

107107

108-
def retry_on_closed_resource(async_reinit_func_name: str):
108+
def retry_on_closed_resource(session_manager_field_name: str):
109109
"""Decorator to automatically reinitialize session and retry action.
110110
111111
When MCP session was closed, the decorator will automatically recreate the
112112
session and retry the action with the same parameters.
113113
114114
Note:
115-
1. async_reinit_func_name is the name of the class member function that
116-
reinitializes the MCP session.
117-
2. Both the decorated function and the async_reinit_func_name must be async
118-
functions.
115+
1. session_manager_field_name is the name of the class member field that
116+
contains the MCPSessionManager instance.
117+
2. The session manager must have a reinitialize_session() async method.
119118
120119
Usage:
121120
class MCPTool:
122-
...
123-
async def create_session(self):
124-
self.session = ...
121+
def __init__(self):
122+
self._mcp_session_manager = MCPSessionManager(...)
125123
126-
@retry_on_closed_resource('create_session')
124+
@retry_on_closed_resource('_mcp_session_manager')
127125
async def use_session(self):
128-
await self.session.call_tool()
126+
session = await self._mcp_session_manager.create_session()
127+
await session.call_tool()
129128
130129
Args:
131-
async_reinit_func_name: The name of the async function to recreate session.
130+
session_manager_field_name: The name of the session manager field.
132131
133132
Returns:
134133
The decorated function.
@@ -141,15 +140,21 @@ async def wrapper(self, *args, **kwargs):
141140
return await func(self, *args, **kwargs)
142141
except anyio.ClosedResourceError as close_err:
143142
try:
144-
if hasattr(self, async_reinit_func_name) and callable(
145-
getattr(self, async_reinit_func_name)
146-
):
147-
async_init_fn = getattr(self, async_reinit_func_name)
148-
await async_init_fn()
143+
if hasattr(self, session_manager_field_name):
144+
session_manager = getattr(self, session_manager_field_name)
145+
if hasattr(session_manager, 'reinitialize_session') and callable(
146+
getattr(session_manager, 'reinitialize_session')
147+
):
148+
await session_manager.reinitialize_session()
149+
else:
150+
raise ValueError(
151+
f'Session manager {session_manager_field_name} does not have'
152+
' reinitialize_session method.'
153+
) from close_err
149154
else:
150155
raise ValueError(
151-
f'Function {async_reinit_func_name} does not exist in decorated'
152-
' class. Please check the function name in'
156+
f'Session manager field {session_manager_field_name} does not'
157+
' exist in decorated class. Please check the field name in'
153158
' retry_on_closed_resource decorator.'
154159
) from close_err
155160
except Exception as reinit_err:
@@ -207,90 +212,111 @@ def __init__(
207212
# Each session manager maintains its own exit stack for proper cleanup
208213
self._exit_stack: Optional[AsyncExitStack] = None
209214
self._session: Optional[ClientSession] = None
215+
# Lock to prevent race conditions in session creation
216+
self._session_lock = asyncio.Lock()
210217

211218
async def create_session(self) -> ClientSession:
212219
"""Creates and initializes an MCP client session.
213220
214221
Returns:
215222
ClientSession: The initialized MCP client session.
216223
"""
224+
# Fast path: if session already exists, return it without acquiring lock
217225
if self._session is not None:
218226
return self._session
219227

220-
# Create a new exit stack for this session
221-
self._exit_stack = AsyncExitStack()
222-
223-
try:
224-
if isinstance(self._connection_params, StdioConnectionParams):
225-
client = stdio_client(
226-
server=self._connection_params.server_params,
227-
errlog=self._errlog,
228-
)
229-
elif isinstance(self._connection_params, SseConnectionParams):
230-
client = sse_client(
231-
url=self._connection_params.url,
232-
headers=self._connection_params.headers,
233-
timeout=self._connection_params.timeout,
234-
sse_read_timeout=self._connection_params.sse_read_timeout,
235-
)
236-
elif isinstance(self._connection_params, StreamableHTTPConnectionParams):
237-
client = streamablehttp_client(
238-
url=self._connection_params.url,
239-
headers=self._connection_params.headers,
240-
timeout=timedelta(seconds=self._connection_params.timeout),
241-
sse_read_timeout=timedelta(
242-
seconds=self._connection_params.sse_read_timeout
243-
),
244-
terminate_on_close=self._connection_params.terminate_on_close,
245-
)
246-
else:
247-
raise ValueError(
248-
'Unable to initialize connection. Connection should be'
249-
' StdioServerParameters or SseServerParams, but got'
250-
f' {self._connection_params}'
251-
)
252-
253-
transports = await self._exit_stack.enter_async_context(client)
254-
# The streamable http client returns a GetSessionCallback in addition to the read/write MemoryObjectStreams
255-
# needed to build the ClientSession, we limit then to the two first values to be compatible with all clients.
256-
if isinstance(self._connection_params, StdioConnectionParams):
257-
session = await self._exit_stack.enter_async_context(
258-
ClientSession(
259-
*transports[:2],
260-
read_timeout_seconds=timedelta(
261-
seconds=self._connection_params.timeout
262-
),
263-
)
264-
)
265-
else:
266-
session = await self._exit_stack.enter_async_context(
267-
ClientSession(*transports[:2])
268-
)
269-
await session.initialize()
270-
271-
self._session = session
272-
return session
273-
274-
except Exception:
275-
# If session creation fails, clean up the exit stack
276-
if self._exit_stack:
277-
await self._exit_stack.aclose()
278-
self._exit_stack = None
279-
raise
228+
# Use async lock to prevent race conditions
229+
async with self._session_lock:
230+
# Double-check: session might have been created while waiting for lock
231+
if self._session is not None:
232+
return self._session
233+
234+
# Create a new exit stack for this session
235+
self._exit_stack = AsyncExitStack()
236+
237+
try:
238+
if isinstance(self._connection_params, StdioConnectionParams):
239+
client = stdio_client(
240+
server=self._connection_params.server_params,
241+
errlog=self._errlog,
242+
)
243+
elif isinstance(self._connection_params, SseConnectionParams):
244+
client = sse_client(
245+
url=self._connection_params.url,
246+
headers=self._connection_params.headers,
247+
timeout=self._connection_params.timeout,
248+
sse_read_timeout=self._connection_params.sse_read_timeout,
249+
)
250+
elif isinstance(
251+
self._connection_params, StreamableHTTPConnectionParams
252+
):
253+
client = streamablehttp_client(
254+
url=self._connection_params.url,
255+
headers=self._connection_params.headers,
< 10000 code>256+
timeout=timedelta(seconds=self._connection_params.timeout),
257+
sse_read_timeout=timedelta(
258+
seconds=self._connection_params.sse_read_timeout
259+
),
260+
terminate_on_close=self._connection_params.terminate_on_close,
261+
)
262+
else:
263+
raise ValueError(
264+
'Unable to initialize connection. Connection should be'
265+
' StdioServerParameters or SseServerParams, but got'
266+
f' {self._connection_params}'
267+
)
268+
269+
transports = await self._exit_stack.enter_async_context(client)
270+
# The streamable http client returns a GetSessionCallback in addition to the read/write MemoryObjectStreams
271+
# needed to build the ClientSession, we limit then to the two first values to be compatible with all clients.
272+
if isinstance(self._connection_params, StdioConnectionParams):
273+
session = await self._exit_stack.enter_async_context(
274+
ClientSession(
275+
*transports[:2],
276+
read_timeout_seconds=timedelta(
277+
seconds=self._connection_params.timeout
278+
),
279+
)
280+
)
281+
else:
282+
session = await self._exit_stack.enter_async_context(
283+
ClientSession(*transports[:2])
284+
)
285+
await session.initialize()
286+
287+
self._session = session
288+
return session
289+
290+
except Exception:
291+
# If session creation fails, clean up the exit stack
292+
if self._exit_stack:
293+
await self._exit_stack.aclose()
294+
self._exit_stack = None
295+
raise
280296

281297
async def close(self):
282298
"""Closes the session and cleans up resources."""
283-
if self._exit_stack:
284-
try:
285-
await self._exit_stack.aclose()
286-
except Exception as e:
287-
# Log the error but don't re-raise to avoid blocking shutdown
288-
print(
289-
f'Warning: Error during MCP session cleanup: {e}', file=self._errlog
290-
)
291-
finally:
292-
self._exit_stack = None
293-
self._session = None
299+
if not self._exit_stack:
300+
return
301+
async with self._session_lock:
302+
if self._exit_stack:
303+
try:
304+
await self._exit_stack.aclose()
305+
except Exception as e:
306+
# Log the error but don't re-raise to avoid blocking shutdown
307+
print(
308+
f'Warning: Error during MCP session cleanup: {e}',
309+
file=self._errlog,
310+
)
311+
finally:
312+
self._exit_stack = None
313+
self._session = None
314+
315+
async def reinitialize_session(self):
316+
"""Reinitializes the session when connection is lost."""
317+
# Close the old session and create a new one
318+
await self.close()
319+
await self.create_session()
294320

295321

296322
SseServerParams = SseConnectionParams

src/google/adk/tools/mcp_tool/mcp_tool.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def _get_declaration(self) -> FunctionDeclaration:
105105
)
106106
return function_decl
107107

108-
@retry_on_closed_resource("_reinitialize_session")
108+
@retry_on_closed_resource("_mcp_session_manager")
109109
async def run_async(self, *, args, tool_context: ToolContext):
110110
"""Runs the tool asynchronously.
111111
@@ -122,9 +122,3 @@ async def run_async(self, *, args, tool_context: ToolContext):
122122
# TODO(cheliu): Support passing tool context to MCP Server.
123123
response = await session.call_tool(self.name, arguments=args)
124124
return response
125-
126-
async def _reinitialize_session(self):
127-
"""Reinitializes the session when connection is lost."""
128-
# Close the old session and create a new one
129-
await self._mcp_session_manager.close()
130-
await self._mcp_session_manager.create_session()

src/google/adk/tools/mcp_tool/mcp_toolset.py

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,8 @@
2828
from .mcp_session_manager import MCPSessionManager
2929
from .mcp_session_manager import retry_on_closed_resource
3030
from .mcp_session_manager import SseConnectionParams
31-
from .mcp_session_manager import SseServerParams
3231
from .mcp_session_manager import StdioConnectionParams
3332
from .mcp_session_manager import StreamableHTTPConnectionParams
34-
from .mcp_session_manager import StreamableHTTPServerParams
3533

3634
# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
3735
# their Python version to 3.10 if it fails.
@@ -127,9 +125,7 @@ def __init__(
127125
errlog=self._errlog,
128126
)
129127

130-
self._session = None
131-
132-
@retry_on_closed_resource("_reinitialize_session")
128+
@retry_on_closed_resource("_mcp_session_manager")
133129
async def get_tools(
134130
self,
135131
readonly_context: Optional[ReadonlyContext] = None,
@@ -144,11 +140,10 @@ async def get_tools(
144140
List[BaseTool]: A list of tools available under the specified context.
145141
"""
146142
# Get session from session manager
147-
if not self._session:
148-
self._session = await self._mcp_session_manager.create_session()
143+
session = await self._mcp_session_manager.create_session()
149144

150145
# Fetch available tools from the MCP server
151-
tools_response: ListToolsResult = await self._session.list_tools()
146+
tools_response: ListToolsResult = await session.list_tools()
152147

153148
# Apply filtering based on context and tool_filter
154149
tools = []
@@ -162,14 +157,6 @@ async def get_tools(
162157
tools.append(mcp_tool)
163158
return tools
164159

165-
async def _reinitialize_session(self):
166-
"""Reinitializes the session when connection is lost."""
167-
# Close the old session and clear cache
168-
await self._mcp_session_manager.close()
169-
self._session = await self._mcp_session_manager.create_session()
170-
171-
# Tools will be reloaded on next get_tools call
172-
173160
async def close(self) -> None:
174161
"""Performs cleanup and releases resources held by the toolset.
175162
@@ -182,7 +169,3 @@ async def close(self) -> None:
182169
except Exception as e:
183170
# Log the error but don't re-raise to avoid blocking shutdown
184171
print(f"Warning: Error during MCPToolset cleanup: {e}", file=self._errlog)
185-
finally:
186-
# Clear cached tools
187-
self._tools_cache = None
188-
self._tools_loaded = False

0 commit comments

Comments
 (0)
0