diff --git a/src/huggingface_hub/inference/_providers/_common.py b/src/huggingface_hub/inference/_providers/_common.py index 97b774a4f2..8ce62d56ea 100644 --- a/src/huggingface_hub/inference/_providers/_common.py +++ b/src/huggingface_hub/inference/_providers/_common.py @@ -1,9 +1,10 @@ from functools import lru_cache -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, overload from huggingface_hub import constants from huggingface_hub.hf_api import InferenceProviderMapping from huggingface_hub.inference._common import RequestParameters +from huggingface_hub.inference._generated.types.chat_completion import ChatCompletionInputMessage from huggingface_hub.utils import build_hf_headers, get_token, logging @@ -36,8 +37,30 @@ } -def filter_none(d: Dict[str, Any]) -> Dict[str, Any]: - return {k: v for k, v in d.items() if v is not None} +@overload +def filter_none(obj: Dict[str, Any]) -> Dict[str, Any]: ... +@overload +def filter_none(obj: List[Any]) -> List[Any]: ... + + +def filter_none(obj: Union[Dict[str, Any], List[Any]]) -> Union[Dict[str, Any], List[Any]]: + if isinstance(obj, dict): + cleaned: Dict[str, Any] = {} + for k, v in obj.items(): + if v is None: + continue + if isinstance(v, (dict, list)): + v = filter_none(v) + # remove empty nested dicts + if isinstance(v, dict) and not v: + continue + cleaned[k] = v + return cleaned + + if isinstance(obj, list): + return [filter_none(v) if isinstance(v, (dict, list)) else v for v in obj] + + raise ValueError(f"Expected dict or list, got {type(obj)}") class TaskProviderHelper: @@ -224,9 +247,12 @@ def _prepare_route(self, mapped_model: str, api_key: str) -> str: return "/v1/chat/completions" def _prepare_payload_as_dict( - self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping + self, + inputs: List[Union[Dict, ChatCompletionInputMessage]], + parameters: Dict, + provider_mapping_info: InferenceProviderMapping, ) -> Optional[Dict]: - return {"messages": inputs, **filter_none(parameters), "model": provider_mapping_info.provider_id} + return filter_none({"messages": inputs, **parameters, "model": provider_mapping_info.provider_id}) class BaseTextGenerationTask(TaskProviderHelper): diff --git a/src/huggingface_hub/inference/_providers/hf_inference.py b/src/huggingface_hub/inference/_providers/hf_inference.py index c5549e9b3d..f5531a02c7 100644 --- a/src/huggingface_hub/inference/_providers/hf_inference.py +++ b/src/huggingface_hub/inference/_providers/hf_inference.py @@ -75,7 +75,7 @@ def _prepare_payload_as_bytes( provider_mapping_info: InferenceProviderMapping, extra_payload: Optional[Dict], ) -> Optional[bytes]: - parameters = filter_none({k: v for k, v in parameters.items() if v is not None}) + parameters = filter_none(parameters) extra_payload = extra_payload or {} has_parameters = len(parameters) > 0 or len(extra_payload) > 0 diff --git a/tests/test_inference_providers.py b/tests/test_inference_providers.py index 33fdb3a921..3b3a8f671c 100644 --- a/tests/test_inference_providers.py +++ b/tests/test_inference_providers.py @@ -13,6 +13,7 @@ BaseConversationalTask, BaseTextGenerationTask, TaskProviderHelper, + filter_none, recursive_merge, ) from huggingface_hub.inference._providers.black_forest_labs import BlackForestLabsTextToImageTask @@ -1152,6 +1153,98 @@ def test_prepare_payload(self): "model": "test-provider-id", } + @pytest.mark.parametrize( + "raw_messages, expected_messages", + [ + ( + [ + { + "role": "assistant", + "content": "", + "tool_calls": None, + } + ], + [ + { + "role": "assistant", + "content": "", + } + ], + ), + ( + [ + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": { + "name": "get_current_weather", + "arguments": '{"location": "San Francisco, CA", "unit": "celsius"}', + }, + }, + ], + }, + { + "role": "tool", + "content": "pong", + "tool_call_id": "abc123", + "name": "dummy_tool", + "tool_calls": None, + }, + ], + [ + { + "role": "assistant", + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": { + "name": "get_current_weather", + "arguments": '{"location": "San Francisco, CA", "unit": "celsius"}', + }, + } + ], + }, + { + "role": "tool", + "content": "pong", + "tool_call_id": "abc123", + "name": "dummy_tool", + }, + ], + ), + ], + ) + def test_prepare_payload_filters_messages(self, raw_messages, expected_messages): + helper = BaseConversationalTask(provider="test-provider", base_url="https://api.test.com") + + parameters = { + "temperature": 0.2, + "max_tokens": None, + "top_p": None, + } + + payload = helper._prepare_payload_as_dict( + inputs=raw_messages, + parameters=parameters, + provider_mapping_info=InferenceProviderMapping( + provider="test-provider", + hf_model_id="test-model", + providerId="test-provider-id", + task="conversational", + status="live", + ), + ) + + assert payload["messages"] == expected_messages + assert payload["temperature"] == 0.2 + assert "max_tokens" not in payload + assert "top_p" not in payload + class TestBaseTextGenerationTask: def test_prepare_route(self): @@ -1236,6 +1329,36 @@ def test_recursive_merge(dict1: Dict, dict2: Dict, expected: Dict): assert dict2 == initial_dict2 +@pytest.mark.parametrize( + "data, expected", + [ + ({}, {}), # empty dictionary remains empty + ({"a": 1, "b": None, "c": 3}, {"a": 1, "c": 3}), # remove None at root level + ({"a": None, "b": {"x": None, "y": 2}}, {"b": {"y": 2}}), # remove nested None + ({"a": {"b": {"c": None}}}, {}), # remove empty nested dict + ( + {"a": "", "b": {"x": {"y": None}, "z": 0}, "c": []}, # do not remove 0, [] and "" values + {"a": "", "b": {"z": 0}, "c": []}, + ), + ( + {"a": [0, 1, None]}, # do not remove None in lists + {"a": [0, 1, None]}, + ), + # dicts inside list are cleaned, list level None kept + ({"a": [{"x": None, "y": 1}, None]}, {"a": [{"y": 1}, None]}), + # remove every None that is the value of a dict key + ( + [None, {"x": None, "y": 5}, [None, 6]], + [None, {"y": 5}, [None, 6]], + ), + ({"a": [None, {"x": None}]}, {"a": [None, {}]}), + ], +) +def test_filter_none(data: Dict, expected: Dict): + """Test that filter_none removes None values from nested dictionaries.""" + assert filter_none(data) == expected + + def test_get_provider_helper_auto(mocker): """Test the 'auto' provider selection logic."""