1414
1515from __future__ import annotations
1616
17+ import asyncio
1718from contextlib import AsyncExitStack
1819from datetime import timedelta
1920import functools
3435 from mcp .client .stdio import stdio_client
3536 from mcp .client .streamable_http import streamablehttp_client
3637except ImportError as e :
37- import sys
3838
3939 if sys .version_info < (3 , 10 ):
4040 raise ImportError (
@@ -105,30 +105,29 @@ class StreamableHTTPConnectionParams(BaseModel):
105105 terminate_on_close : bool = True
106106
107107
108- def retry_on_closed_resource (async_reinit_func_name : str ):
108+ def retry_on_closed_resource (session_manager_field_name : str ):
109109 """Decorator to automatically reinitialize session and retry action.
110110
111111 When MCP session was closed, the decorator will automatically recreate the
112112 session and retry the action with the same parameters.
113113
114114 Note:
115- 1. async_reinit_func_name is the name of the class member function that
116- reinitializes the MCP session.
117- 2. Both the decorated function and the async_reinit_func_name must be async
118- functions.
115+ 1. session_manager_field_name is the name of the class member field that
116+ contains the MCPSessionManager instance.
117+ 2. The session manager must have a reinitialize_session() async method.
119118
120119 Usage:
121120 class MCPTool:
122- ...
123- async def create_session(self):
124- self.session = ...
121+ def __init__(self):
122+ self._mcp_session_manager = MCPSessionManager(...)
125123
126- @retry_on_closed_resource('create_session ')
124+ @retry_on_closed_resource('_mcp_session_manager ')
127125 async def use_session(self):
128- await self.session.call_tool()
126+ session = await self._mcp_session_manager.create_session()
127+ await session.call_tool()
129128
130129 Args:
131- async_reinit_func_name : The name of the async function to recreate session .
130+ session_manager_field_name : The name of the session manager field .
132131
133132 Returns:
134133 The decorated function.
@@ -141,15 +140,21 @@ async def wrapper(self, *args, **kwargs):
141140 return await func (self , * args , ** kwargs )
142141 except anyio .ClosedResourceError as close_err :
143142 try :
144- if hasattr (self , async_reinit_func_name ) and callable (
145- getattr (self , async_reinit_func_name )
146- ):
147- async_init_fn = getattr (self , async_reinit_func_name )
148- await async_init_fn ()
143+
45C0
if hasattr (self , session_manager_field_name ):
144+ session_manager = getattr (self , session_manager_field_name )
145+ if hasattr (session_manager , 'reinitialize_session' ) and callable (
146+ getattr (session_manager , 'reinitialize_session' )
147+ ):
148+ await session_manager .reinitialize_session ()
149+ else :
150+ raise ValueError (
151+ f'Session manager { session_manager_field_name } does not have'
152+ ' reinitialize_session method.'
153+ ) from close_err
149154 else :
150155 raise ValueError (
C600
151- f'Function { async_reinit_func_name } does not exist in decorated '
152- ' class. Please check the function name in'
156+ f'Session manager field { session_manager_field_name } does not'
157+ ' exist in decorated class. Please check the field name in'
153158 ' retry_on_closed_resource decorator.'
154159 ) from close_err
155160 except Exception as reinit_err :
@@ -207,90 +212,111 @@ def __init__(
207212 # Each session manager maintains its own exit stack for proper cleanup
208213 self ._exit_stack : Optional [AsyncExitStack ] = None
209214 self ._session : Optional [ClientSession ] = None
215+ # Lock to prevent race conditions in session creation
216+ self ._session_lock = asyncio .Lock ()
210217
211218 async def create_session (self ) -> ClientSession :
212219 """Creates and initializes an MCP client session.
213220
214221 Returns:
215222 ClientSession: The initialized MCP client session.
216223 """
224+ # Fast path: if session already exists, return it without acquiring lock
217225 if self ._session is not None :
218226 return self ._session
219227
220- # Create a new exit stack for this session
221- self ._exit_stack = AsyncExitStack ()
222-
223- try :
224- if isinstance (self ._connection_params , StdioConnectionParams ):
225- client = stdio_client (
226- server = self ._connection_params .server_params ,
227- errlog = self ._errlog ,
228- )
229- elif isinstance (self ._connection_params , SseConnectionParams ):
230- client = sse_client (
231- url = self ._connection_params .url ,
232- headers = self ._connection_params .headers ,
233- timeout = self ._connection_params .timeout ,
234- sse_read_timeout = self ._connection_params .sse_read_timeout ,
235- )
236- elif isinstance (self ._connection_params , StreamableHTTPConnectionParams ):
237- client = streamablehttp_client (
238- url = self ._connection_params .url ,
239- headers = self ._connection_params .headers ,
240- timeout = timedelta (seconds = self ._connection_params .timeout ),
241- sse_read_timeout = timedelta (
242- seconds = self ._connection_params .sse_read_timeout
243- ),
244- terminate_on_close = self ._connection_params .terminate_on_close ,
245- )
246- else :
247- raise ValueError (
248- 'Unable to initialize connection. Connection should be'
249- ' StdioServerParameters or SseServerParams, but got'
250- f' { self ._connection_params } '
251- )
252-
253- transports = await self ._exit_stack .enter_async_context (client )
254- # The streamable http client returns a GetSessionCallback in addition to the read/write MemoryObjectStreams
255- # needed to build the ClientSession, we limit then to the two first values to be compatible with all clients.
256- if isinstance (self ._connection_params , StdioConnectionParams ):
257- session = await self ._exit_stack .enter_async_context (
258- ClientSession (
259- * transports [:2 ],
260- read_timeout_seconds = timedelta (
261- seconds = self ._connection_params .timeout
262- ),
263- )
264- )
265- else :
266- session = await self ._exit_stack .enter_async_context (
267- ClientSession (* transports [:2 ])
268- )
269- await session .initialize ()
270-
271- self ._session = session
272- return session
273-
274- except Exception :
275- # If session creation fails, clean up the exit stack
276- if self ._exit_stack :
277- await self ._exit_stack .aclose ()
278- self ._exit_stack = None
279- raise
228+ # Use async lock to prevent race conditions
229+ async with self ._session_lock :
230+ # Double-check: session might have been created while waiting for lock
231+ if self ._session is not None :
232+ return self ._session
233+
234+ # Create a new exit stack for this session
235+ self ._exit_stack = AsyncExitStack ()
236+
237+ try :
238+ if isinstance (self ._connection_params , StdioConnectionParams ):
239+ client = stdio_client (
240+ server = self ._connection_params .server_params ,
241+ errlog = self ._errlog ,
242+ )
243+ elif isinstance (self ._connection_params , SseConnectionParams ):
244+ client = sse_client (
245+ url = self ._connection_params .url ,
246+ headers = self ._connection_params .headers ,
247+ timeout = self ._connection_params .timeout ,
248+ sse_read_timeout = self ._connection_params .sse_read_timeout ,
249+ )
250+ elif isinstance (
251+ self ._connection_params , StreamableHTTPConnectionParams
252+ ):
253+ client = streamablehttp_client (
254+ url = self ._connection_params .url ,
255+ headers = self ._connection_params .headers ,
256+ timeout = timedelta (seconds = self ._connection_params .timeout ),
257+ sse_read_timeout = timedelta (
258+ seconds = self ._connection_params .sse_read_timeout
259+ ),
260+ terminate_on_close = self ._connection_params .terminate_on_close ,
261+ )
262+ else :
263+ raise ValueError (
264+ 'Unable to initialize connection. Connection should be'
265+ ' StdioServerParameters or SseServerParams, but got'
266+ f' { self ._connection_params } '
267+ )
268+
269+ transports = await self ._exit_stack .enter_async_context (client )
270+ # The streamable http client returns a GetSessionCallback in addition to the read/write MemoryObjectStreams
271+ # needed to build the ClientSession, we limit then to the two first values to be compatible with all clients.
272+ if isinstance (self ._connection_params , StdioConnectionParams ):
273+ session = await self ._exit_stack .enter_async_context (
274+ ClientSession (
275+ * transports [:2 ],
276+ read_timeout_seconds = timedelta (
277+ seconds = self ._connection_params .timeout
278+ ),
279+ )
280+ )
281+ else :
282+ session = await self ._exit_stack .enter_async_context (
283+ ClientSession (* transports [:2 ])
284+ )
285+ await session .initialize ()
286+
287+ self ._session = session
288+ return session
289+
290+ except Exception :
291+ # If session creation fails, clean up the exit stack
292+ if self ._exit_stack :
293+ await self ._exit_stack .aclose ()
294+ self ._exit_stack = None
295+ raise
280296
281297 async def close (self ):
282298 """Closes the session and cleans up resources."""
283- if self ._exit_stack :
284- try :
285- await self ._exit_stack .aclose ()
286- except Exception as e :
287- # Log the error but don't re-raise to avoid blocking shutdown
288- print (
289- f'Warning: Error during MCP session cleanup: { e } ' , file = self ._errlog
290- )
291- finally :
292- self ._exit_stack = None
293- self ._session = None
299+ if not self ._exit_stack :
300+ return
301+ async with self ._session_lock :
302+ if self ._exit_stack :
303+ try :
304+ await self ._exit_stack .aclose ()
305+ except Exception as e :
306+ # Log the error but don't re-raise to avoid blocking shutdown
307+ print (
308+ f'Warning: Error during MCP session cleanup: { e } ' ,
309+ file = self ._errlog ,
310+ )
311+ finally :
312+ self ._exit_stack = None
313+ self ._session = None
314+
315+ async def reinitialize_session (self ):
316+ """Reinitializes the session when connection is lost."""
317+ # Close the old session and create a new one
318+ await self .close ()
319+ await self .create_session ()
294320
295321
296322SseServerParams = SseConnectionParams
0 commit comments