12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- from contextlib import AsyncExitStack
15
+ import asyncio
16
+ from contextlib import AsyncExitStack , asynccontextmanager
16
17
import functools
18
+ import logging
17
19
import sys
18
- from typing import Any , TextIO
20
+ from typing import Any , Optional , TextIO
19
21
import anyio
20
22
from pydantic import BaseModel
21
23
34
36
else :
35
37
raise e
36
38
39
+ logger = logging .getLogger (__name__ )
40
+
37
41
38
42
class SseServerParams (BaseModel ):
39
43
"""Parameters for the MCP SSE connection.
@@ -108,6 +112,45 @@ async def wrapper(self, *args, **kwargs):
108
112
return decorator
109
113
110
114
115
+ @asynccontextmanager
116
+ async def tracked_stdio_client (server , errlog , process = None ):
117
+ """A wrapper around stdio_client that ensures proper process tracking and cleanup."""
118
+ our_process = process
119
+
120
+ # If no process was provided, create one
121
+ if our_process is None :
122
+ our_process = await asyncio .create_subprocess_exec (
123
+ server .command ,
124
+ * server .args ,
125
+ stdin = asyncio .subprocess .PIPE ,
126
+ stdout = asyncio .subprocess .PIPE ,
127
+ stderr = errlog ,
128
+ )
129
+
130
+ # Use the original stdio_client, but ensure process cleanup
131
+ try :
132
+ async with stdio_client (server = server , errlog = errlog ) as client :
133
+ yield client , our_process
134
+ finally :
135
+ # Ensure the process is properly terminated if it still exists
136
+ if our_process and our_process .returncode is None :
137
+ try :
138
+ logger .info (
139
+ f'Terminating process { our_process .pid } from tracked_stdio_client'
140
+ )
141
+ our_process .terminate ()
142
+ try :
143
+ await asyncio .wait_for (our_process .wait (), timeout = 3.0 )
144
+ except asyncio .TimeoutError :
145
+ # Force kill if it doesn't terminate quickly
146
+ if our_process .returncode is None :
147
+ logger .warning (f'Forcing kill of process { our_process .pid } ' )
148
+ our_process .kill ()
149
+ except ProcessLookupError :
150
+ # Process already gone, that's fine
151
+ logger .info (f'Process { our_process .pid } already terminated' )
152
+
153
+
111
154
class MCPSessionManager :
112
155
"""Manages MCP client sessions.
113
156
@@ -138,25 +181,39 @@ def __init__(
138
181
errlog: (Optional) TextIO stream for error logging. Use only for
139
182
initializing a local stdio MCP session.
140
183
"""
184
+
141
185
self ._connection_params = connection_params
142
186
self ._exit_stack = exit_stack
143
187
self ._errlog = errlog
188
+ self ._process = None # Track the subprocess
189
+ self ._active_processes = set () # Track all processes created
190
+ self ._active_file_handles = set () # Track file handles
144
191
145
- async def create_session (self ) -> ClientSession :
146
- return await MCPSessionManager .initialize_session (
192
+ async def create_session (
193
+ self ,
194
+ ) -> tuple [ClientSession , Optional [asyncio .subprocess .Process ]]:
195
+ """Creates a new MCP session and tracks the associated process."""
196
+ session , process = await self ._initialize_session (
147
197
connection_params = self ._connection_params ,
148
198
exit_stack = self ._exit_stack ,
149
199
errlog = self ._errlog ,
150
200
)
201
+ self ._process = process # Store reference to process
202
+
203
+ # Track the process
204
+ if process :
205
+ self ._active_processes .add (process )
206
+
207
+ return session , process
151
208
152
209
@classmethod
153
- async def initialize_session (
210
+ async def _initialize_session (
154
211
cls ,
155
212
* ,
156
213
connection_params : StdioServerParameters | SseServerParams ,
157
214
exit_stack : AsyncExitStack ,
158
215
errlog : TextIO = sys .stderr ,
159
- ) -> ClientSession :
216
+ ) -> tuple [ ClientSession , Optional [ asyncio . subprocess . Process ]] :
160
217
"""Initializes an MCP client session.
161
218
162
219
Args:
@@ -168,9 +225,17 @@ async def initialize_session(
168
225
Returns:
169
226
ClientSession: The initialized MCP client session.
170
227
"""
228
+ process = None
229
+
171
230
if isinstance (connection_params , StdioServerParameters ):
172
- client = stdio_client (server = connection_params , errlog = errlog )
231
+ # For stdio connections, we need to track the subprocess
232
+ client , process = await cls ._create_stdio_client (
233
+ server = connection_params ,
234
+ errlog = errlog ,
235
+ exit_stack = exit_stack ,
236
+ )
173
237
elif isinstance (connection_params , SseServerParams ):
238
+ # For SSE connections, create the client without a subprocess
174
239
client = sse_client (
175
240
url = connection_params .url ,
176
241
headers = connection_params .headers ,
@@ -184,7 +249,74 @@ async def initialize_session(
184
249
f' { connection_params } '
185
250
)
186
251
252
+ # Create the session with the client
187
253
transports = await exit_stack .enter_async_context (client )
188
254
session = await exit_stack .enter_async_context (ClientSession (* transports ))
189
255
await session .initialize ()
190
- return session
256
+
257
+ return session , process
258
+
259
+ @staticmethod
260
+ async def _create_stdio_client (
261
+ server : StdioServerParameters ,
262
+ errlog : TextIO ,
263
+ exit_stack : AsyncExitStack ,
264
+ ) -> tuple [Any , asyncio .subprocess .Process ]:
265
+ """Create stdio client and return both the client and process.
266
+
267
+ This implementation adapts to how the MCP stdio_client is created.
268
+ The actual implementation may need to be adjusted based on the MCP library
269
+ structure.
270
+ """
271
+ # Create the subprocess directly so we can track it
272
+ process = await asyncio .create_subprocess_exec (
273
+ server .command ,
274
+ * server .args ,
275
+ stdin = asyncio .subprocess .PIPE ,
276
+ stdout = asyncio .subprocess .PIPE ,
277
+ stderr = errlog ,
278
+ )
279
+
280
+ # Create the stdio client using the MCP library
281
+ try :
282
+ # Method 1: Try using the existing process if stdio_client supports it
283
+ client = stdio_client (server = server , errlog = errlog , process = process )
284
+ except TypeError :
285
+ # Method 2: If the above doesn't work, let stdio_client create its own process
286
+ # and we'll need to terminate both processes later
287
+ logger .warning (
288
+ 'Using stdio_client with its own process - may lead to duplicate'
289
+ ' processes'
290
+ )
291
+ client = stdio_client (server = server , errlog = errlog )
292
+
293
+ return client , process
294
+
295
+ async def _emergency_cleanup (self ):
296
+ """Perform emergency cleanup of resources when normal cleanup fails."""
297
+ logger .info ('Performing emergency cleanup of MCPSessionManager resources' )
298
+
299
+ # Clean up any tracked processes
300
+ for proc in list (self ._active_processes ):
301
+ try :
302
+ if proc and proc .returncode is None :
303
+ logger .info (f'Emergency termination of process { proc .pid } ' )
304
+ proc .terminate ()
305
+ try :
306
+ await asyncio .wait_for (proc .wait (), timeout = 1.0 )
307
+ except asyncio .TimeoutError :
308
+ logger .warning (f"Process { proc .pid } didn't terminate, forcing kill" )
309
+ proc .kill ()
310
+ self ._active_processes .remove (proc )
311
+ except Exception as e :
312
+ logger .error (f'Error during process cleanup: { e } ' )
313
+
314
+ # Clean up any tracked file handles
315
+ for handle in list (self ._active_file_handles ):
316
+ try :
317
+ if not handle .closed :
318
+ logger .info ('Closing file handle' )
319
+ handle .close ()
320
+ self ._active_file_handles .remove (handle )
321
+ except Exception as e :
322
+ logger .error (f'Error closing file handle: { e } ' )
0 commit comments