Skip to content

Commit 9e0493c

Browse files
[tiny-agents] Configure inference API key from inputs + keep empty dicts in chat completion payload (#3226)
* allow passing api key for inference in tiny-agents * keep empty dicts in chat completion payload * fix test * re-add check * Apply suggestions from code review Co-authored-by: Lucain <lucain@huggingface.co> --------- Co-authored-by: Lucain <lucain@huggingface.co>
1 parent 37d3278 commit 9e0493c

File tree

5 files changed

+44
-36
lines changed

5 files changed

+44
-36
lines changed

src/huggingface_hub/inference/_mcp/cli.py

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ def _sigint_handler() -> None:
7171
signal.signal(signal.SIGINT, lambda *_: _sigint_handler())
7272

7373
# Handle inputs (i.e. env variables injection)
74+
resolved_inputs: dict[str, str] = {}
75+
7476
if len(inputs) > 0:
7577
print(
7678
"[bold blue]Some initial inputs are required by the agent. "
@@ -79,19 +81,26 @@ def _sigint_handler() -> None:
7981
for input_item in inputs:
8082
input_id = input_item["id"]
8183
description = input_item["description"]
82-
env_special_value = "${input:" + input_id + "}" # Special value to indicate env variable injection
84+
env_special_value = f"${{input:{input_id}}}"
8385

84-
# Check env variables that will use this input
85-
input_vars = set()
86+
# Check if the input is used by any server or as an apiKey
87+
input_usages = set()
8688
for server in servers:
8789
# Check stdio's "env" and http/sse's "headers" mappings
8890
env_or_headers = server.get("env", {}) if server["type"] == "stdio" else server.get("headers", {})
8991
for key, value in env_or_headers.items():
9092
if env_special_value in value:
91-
input_vars.add(key)
93+
input_usages.add(key)
94+
95+
raw_api_key = config.get("apiKey")
96+
if isinstance(raw_api_key, str) and env_special_value in raw_api_key:
97+
input_usages.add("apiKey")
9298

93-
if not input_vars:
94-
print(f"[yellow]Input {input_id} defined in config but not used by any server.[/yellow]")
99+
if not input_usages:
100+
print(
101+
f"[yellow]Input '{input_id}' defined in config but not used by any server or as an API key."
102+
" Skipping.[/yellow]"
103+
)
95104
continue
96105

97106
# Prompt user for input
@@ -104,30 +113,39 @@ def _sigint_handler() -> None:
104113
if exit_event.is_set():
105114
return
106115

107-
# Inject user input (or env variable) into stdio's env or http/sse's headers
116+
# Fallback to environment variable when user left blank
117+
final_value = user_input
118+
if not final_value:
119+
final_value = os.getenv(env_variable_key, "")
120+
if final_value:
121+
print(f"[green]Value successfully loaded from '{env_variable_key}'[/green]")
122+
else:
123+
print(
124+
f"[yellow]No value found for '{env_variable_key}' in environment variables. Continuing.[/yellow]"
125+
)
126+
resolved_inputs[input_id] = final_value
127+
128+
# Inject resolved value (can be empty) into stdio's env or http/sse's headers
108129
for server in servers:
109130
env_or_headers = server.get("env", {}) if server["type"] == "stdio" else server.get("headers", {})
110131
for key, value in env_or_headers.items():
111132
if env_special_value in value:
112-
if user_input:
113-
env_or_headers[key] = env_or_headers[key].replace(env_special_value, user_input)
114-
else:
115-
value_from_env = os.getenv(env_variable_key, "")
116-
env_or_headers[key] = env_or_headers[key].replace(env_special_value, value_from_env)
117-
if value_from_env:
118-
print(f"[green]Value successfully loaded from '{env_variable_key}'[/green]")
119-
else:
120-
print(
121-
f"[yellow]No value found for '{env_variable_key}' in environment variables. Continuing.[/yellow]"
122-
)
133+
env_or_headers[key] = env_or_headers[key].replace(env_special_value, final_value)
123134

124135
print()
125136

137+
raw_api_key = config.get("apiKey")
138+
if isinstance(raw_api_key, str):
139+
substituted_api_key = raw_api_key
140+
for input_id, val in resolved_inputs.items():
141+
substituted_api_key = substituted_api_key.replace(f"${{input:{input_id}}}", val)
142+
config["apiKey"] = substituted_api_key
126143
# Main agent loop
127144
async with Agent(
128145
provider=config.get("provider"), # type: ignore[arg-type]
129146
model=config.get("model"),
130147
base_url=config.get("endpointUrl"), # type: ignore[arg-type]
148+
api_key=config.get("apiKey"),
131149
servers=servers, # type: ignore[arg-type]
132150
prompt=prompt,
133151
) as agent:

src/huggingface_hub/inference/_mcp/constants.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,7 @@
5656
"description": "Call this tool when the task given by the user is complete",
5757
"parameters": {
5858
"type": "object",
59-
"properties": {
60-
"trigger": {
61-
"type": "boolean",
62-
"description": "Set to true to trigger this function",
63-
}
64-
},
59+
"properties": {},
6560
},
6661
},
6762
}
@@ -75,12 +70,7 @@
7570
"description": "Ask the user for more info required to solve or clarify their problem.",
7671
"parameters": {
7772
"type": "object",
78-
"properties": {
79-
"trigger": {
80-
"type": "boolean",
81-
"description": "Set to true to trigger this function",
82-
}
83-
},
73+
"properties": {},
8474
},
8575
},
8676
}

