Skip to content

Commit 45d60a1

Browse files
seanzhougooglecopybara-github
authored andcommitted
refactor: extract mcp client creation logic to a separate method
PiperOrigin-RevId: 780609120
1 parent bf39c00 commit 45d60a1

File tree

1 file changed

+44
-30
lines changed

1 file changed

+44
-30
lines changed

src/google/adk/tools/mcp_tool/mcp_session_manager.py

Lines changed: 44 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,49 @@ def _is_session_disconnected(self, session: ClientSession) -> bool:
251251
"""
252252
return session._read_stream._closed or session._write_stream._closed
253253

254+
def _create_client(self, merged_headers: Optional[Dict[str, str]] = None):
255+
"""Creates an MCP client based on the connection parameters.
256+
257+
Args:
258+
merged_headers: Optional headers to include in the connection.
259+
Only applicable for SSE and StreamableHTTP connections.
260+
261+
Returns:
262+
The appropriate MCP client instance.
263+
264+
Raises:
265+
ValueError: If the connection parameters are not supported.
266+
"""
267+
if isinstance(self._connection_params, StdioConnectionParams):
268+
client = stdio_client(
269+
server=self._connection_params.server_params,
270+
errlog=self._errlog,
271+
)
272+
elif isinstance(self._connection_params, SseConnectionParams):
273+
client = sse_client(
274+
url=self._connection_params.url,
275+
headers=merged_headers,
276+
timeout=self._connection_params.timeout,
277+
sse_read_timeout=self._connection_params.sse_read_timeout,
278+
)
279+
elif isinstance(self._connection_params, StreamableHTTPConnectionParams):
280+
client = streamablehttp_client(
281+
url=self._connection_params.url,
282+
headers=merged_headers,
283+
timeout=timedelta(seconds=self._connection_params.timeout),
284+
sse_read_timeout=timedelta(
285+
seconds=self._connection_params.sse_read_timeout
286+
),
287+
terminate_on_close=self._connection_params.terminate_on_close,
288+
)
289+
else:
290+
raise ValueError(
291+
'Unable to initialize connection. Connection should be'
292+
' StdioServerParameters or SseServerParams, but got'
293+
f' {self._connection_params}'
294+
)
295+
return client
296+
254297
async def create_session(
255298
self, headers: Optional[Dict[str, str]] = None
256299
) -> ClientSession:
@@ -298,36 +341,7 @@ async def create_session(
298341
exit_stack = AsyncExitStack()
299342

300343
try:
301-
if isinstance(self._connection_params, StdioConnectionParams):
302-
client = stdio_client(
303-
server=self._connection_params.server_params,
304-
errlog=self._errlog,
305-
)
306-
elif isinstance(self._connection_params, SseConnectionParams):
307-
client = sse_client(
308-
url=self._connection_params.url,
309-
headers=merged_headers,
310-
timeout=self._connection_params.timeout,
311-
sse_read_timeout=self._connection_params.sse_read_timeout,
312-
)
313-
elif isinstance(
314-
self._connection_params, StreamableHTTPConnectionParams
315-
):
316-
client = streamablehttp_client(
317-
url=self._connection_params.url,
318-
headers=merged_headers,
319-
timeout=timedelta(seconds=self._connection_params.timeout),
320-
sse_read_timeout=timedelta(
321-
seconds=self._connection_params.sse_read_timeout
322-
),
323-
terminate_on_close=self._connection_params.terminate_on_close,
324-
)
325-
else:
326-
raise ValueError(
327-
'Unable to initialize connection. Connection should be'
328-
' StdioServerParameters or SseServerParams, but got'
329-
f' {self._connection_params}'
330-
)
344+
client = self._create_client(merged_headers)
331345

332346
transports = await exit_stack.enter_async_context(client)
333347
# The streamable http client returns a GetSessionCallback in addition to the read/write MemoryObjectStreams

0 commit comments

Comments
 (0)