Skip to content

Commit 10d4494

Browse files
authored
[MCP] add support for SSE + HTTP (#3099)
* [MCP] add support for SSE + HTTP * http instead of streamablehttp
1 parent 5c3efe3 commit 10d4494

File tree

1 file changed

+98
-28
lines changed

1 file changed

+98
-28
lines changed

src/huggingface_hub/inference/_mcp/mcp_client.py

Lines changed: 98 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import json
22
import logging
33
from contextlib import AsyncExitStack
4+
from datetime import timedelta
45
from pathlib import Path
5-
from typing import TYPE_CHECKING, AsyncIterable, Dict, List, Optional, Union
6+
from typing import TYPE_CHECKING, Any, AsyncIterable, Dict, List, Literal, Optional, Union, overload
67

7-
from typing_extensions import TypeAlias
8+
from typing_extensions import NotRequired, TypeAlias, TypedDict, Unpack
89

910
from ...utils._runtime import get_hf_hub_version
1011
from .._generated._async_client import AsyncInferenceClient
@@ -26,6 +27,30 @@
2627
# Type alias for tool names
2728
ToolName: TypeAlias = str
2829

30+
ServerType: TypeAlias = Literal["stdio", "sse", "http"]
31+
32+
33+
class StdioServerParameters_T(TypedDict):
34+
command: str
35+
args: NotRequired[List[str]]
36+
env: NotRequired[Dict[str, str]]
37+
cwd: NotRequired[Union[str, Path, None]]
38+
39+
40+
class SSEServerParameters_T(TypedDict):
41+
url: str
42+
headers: NotRequired[Dict[str, Any]]
43+
timeout: NotRequired[float]
44+
sse_read_timeout: NotRequired[float]
45+
46+
47+
class StreamableHTTPParameters_T(TypedDict):
48+
url: str
49+
headers: NotRequired[dict[str, Any]]
50+
timeout: NotRequired[timedelta]
51+
sse_read_timeout: NotRequired[timedelta]
52+
terminate_on_close: NotRequired[bool]
53+
2954

3055
class MCPClient:
3156
"""
@@ -64,39 +89,84 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
6489
await self.client.__aexit__(exc_type, exc_val, exc_tb)
6590
await self.cleanup()
6691

67-
async def add_mcp_server(
68-
self,
69-
*,
70-
command: str,
71-
args: Optional[List[str]] = None,
72-
env: Optional[Dict[str, str]] = None,
73-
cwd: Union[str, Path, None] = None,
74-
):
92+
@overload
93+
async def add_mcp_server(self, type: Literal["stdio"], **params: Unpack[StdioServerParameters_T]): ...
94+
95+
@overload
96+
async def add_mcp_server(self, type: Literal["sse"], **params: Unpack[SSEServerParameters_T]): ...
97+
98+
@overload
99+
async def add_mcp_server(self, type: Literal["http"], **params: Unpack[StreamableHTTPParameters_T]): ...
100+
101+
async def add_mcp_server(self, type: ServerType, **params: Any):
75102
"""Connect to an MCP server
76103
77104
Args:
78-
command (str):
79-
The command to run the MCP server.
80-
args (List[str], optional):
81-
Arguments for the command.
82-
env (Dict[str, str], optional):
83-
Environment variables for the command. Default is to inherit the parent environment.
84-
cwd (Union[str, Path, None], optional):
85-
Working directory for the command. Default to current directory.
105+
type (`str`):
106+
Type of the server to connect to. Can be one of:
107+
- "stdio": Standard input/output server (local)
108+
- "sse": Server-sent events (SSE) server
109+
- "http": StreamableHTTP server
110+
**params: Server parameters that can be either:
111+
- For stdio servers:
112+
- command (str): The command to run the MCP server
113+
- args (List[str], optional): Arguments for the command
114+
- env (Dict[str, str], optional): Environment variables for the command
115+
- cwd (Union[str, Path, None], optional): Working directory for the command
116+
- For SSE servers:
117+
- url (str): The URL of the SSE server
118+
- headers (Dict[str, Any], optional): Headers for the SSE connection
119+
- timeout (float, optional): Connection timeout
120+
- sse_read_timeout (float, optional): SSE read timeout
121+
- For StreamableHTTP servers:
122+
- url (str): The URL of the StreamableHTTP server
123+
- headers (Dict[str, Any], optional): Headers for the StreamableHTTP connection
124+
- timeout (timedelta, optional): Connection timeout
125+
- sse_read_timeout (timedelta, optional): SSE read timeout
126+
- terminate_on_close (bool, optional): Whether to terminate on close
86127
"""
87128
from mcp import ClientSession, StdioServerParameters
88129
from mcp import types as mcp_types
89-
from mcp.client.stdio import stdio_client
90-
91-
logger.info(f"Connecting to MCP server with command: {command} {args}")
92-
server_params = StdioServerParameters(
93-
command=command,
94-
args=args if args is not None else [],
95-
env=env,
96-
cwd=cwd,
97-
)
98130

99-
read, write = await self.exit_stack.enter_async_context(stdio_client(server_params))
131+
# Determine server type and create appropriate parameters
132+
if type == "stdio":
133+
# Handle stdio server
134+
from mcp.client.stdio import stdio_client
135+
136+
logger.info(f"Connecting to stdio MCP server with command: {params['command']} {params.get('args', [])}")
137+
138+
client_kwargs = {"command": params["command"]}
139+
for key in ["args", "env", "cwd"]:
140+
if params.get(key) is not None:
141+
client_kwargs[key] = params[key]
142+
server_params = StdioServerParameters(**client_kwargs)
143+
read, write = await self.exit_stack.enter_async_context(stdio_client(server_params))
144+
elif type == "sse":
145+
# Handle SSE server
146+
from mcp.client.sse import sse_client
147+
148+
logger.info(f"Connecting to SSE MCP server at: {params['url']}")
149+
150+
client_kwargs = {"url": params["url"]}
151+
for key in ["headers", "timeout", "sse_read_timeout"]:
152+
if params.get(key) is not None:
153+
client_kwargs[key] = params[key]
154+
read, write = await self.exit_stack.enter_async_context(sse_client(**client_kwargs))
155+
elif type == "http":
156+
# Handle StreamableHTTP server
157+
from mcp.client.streamable_http import streamablehttp_client
158+
159+
logger.info(f"Connecting to StreamableHTTP MCP server at: {params['url']}")
160+
161+
client_kwargs = {"url": params["url"]}
162+
for key in ["headers", "timeout", "sse_read_timeout", "terminate_on_close"]:
163+
if params.get(key) is not None:
164+
client_kwargs[key] = params[key]
165+
read, write, _ = await self.exit_stack.enter_async_context(streamablehttp_client(**client_kwargs))
166+
# ^ TODO: should be handle `get_session_id_callback`? (function to retrieve the current session ID)
167+
else:
168+
raise ValueError(f"Unsupported server type: {type}")
169+
100170
session = await self.exit_stack.enter_async_context(
101171
ClientSession(
102172
read_stream=read,

0 commit comments

Comments
 (0)