E581 fix: Fix event loop closed bug in McpSessionManager · google/adk-python@4aa4751 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4aa4751

Browse files
wukathcopybara-github
authored andcommitted
fix: Fix event loop closed bug in McpSessionManager
Sessions were being erroneously cached and reused across different asyncio event loops, causing "Event loop is closed" in environments with transient loops. This updates the session caching to be loop-aware: before reusing a cached session, check that the stored loop matches the current loop. Also, if session is disconnected and loops do not match, discard the cached entry without calling aclose(). Co-authored-by: Kathy Wu <wukathy@google.com> PiperOrigin-RevId: 868380746
1 parent 7110336 commit 4aa4751

File tree

2 files changed

+261
-35
lines changed

2 files changed

+261
-35
lines changed

src/google/adk/tools/mcp_tool/mcp_session_manager.py

Lines changed: 95 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import json
2424
import logging
2525
import sys
26+
import threading
2627
from typing import Any
2728
from typing import Dict
2829
from typing import Optional
@@ -220,11 +221,24 @@ def __init__(
220221
self._connection_params = connection_params
221222
self._errlog = errlog
222223

223-
# Session pool: maps session keys to (session, exit_stack) tuples
224-
self._sessions: Dict[str, tuple[ClientSession, AsyncExitStack]] = {}
225-
226-
# Lock to prevent race conditions in session creation
227-
self._session_lock = asyncio.Lock()
224+
# Session pool: maps session keys to (session, exit_stack, loop) tuples
225+
self._sessions: Dict[
226+
str, tuple[ClientSession, AsyncExitStack, asyncio.AbstractEventLoop]
227+
] = {}
228+
229+
# Map of event loops to their respective locks to prevent race conditions
230+
# across different event loops in session creation.
231+
self._session_lock_map: dict[asyncio.AbstractEventLoop, asyncio.Lock] = {}
232+
self._lock_map_lock = threading.Lock()
233+
234+
@property
235+
def _session_lock(self) -> asyncio.Lock:
236+
"""Returns an asyncio.Lock bound to the current event loop."""
237+
current_loop = asyncio.get_running_loop()
238+
with self._lock_map_lock:
239+
if current_loop not in self._session_lock_map:
240+
self._session_lock_map[current_loop] = asyncio.Lock()
241+
return self._session_lock_map[current_loop]
228242

229243
def _generate_session_key(
230244
self, merged_headers: Optional[Dict[str, str]] = None
@@ -293,6 +307,62 @@ def _is_session_disconnected(self, session: ClientSession) -> bool:
293307
"""
294308
return session._read_stream._closed or session._write_stream._closed
295309

310+
async def _cleanup_session(
311+
self,
312+
session_key: str,
313+
exit_stack: AsyncExitStack,
314+
stored_loop: asyncio.AbstractEventLoop,
315+
):
316+
"""Cleans up a session, handling different event loops safely.
317+
318+
Args:
319+
session_key: The session key to clean up.
320+
exit_stack: The AsyncExitStack managing the session resources.
321+
stored_loop: The event loop on which the session was created.
322+
"""
323+
current_loop = asyncio.get_running_loop()
324+
try:
325+
if stored_loop is current_loop:
326+
await exit_stack.aclose()
327+
elif stored_loop.is_closed():
328+
logger.warning(
329+
f'Error cleaning up session {session_key}: original event loop'
330+
' is closed, resources may be leaked.'
331+
)
332+
else:
333+
# The old loop is still running in another thread;
334+
# schedule cleanup on it.
335+
logger.info(
336+
f'Scheduling cleanup of session {session_key} on its original'
337+
' event loop.'
338+
)
339+
future = asyncio.run_coroutine_threadsafe(
340+
exit_stack.aclose(), stored_loop
341+
)
342+
343+
# Attach a callback so errors don't go unnoticed
344+
def cleanup_done(f: asyncio.Future):
345+
try:
346+
if f.exception():
347+
logger.warning(
348+
f'Error cleaning up session {session_key} on original'
349+
f' loop: {f.exception()}'
350+
)
351+
except Exception as e:
352+
logger.warning(
353+
f'Failed to check cleanup status for {session_key}: {e}'
354+
)
355+
356+
future.add_done_callback(cleanup_done)
357+
except Exception as e:
358+
logger.warning(
359+
f'Error during session cleanup for {session_key}: {e}',
360+
exc_info=True,
361+
)
362+
finally:
363+
if session_key in self._sessions:
364+
del self._sessions[session_key]
365+
296366
def _create_client(self, merged_headers: Optional[Dict[str, str]] = None):
297367
"""Creates an MCP client based on the conn 10BC0 ection parameters.
298368
@@ -364,21 +434,22 @@ async def create_session(
364434
async with self._session_lock:
365435
# Check if we have an existing session
366436
if session_key in self._sessions:
367-
session, exit_stack = self._sessions[session_key]
437+
session, exit_stack, stored_loop = self._sessions[session_key]
368438

369-
# Check if the existing session is still connected
370-
if not self._is_session_disconnected(session):
439+
# Check if the existing session is still connected and bound to the current loop
440+
current_loop = asyncio.get_running_loop()
441+
if stored_loop is current_loop and not self._is_session_disconnected(
442+
session
443+
):
371444
# Session is still good, return it
372445
return session
373446
else:
374-
# Session is disconnected, clean it up
375-
logger.info('Cleaning up disconnected session: %s', session_key)
376-
try:
377-
await exit_stack.aclose()
378-
except Exception as e:
379-
logger.warning('Error during disconnected session cleanup: %s', e)
380-
finally:
381-
del self._sessions[session_key]
447+
# Session is disconnected or from a different loop, clean it up
448+
logger.info(
449+
'Cleaning up session (disconnected or different loop): %s',
450+
session_key,
451+
)
452+
await self._cleanup_session(session_key, exit_stack, stored_loop)
382453

383454
# Create a new session (either first time or replacing disconnected one)
384455
exit_stack = AsyncExitStack()
@@ -409,8 +480,12 @@ async def create_session(
409480
timeout=timeout_in_seconds,
410481
)
411482

412-
# Store session and exit stack in the pool
413-
self._sessions[session_key] = (session, exit_stack)
483+
# Store session, exit stack, and loop in the pool
484+
self._sessions[session_key] = (
485+
session,
486+
exit_stack,
487+
asyncio.get_running_loop(),
488+
)
414489
logger.debug('Created new session: %s', session_key)
415490
return session
416491

@@ -429,17 +504,8 @@ async def close(self):
429504
"""Closes all sessions and cleans up resources."""
430505
async with self._session_lock:
431506
for session_key in list(self._sessions.keys()):
432-
_, exit_stack = self._sessions[session_key]
433-
try:
434-
await exit_stack.aclose()
435-
except Exception as e:
436-
# Log the error but don't re-raise to avoid blocking shutdown
437-
logger.warning(
438-
f'Error during MCP session cleanup for {session_key}',
439-
exc_info=True,
440-
)
441-
finally:
442-
del self._sessions[session_key]
507+
_, exit_stack, stored_loop = self._sessions[session_key]
508+
await self._cleanup_session(session_key, exit_stack, stored_loop)
443509

444510

445511
SseServerParams = SseConnectionParams

tests/unittests/tools/mcp_tool/test_mcp_session_manager.py

Lines changed: 166 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818
from io import StringIO
1919
import json
2020
import sys
21+
from unittest.mock import ANY
2122
from unittest.mock import AsyncMock
2223
from unittest.mock import Mock
2324
from unittest.mock import patch
2425

26+
from google.adk.platform import thread as platform_thread
2527
from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager
2628
from google.adk.tools.mcp_tool.mcp_session_manager import retry_on_errors
2729
from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams
@@ -298,6 +300,10 @@ async def test_create_session_stdio_new(self):
298300
assert session == mock_session
299301
assert len(manager._sessions) == 1
300302
assert "stdio_session" in manager._sessions
303+
session_data = manager._sessions["stdio_session"]
304+
assert len(session_data) == 3
305+
assert session_data[0] == mock_session
306+
assert session_data[2] == asyncio.get_running_loop()
301307

302308
# Verify SessionContext was created
303309
mock_session_context_class.assert_called_once()
@@ -312,7 +318,11 @@ async def test_create_session_reuse_existing(self):
312318
# Create mock existing session
313319
existing_session = MockClientSession()
314320
existing_exit_stack = MockAsyncExitStack()
315-
manager._sessions["stdio_session"] = (existing_session, existing_exit_stack)
321+
manager._sessions["stdio_session"] = (
322+
existing_session,
323+
existing_exit_stack,
324+
asyncio.get_running_loop(),
325+
)
316326

317327
# Session is connected
318328
existing_session._read_stream._closed = False
@@ -377,8 +387,16 @@ async def test_close_success(self):
377387
session2 = MockClientSession()
378388
exit_stack2 = MockAsyncExitStack()
379389

380-
manager._sessions["session1"] = (session1, exit_stack1)
381-
manager._sessions["session2"] = (session2, exit_stack2)
390+
manager._sessions["session1"] = (
391+
session1,
392+
exit_stack1,
393+
asyncio.get_running_loop(),
394+
)
395+
manager._sessions["session2"] = (
396+
session2,
397+
exit_stack2,
398+
asyncio.get_running_loop(),
399+
)
382400

383401
await manager.close()
384402

@@ -401,8 +419,16 @@ async def test_close_with_errors(self, mock_logger):
401419
session2 = MockClientSession()
402420
exit_stack2 = MockAsyncExitStack()
403421

404-
manager._sessions["session1"] = (session1, exit_stack1)
405-
manager._sessions["session2"] = (session2, exit_stack2)
422+
manager._sessions["session1"] = (
423+
session1,
424+
exit_stack1,
425+
asyncio.get_running_loop(),
426+
)
427+
manager._sessions["session2"] = (
428+
session2,
429+
exit_stack2,
430+
asyncio.get_running_loop(),
431+
)
406432

407433
# Should not raise exception
408434
await manager.close()
@@ -414,7 +440,7 @@ async def test_close_with_errors(self, mock_logger):
414440
# Error should be logged via logger.warning
415441
mock_logger.warning.assert_called_once()
416442
args, kwargs = mock_logger.warning.call_args
417-
assert "Error during MCP session cleanup for session1" in args[0]
443+
assert "Error during session cleanup for session1: Close error 1" in args[0]
418444
assert kwargs.get("exc_info")
419445

420446
@pytest.mark.asyncio
@@ -447,6 +473,140 @@ async def test_create_and_close_session_in_different_tasks(
447473
# Verify session was closed
448474
assert not manager._sessions
449475

476+
@pytest.mark.asyncio
477+
async def test_session_lock_different_loops(self):
478+
"""Verify that _session_lock returns different locks for different loops."""
479+
480+
manager = MCPSessionManager(self.mock_stdio_connection_params)
481+
482+
# Access in current loop
483+
lock1 = manager._session_lock
484+
assert isinstance(lock1, asyncio.Lock)
485+
486+
# Access in a different loop (in a separate thread)
487+
lock_container = []
488+
489+
def run_in_thread():
490+
loop2 = asyncio.new_event_loop()
491+
asyncio.set_event_loop(loop2)
492+
try:
493+
494+
async def get_lock():
495+
return manager._session_lock
496+
497+
lock_container.append(loop2.run_until_complete(get_lock()))
498+
finally:
499+
loop2.close()
500+
501+
thread = platform_thread.create_thread(target=run_in_thread)
502+
thread.start()
503+
thread.join()
504+
505+
assert lock_container
506+
lock2 = lock_container[0]
507+
assert isinstance(lock2, asyncio.Lock)
508+
assert lock1 is not lock2
509+
510+
@pytest.mark.asyncio
511+
async def test_cleanup_session_cross_loop(self):
512+
"""Verify that _cleanup_session uses run_coroutine_threadsafe for different loops."""
513+
manager = MCPSessionManager(self.mock_stdio_connection_params)
514+
mock_exit_stack = MockAsyncExitStack()
515+
516+
# Create a dummy loop that is "running" in another thread
517+
loop2 = asyncio.new_event_loop()
518+
try:
519+
with patch(
520+
"google.adk.tools.mcp_tool.mcp_session_manager.asyncio.run_coroutine_threadsafe"
521+
) as mock_run_threadsafe:
522+
with patch(
523+
"google.adk.tools.mcp_tool.mcp_session_manager.logger"
524+
) as mock_logger:
525+
# We need to mock the return value of run_coroutine_threadsafe to be a future
526+
mock_future = Mock()
527+
mock_run_threadsafe.return_value = mock_future
528+
529+
await manager._cleanup_session("test_session", mock_exit_stack, loop2)
530+
531+
# Verify run_coroutine_threadsafe was called
532+
# ANY is used because a new coroutine object is created each time
533+
mock_run_threadsafe.assert_called_once_with(ANY, loop2)
534+
535+
mock_logger.info.assert_any_call(
536+
"Scheduling cleanup of session test_session on its original"
537+
" event loop."
538+
)
539+
mock_future.add_done_callback.assert_called_once()
540+
finally:
541+
loop2.close()
542+
543+
@pytest.mark.asyncio
544+
async def test_create_session_cleans_up_without_aclose_if_loop_is_different(
545+
self,
546+
):
547+
"""Verify that sessions from different loops are cleaned up without calling aclose()."""
548+
manager = MCPSessionManager(self.mock_stdio_connection_params)
549+
550+
# 1. Simulate a session created in a "different" loop
551+
mock_session = MockClientSession()
552+
mock_exit_stack = MockAsyncExitStack()
553+
# Use a dummy object as a different loop
554+
different_loop = Mock(spec=asyncio.AbstractEventLoop)
555+
556+
manager._sessions["stdio_session"] = (
557+
mock_session,
558+
mock_exit_stack,
559+
different_loop,
560+
)
561+
562+
# 2. Mock creation of a new session
563+
# We need to mock create_client, wait_for, and SessionContext
564+
with patch.object(manager, "_create_client") as mock_create_client:
565+
with patch(
566+
"google.adk.tools.mcp_tool.mcp_session_manager.asyncio.wait_for"
567+
) as mock_wait_for:
568+
with patch(
569+
"google.adk.tools.mcp_tool.mcp_session_manager.SessionContext"
570+
) as mock_session_context_class:
571+
# Setup mocks for new session creation
572+
mock_create_client.return_value = AsyncMock()
573+
new_session = MockClientSession()
574+
mock_wait_for.return_value = new_session
575+
mock_session_context_class.return_value = AsyncMock()
576+
577+
# 3. Call create_session
578+
session = await manager.create_session()
579+
580+
# 4. Verify results
581+
assert session == new_session
582+
assert len(manager._sessions) == 1
583+
# Verify that old exit_stack.aclose was NOT called since loop was different
584+
mock_exit_stack.aclose.assert_not_called()
585+
586+
@pytest.mark.asyncio
587+
async def test_close_skips_aclose_for_different_loop_sessions(self):
588+
"""Verify that close() skips aclose() for sessions from different loops."""
589+
manager = MCPSessionManager(self.mock_stdio_connection_params)
590+
591+
# Add one session from same loop and one from different loop
592+
current_loop = asyncio.get_running_loop()
593+
different_loop = Mock(spec=asyncio.AbstractEventLoop)
594+
595+
session1 = MockClientSession()
596+
exit_stack1 = MockAsyncExitStack()
597+
manager._sessions["session1"] = (session1, exit_stack1, current_loop)
598+
599+
session2 = MockClientSession()
600+
exit_stack2 = MockAsyncExitStack()
601+
manager._sessions["session2"] = (session2, exit_stack2, different_loop)
602+
603+
await manager.close()
604+
605+
# exit_stack1 should be closed, exit_stack2 should be skipped
606+
exit_stack1.aclose.assert_called_once()
607+
exit_stack2.aclose.assert_not_called()
608+
assert len(manager._sessions) == 0
609+
450610

451611
@pytest.mark.asyncio
452612
async def test_retry_on_errors_decorator():

0 commit comments

Comments
 (0)
0