Skip to content

Commit 4ada292

Browse files
mishig25Wauplin
authored andcommitted
fix: update payload preparation to merge parameters into the output dictionary (#3160)
* fix: update payload preparation to merge parameters into the output dictionary * dev change * Update only feature-extraction task + add test + add specs link * [dev change] * Revert "[dev change]" This reverts commit 50dba39. --------- Co-authored-by: Lucain Pouget <lucainp@gmail.com>
1 parent d5dff4e commit 4ada292

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

src/huggingface_hub/inference/_providers/hf_inference.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,18 @@ class HFInferenceFeatureExtractionTask(HFInferenceTask):
194194
def __init__(self):
195195
super().__init__("feature-extraction")
196196

197+
def _prepare_payload_as_dict(
198+
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
199+
) -> Optional[Dict]:
200+
if isinstance(inputs, bytes):
201+
raise ValueError(f"Unexpected binary input for task {self.task}.")
202+
if isinstance(inputs, Path):
203+
raise ValueError(f"Unexpected path input for task {self.task} (got {inputs})")
204+
205+
# Parameters are sent at root-level for feature-extraction task
206+
# See specs: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/feature-extraction/spec/input.json
207+
return {"inputs": inputs, **filter_none(parameters)}
208+
197209
def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any:
198210
if isinstance(response, bytes):
199211
return _bytes_to_dict(response)

tests/test_inference_providers.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import base64
22
import logging
33
from typing import Dict
4-
from unittest.mock import patch
4+
from unittest.mock import MagicMock, patch
55

66
import pytest
77
from pytest import LogCaptureFixture
@@ -33,6 +33,7 @@
3333
from huggingface_hub.inference._providers.hf_inference import (
3434
HFInferenceBinaryInputTask,
3535
HFInferenceConversational,
36+
HFInferenceFeatureExtractionTask,
3637
HFInferenceTask,
3738
)
3839
from huggingface_hub.inference._providers.hyperbolic import HyperbolicTextGenerationTask, HyperbolicTextToImageTask
@@ -654,6 +655,15 @@ def test_prepare_payload_as_dict_conversational(self, mapped_model, parameters,
654655
assert payload["model"] == expected_model
655656
assert payload["messages"] == messages
656657

658+
def test_prepare_payload_feature_extraction(self):
659+
helper = HFInferenceFeatureExtractionTask()
660+
payload = helper._prepare_payload_as_dict(
661+
inputs="This is a test sentence.",
662+
parameters={"truncate": True},
663+
provider_mapping_info=MagicMock(),
664+
)
665+
assert payload == {"inputs": "This is a test sentence.", "truncate": True} # not under "parameters"
666+
657667
@pytest.mark.parametrize(
658668
"pipeline_tag,tags,task,should_raise",
659669
[

0 commit comments

Comments
 (0)