From 611f468e59c54138ea158c5883a097e6928bd364 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Mon, 7 Jul 2025 13:52:29 -0700 Subject: [PATCH] refactor: extract mcp client creation logic to a separate method PiperOrigin-RevId: 780246670 --- .../adk/tools/mcp_tool/mcp_session_manager.py | 74 +++++++++++-------- 1 file changed, 44 insertions(+), 30 deletions(-) diff --git a/src/google/adk/tools/mcp_tool/mcp_session_manager.py b/src/google/adk/tools/mcp_tool/mcp_session_manager.py index 90b39e6cb..1853fb1a7 100644 --- a/src/google/adk/tools/mcp_tool/mcp_session_manager.py +++ b/src/google/adk/tools/mcp_tool/mcp_session_manager.py @@ -251,6 +251,49 @@ def _is_session_disconnected(self, session: ClientSession) -> bool: """ return session._read_stream._closed or session._write_stream._closed + def _create_client(self, merged_headers: Optional[Dict[str, str]] = None): + """Creates an MCP client based on the connection parameters. + + Args: + merged_headers: Optional headers to include in the connection. + Only applicable for SSE and StreamableHTTP connections. + + Returns: + The appropriate MCP client instance. + + Raises: + ValueError: If the connection parameters are not supported. + """ + if isinstance(self._connection_params, StdioConnectionParams): + client = stdio_client( + server=self._connection_params.server_params, + errlog=self._errlog, + ) + elif isinstance(self._connection_params, SseConnectionParams): + client = sse_client( + url=self._connection_params.url, + headers=merged_headers, + timeout=self._connection_params.timeout, + sse_read_timeout=self._connection_params.sse_read_timeout, + ) + elif isinstance(self._connection_params, StreamableHTTPConnectionParams): + client = streamablehttp_client( + url=self._connection_params.url, + headers=merged_headers, + timeout=timedelta(seconds=self._connection_params.timeout), + sse_read_timeout=timedelta( + seconds=self._connection_params.sse_read_timeout + ), + terminate_on_close=self._connection_params.terminate_on_close, + ) + else: + raise ValueError( + 'Unable to initialize connection. Connection should be' + ' StdioServerParameters or SseServerParams, but got' + f' {self._connection_params}' + ) + return client + async def create_session( self, headers: Optional[Dict[str, str]] = None ) -> ClientSession: @@ -298,36 +341,7 @@ async def create_session( exit_stack = AsyncExitStack() try: - if isinstance(self._connection_params, StdioConnectionParams): - client = stdio_client( - server=self._connection_params.server_params, - errlog=self._errlog, - ) - elif isinstance(self._connection_params, SseConnectionParams): - client = sse_client( - url=self._connection_params.url, - headers=merged_headers, - timeout=self._connection_params.timeout, - sse_read_timeout=self._connection_params.sse_read_timeout, - ) - elif isinstance( - self._connection_params, StreamableHTTPConnectionParams - ): - client = streamablehttp_client( - url=self._connection_params.url, - headers=merged_headers, - timeout=timedelta(seconds=self._connection_params.timeout), - sse_read_timeout=timedelta( - seconds=self._connection_params.sse_read_timeout - ), - terminate_on_close=self._connection_params.terminate_on_close, - ) - else: - raise ValueError( - 'Unable to initialize connection. Connection should be' - ' StdioServerParameters or SseServerParams, but got' - f' {self._connection_params}' - ) + client = self._create_client(merged_headers) transports = await exit_stack.enter_async_context(client) # The streamable http client returns a GetSessionCallback in addition to the read/write MemoryObjectStreams