14
14
15
15
from __future__ import annotations
16
16
17
+ import asyncio
17
18
from contextlib import AsyncExitStack
18
19
from datetime import timedelta
19
20
import functools
34
35
from mcp .client .stdio import stdio_client
35
36
from mcp .client .streamable_http import streamablehttp_client
36
37
except ImportError as e :
37
- import sys
38
38
39
39
if sys .version_info < (3 , 10 ):
40
40
raise ImportError (
@@ -105,30 +105,29 @@ class StreamableHTTPConnectionParams(BaseModel):
105
105
terminate_on_close : bool = True
106
106
107
107
108
- def retry_on_closed_resource (async_reinit_func_name : str ):
108
+ def retry_on_closed_resource (session_manager_field_name : str ):
109
109
"""Decorator to automatically reinitialize session and retry action.
110
110
111
111
When MCP session was closed, the decorator will automatically recreate the
112
112
session and retry the action with the same parameters.
113
113
114
114
Note:
115
- 1. async_reinit_func_name is the name of the class member function that
116
- reinitializes the MCP session.
117
- 2. Both the decorated function and the async_reinit_func_name must be async
118
- functions.
115
+ 1. session_manager_field_name is the name of the class member field that
116
+ contains the MCPSessionManager instance.
117
+ 2. The session manager must have a reinitialize_session() async method.
119
118
120
119
Usage:
121
120
class MCPTool:
122
- ...
123
- async def create_session(self):
124
- self.session = ...
121
+ def __init__(self):
122
+ self._mcp_session_manager = MCPSessionManager(...)
125
123
126
- @retry_on_closed_resource('create_session ')
124
+ @retry_on_closed_resource('_mcp_session_manager ')
127
125
async def use_session(self):
128
- await self.session.call_tool()
126
+ session = await self._mcp_session_manager.create_session()
127
+ await session.call_tool()
129
128
130
129
Args:
131
- async_reinit_func_name : The name of the async function to recreate session .
130
+ session_manager_field_name : The name of the session manager field .
132
131
133
132
Returns:
134
133
The decorated function.
@@ -141,15 +140,21 @@ async def wrapper(self, *args, **kwargs):
141
140
return await func (self , * args , ** kwargs )
142
141
except anyio .ClosedResourceError as close_err :
143
142
try :
144
- if hasattr (self , async_reinit_func_name ) and callable (
145
- getattr (self , async_reinit_func_name )
146
- ):
147
- async_init_fn = getattr (self , async_reinit_func_name )
148
- await async_init_fn ()
143
+ if hasattr (self , session_manager_field_name ):
144
+ session_manager = getattr (self , session_manager_field_name )
145
+ if hasattr (session_manager , 'reinitialize_session' ) and callable (
146
+ getattr (session_manager , 'reinitialize_session' )
147
+ ):
148
+ await session_manager .reinitialize_session ()
149
+ else :
150
+ raise ValueError (
151
+ f'Session manager { session_manager_field_name } does not have'
152
+ ' reinitialize_session method.'
153
+ ) from close_err
149
154
else :
150
155
raise ValueError (
151
- f'Function { async_reinit_func_name } does not exist in decorated '
152
- ' class. Please check the function name in'
156
+ f'Session manager field { session_manager_field_name } does not'
157
+ ' exist in decorated class. Please check the field name in'
153
158
' retry_on_closed_resource decorator.'
154
159
) from close_err
155
160
except Exception as reinit_err :
@@ -207,90 +212,111 @@ def __init__(
207
212
# Each session manager maintains its own exit stack for proper cleanup
208
213
self ._exit_stack : Optional [AsyncExitStack ] = None
209
214
self ._session : Optional [ClientSession ] = None
215
+ # Lock to prevent race conditions in session creation
216
+ self ._session_lock = asyncio .Lock ()
210
217
211
218
async def create_session (self ) -> ClientSession :
212
219
"""Creates and initializes an MCP client session.
213
220
214
221
Returns:
215
222
ClientSession: The initialized MCP client session.
216
223
"""
224
+ # Fast path: if session already exists, return it without acquiring lock
217
225
if self ._session is not None :
218
226
return self ._session
219
227
220
- # Create a new exit stack for this session
221
- self ._exit_stack = AsyncExitStack ()
222
-
223
- try :
224
- if isinstance (self ._connection_params , StdioConnectionParams ):
225
- client = stdio_client (
226
- server = self ._connection_params .server_params ,
227
- errlog = self ._errlog ,
228
- )
229
- elif isinstance (self ._connection_params , SseConnectionParams ):
230
- client = sse_client (
231
- url = self ._connection_params .url ,
232
- headers = self ._connection_params .headers ,
233
- timeout = self ._connection_params .timeout ,
234
- sse_read_timeout = self ._connection_params .sse_read_timeout ,
235
- )
236
- elif isinstance (self ._connection_params , StreamableHTTPConnectionParams ):
237
- client = streamablehttp_client (
238
- url = self ._connection_params .url ,
239
- headers = self ._connection_params .headers ,
240
- timeout = timedelta (seconds = self ._connection_params .timeout ),
241
- sse_read_timeout = timedelta (
242
- seconds = self ._connection_params .sse_read_timeout
243
- ),
244
- terminate_on_close = self ._connection_params .terminate_on_close ,
245
- )
246
- else :
247
- raise ValueError (
248
- 'Unable to initialize connection. Connection should be'
249
- ' StdioServerParameters or SseServerParams, but got'
250
- f' { self ._connection_params } '
251
- )
252
-
253
- transports = await self ._exit_stack .enter_async_context (client )
254
- # The streamable http client returns a GetSessionCallback in addition to the read/write MemoryObjectStreams
255
- # needed to build the ClientSession, we limit then to the two first values to be compatible with all clients.
256
- if isinstance (self ._connection_params , StdioConnectionParams ):
257
- session = await self ._exit_stack .enter_async_context (
258
- ClientSession (
259
- * transports [:2 ],
260
- read_timeout_seconds = timedelta (
261
- seconds = self ._connection_params .timeout
262
- ),
263
- )
264
- )
265
- else :
266
- session = await self ._exit_stack .enter_async_context (
267
- ClientSession (* transports [:2 ])
268
- )
269
- await session .initialize ()
270
-
271
- self ._session = session
272
- return session
273
-
274
- except Exception :
275
- # If session creation fails, clean up the exit stack
276
- if self ._exit_stack :
277
- await self ._exit_stack .aclose ()
278
- self ._exit_stack = None
279
- raise
228
+ # Use async lock to prevent race conditions
229
+ async with self ._session_lock :
230
+ # Double-check: session might have been created while waiting for lock
231
+ if self ._session is not None :
232
+ return self ._session
233
+
234
+ # Create a new exit stack for this session
235
+ self ._exit_stack = AsyncExitStack ()
236
+
237
+ try :
238
+ if isinstance (self ._connection_params , StdioConnectionParams ):
239
+ client = stdio_client (
240
+ server = self ._connection_params .server_params ,
241
+ errlog = self ._errlog ,
242
+ )
243
+ elif isinstance (self ._connection_params , SseConnectionParams ):
244
+ client = sse_client (
245
+ url = self ._connection_params .url ,
246
+ headers = self ._connection_params .headers ,
247
+ timeout = self ._connection_params .timeout ,
248
+ sse_read_timeout = self ._connection_params .sse_read_timeout ,
249
+ )
250
+ elif isinstance (
251
+ self ._connection_params , StreamableHTTPConnectionParams
252
+ ):
253
+ client = streamablehttp_client (
254
+ url = self ._connection_params .url ,
255
+ headers = self ._connection_params .headers ,
<
10000
code> 256
+ timeout = timedelta (seconds = self ._connection_params .timeout ),
257
+ sse_read_timeout = timedelta (
258
+ seconds = self ._connection_params .sse_read_timeout
259
+ ),
260
+ terminate_on_close = self ._connection_params .terminate_on_close ,
261
+ )
262
+ else :
263
+ raise ValueError (
264
+ 'Unable to initialize connection. Connection should be'
265
+ ' StdioServerParameters or SseServerParams, but got'
266
+ f' { self ._connection_params } '
267
+ )
268
+
269
+ transports = await self ._exit_stack .enter_async_context (client )
270
+ # The streamable http client returns a GetSessionCallback in addition to the read/write MemoryObjectStreams
271
+ # needed to build the ClientSession, we limit then to the two first values to be compatible with all clients.
272
+ if isinstance (self ._connection_params , StdioConnectionParams ):
273
+ session = await self ._exit_stack .enter_async_context (
274
+ ClientSession (
275
+ * transports [:2 ],
276
+ read_timeout_seconds = timedelta (
277
+ seconds = self ._connection_params .timeout
278
+ ),
279
+ )
280
+ )
281
+ else :
282
+ session = await self ._exit_stack .enter_async_context (
283
+ ClientSession (* transports [:2 ])
284
+ )
285
+ await session .initialize ()
286
+
287
+ self ._session = session
288
+ return session
289
+
290
+ except Exception :
291
+ # If session creation fails, clean up the exit stack
292
+ if self ._exit_stack :
293
+ await self ._exit_stack .aclose ()
294
+ self ._exit_stack = None
295
+ raise
280
296
281
297
async def close (self ):
282
298
"""Closes the session and cleans up resources."""
283
- if self ._exit_stack :
284
- try :
285
- await self ._exit_stack .aclose ()
286
- except Exception as e :
287
- # Log the error but don't re-raise to avoid blocking shutdown
288
- print (
289
- f'Warning: Error during MCP session cleanup: { e } ' , file = self ._errlog
290
- )
291
- finally :
292
- self ._exit_stack = None
293
- self ._session = None
299
+ if not self ._exit_stack :
300
+ return
301
+ async with self ._session_lock :
302
+ if self ._exit_stack :
303
+ try :
304
+ await self ._exit_stack .aclose ()
305
+ except Exception as e :
306
+ # Log the error but don't re-raise to avoid blocking shutdown
307
+ print (
308
+ f'Warning: Error during MCP session cleanup: { e } ' ,
309
+ file = self ._errlog ,
310
+ )
311
+ finally :
312
+ self ._exit_stack = None
313
+ self ._session = None
314
+
315
+ async def reinitialize_session (self ):
316
+ """Reinitializes the session when connection is lost."""
317
+ # Close the old session and create a new one
318
+ await self .close ()
319
+ await self .create_session ()
294
320
295
321
296
322
SseServerParams = SseConnectionParams
0 commit comments