src/huggingface_hub/inference/_mcp/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import Dict, List, Literal, TypedDict, Union
22

3+
from typing_extensions import NotRequired
4+
35

46
class InputConfig(TypedDict, total=False):
57
id: str
@@ -35,5 +37,6 @@ class SSEServerConfig(TypedDict):
3537
class AgentConfig(TypedDict):
3638
model: str
3739
provider: str
40+
apiKey: NotRequired[str]
3841
inputs: List[InputConfig]
3942
servers: List[ServerConfig]

src/huggingface_hub/inference/_providers/_common.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,6 @@ def filter_none(obj: Union[Dict[str, Any], List[Any]]) -> Union[Dict[str, Any],
5151
continue
5252
if isinstance(v, (dict, list)):
5353
v = filter_none(v)
54-
# remove empty nested dicts
55-
if isinstance(v, dict) and not v:
56-
continue
5754
cleaned[k] = v
5855
return cleaned
5956

tests/test_inference_providers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -645,7 +645,7 @@ def test_prepare_request(self, mocker):
645645
assert request.task == "text-classification"
646646
assert request.model == "username/repo_name"
647647
assert request.headers["authorization"] == "Bearer hf_test_token"
648-
assert request.json == {"inputs": "this is a dummy input"}
648+
assert request.json == {"inputs": "this is a dummy input", "parameters": {}}
649649

650650
def test_prepare_request_conversational(self, mocker):
651651
mocker.patch(
@@ -1445,10 +1445,10 @@ def test_recursive_merge(dict1: Dict, dict2: Dict, expected: Dict):
14451445
({}, {}), # empty dictionary remains empty
14461446
({"a": 1, "b": None, "c": 3}, {"a": 1, "c": 3}), # remove None at root level
14471447
({"a": None, "b": {"x": None, "y": 2}}, {"b": {"y": 2}}), # remove nested None
1448-
({"a": {"b": {"c": None}}}, {}), # remove empty nested dict
1448+
({"a": {"b": {"c": None}}}, {"a": {"b": {}}}), # keep empty nested dict
14491449
(
14501450
{"a": "", "b": {"x": {"y": None}, "z": 0}, "c": []}, # do not remove 0, [] and "" values
1451-
{"a": "", "b": {"z": 0}, "c": []},
1451+
{"a": "", "b": {"x": {}, "z": 0}, "c": []},
14521452
),
14531453
(
14541454
{"a": [0, 1, None]}, # do not remove None in lists

0 commit comments

Comments
 (0)