|
1 | 1 | import json
|
2 | 2 | import logging
|
3 | 3 | from contextlib import AsyncExitStack
|
| 4 | +from datetime import timedelta |
4 | 5 | from pathlib import Path
|
5 |
| -from typing import TYPE_CHECKING, AsyncIterable, Dict, List, Optional, Union |
| 6 | +from typing import TYPE_CHECKING, Any, AsyncIterable, Dict, List, Literal, Optional, Union, overload |
6 | 7 |
|
7 |
| -from typing_extensions import TypeAlias |
| 8 | +from typing_extensions import NotRequired, TypeAlias, TypedDict, Unpack |
8 | 9 |
|
9 | 10 | from ...utils._runtime import get_hf_hub_version
|
10 | 11 | from .._generated._async_client import AsyncInferenceClient
|
|
26 | 27 | # Type alias for tool names
|
27 | 28 | ToolName: TypeAlias = str
|
28 | 29 |
|
| 30 | +ServerType: TypeAlias = Literal["stdio", "sse", "http"] |
| 31 | + |
| 32 | + |
| 33 | +class StdioServerParameters_T(TypedDict): |
| 34 | + command: str |
| 35 | + args: NotRequired[List[str]] |
| 36 | + env: NotRequired[Dict[str, str]] |
| 37 | + cwd: NotRequired[Union[str, Path, None]] |
| 38 | + |
| 39 | + |
| 40 | +class SSEServerParameters_T(TypedDict): |
| 41 | + url: str |
| 42 | + headers: NotRequired[Dict[str, Any]] |
| 43 | + timeout: NotRequired[float] |
| 44 | + sse_read_timeout: NotRequired[float] |
| 45 | + |
| 46 | + |
| 47 | +class StreamableHTTPParameters_T(TypedDict): |
| 48 | + url: str |
| 49 | + headers: NotRequired[dict[str, Any]] |
| 50 | + timeout: NotRequired[timedelta] |
| 51 | + sse_read_timeout: NotRequired[timedelta] |
| 52 | + terminate_on_close: NotRequired[bool] |
| 53 | + |
29 | 54 |
|
30 | 55 | class MCPClient:
|
31 | 56 | """
|
@@ -64,39 +89,84 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
|
64 | 89 | await self.client.__aexit__(exc_type, exc_val, exc_tb)
|
65 | 90 | await self.cleanup()
|
66 | 91 |
|
67 |
| - async def add_mcp_server( |
68 |
| - self, |
69 |
| - *, |
70 |
| - command: str, |
71 |
| - args: Optional[List[str]] = None, |
72 |
| - env: Optional[Dict[str, str]] = None, |
73 |
| - cwd: Union[str, Path, None] = None, |
74 |
| - ): |
| 92 | + @overload |
| 93 | + async def add_mcp_server(self, type: Literal["stdio"], **params: Unpack[StdioServerParameters_T]): ... |
| 94 | + |
| 95 | + @overload |
| 96 | + async def add_mcp_server(self, type: Literal["sse"], **params: Unpack[SSEServerParameters_T]): ... |
| 97 | + |
| 98 | + @overload |
| 99 | + async def add_mcp_server(self, type: Literal["http"], **params: Unpack[StreamableHTTPParameters_T]): ... |
| 100 | + |
| 101 | + async def add_mcp_server(self, type: ServerType, **params: Any): |
75 | 102 | """Connect to an MCP server
|
76 | 103 |
|
77 | 104 | Args:
|
78 |
| - command (str): |
79 |
| - The command to run the MCP server. |
80 |
| - args (List[str], optional): |
81 |
| - Arguments for the command. |
82 |
| - env (Dict[str, str], optional): |
83 |
| - Environment variables for the command. Default is to inherit the parent environment. |
84 |
| - cwd (Union[str, Path, None], optional): |
85 |
| - Working directory for the command. Default to current directory. |
| 105 | + type (`str`): |
| 106 | + Type of the server to connect to. Can be one of: |
| 107 | + - "stdio": Standard input/output server (local) |
| 108 | + - "sse": Server-sent events (SSE) server |
| 109 | + - "http": StreamableHTTP server |
| 110 | + **params: Server parameters that can be either: |
| 111 | + - For stdio servers: |
| 112 | + - command (str): The command to run the MCP server |
| 113 | + - args (List[str], optional): Arguments for the command |
| 114 | + - env (Dict[str, str], optional): Environment variables for the command |
| 115 | + - cwd (Union[str, Path, None], optional): Working directory for the command |
| 116 | + - For SSE servers: |
| 117 | + - url (str): The URL of the SSE server |
| 118 | + - headers (Dict[str, Any], optional): Headers for the SSE connection |
| 119 | + - timeout (float, optional): Connection timeout |
| 120 | + - sse_read_timeout (float, optional): SSE read timeout |
| 121 | + - For StreamableHTTP servers: |
| 122 | + - url (str): The URL of the StreamableHTTP server |
| 123 | + - headers (Dict[str, Any], optional): Headers for the StreamableHTTP connection |
| 124 | + - timeout (timedelta, optional): Connection timeout |
| 125 | + - sse_read_timeout (timedelta, optional): SSE read timeout |
| 126 | + - terminate_on_close (bool, optional): Whether to terminate on close |
86 | 127 | """
|
87 | 128 | from mcp import ClientSession, StdioServerParameters
|
88 | 129 | from mcp import types as mcp_types
|
89 |
| - from mcp.client.stdio import stdio_client |
90 |
| - |
91 |
| - logger.info(f"Connecting to MCP server with command: {command} {args}") |
92 |
| - server_params = StdioServerParameters( |
93 |
| - command=command, |
94 |
| - args=args if args is not None else [], |
95 |
| - env=env, |
96 |
| - cwd=cwd, |
97 |
| - ) |
98 | 130 |
|
99 |
| - read, write = await self.exit_stack.enter_async_context(stdio_client(server_params)) |
| 131 | + # Determine server type and create appropriate parameters |
| 132 | + if type == "stdio": |
| 133 | + # Handle stdio server |
| 134 | + from mcp.client.stdio import stdio_client |
| 135 | + |
| 136 | + logger.info(f"Connecting to stdio MCP server with command: {params['command']} {params.get('args', [])}") |
| 137 | + |
| 138 | + client_kwargs = {"command": params["command"]} |
| 139 | + for key in ["args", "env", "cwd"]: |
| 140 | + if params.get(key) is not None: |
| 141 | + client_kwargs[key] = params[key] |
| 142 | + server_params = StdioServerParameters(**client_kwargs) |
| 143 | + read, write = await self.exit_stack.enter_async_context(stdio_client(server_params)) |
| 144 | + elif type == "sse": |
| 145 | + # Handle SSE server |
| 146 | + from mcp.client.sse import sse_client |
| 147 | + |
| 148 | + logger.info(f"Connecting to SSE MCP server at: {params['url']}") |
| 149 | + |
| 150 | + client_kwargs = {"url": params["url"]} |
| 151 | + for key in ["headers", "timeout", "sse_read_timeout"]: |
| 152 | + if params.get(key) is not None: |
| 153 | + client_kwargs[key] = params[key] |
| 154 | + read, write = await self.exit_stack.enter_async_context(sse_client(**client_kwargs)) |
| 155 | + elif type == "http": |
| 156 | + # Handle StreamableHTTP server |
| 157 | + from mcp.client.streamable_http import streamablehttp_client |
| 158 | + |
| 159 | + logger.info(f"Connecting to StreamableHTTP MCP server at: {params['url']}") |
| 160 | + |
| 161 | + client_kwargs = {"url": params["url"]} |
| 162 | + for key in ["headers", "timeout", "sse_read_timeout", "terminate_on_close"]: |
| 163 | + if params.get(key) is not None: |
| 164 | + client_kwargs[key] = params[key] |
| 165 | + read, write, _ = await self.exit_stack.enter_async_context(streamablehttp_client(**client_kwargs)) |
| 166 | + # ^ TODO: should be handle `get_session_id_callback`? (function to retrieve the current session ID) |
| 167 | + else: |
| 168 | + raise ValueError(f"Unsupported server type: {type}") |
| 169 | + |
100 | 170 | session = await self.exit_stack.enter_async_context(
|
101 | 171 | ClientSession(
|
102 | 172 | read_stream=read,
|
|
0 commit comments