Skip to content

Commit 24a0c75

Browse files
Recursive filter_none in Inference Providers (#3178)
* Recursive filter_none in Inference Providers * filter none values from messages as well * update filter_none and add test cases * don't drop empty dicts in list * better typing and refactor logic into a list of comrpehension --------- Co-authored-by: Celina Hanouti <hanouticelina@gmail.com>
1 parent 4ada292 commit 24a0c75

File tree

3 files changed

+155
-6
lines changed

3 files changed

+155
-6
lines changed

src/huggingface_hub/inference/_providers/_common.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from functools import lru_cache
2-
from typing import Any, Dict, List, Optional, Union
2+
from typing import Any, Dict, List, Optional, Union, overload
33

44
from huggingface_hub import constants
55
from huggingface_hub.hf_api import InferenceProviderMapping
66
from huggingface_hub.inference._common import RequestParameters
7+
from huggingface_hub.inference._generated.types.chat_completion import ChatCompletionInputMessage
78
from huggingface_hub.utils import build_hf_headers, get_token, logging
89

910

@@ -36,8 +37,30 @@
3637
}
3738

3839

39-
def filter_none(d: Dict[str, Any]) -> Dict[str, Any]:
40-
return {k: v for k, v in d.items() if v is not None}
40+
@overload
41+
def filter_none(obj: Dict[str, Any]) -> Dict[str, Any]: ...
42+
@overload
43+
def filter_none(obj: List[Any]) -> List[Any]: ...
44+
45+
46+
def filter_none(obj: Union[Dict[str, Any], List[Any]]) -> Union[Dict[str, Any], List[Any]]:
47+
if isinstance(obj, dict):
48+
cleaned: Dict[str, Any] = {}
49+
for k, v in obj.items():
50+
if v is None:
51+
continue
52+
if isinstance(v, (dict, list)):
53+
v = filter_none(v)
54+
# remove empty nested dicts
55+
if isinstance(v, dict) and not v:
56+
continue
57+
cleaned[k] = v
58+
return cleaned
59+
60+
if isinstance(obj, list):
61+
return [filter_none(v) if isinstance(v, (dict, list)) else v for v in obj]
62+
63+
raise ValueError(f"Expected dict or list, got {type(obj)}")
4164

4265

4366
class TaskProviderHelper:
@@ -224,9 +247,12 @@ def _prepare_route(self, mapped_model: str, api_key: str) -> str:
224247
return "/v1/chat/completions"
225248

226249
def _prepare_payload_as_dict(
227-
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
250+
self,
251+
inputs: List[Union[Dict, ChatCompletionInputMessage]],
252+
parameters: Dict,
253+
provider_mapping_info: InferenceProviderMapping,
228254
) -> Optional[Dict]:
229-
return {"messages": inputs, **filter_none(parameters), "model": provider_mapping_info.provider_id}
255+
return filter_none({"messages": inputs, **parameters, "model": provider_mapping_info.provider_id})
230256

231257

232258
class BaseTextGenerationTask(TaskProviderHelper):

src/huggingface_hub/inference/_providers/hf_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def _prepare_payload_as_bytes(
7575
provider_mapping_info: InferenceProviderMapping,
7676
extra_payload: Optional[Dict],
7777
) -> Optional[bytes]:
78-
parameters = filter_none({k: v for k, v in parameters.items() if v is not None})
78+
parameters = filter_none(parameters)
7979
extra_payload = extra_payload or {}
8080
has_parameters = len(parameters) > 0 or len(extra_payload) > 0
8181

tests/test_inference_providers.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
BaseConversationalTask,
1414
BaseTextGenerationTask,
1515
TaskProviderHelper,
16+
filter_none,
1617
recursive_merge,
1718
)
1819
from huggingface_hub.inference._providers.black_forest_labs import BlackForestLabsTextToImageTask
@@ -1152,6 +1153,98 @@ def test_prepare_payload(self):
11521153
"model": "test-provider-id",
11531154
}
11541155

