1818from io import StringIO
1919import json
2020import sys
21+ from unittest .mock import ANY
2122from unittest .mock import AsyncMock
2223from unittest .mock import Mock
2324from unittest .mock import patch
2425
26+ from google .adk .platform import thread as platform_thread
2527from google .adk .tools .mcp_tool .mcp_session_manager import MCPSessionManager
2628from google .adk .tools .mcp_tool .mcp_session_manager import retry_on_errors
2729from google .adk .tools .mcp_tool .mcp_session_manager import SseConnectionParams
@@ -298,6 +300,10 @@ async def test_create_session_stdio_new(self):
298300 assert session == mock_session
299301 assert len (manager ._sessions ) == 1
300302 assert "stdio_session" in manager ._sessions
303+ session_data = manager ._sessions ["stdio_session" ]
304+ assert len (session_data ) == 3
305+ assert session_data [0 ] == mock_session
306+ assert session_data [2 ] == asyncio .get_running_loop ()
301307
302308 # Verify SessionContext was created
303309 mock_session_context_class .assert_called_once ()
@@ -312,7 +318,11 @@ async def test_create_session_reuse_existing(self):
312318 # Create mock existing session
313319 existing_session = MockClientSession ()
314320 existing_exit_stack = MockAsyncExitStack ()
315- manager ._sessions ["stdio_session" ] = (existing_session , existing_exit_stack )
321+ manager ._sessions ["stdio_session" ] = (
322+ existing_session ,
323+ existing_exit_stack ,
324+ asyncio .get_running_loop (),
325+ )
316326
317327 # Session is connected
318328 existing_session ._read_stream ._closed = False
@@ -377,8 +387,16 @@ async def test_close_success(self):
377387 session2 = MockClientSession ()
378388 exit_stack2 = MockAsyncExitStack ()
379389
380- manager ._sessions ["session1" ] = (session1 , exit_stack1 )
381- manager ._sessions ["session2" ] = (session2 , exit_stack2 )
390+ manager ._sessions ["session1" ] = (
391+ session1 ,
392+ exit_stack1 ,
393+ asyncio .get_running_loop (),
394+ )
395+ manager ._sessions ["session2" ] = (
396+ session2 ,
397+ exit_stack2 ,
398+ asyncio .get_running_loop (),
399+ )
382400
383401 await manager .close ()
384402
@@ -401,8 +419,16 @@ async def test_close_with_errors(self, mock_logger):
401419 session2 = MockClientSession ()
402420 exit_stack2 = MockAsyncExitStack ()
403421
404- manager ._sessions ["session1" ] = (session1 , exit_stack1 )
405- manager ._sessions ["session2" ] = (session2 , exit_stack2 )
422+ manager ._sessions ["session1" ] = (
423+ session1 ,
424+ exit_stack1 ,
425+ asyncio .get_running_loop (),
426+ )
427+ manager ._sessions ["session2" ] = (
428+ session2 ,
429+ exit_stack2 ,
430+ asyncio .get_running_loop (),
431+ )
406432
407433 # Should not raise exception
408434 await manager .close ()
@@ -414,7 +440,7 @@ async def test_close_with_errors(self, mock_logger):
414440 # Error should be logged via logger.warning
415441 mock_logger .warning .assert_called_once ()
416442 args , kwargs = mock_logger .warning .call_args
417- assert "Error during MCP session cleanup for session1" in args [0 ]
443+ assert "Error during session cleanup for session1: Close error 1 " in args [0 ]
418444 assert kwargs .get ("exc_info" )
419445
420446 @pytest .mark .asyncio
@@ -447,6 +473,140 @@ async def test_create_and_close_session_in_different_tasks(
447473 # Verify session was closed
448474 assert not manager ._sessions
449475
476+ @pytest .mark .asyncio
477+ async def test_session_lock_different_loops (self ):
478+ """Verify that _session_lock returns different locks for different loops."""
479+
480+ manager = MCPSessionManager (self .mock_stdio_connection_params )
481+
482+ # Access in current loop
483+ lock1 = manager ._session_lock
484+ assert isinstance (lock1 , asyncio .Lock )
485+
486+ # Access in a different loop (in a separate thread)
487+ lock_container = []
488+
489+ def run_in_thread ():
490+ loop2 = asyncio .new_event_loop ()
491+ asyncio .set_event_loop (loop2 )
492+ try :
493+
494+ async def get_lock ():
495+ return manager ._session_lock
496+
497+ lock_container .append (loop2 .run_until_complete (get_lock ()))
498+ finally :
499+ loop2 .close ()
500+
501+ thread = platform_thread .create_thread (target = run_in_thread )
502+ thread .start ()
503+ thread .join ()
504+
505+ assert lock_container
506+ lock2 = lock_container [0 ]
507+ assert isinstance (lock2 , asyncio .Lock )
508+ assert lock1 is not lock2
509+
510+ @pytest .mark .asyncio
511+ async def test_cleanup_session_cross_loop (self ):
512+ """Verify that _cleanup_session uses run_coroutine_threadsafe for different loops."""
513+ manager = MCPSessionManager (self .mock_stdio_connection_params )
514+ mock_exit_stack = MockAsyncExitStack ()
515+
516+ # Create a dummy loop that is "running" in another thread
517+ loop2 = asyncio .new_event_loop ()
518+ try :
519+ with patch (
520+ "google.adk.tools.mcp_tool.mcp_session_manager.asyncio.run_coroutine_threadsafe"
521+ ) as mock_run_threadsafe :
522+ with patch (
523+ "google.adk.tools.mcp_tool.mcp_session_manager.logger"
524+ ) as mock_logger :
525+ # We need to mock the return value of run_coroutine_threadsafe to be a future
526+ mock_future = Mock ()
527+ mock_run_threadsafe .return_value = mock_future
528+
529+ await manager ._cleanup_session ("test_session" , mock_exit_stack , loop2 )
530+
531+ # Verify run_coroutine_threadsafe was called
532+ # ANY is used because a new coroutine object is created each time
533+ mock_run_threadsafe .assert_called_once_with (ANY , loop2 )
534+
535+ mock_logger .info .assert_any_call (
536+ "Scheduling cleanup of session test_session on its original"
537+ " event loop."
538+ )
539+ mock_future .add_done_callback .assert_called_once ()
540+ finally :
541+ loop2 .close ()
542+
543+ @pytest .mark .asyncio
544+ async def test_create_session_cleans_up_without_aclose_if_loop_is_different (
545+ self ,
546+ ):
547+ """Verify that sessions from different loops are cleaned up without calling aclose()."""
548+ manager = MCPSessionManager (self .mock_stdio_connection_params )
549+
550+ # 1. Simulate a session created in a "different" loop
551+ mock_session = MockClientSession ()
552+ mock_exit_stack = MockAsyncExitStack ()
553+ # Use a dummy object as a different loop
554+ different_loop = Mock (spec = asyncio .AbstractEventLoop )
555+
556+ manager ._sessions ["stdio_session" ] = (
557+ mock_session ,
558+ mock_exit_stack ,
559+ different_loop ,
560+ )
561+
562+ # 2. Mock creation of a new session
563+ # We need to mock create_client, wait_for, and SessionContext
564+ with patch .object (manager , "_create_client" ) as mock_create_client :
565+ with patch (
566+ "google.adk.tools.mcp_tool.mcp_session_manager.asyncio.wait_for"
567+ ) as mock_wait_for :
568+ with patch (
569+ "google.adk.tools.mcp_tool.mcp_session_manager.SessionContext"
570+ ) as mock_session_context_class :
571+ # Setup mocks for new session creation
572+ mock_create_client .return_value = AsyncMock ()
573+ new_session = MockClientSession ()
574+ mock_wait_for .return_value = new_session
575+ mock_session_context_class .return_value = AsyncMock ()
576+
577+ # 3. Call create_session
578+ session = await manager .create_session ()
579+
580+ # 4. Verify results
581+ assert session == new_session
582+ assert len (manager ._sessions ) == 1
583+ # Verify that old exit_stack.aclose was NOT called since loop was different
584+ mock_exit_stack .aclose .assert_not_called ()
585+
586+ @pytest .mark .asyncio
587+ async def test_close_skips_aclose_for_different_loop_sessions (self ):
588+ """Verify that close() skips aclose() for sessions from different loops."""
589+ manager = MCPSessionManager (self .mock_stdio_connection_params )
590+
591+ # Add one session from same loop and one from different loop
592+ current_loop = asyncio .get_running_loop ()
593+ different_loop = Mock (spec = asyncio .AbstractEventLoop )
594+
595+ session1 = MockClientSession ()
596+ exit_stack1 = MockAsyncExitStack ()
597+ manager ._sessions ["session1" ] = (session1 , exit_stack1 , current_loop )
598+
599+ session2 = MockClientSession ()
600+ exit_stack2 = MockAsyncExitStack ()
601+ manager ._sessions ["session2" ] = (session2 , exit_stack2 , different_loop )
602+
603+ await manager .close ()
604+
605+ # exit_stack1 should be closed, exit_stack2 should be skipped
606+ exit_stack1 .aclose .assert_called_once ()
607+ exit_stack2 .aclose .assert_not_called ()
608+ assert len (manager ._sessions ) == 0
609+
450610
451611@pytest .mark .asyncio
452612async def test_retry_on_errors_decorator ():
0 commit comments