18
18
from contextlib import AsyncExitStack
19
19
from datetime import timedelta
20
20
import functools
21
+ import hashlib
22
+ import json
21
23
import logging
22
24
import sys
23
25
from typing import Any
26
+ from typing import Dict
24
27
from typing import Optional
25
28
from typing import TextIO
26
29
from typing import Union
@@ -105,74 +108,39 @@ class StreamableHTTPConnectionParams(BaseModel):
105
108
terminate_on_close : bool = True
106
109
107
110
108
- def retry_on_closed_resource (session_manager_field_name : str ):
109
- """Decorator to automatically reinitialize session and retry action .
111
+ def retry_on_closed_resource (func ):
112
+ """Decorator to automatically retry action when MCP session is closed .
110
113
111
- When MCP session was closed, the decorator will automatically recreate the
112
- session and retry the action with the same parameters.
113
-
114
- Note:
115
- 1. session_manager_field_name is the name of the class member field that
116
- contains the MCPSessionManager instance.
117
- 2. The session manager must have a reinitialize_session() async method.
118
-
119
- Usage:
120
- class MCPTool:
121
- def __init__(self):
122
- self._mcp_session_manager = MCPSessionManager(...)
123
-
124
- @retry_on_closed_resource('_mcp_session_manager')
125
- async def use_session(self):
126
- session = await self._mcp_session_manager.create_session()
127
- await session.call_tool()
114
+ When MCP session was closed, the decorator will automatically retry the
115
+ action once. The create_session method will handle creating a new session
116
+ if the old one was disconnected.
128
117
129
118
Args:
130
- session_manager_field_name : The name of the session manager field .
119
+ func : The function to decorate .
131
120
132
121
Returns:
133
122
The decorated function.
134
123
"""
135
124
136
- def decorator (func ):
137
- @functools .wraps (func ) # Preserves original function metadata
138
- async def wrapper (self , * args , ** kwargs ):
139
- try :
140
- return await func (self , * args , ** kwargs )
141
- except anyio .ClosedResourceError as close_err :
142
- try :
143
- if hasattr (self , session_manager_field_name ):
144
- session_manager = getattr (self , session_manager_field_name )
145
- if hasattr (session_manager , 'reinitialize_session' ) and callable (
146
- getattr (session_manager , 'reinitialize_session' )
147
- ):
148
- await session_manager .reinitialize_session ()
149
- else :
150
- raise ValueError (
151
- f'Session manager { session_manager_field_name } does not have'
152
- ' reinitialize_session method.'
153
- ) from close_err
154
- else :
155
- raise ValueError (
156
- f'Session manager field { session_manager_field_name } does not'
157
- ' exist in decorated class. Please check the field name in'
158
- ' retry_on_closed_resource decorator.'
159
- ) from close_err
160
- except Exception as reinit_err :
161
- raise RuntimeError (
162
- f'Error reinitializing: { reinit_err } '
163
- ) from reinit_err
164
- return await func (self , * args , ** kwargs )
165
-
166
- return wrapper
167
-
168
- return decorator
125
+ @functools .wraps (func ) # Preserves original function metadata
126
+ async def wrapper (self , * args , ** kwargs ):
127
+ try :
128
+ return await func (self , * args , ** kwargs )
129
+ except anyio .ClosedResourceError :
130
+ # Simply retry the function - create_session will handle
131
+ # detecting and replacing disconnected sessions
132
+ logger .info ('Retrying %s due to closed resource' , func .__name__ )
133
+ return await func (self , * args , ** kwargs )
134
+
135
+ return wrapper
169
136
170
137
171
138
class MCPSessionManager :
172
139
"""Manages MCP client sessions.
173
140
174
141
This class provides methods for creating and initializing MCP client sessions,
175
- handling different connection parameters (Stdio and SSE).
142
+ handling different connection parameters (Stdio and SSE) and supporting
143
+ session pooling based on authentication headers.
176
144
"""
177
145
178
146
def __init__ (
@@ -209,30 +177,125 @@ def __init__(
209
177
else :
210
178
self ._connection_params = connection_params
211
179
self ._errlog = errlog
212
- # Each session manager maintains its own exit stack for proper cleanup
213
- self ._exit_stack : Optional [AsyncExitStack ] = None
214
- self ._session : Optional [ClientSession ] = None
180
+
181
+ # Session pool: maps session keys to (session, exit_stack) tuples
182
+ self ._sessions : Dict [str , tuple [ClientSession , AsyncExitStack ]] = {}
183
+
215
184
# Lock to prevent race conditions in session creation
216
185
self ._session_lock = asyncio .Lock ()
217
186
218
- async def create_session (self ) -> ClientSession :
187
+ def _generate_session_key (
188
+ self , merged_headers : Optional [Dict [str , str ]] = None
189
+ ) -> str :
190
+ """Generates a session key based on connection params and merged headers.
191
+
192
+ For StdioConnectionParams, returns a constant key since headers are not
193
+ supported. For SSE and StreamableHTTP connections, generates a key based
194
+ on the provided merged headers.
195
+
196
+ Args:
197
+ merged_headers: Already merged headers (base + additional).
198
+
199
+ Returns:
200
+ A unique session key string.
201
+ """
202
+ if isinstance (self ._connection_params , StdioConnectionParams ):
203
+ # For stdio connections, headers are not supported, so use constant key
204
+ return 'stdio_session'
205
+
206
+ # For SSE and StreamableHTTP connections, use merged headers
207
+ if merged_headers :
208
+ headers_json = json .dumps (merged_headers , sort_keys = True )
209
+ headers_hash = hashlib .md5 (headers_json .encode ()).hexdigest ()
210
+ return f'session_{ headers_hash } '
211
+ else :
212
+ return 'session_no_headers'
213
+
214
+ def _merge_headers (
215
+ self , additional_headers : Optional [Dict [str , str ]] = None
216
+ ) -> Optional [Dict [str , str ]]:
217
+ """Merges base connection headers with additional headers.
218
+
219
+ Args:
220
+ additional_headers: Optional headers to merge with connection headers.
221
+
222
+ Returns:
223
+ Merged headers dictionary, or None if no headers are provided.
224
+ """
225
+ if isinstance (self ._connection_params , StdioConnectionParams ) or isinstance (
226
+ self ._connection_params , StdioServerParameters
227
+ ):
228
+ # Stdio connections don't support headers
229
+ return None
230
+
231
+ base_headers = {}
232
+ if (
233
+ hasattr (self ._connection_params , 'headers' )
234
+ and self ._connection_params .headers
235
+ ):
236
+ base_headers = self ._connection_params .headers .copy ()
237
+
238
+ if additional_headers :
239
+ base_headers .update (additional_headers )
240
+
241
+ return base_headers
242
+
243
+ def _is_session_disconnected (self , session : ClientSession ) -> bool :
244
+ """Checks if a session is disconnected or closed.
245
+
246
+ Args:
247
+ session: The ClientSession to check.
248
+
249
+ Returns:
250
+ True if the session is disconnected, False otherwise.
251
+ """
252
+ return session ._read_stream ._closed or session ._write_stream ._closed
253
+
254
+ async def create_session (
255
+ self , headers : Optional [Dict [str , str ]] = None
256
+ ) -> ClientSession :
219
257
"""Creates and initializes an MCP client session.
220
258
259
+ This method will check if an existing session for the given headers
260
+ is still connected. If it's disconnected, it will be cleaned up and
261
+ a new session will be created.
262
+
263
+ Args:
264
+ headers: Optional headers to include in the session. These will be
265
+ merged with any existing connection headers. Only applicable
266
+ for SSE and StreamableHTTP connections.
267
+
221
268
Returns:
222
269
ClientSession: The initialized MCP client session.
223
270
"""
224
- # Fast path: if session already exists, return it without acquiring lock
225
- if self ._session is not None :
226
- return self ._session
271
+ # Merge headers once at the beginning
272
+ merged_headers = self ._merge_headers (headers )
273
+
274
+ # Generate session key using merged headers
275
+ session_key = self ._generate_session_key (merged_headers )
227
276
228
277
# Use async lock to prevent race conditions
229
278
async with self ._session_lock :
230
- # Double-check: session might have been created while waiting for lock
231
- if self ._session is not None :
232
- return self ._session
233
-
234
- # Create a new exit stack for this session
235
- self ._exit_stack = AsyncExitStack ()
279
+ # Check if we have an existing session
280
+ if session_key in self ._sessions :
281
+ session , exit_stack = self ._sessions [session_key ]
282
+
283
+ # Check if the existing session is still connected
284
+ if not self ._is_session_disconnected (session ):
285
+ # Session is still good, return it
286
+ return session
287
+ else :
288
+ # Session is disconnected, clean it up
289
+ logger .info ('Cleaning up disconnected session: %s' , session_key )
290
+ try :
291
+ await exit_stack .aclose ()
292
+ except Exception as e :
293
+ logger .warning ('Error during disconnected session cleanup: %s' , e )
294
+ finally :
295
+ del self ._sessions [session_key ]
296
+
297
+ # Create a new session (either first time or replacing disconnected one)
298
+ exit_stack = AsyncExitStack ()
236
299
237
300
try :
238
301
if isinstance (self ._connection_params , StdioConnectionParams ):
@@ -243,7 +306,7 @@ async def create_session(self) -> ClientSession:
243
306
elif isinstance (self ._connection_params , SseConnectionParams ):
244
307
client = sse_client (
245
308
url = self ._connection_params .url ,
246
- headers = self . _connection_params . headers ,
309
+ headers = merged_headers ,
247
310
timeout = self ._connection_params .timeout ,
248
311
sse_read_timeout = self ._connection_params .sse_read_timeout ,
249
312
)
@@ -252,7 +315,7 @@ async def create_session(self) -> ClientSession:
252
315
):
253
316
client = streamablehttp_client (
254
317
url = self ._connection_params .url ,
255
- headers = self . _connection_params . headers ,
318
+ headers = merged_headers ,
256
319
timeout = timedelta (seconds = self ._connection_params .timeout ),
257
320
sse_read_timeout = timedelta (
258
321
seconds = self ._connection_params .sse_read_timeout
@@ -266,11 +329,11 @@ async def create_session(self) -> ClientSession:
266
329
f' { self ._connection_params } '
267
330
)
268
331
269
- transports = await self . _exit_stack .enter_async_context (client )
332
+ transports = await exit_stack .enter_async_context (client )
270
333
# The streamable http client returns a GetSessionCallback in addition to the read/write MemoryObjectStreams
271
334
# needed to build the ClientSession, we limit then to the two first values to be compatible with all clients.
272
335
if isinstance (self ._connection_params , StdioConnectionParams ):
273
- session = await self . _exit_stack .enter_async_context (
336
+ session = await exit_stack .enter_async_context (
274
337
ClientSession (
275
338
* transports [:2 ],
276
339
read_timeout_seconds = timedelta (
@@ -279,44 +342,38 @@ async def create_session(self) -> ClientSession:
279
342
)
280
343
)
281
344
else :
282
- session = await self . _exit_stack .enter_async_context (
345
+ session = await exit_stack .enter_async_context (
283
346
ClientSession (* transports [:2 ])
284
347
)
285
348
await session .initialize ()
286
349
287
- self ._session = session
350
+ # Store session and exit stack in the pool
351
+ self ._sessions [session_key ] = (session , exit_stack )
352
+ logger .debug ('Created new session: %s' , session_key )
288
353
return session
289
354
290
355
except Exception :
291
356
# If session creation fails, clean up the exit stack
292
- if self ._exit_stack :
293
- await self ._exit_stack .aclose ()
294
- self ._exit_stack = None
357
+ if exit_stack :
358
+ await exit_stack .aclose ()
295
359
raise
296
360
297
361
async def close (self ):
298
- """Closes the session and cleans up resources."""
299
- if not self ._exit_stack :
300
- return
362
+ """Closes all sessions and cleans up resources."""
301
363
async with self ._session_lock :
302
- if self ._exit_stack :
364
+ for session_key in list (self ._sessions .keys ()):
365
+ _ , exit_stack = self ._sessions [session_key ]
303
366
try :
304
- await self . _exit_stack .aclose ()
367
+ await exit_stack .aclose ()
305
368
except Exception as e :
306
369
# Log the error but don't re-raise to avoid blocking shutdown
307
370
print (
308
- f'Warning: Error during MCP session cleanup: { e } ' ,
371
+ 'Warning: Error during MCP session cleanup for'
372
+ f' { session_key } : { e } ' ,
309
373
file = self ._errlog ,
310
374
)
311
375
finally :
312
- self ._exit_stack = None
313
- self ._session = None
314
-
315
- async def reinitialize_session (self ):
316
- """Reinitializes the session when connection is lost."""
317
- # Close the old session and create a new one
318
- await self .close ()
319
- await self .create_session ()
376
+ del self ._sessions [session_key ]
320
377
321
378
322
379
SseServerParams = SseConnectionParams
0 commit comments