1156+
@pytest.mark.parametrize(
1157+
"raw_messages, expected_messages",
1158+
[
1159+
(
1160+
[
1161+
{
1162+
"role": "assistant",
1163+
"content": "",
1164+
"tool_calls": None,
1165+
}
1166+
],
1167+
[
1168+
{
1169+
"role": "assistant",
1170+
"content": "",
1171+
}
1172+
],
1173+
),
1174+
(
1175+
[
1176+
{
1177+
"role": "assistant",
1178+
"content": None,
1179+
"tool_calls": [
1180+
{
1181+
"id": "call_1",
1182+
"type": "function",
1183+
"function": {
1184+
"name": "get_current_weather",
1185+
"arguments": '{"location": "San Francisco, CA", "unit": "celsius"}',
1186+
},
1187+
},
1188+
],
1189+
},
1190+
{
1191+
"role": "tool",
1192+
"content": "pong",
1193+
"tool_call_id": "abc123",
1194+
"name": "dummy_tool",
1195+
"tool_calls": None,
1196+
},
1197+
],
1198+
[
1199+
{
1200+
"role": "assistant",
1201+
"tool_calls": [
1202+
{
1203+
"id": "call_1",
1204+
"type": "function",
1205+
"function": {
1206+
"name": "get_current_weather",
1207+
"arguments": '{"location": "San Francisco, CA", "unit": "celsius"}',
1208+
},
1209+
}
1210+
],
1211+
},
1212+
{
1213+
"role": "tool",
1214+
"content": "pong",
1215+
"tool_call_id": "abc123",
1216+
"name": "dummy_tool",
1217+
},
1218+
],
1219+
),
1220+
],
1221+
)
1222+
def test_prepare_payload_filters_messages(self, raw_messages, expected_messages):
1223+
helper = BaseConversationalTask(provider="test-provider", base_url="https://api.test.com")
1224+
1225+
parameters = {
1226+
"temperature": 0.2,
1227+
"max_tokens": None,
1228+
"top_p": None,
1229+
}
1230+
1231+
payload = helper._prepare_payload_as_dict(
1232+
inputs=raw_messages,
1233+
parameters=parameters,
1234+
provider_mapping_info=InferenceProviderMapping(
1235+
provider="test-provider",
1236+
hf_model_id="test-model",
1237+
providerId="test-provider-id",
1238+
task="conversational",
1239+
status="live",
1240+
),
1241+
)
1242+
1243+
assert payload["messages"] == expected_messages
1244+
assert payload["temperature"] == 0.2
1245+
assert "max_tokens" not in payload
1246+
assert "top_p" not in payload
1247+
11551248

11561249
class TestBaseTextGenerationTask:
11571250
def test_prepare_route(self):
@@ -1236,6 +1329,36 @@ def test_recursive_merge(dict1: Dict, dict2: Dict, expected: Dict):
12361329
assert dict2 == initial_dict2
12371330

12381331

1332+
@pytest.mark.parametrize(
1333+
"data, expected",
1334+
[
1335+
({}, {}), # empty dictionary remains empty
1336+
({"a": 1, "b": None, "c": 3}, {"a": 1, "c": 3}), # remove None at root level
1337+
({"a": None, "b": {"x": None, "y": 2}}, {"b": {"y": 2}}), # remove nested None
1338+
({"a": {"b": {"c": None}}}, {}), # remove empty nested dict
1339+
(
1340+
{"a": "", "b": {"x": {"y": None}, "z": 0}, "c": []}, # do not remove 0, [] and "" values
1341+
{"a": "", "b": {"z": 0}, "c": []},
1342+
),
1343+
(
1344+
{"a": [0, 1, None]}, # do not remove None in lists
1345+
{"a": [0, 1, None]},
1346+
),
1347+
# dicts inside list are cleaned, list level None kept
1348+
({"a": [{"x": None, "y": 1}, None]}, {"a": [{"y": 1}, None]}),
1349+
# remove every None that is the value of a dict key
1350+
(
1351+
[None, {"x": None, "y": 5}, [None, 6]],
1352+
[None, {"y": 5}, [None, 6]],
1353+
),
1354+
({"a": [None, {"x": None}]}, {"a": [None, {}]}),
1355+
],
1356+
)
1357+
def test_filter_none(data: Dict, expected: Dict):
1358+
"""Test that filter_none removes None values from nested dictionaries."""
1359+
assert filter_none(data) == expected
1360+
1361+
12391362
def test_get_provider_helper_auto(mocker):
12401363
"""Test the 'auto' provider selection logic."""
12411364

0 commit comments

Comments
 (0)