E5FF feat: Enable MCP Tool Auth (Experimental) · cloudbuilderspa/adk-python@157d9be · GitHub
[go: up one dir, main page]

Skip to content

Commit 157d9be

Browse files
seanzhougooglecopybara-github
authored andcommitted
feat: Enable MCP Tool Auth (Experimental)
PiperOrigin-RevId: 773002759
1 parent 18a541c commit 157d9be

File tree

7 files changed

+1231
-103
lines changed

7 files changed

+1231
-103
lines changed

src/google/adk/tools/mcp_tool/mcp_session_manager.py

Lines changed: 147 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,12 @@
1818
from contextlib import AsyncExitStack
1919
from datetime import timedelta
2020
import functools
21+
import hashlib
22+
import json
2123
import logging
2224
import sys
2325
from typing import Any
26+
from typing import Dict
2427
from typing import Optional
2528
from typing import TextIO
2629
from typing import Union
@@ -105,74 +108,39 @@ class StreamableHTTPConnectionParams(BaseModel):
105108
terminate_on_close: bool = True
106109

107110

108-
def retry_on_closed_resource(session_manager_field_name: str):
109-
"""Decorator to automatically reinitialize session and retry action.
111+
def retry_on_closed_resource(func):
112+
"""Decorator to automatically retry action when MCP session is closed.
110113
111-
When MCP session was closed, the decorator will automatically recreate the
112-
session and retry the action with the same parameters.
113-
114-
Note:
115-
1. session_manager_field_name is the name of the class member field that
116-
contains the MCPSessionManager instance.
117-
2. The session manager must have a reinitialize_session() async method.
118-
119-
Usage:
120-
class MCPTool:
121-
def __init__(self):
122-
self._mcp_session_manager = MCPSessionManager(...)
123-
124-
@retry_on_closed_resource('_mcp_session_manager')
125-
async def use_session(self):
126-
session = await self._mcp_session_manager.create_session()
127-
await session.call_tool()
114+
When MCP session was closed, the decorator will automatically retry the
115+
action once. The create_session method will handle creating a new session
116+
if the old one was disconnected.
128117
129118
Args:
130-
session_manager_field_name: The name of the session manager field.
119+
func: The function to decorate.
131120
132121
Returns:
133122
The decorated function.
134123
"""
135124

136-
def decorator(func):
137-
@functools.wraps(func) # Preserves original function metadata
138-
async def wrapper(self, *args, **kwargs):
139-
try:
140-
return await func(self, *args, **kwargs)
141-
except anyio.ClosedResourceError as close_err:
142-
try:
143-
if hasattr(self, session_manager_field_name):
144-
session_manager = getattr(self, session_manager_field_name)
145-
if hasattr(session_manager, 'reinitialize_session') and callable(
146-
getattr(session_manager, 'reinitialize_session')
147-
):
148-
await session_manager.reinitialize_session()
149-
else:
150-
raise ValueError(
151-
f'Session manager {session_manager_field_name} does not have'
152-
' reinitialize_session method.'
153-
) from close_err
154-
else:
155-
raise ValueError(
156-
f'Session manager field {session_manager_field_name} does not'
157-
' exist in decorated class. Please check the field name in'
158-
' retry_on_closed_resource decorator.'
159-
) from close_err
160-
except Exception as reinit_err:
161-
raise RuntimeError(
162-
f'Error reinitializing: {reinit_err}'
163-
) from reinit_err
164-
return await func(self, *args, **kwargs)
165-
166-
return wrapper
167-
168-
return decorator
125+
@functools.wraps(func) # Preserves original function metadata
126+
async def wrapper(self, *args, **kwargs):
127+
try:
128+
return await func(self, *args, **kwargs)
129+
except anyio.ClosedResourceError:
130+
# Simply retry the function - create_session will handle
131+
# detecting and replacing disconnected sessions
132+
logger.info('Retrying %s due to closed resource', func.__name__)
133+
return await func(self, *args, **kwargs)
134+
135+
return wrapper
169136

170137

171138
class MCPSessionManager:
172139
"""Manages MCP client sessions.
173140
174141
This class provides methods for creating and initializing MCP client sessions,
175-
handling different connection parameters (Stdio and SSE).
142+
handling different connection parameters (Stdio and SSE) and supporting
143+
session pooling based on authentication headers.
176144
"""
177145

178146
def __init__(
@@ -209,30 +177,125 @@ def __init__(
209177
else:
210178
self._connection_params = connection_params
211179
self._errlog = errlog
212-
# Each session manager maintains its own exit stack for proper cleanup
213-
self._exit_stack: Optional[AsyncExitStack] = None
214-
self._session: Optional[ClientSession] = None
180+
181+
# Session pool: maps session keys to (session, exit_stack) tuples
182+
self._sessions: Dict[str, tuple[ClientSession, AsyncExitStack]] = {}
183+
215184
# Lock to prevent race conditions in session creation
216185
self._session_lock = asyncio.Lock()
217186

218-
async def create_session(self) -> ClientSession:
187+
def _generate_session_key(
188+
self, merged_headers: Optional[Dict[str, str]] = None
189+
) -> str:
190+
"""Generates a session key based on connection params and merged headers.
191+
192+
For StdioConnectionParams, returns a constant key since headers are not
193+
supported. For SSE and StreamableHTTP connections, generates a key based
194+
on the provided merged headers.
195+
196+
Args:
197+
merged_headers: Already merged headers (base + additional).
198+
199+
Returns:
200+
A unique session key string.
201+
"""
202+
if isinstance(self._connection_params, StdioConnectionParams):
203+
# For stdio connections, headers are not supported, so use constant key
204+
return 'stdio_session'
205+
206+
# For SSE and StreamableHTTP connections, use merged headers
207+
if merged_headers:
208+
headers_json = json.dumps(merged_headers, sort_keys=True)
209+
headers_hash = hashlib.md5(headers_json.encode()).hexdigest()
210+
return f'session_{headers_hash}'
211+
else:
212+
return 'session_no_headers'
213+
214+
def _merge_headers(
215+
self, additional_headers: Optional[Dict[str, str]] = None
216+
) -> Optional[Dict[str, str]]:
217+
"""Merges base connection headers with additional headers.
218+
219+
Args:
220+
additional_headers: Optional headers to merge with connection headers.
221+
222+
Returns:
223+
Merged headers dictionary, or None if no headers are provided.
224+
"""
225+
if isinstance(self._connection_params, StdioConnectionParams) or isinstance(
226+
self._connection_params, StdioServerParameters
227+
):
228+
# Stdio connections don't support headers
229+
return None
230+
231+
base_headers = {}
232+
if (
233+
hasattr(self._connection_params, 'headers')
234+
and self._connection_params.headers
235+
):
236+
base_headers = self._connection_params.headers.copy()
237+
238+
if additional_headers:
239+
base_headers.update(additional_headers)
240+
241+
return base_headers
242+
243+
def _is_session_disconnected(self, session: ClientSession) -> bool:
244+
"""Checks if a session is disconnected or closed.
245+
246+
Args:
247+
session: The ClientSession to check.
248+
249+
Returns:
250+
True if the session is disconnected, False otherwise.
251+
"""
252+
return session._read_stream._closed or session._write_stream._closed
253+
254+
async def create_session(
255+
self, headers: Optional[Dict[str, str]] = None
256+
) -> ClientSession:
219257
"""Creates and initializes an MCP client session.
220258
259+
This method will check if an existing session for the given headers
260+
is still connected. If it's disconnected, it will be cleaned up and
261+
a new session will be created.
262+
263+
Args:
264+
headers: Optional headers to include in the session. These will be
265+
merged with any existing connection headers. Only applicable
266+
for SSE and StreamableHTTP connections.
267+
221268
Returns:
222269
ClientSession: The initialized MCP client session.
223270
"""
224-
# Fast path: if session already exists, return it without acquiring lock
225-
if self._session is not None:
226-
return self._session
271+
# Merge headers once at the beginning
272+
merged_headers = self._merge_headers(headers)
273+
274+
# Generate session key using merged headers
275+
session_key = self._generate_session_key(merged_headers)
227276

228277
# Use async lock to prevent race conditions
229278
async with self._session_lock:
230-
# Double-check: session might have been created while waiting for lock
231-
if self._session is not None:
232-
return self._session
233-
234-
# Create a new exit stack for this session
235-
self._exit_stack = AsyncExitStack()
279+
# Check if we have an existing session
280+
if session_key in self._sessions:
281+
session, exit_stack = self._sessions[session_key]
282+
283+
# Check if the existing session is still connected
284+
if not self._is_session_disconnected(session):
285+
# Session is still good, return it
286+
return session
287+
else:
288+
# Session is disconnected, clean it up
289+
logger.info('Cleaning up disconnected session: %s', session_key)
290+
try:
291+
await exit_stack.aclose()
292+
except Exception as e:
293+
logger.warning('Error during disconnected session cleanup: %s', e)
294+
finally:
295+
del self._sessions[session_key]
296+
297+
# Create a new session (either first time or replacing disconnected one)
298+
exit_stack = AsyncExitStack()
236299

237300
try:
238301
if isinstance(self._connection_params, StdioConnectionParams):
@@ -243,7 +306,7 @@ async def create_session(self) -> ClientSession:
243306
elif isinstance(self._connection_params, SseConnectionParams):
244307
client = sse_client(
245308
url=self._connection_params.url,
246-
headers=self._connection_params.headers,
309+
headers=merged_headers,
247310
timeout=self._connection_params.timeout,
248311
sse_read_timeout=self._connection_params.sse_read_timeout,
249312
)
@@ -252,7 +315,7 @@ async def create_session(self) -> ClientSession:
252315
):
253316
client = streamablehttp_client(
254317
url=self._connection_params.url,
255-
headers=self._connection_params.headers,
318+
headers=merged_headers,
256319
timeout=timedelta(seconds=self._connection_params.timeout),
257320
sse_read_timeout=timedelta(
258321
seconds=self._connection_params.sse_read_timeout
@@ -266,11 +329,11 @@ async def create_session(self) -> ClientSession:
266329
f' {self._connection_params}'
267330
)
268331

269-
transports = await self._exit_stack.enter_async_context(client)
332+
transports = await exit_stack.enter_async_context(client)
270333
# The streamable http client returns a GetSessionCallback in addition to the read/write MemoryObjectStreams
271334
# needed to build the ClientSession, we limit then to the two first values to be compatible with all clients.
272335
if isinstance(self._connection_params, StdioConnectionParams):
273-
session = await self._exit_stack.enter_async_context(
336+
session = await exit_stack.enter_async_context(
274337
ClientSession(
275338
*transports[:2],
276339
read_timeout_seconds=timedelta(
@@ -279,44 +342,38 @@ async def create_session(self) -> ClientSession:
279342
)
280343
)
281344
else:
282-
session = await self._exit_stack.enter_async_context(
345+
session = await exit_stack.enter_async_context(
283346
ClientSession(*transports[:2])
284347
)
285348
await session.initialize()
286349

287-
self._session = session
350+
# Store session and exit stack in the pool
351+
self._sessions[session_key] = (session, exit_stack)
352+
logger.debug('Created new session: %s', session_key)
288353
return session
289354

290355
except Exception:
291356
# If session creation fails, clean up the exit stack
292-
if self._exit_stack:
293-
await self._exit_stack.aclose()
294-
self._exit_stack = None
357+
if exit_stack:
358+
await exit_stack.aclose()
295359
raise
296360

297361
async def close(self):
298-
"""Closes the session and cleans up resources."""
299-
if not self._exit_stack:
300-
return
362+
"""Closes all sessions and cleans up resources."""
301363
async with self._session_lock:
302-
if self._exit_stack:
364+
for session_key in list(self._sessions.keys()):
365+
_, exit_stack = self._sessions[session_key]
303366
try:
304-
await self._exit_stack.aclose()
367+
await exit_stack.aclose()
305368
except Exception as e:
306369
# Log the error but don't re-raise to avoid blocking shutdown
307370
print(
308-
f'Warning: Error during MCP session cleanup: {e}',
371+
'Warning: Error during MCP session cleanup for'
372+
f' {session_key}: {e}',
309373
file=self._errlog,
310374
)
311375
finally:
312-
self._exit_stack = None
313-
self._session = None
314-
315-
async def reinitialize_session(self):
316-
"""Reinitializes the session when connection is lost."""
317-
# Close the old session and create a new one
318-
await self.close()
319-
await self.create_session()
376+
del self._sessions[session_key]
320377

321378

322379
SseServerParams = SseConnectionParams

0 commit comments

Comments
 (0)
0