10000 fix:fix mcp toolset close issue · ditinagrawal/adk-python@05a853b · GitHub
[go: up one dir, main page]

Skip to content

Commit 05a853b

Browse files
seanzhougooglecopybara-github
authored andcommitted
fix:fix mcp toolset close issue
PiperOrigin-RevId: 759636772
1 parent 12507dc commit 05a853b

File tree

3 files changed

+293
-22
lines changed

3 files changed

+293
-22
lines changed

src/google/adk/cli/fast_api.py

Lines changed: 65 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import logging
2222
import os
2323
from pathlib import Path
24+
import signal
2425
import sys
2526
import time
2627
import traceback
@@ -221,7 +222,7 @@ def get_fast_api_app(
221222
)
222223
provider.add_span_processor(processor)
223224
else:
224-
logging.warning(
225+
logger.warning(
225226
"GOOGLE_CLOUD_PROJECT environment variable is not set. Tracing will"
226227
" not be enabled."
227228
)
@@ -232,14 +233,71 @@ def get_fast_api_app(
232233

233234
@asynccontextmanager
234235
async def internal_lifespan(app: FastAPI):
235-
if lifespan:
236-
async with lifespan(app) as lifespan_context:
236+
# Set up signal handlers for graceful shutdown
237+
original_sigterm = signal.getsignal(signal.SIGTERM)
238+
original_sigint = signal.getsignal(signal.SIGINT)
239+
240+
def cleanup_handler(sig, frame):
241+
# Log the signal
242+
logger.info("Received signal %s, performing pre-shutdown cleanup", sig)
243+
# Do synchronous cleanup if needed
244+
# Then call original handler if it exists
245+
if sig == signal.SIGTERM and callable(original_sigterm):
246+
original_sigterm(sig, frame)
247+
elif sig == signal.SIGINT and callable(original_sigint):
248+
original_sigint(sig, frame)
249+
250+
# Install cleanup handlers
251+
signal.signal(signal.SIGTERM, cleanup_handler)
252+
signal.signal(signal.SIGINT, cleanup_handler)
253+
254+
try:
255+
if lifespan:
256+
async with lifespan(app) as lifespan_context:
257+
yield lifespan_context
258+
else:
237259
yield
260+
finally:
261+
# During shutdown, properly clean up all toolsets
262+
logger.info(
263+
"Server shutdown initiated, cleaning up %s toolsets",
264+
len(toolsets_to_close),
265+
)
238266

239-
for toolset in toolsets_to_close:
240-
await toolset.close()
241-
else:
242-
yield
267+
# Create tasks for all toolset closures to run concurrently
268+
cleanup_tasks = []
269+
for toolset in toolsets_to_close:
270+
task = asyncio.create_task(close_toolset_safely(toolset))
271+
cleanup_tasks.append(task)
272+
273+
if cleanup_tasks:
274+
# Wait for all cleanup tasks with timeout
275+
done, pending = await asyncio.wait(
276+
cleanup_tasks,
277+
timeout=10.0, # 10 second timeout for cleanup
278+
return_when=asyncio.ALL_COMPLETED,
279+
)
280+
281+
# If any tasks are still pending, log it
282+
if pending:
283+
logger.warn(
284+
f"{len(pending)} toolset cleanup tasks didn't complete in time"
285+
)
286+
for task in pending:
287+
task.cancel()
288+
289+
# Restore original signal handlers
290+
signal.signal(signal.SIGTERM, original_sigterm)
291+
signal.signal(signal.SIGINT, original_sigint)
292+
293+
async def close_toolset_safely(toolset):
294+
"""Safely close a toolset with error handling."""
295+
try:
296+
logger.info(f"Closing toolset: {type(toolset).__name__}")
297+
await toolset.close()
298+
logger.info(f"Successfully closed toolset: {type(toolset).__name__}")
299+
except Exception as e:
300+
logger.error(f"Error closing toolset {type(toolset).__name__}: {e}")
243301

244302
# Run the FastAPI server.
245303
app = FastAPI(lifespan=internal_lifespan)

src/google/adk/tools/mcp_tool/mcp_session_manager.py

Lines changed: 140 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from contextlib import AsyncExitStack
15+
import asyncio
16+
from contextlib import AsyncExitStack, asynccontextmanager
1617
import functools
18+
import logging
1719
import sys
18-
from typing import Any, TextIO
20+
from typing import Any, Optional, TextIO
1921
import anyio
2022
from pydantic import BaseModel
2123

@@ -34,6 +36,8 @@
3436
else:
3537
raise e
3638

39+
logger = logging.getLogger(__name__)
40+
3741

3842
class SseServerParams(BaseModel):
3943
"""Parameters for the MCP SSE connection.
@@ -108,6 +112,45 @@ async def wrapper(self, *args, **kwargs):
108112
return decorator
109113

110114

115+
@asynccontextmanager
116+
async def tracked_stdio_client(server, errlog, process=None):
117+
"""A wrapper around stdio_client that ensures proper process tracking and cleanup."""
118+
our_process = process
119+
120+
# If no process was provided, create one
121+
if our_process is None:
122+
our_process = await asyncio.create_subprocess_exec(
123+
server.command,
124+
*server.args,
125+
stdin=asyncio.subprocess.PIPE,
126+
stdout=asyncio.subprocess.PIPE,
127+
stderr=errlog,
128+
)
129+
130+
# Use the original stdio_client, but ensure process cleanup
131+
try:
132+
async with stdio_client(server=server, errlog=errlog) as client:
133+
yield client, our_process
134+
finally:
135+
# Ensure the process is properly terminated if it still exists
136+
if our_process and our_process.returncode is None:
137+
try:
138+
logger.info(
139+
f'Terminating process {our_process.pid} from tracked_stdio_client'
140+
)
141+
our_process.terminate()
142+
try:
143+
await asyncio.wait_for(our_process.wait(), timeout=3.0)
144+
except asyncio.TimeoutError:
145+
# Force kill if it doesn't terminate quickly
146+
if our_process.returncode is None:
147+
logger.warning(f'Forcing kill of process {our_process.pid}')
148+
our_process.kill()
149+
except ProcessLookupError:
150+
# Process already gone, that's fine
151+
logger.info(f'Process {our_process.pid} already terminated')
152+
153+
111154
class MCPSessionManager:
112155
"""Manages MCP client sessions.
113156
@@ -138,25 +181,39 @@ def __init__(
138181
errlog: (Optional) TextIO stream for error logging. Use only for
139182
initializing a local stdio MCP session.
140183
"""
184+
141185
self._connection_params = connection_params
142186
self._exit_stack = exit_stack
143187
self._errlog = errlog
188+
self._process = None # Track the subprocess
189+
self._active_processes = set() # Track all processes created
190+
self._active_file_handles = set() # Track file handles
144191

145-
async def create_session(self) -> ClientSession:
146-
return await MCPSessionManager.initialize_session(
192+
async def create_session(
193+
self,
194+
) -> tuple[ClientSession, Optional[asyncio.subprocess.Process]]:
195+
"""Creates a new MCP session and tracks the associated process."""
196+
session, process = await self._initialize_session(
147197
connection_params=self._connection_params,
148198
exit_stack=self._exit_stack,
149199
errlog=self._errlog,
150200
)
201+
self._process = process # Store reference to process
202+
203+
# Track the process
204+
if process:
205+
self._active_processes.add(process)
206+
207+
return session, process
151208

152209
@classmethod
153-
async def initialize_session(
210+
async def _initialize_session(
154211
cls,
155212
*,
156213
connection_params: StdioServerParameters | SseServerParams,
157214
exit_stack: AsyncExitStack,
158215
errlog: TextIO = sys.stderr,
159-
) -> ClientSession:
216+
) -> tuple[ClientSession, Optional[asyncio.subprocess.Process]]:
160217
"""Initializes an MCP client session.
161218
162219
Args:
@@ -168,9 +225,17 @@ async def initialize_session(
168225
Returns:
169226
ClientSession: The initialized MCP client session.
170227
"""
228+
process = None
229+
171230
if isinstance(connection_params, StdioServerParameters):
172-
client = stdio_client(server=connection_params, errlog=errlog)
231+
# For stdio connections, we need to track the subprocess
232+
client, process = await cls._create_stdio_client(
233+
server=connection_params,
234+
errlog=errlog,
235+
exit_stack=exit_stack,
236+
)
173237
elif isinstance(connection_params, SseServerParams):
238+
# For SSE connections, create the client without a subprocess
174239
client = sse_client(
175240
url=connection_params.url,
176241
headers=connection_params.headers,
@@ -184,7 +249,74 @@ async def initialize_session(
184249
f' {connection_params}'
185250
)
186251

252+
# Create the session with the client
187253
transports = await exit_stack.enter_async_context(client)
188254
session = await exit_stack.enter_async_context(ClientSession(*transports))
189255
await session.initialize()
190-
return session
256+
257+
return session, process
258+
259+
@staticmethod
260+
async def _create_stdio_client(
261+
server: StdioServerParameters,
262+
errlog: TextIO,
263+
exit_stack: AsyncExitStack,
264+
) -> tuple[Any, asyncio.subprocess.Process]:
265+
"""Create stdio client and return both the client and process.
266+
267+
This implementation adapts to how the MCP stdio_client is created.
268+
The actual implementation may need to be adjusted based on the MCP library
269+
structure.
270+
"""
271+
# Create the subprocess directly so we can track it
272+
process = await asyncio.create_subprocess_exec(
273+
server.command,
274+
*server.args,
275+
stdin=asyncio.subprocess.PIPE,
276+
stdout=asyncio.subprocess.PIPE,
277+
stderr=errlog,
278+
)
279+
280+
# Create the stdio client using the MCP library
281+
try:
282+
# Method 1: Try using the existing process if stdio_client supports it
283+
client = stdio_client(server=server, errlog=errlog, process=process)
284+
except TypeError:
285+
# Method 2: If the above doesn't work, let stdio_client create its own process
286+
# and we'll need to terminate both processes later
287+
logger.warning(
288+
'Using stdio_client with its own process - may lead to duplicate'
289+
' processes'
290+
)
291+
client = stdio_client(server=server, errlog=errlog)
292+
293+
return client, process
294+
295+
async def _emergency_cleanup(self):
296+
"""Perform emergency cleanup of resources when normal cleanup fails."""
297+
logger.info('Performing emergency cleanup of MCPSessionManager resources')
298+
299+
# Clean up any tracked processes
300+
for proc in list(self._active_processes):
301+
try:
302+
if proc and proc.returncode is None:
303+
logger.info(f'Emergency termination of process {proc.pid}')
304+
proc.terminate()
305+
try:
306+
await asyncio.wait_for(proc.wait(), timeout=1.0)
307+
except asyncio.TimeoutError:
308+
logger.warning(f"Process {proc.pid} didn't terminate, forcing kill")
309+
proc.kill()
310+
self._active_processes.remove(proc)
311+
except Exception as e:
312+
logger.error(f'Error during process cleanup: {e}')
313+
314+
# Clean up any tracked file handles
315+
for handle in list(self._active_file_handles):
316+
try:
317+
if not handle.closed:
318+
logger.info('Closing file handle')
319+
handle.close()
320+
self._active_file_handles.remove(handle)
321+
except Exception as e:
322+
logger.error(f'Error closing file handle: {e}')

0 commit comments

Comments
 (0)
0