Skip to content

Commit 85853e3

Browse files
authored
Inject env var in headers + better type annotations (#3142)
* Inject env var in headers + better type annotations * fix lint * fix thanks to copilot
1 parent 8809f44 commit 85853e3

File tree

3 files changed

+99
-26
lines changed

3 files changed

+99
-26
lines changed

src/huggingface_hub/inference/_mcp/cli.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import signal
44
import traceback
5-
from typing import Any, Dict, List, Optional
5+
from typing import Optional
66

77
import typer
88
from rich import print
@@ -40,8 +40,8 @@ async def run_agent(
4040

4141
config, prompt = _load_agent_config(agent_path)
4242

43-
inputs: List[Dict[str, Any]] = config.get("inputs", [])
44-
servers: List[Dict[str, Any]] = config.get("servers", [])
43+
inputs = config.get("inputs", [])
44+
servers = config.get("servers", [])
4545

4646
abort_event = asyncio.Event()
4747
exit_event = asyncio.Event()
@@ -82,38 +82,45 @@ def _sigint_handler() -> None:
8282
env_special_value = "${input:" + input_id + "}" # Special value to indicate env variable injection
8383

8484
# Check env variables that will use this input
85-
input_vars = list(
86-
{
87-
key
88-
for server in servers
89-
for key, value in server.get("config", {}).get("env", {}).items()
90-
if value == env_special_value
91-
}
92-
)
85+
input_vars = set()
86+
for server in servers:
87+
# Check stdio's "env" and http/sse's "headers" mappings
88+
env_or_headers = (
89+
server["config"].get("env", {})
90+
if server["type"] == "stdio"
91+
else server["config"].get("options", {}).get("requestInit", {}).get("headers", {})
92+
)
93+
for key, value in env_or_headers.items():
94+
if env_special_value in value:
95+
input_vars.add(key)
9396

9497
if not input_vars:
9598
print(f"[yellow]Input {input_id} defined in config but not used by any server.[/yellow]")
9699
continue
97100

98101
# Prompt user for input
99102
print(
100-
f"[blue] • {input_id}[/blue]: {description}. (default: load from {', '.join(input_vars)}).",
103+
f"[blue] • {input_id}[/blue]: {description}. (default: load from {', '.join(sorted(input_vars))}).",
101104
end=" ",
102105
)
103106
user_input = (await _async_prompt(exit_event=exit_event)).strip()
104107
if exit_event.is_set():
105108
return
106109

107-
# Inject user input (or env variable) into servers' env
110+
# Inject user input (or env variable) into stdio's env or http/sse's headers
108111
for server in servers:
109-
env = server.get("config", {}).get("env", {})
110-
for key, value in env.items():
111-
if value == env_special_value:
112+
env_or_headers = (
113+
server["config"].get("env", {})
114+
if server["type"] == "stdio"
115+
else server["config"].get("options", {}).get("requestInit", {}).get("headers", {})
116+
)
117+
for key, value in env_or_headers.items():
118+
if env_special_value in value:
112119
if user_input:
113-
env[key] = user_input
120+
env_or_headers[key] = env_or_headers[key].replace(env_special_value, user_input)
114121
else:
115122
value_from_env = os.getenv(key, "")
116-
env[key] = value_from_env
123+
env_or_headers[key] = env_or_headers[key].replace(env_special_value, value_from_env)
117124
if value_from_env:
118125
print(f"[green]Value successfully loaded from '{key}'[/green]")
119126
else:
@@ -125,10 +132,10 @@ def _sigint_handler() -> None:
125132

126133
# Main agent loop
127134
async with Agent(
128-
provider=config.get("provider"),
135+
provider=config.get("provider"), # type: ignore[arg-type]
129136
model=config.get("model"),
130-
base_url=config.get("endpointUrl"),
131-
servers=servers,
137+
base_url=config.get("endpointUrl"), # type: ignore[arg-type]
138+
servers=servers, # type: ignore[arg-type]
132139
prompt=prompt,
133140
) as agent:
134141
await agent.load_tools()
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from typing import Dict, List, Literal, TypedDict, Union
2+
3+
4+
# Input config
5+
class InputConfig(TypedDict, total=False):
6+
id: str
7+
description: str
8+
type: str
9+
password: bool
10+
11+
12+
# stdio server config
13+
class StdioServerConfig(TypedDict, total=False):
14+
command: str
15+
args: List[str]
16+
env: Dict[str, str]
17+
cwd: str
18+
19+
20+
class StdioServer(TypedDict):
21+
type: Literal["stdio"]
22+
config: StdioServerConfig
23+
24+
25+
# http server config
26+
class HTTPRequestInit(TypedDict, total=False):
27+
headers: Dict[str, str]
28+
29+
30+
class HTTPServerOptions(TypedDict, total=False):
31+
requestInit: HTTPRequestInit
32+
sessionId: str
33+
34+
35+
class HTTPServerConfig(TypedDict, total=False):
36+
url: str
37+
options: HTTPServerOptions
38+
39+
40+
class HTTPServer(TypedDict):
41+
type: Literal["http"]
42+
config: HTTPServerConfig
43+
44+
45+
# sse server config
46+
class SSEServerOptions(TypedDict, total=False):
47+
requestInit: HTTPRequestInit
48+
49+
50+
class SSEServerConfig(TypedDict):
51+
url: str
52+
options: SSEServerOptions
53+
54+
55+
class SSEServer(TypedDict):
56+
type: Literal["sse"]
57+
config: SSEServerConfig
58+
59+
60+
# AgentConfig root object
61+
class AgentConfig(TypedDict):
62+
model: str
63+
provider: str
64+
inputs: List[InputConfig]
65+
servers: List[Union[StdioServer, HTTPServer, SSEServer]]

src/huggingface_hub/inference/_mcp/utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66

77
import json
88
from pathlib import Path
9-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
9+
from typing import TYPE_CHECKING, List, Optional, Tuple
1010

1111
from huggingface_hub import snapshot_download
1212
from huggingface_hub.errors import EntryNotFoundError
1313

1414
from .constants import DEFAULT_AGENT, DEFAULT_REPO_ID, FILENAME_CONFIG, FILENAME_PROMPT
15+
from .types import AgentConfig
1516

1617

1718
if TYPE_CHECKING:
@@ -83,21 +84,21 @@ def _get_base64_size(base64_str: str) -> int:
8384
return (len(base64_str) * 3) // 4 - padding
8485

8586

86-
def _load_agent_config(agent_path: Optional[str]) -> Tuple[Dict[str, Any], Optional[str]]:
87+
def _load_agent_config(agent_path: Optional[str]) -> Tuple[AgentConfig, Optional[str]]:
8788
"""Load server config and prompt."""
8889

89-
def _read_dir(directory: Path) -> Tuple[Dict[str, Any], Optional[str]]:
90+
def _read_dir(directory: Path) -> Tuple[AgentConfig, Optional[str]]:
9091
cfg_file = directory / FILENAME_CONFIG
9192
if not cfg_file.exists():
9293
raise FileNotFoundError(f" Config file not found in {directory}! Please make sure it exists locally")
9394

94-
config: Dict[str, Any] = json.loads(cfg_file.read_text(encoding="utf-8"))
95+
config: AgentConfig = json.loads(cfg_file.read_text(encoding="utf-8"))
9596
prompt_file = directory / FILENAME_PROMPT
9697
prompt: Optional[str] = prompt_file.read_text(encoding="utf-8") if prompt_file.exists() else None
9798
return config, prompt
9899

99100
if agent_path is None:
100-
return DEFAULT_AGENT, None
101+
return DEFAULT_AGENT, None # type: ignore[return-value]
101102

102103
path = Path(agent_path).expanduser()
103104

0 commit comments

Comments
 (0)