Skip to content

[Inference Providers] Fix structured output schema in chat completion #3082

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 22, 2025
10 changes: 8 additions & 2 deletions docs/source/en/package_reference/inference_types.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ This part of the lib is still under development and will be improved in future r

[[autodoc]] huggingface_hub.ChatCompletionInputFunctionName

[[autodoc]] huggingface_hub.ChatCompletionInputGrammarType

[[autodoc]] huggingface_hub.ChatCompletionInputMessage

[[autodoc]] huggingface_hub.ChatCompletionInputMessageChunk
Expand Down Expand Up @@ -109,6 +107,14 @@ This part of the lib is still under development and will be improved in future r

[[autodoc]] huggingface_hub.ChatCompletionStreamOutputUsage

[[autodoc]] huggingface_hub.JSONSchema

[[autodoc]] huggingface_hub.ResponseFormatJSONObject

[[autodoc]] huggingface_hub.ResponseFormatJSONSchema

[[autodoc]] huggingface_hub.ResponseFormatText



## depth_estimation
Expand Down
10 changes: 8 additions & 2 deletions docs/source/ko/package_reference/inference_types.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@ rendered properly in your Markdown viewer.

[[autodoc]] huggingface_hub.ChatCompletionInputFunctionName

[[autodoc]] huggingface_hub.ChatCompletionInputGrammarType

[[autodoc]] huggingface_hub.ChatCompletionInputMessage

[[autodoc]] huggingface_hub.ChatCompletionInputMessageChunk
Expand Down Expand Up @@ -108,6 +106,14 @@ rendered properly in your Markdown viewer.

[[autodoc]] huggingface_hub.ChatCompletionStreamOutputUsage

[[autodoc]] huggingface_hub.JSONSchema

[[autodoc]] huggingface_hub.ResponseFormatJSONObject

[[autodoc]] huggingface_hub.ResponseFormatJSONSchema

[[autodoc]] huggingface_hub.ResponseFormatText



