@@ -251,6 +251,49 @@ def _is_session_disconnected(self, session: ClientSession) -> bool:
251
251
"""
252
252
return session ._read_stream ._closed or session ._write_stream ._closed
253
253
254
+ def _create_client (self , merged_headers : Optional [Dict [str , str ]] = None ):
255
+ """Creates an MCP client based on the connection parameters.
256
+
257
+ Args:
258
+ merged_headers: Optional headers to include in the connection.
259
+ Only applicable for SSE and StreamableHTTP connections.
260
+
261
+ Returns:
262
+ The appropriate MCP client instance.
263
+
264
+ Raises:
265
+ ValueError: If the connection parameters are not supported.
266
+ """
267
+ if isinstance (self ._connection_params , StdioConnectionParams ):
268
+ client = stdio_client (
269
+ server = self ._connection_params .server_params ,
270
+ errlog = self ._errlog ,
271
+ )
272
+ elif isinstance (self ._connection_params , SseConnectionParams ):
273
+ client = sse_client (
274
+ url = self ._connection_params .url ,
275
+ headers = merged_headers ,
276
+ timeout = self ._connection_params .timeout ,
277
+ sse_read_timeout = self ._connection_params .sse_read_timeout ,
278
+ )
279
+ elif isinstance (self ._connection_params , StreamableHTTPConnectionParams ):
280
+ client = streamablehttp_client (
281
+ url = self ._connection_params .url ,
282
+ headers = merged_headers ,
283
+ timeout = timedelta (seconds = self ._connection_params .timeout ),
284
+ sse_read_timeout = timedelta (
285
+ seconds = self ._connection_params .sse_read_timeout
286
+ ),
287
+ terminate_on_close = self ._connection_params .terminate_on_close ,
288
+ )
289
+ else :
290
+ raise ValueError (
291
+ 'Unable to initialize connection. Connection should be'
292
+ ' StdioServerParameters or SseServerParams, but got'
293
+ f' { self ._connection_params } '
294
+ )
295
+ return client
296
+
254
297
async def create_session (
255
298
self , headers : Optional [Dict [str , str ]] = None
256
299
) -> ClientSession :
@@ -298,36 +341,7 @@ async def create_session(
298
341
exit_stack = AsyncExitStack ()
299
342
300
343
try :
301
- if isinstance (self ._connection_params , StdioConnectionParams ):
302
- client = stdio_client (
303
- server = self ._connection_params .server_params ,
304
- errlog = self ._errlog ,
305
- )
306
- elif isinstance (self ._connection_params , SseConnectionParams ):
307
- client = sse_client (
308
- url = self ._connection_params .url ,
309
- headers = merged_headers ,
310
- timeout = self ._connection_params .timeout ,
311
- sse_read_timeout = self ._connection_params .sse_read_timeout ,
312
- )
313
- elif isinstance (
314
- self ._connection_params , StreamableHTTPConnectionParams
315
- ):
316
- client = streamablehttp_client (
317
- url = self ._connection_params .url ,
318
- headers = merged_headers ,
319
- timeout = timedelta (seconds = self ._connection_params .timeout ),
320
- sse_read_timeout = timedelta (
321
- seconds = self ._connection_params .sse_read_timeout
322
- ),
323
- terminate_on_close = self ._connection_params .terminate_on_close ,
324
- )
325
- else :
326
- raise ValueError (
327
- 'Unable to initialize connection. Connection should be'
328
- ' StdioServerParameters or SseServerParams, but got'
329
- f' { self ._connection_params } '
330
- )
344
+ client = self ._create_client (merged_headers )
331
345
332
346
transports = await exit_stack .enter_async_context (client )
333
347
# The streamable http client returns a GetSessionCallback in addition to the read/write MemoryObjectStreams
0 commit comments