## depth_estimation[[huggingface_hub.DepthEstimationInput]]
Expand Down
15 changes: 12 additions & 3 deletions src/huggingface_hub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,6 @@
"ChatCompletionInputFunctionDefinition",
"ChatCompletionInputFunctionName",
"ChatCompletionInputGrammarType",
"ChatCompletionInputGrammarTypeType",
"ChatCompletionInputMessage",
"ChatCompletionInputMessageChunk",
"ChatCompletionInputMessageChunkType",
Expand Down Expand Up @@ -357,6 +356,7 @@
"ImageToTextInput",
"ImageToTextOutput",
"ImageToTextParameters",
"JSONSchema",
"ObjectDetectionBoundingBox",
"ObjectDetectionInput",
"ObjectDetectionOutputElement",
Expand All @@ -366,6 +366,9 @@
"QuestionAnsweringInputData",
"QuestionAnsweringOutputElement",
"QuestionAnsweringParameters",
"ResponseFormatJSONObject",
"ResponseFormatJSONSchema",
"ResponseFormatText",
"SentenceSimilarityInput",
"SentenceSimilarityInputData",
"SummarizationInput",
Expand Down Expand Up @@ -545,7 +548,6 @@
"ChatCompletionInputFunctionDefinition",
"ChatCompletionInputFunctionName",
"ChatCompletionInputGrammarType",
"ChatCompletionInputGrammarTypeType",
"ChatCompletionInputMessage",
"ChatCompletionInputMessageChunk",
"ChatCompletionInputMessageChunkType",
Expand Down Expand Up @@ -646,6 +648,7 @@
"InferenceEndpointTimeoutError",
"InferenceEndpointType",
"InferenceTimeoutError",
"JSONSchema",
"KerasModelHubMixin",
"MCPClient",
"ModelCard",
Expand All @@ -672,6 +675,9 @@
"RepoCard",
"RepoUrl",
"Repository",
"ResponseFormatJSONObject",
"ResponseFormatJSONSchema",
"ResponseFormatText",
"SentenceSimilarityInput",
"SentenceSimilarityInputData",
"SpaceCard",
Expand Down Expand Up @@ -1267,7 +1273,6 @@ def __dir__():
ChatCompletionInputFunctionDefinition, # noqa: F401
ChatCompletionInputFunctionName, # noqa: F401
ChatCompletionInputGrammarType, # noqa: F401
ChatCompletionInputGrammarTypeType, # noqa: F401
ChatCompletionInputMessage, # noqa: F401
ChatCompletionInputMessageChunk, # noqa: F401
ChatCompletionInputMessageChunkType, # noqa: F401
Expand Down Expand Up @@ -1323,6 +1328,7 @@ def __dir__():
ImageToTextInput, # noqa: F401
ImageToTextOutput, # noqa: F401
ImageToTextParameters, # noqa: F401
JSONSchema, # noqa: F401
ObjectDetectionBoundingBox, # noqa: F401
ObjectDetectionInput, # noqa: F401
ObjectDetectionOutputElement, # noqa: F401
Expand All @@ -1332,6 +1338,9 @@ def __dir__():
QuestionAnsweringInputData, # noqa: F401
QuestionAnsweringOutputElement, # noqa: F401
QuestionAnsweringParameters, # noqa: F401
ResponseFormatJSONObject, # noqa: F401
ResponseFormatJSONSchema, # noqa: F401
ResponseFormatText, # noqa: F401
SentenceSimilarityInput, # noqa: F401
SentenceSimilarityInputData, # noqa: F401
SummarizationInput, # noqa: F401
Expand Down
5 changes: 4 additions & 1 deletion src/huggingface_hub/inference/_generated/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
ChatCompletionInputFunctionDefinition,
ChatCompletionInputFunctionName,
ChatCompletionInputGrammarType,
ChatCompletionInputGrammarTypeType,
ChatCompletionInputMessage,
ChatCompletionInputMessageChunk,
ChatCompletionInputMessageChunkType,
Expand Down Expand Up @@ -52,6 +51,10 @@
ChatCompletionStreamOutputLogprobs,
ChatCompletionStreamOutputTopLogprob,
ChatCompletionStreamOutputUsage,
JSONSchema,
ResponseFormatJSONObject,
ResponseFormatJSONSchema,
ResponseFormatText,
)
from .depth_estimation import DepthEstimationInput, DepthEstimationOutput
from .document_question_answering import (
Expand Down
48 changes: 39 additions & 9 deletions src/huggingface_hub/inference/_generated/types/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# See:
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
from typing import Any, List, Literal, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union

from .base import BaseInferenceType, dataclass_with_extra

Expand Down Expand Up @@ -45,17 +45,47 @@ class ChatCompletionInputMessage(BaseInferenceType):
tool_calls: Optional[List[ChatCompletionInputToolCall]] = None


ChatCompletionInputGrammarTypeType = Literal["json", "regex", "json_schema"]
@dataclass_with_extra
class JSONSchema(BaseInferenceType):
name: str
"""
The name of the response format.
"""
description: Optional[str] = None
"""
A description of what the response format is for, used by the model to determine
how to respond in the format.
"""
schema: Optional[Dict[str, object]] = None
"""
The schema for the response format, described as a JSON Schema object. Learn how
to build JSON schemas [here](https://json-schema.org/).
"""
strict: Optional[bool] = None
"""
Whether to enable strict schema adherence when generating the output. If set to
true, the model will always follow the exact schema defined in the `schema`
field.
"""


@dataclass_with_extra
class ChatCompletionInputGrammarType(BaseInferenceType):
type: "ChatCompletionInputGrammarTypeType"
value: Any
"""A string that represents a [JSON Schema](https://json-schema.org/).
JSON Schema is a declarative language that allows to annotate JSON documents
with types and descriptions.
"""
class ResponseFormatText(BaseInferenceType):
type: Literal["text"]


@dataclass_with_extra
class ResponseFormatJSONSchema(BaseInferenceType):
type: Literal["json_schema"]
json_schema: JSONSchema


@dataclass_with_extra
class ResponseFormatJSONObject(BaseInferenceType):
type: Literal["json_object"]


ChatCompletionInputGrammarType = Union[ResponseFormatText, ResponseFormatJSONSchema, ResponseFormatJSONObject]


@dataclass_with_extra
Expand Down
2 changes: 1 addition & 1 deletion src/huggingface_hub/inference/_providers/cerebras.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from huggingface_hub.inference._providers._common import BaseConversationalTask
from ._common import BaseConversationalTask


class CerebrasConversationalTask(BaseConversationalTask):
Expand Down
22 changes: 19 additions & 3 deletions src/huggingface_hub/inference/_providers/cohere.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from huggingface_hub.inference._providers._common import (
BaseConversationalTask,
)
from typing import Any, Dict, Optional

from huggingface_hub.hf_api import InferenceProviderMapping

from ._common import BaseConversationalTask


_PROVIDER = "cohere"
Expand All @@ -13,3 +15,17 @@ def __init__(self):

def _prepare_route(self, mapped_model: str, api_key: str) -> str:
return "/compatibility/v1/chat/completions"

def _prepare_payload_as_dict(
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
) -> Optional[Dict]:
payload = super()._prepare_payload_as_dict(inputs, parameters, provider_mapping_info)
response_format = parameters.pop("response_format", None)
if response_format is not None and response_format["type"] == "json_schema":
json_schema_details = response_format.get("json_schema")
if isinstance(json_schema_details, dict) and "schema" in json_schema_details:
payload["response_format"] = { # type: ignore [index]
"type": "json_object",
"schema": json_schema_details["schema"],
}
return payload
18 changes: 18 additions & 0 deletions src/huggingface_hub/inference/_providers/fireworks_ai.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from typing import Any, Dict, Optional

from huggingface_hub.hf_api import InferenceProviderMapping

from ._common import BaseConversationalTask


Expand All @@ -7,3 +11,17 @@ def __init__(self):

def _prepare_route(self, mapped_model: str, api_key: str) -> str:
return "/inference/v1/chat/completions"

def _prepare_payload_as_dict(
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
) -> Optional[Dict]:
payload = super()._prepare_payload_as_dict(inputs, parameters, provider_mapping_info)
response_format = parameters.pop("response_format", None)
if response_format is not None and response_format["type"] == "json_schema":
json_schema_details = response_format.get("json_schema")
if isinstance(json_schema_details, dict) and "schema" in json_schema_details:
payload["response_format"] = { # type: ignore [index]
"type": "json_object",
"schema": json_schema_details["schema"],
}
return payload
9 changes: 8 additions & 1 deletion src/huggingface_hub/inference/_providers/hf_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,20 @@ def __init__(self):
def _prepare_payload_as_dict(
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
) -> Optional[Dict]:
payload = filter_none(parameters)
mapped_model = provider_mapping_info.provider_id
payload_model = parameters.get("model") or mapped_model

if payload_model is None or payload_model.startswith(("http://", "https://")):
payload_model = "dummy"

return {**filter_none(parameters), "model": payload_model, "messages": inputs}
response_format = parameters.pop("response_format", None)
if response_format is not None and response_format["type"] == "json_schema":
payload["response_format"] = {
"type": "json_object",
"value": response_format["json_schema"]["schema"],
}
return {**payload, "model": payload_model, "messages": inputs}

def _prepare_url(self, api_key: str, mapped_model: str) -> str:
base_url = (
Expand Down
11 changes: 11 additions & 0 deletions src/huggingface_hub/inference/_providers/nebius.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@ class NebiusConversationalTask(BaseConversationalTask):
def __init__(self):
super().__init__(provider="nebius", base_url="https://api.studio.nebius.ai")

def _prepare_payload_as_dict(
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
) -> Optional[Dict]:
payload = super()._prepare_payload_as_dict(inputs, parameters, provider_mapping_info)
response_format = parameters.pop("response_format", None)
if response_format is not None and response_format["type"] == "json_schema":
json_schema_details = response_format.get("json_schema")
if isinstance(json_schema_details, dict) and "schema" in json_schema_details:
payload["guided_json"] = json_schema_details["schema"] # type: ignore [index]
return payload


class NebiusTextToImageTask(TaskProviderHelper):
def __init__(self):
Expand Down
14 changes: 14 additions & 0 deletions src/huggingface_hub/inference/_providers/sambanova.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,20 @@ class SambanovaConversationalTask(BaseConversationalTask):
def __init__(self):
super().__init__(provider="sambanova", base_url="https://api.sambanova.ai")

def _prepare_payload_as_dict(
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
) -> Optional[Dict]:
response_format_config = parameters.get("response_format")
if isinstance(response_format_config, dict):
if response_format_config.get("type") == "json_schema":
json_schema_config = response_format_config.get("json_schema", {})
strict = json_schema_config.get("strict")
if isinstance(json_schema_config, dict) and (strict is True or strict is None):
json_schema_config["strict"] = False

payload = super()._prepare_payload_as_dict(inputs, parameters, provider_mapping_info)
return payload


class SambanovaFeatureExtractionTask(TaskProviderHelper):
def __init__(self):
Expand Down
14 changes: 14 additions & 0 deletions src/huggingface_hub/inference/_providers/together.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,20 @@ class TogetherConversationalTask(BaseConversationalTask):
def __init__(self):
super().__init__(provider=_PROVIDER, base_url=_BASE_URL)

def _prepare_payload_as_dict(
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
) -> Optional[Dict]:
payload = super()._prepare_payload_as_dict(inputs, parameters, provider_mapping_info)
response_format = parameters.pop("response_format", None)
if isinstance(response_format, dict) and response_format.get("type") == "json_schema":
json_schema_details = response_format.get("json_schema")
if isinstance(json_schema_details, dict) and "schema" in json_schema_details:
payload["response_format"] = { # type: ignore [index]
"type": "json_object",
"schema": json_schema_details["schema"],
}
return payload


class TogetherTextToImageTask(TogetherTask):
def __init__(self):
Expand Down