Skip to content

Commit 417ad89

Browse files
[Inference Providers] Fix structured output schema in chat completion (#3082)
* fix structured output * fix * style * run style again * fix tests * rename types * review suggestions * no need to mutate parameters * docs * better
1 parent 0f27bdf commit 417ad89

File tree

13 files changed

+264
-20
lines changed

13 files changed

+264
-20
lines changed

docs/source/en/guides/inference.md

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,110 @@ You might wonder why using [`InferenceClient`] instead of OpenAI's client? There
308308

309309
</Tip>
310310

311+
## Function Calling
312+
313+
Function calling allows LLMs to interact with external tools, such as defined functions or APIs. This enables users to easily build applications tailored to specific use cases and real-world tasks.
314+
`InferenceClient` implements the same tool calling interface as the OpenAI Chat Completions API. Here is a simple example of tool calling using [Nebius](https://nebius.com/) as the inference provider:
315+
316+
```python
317+
from huggingface_hub import InferenceClient
318+
319+
tools = [
320+
{
321+
"type": "function",
322+
"function": {
323+
"name": "get_weather",
324+
"description": "Get current temperature for a given location.",
325+
"parameters": {
326+
"type": "object",
327+
"properties": {
328+
"location": {
329+
"type": "string",
330+
"description": "City and country e.g. Paris, France"
331+
}
332+
},
333+
"required": ["location"],
334+
},
335+
}
336+
}
337+
]
338+
339+
client = InferenceClient(provider="nebius")
340+
341+
response = client.chat.completions.create(
342+
model="Qwen/Qwen2.5-72B-Instruct",
343+
messages=[
344+
{
345+
"role": "user",
346+
"content": "What's the weather like the next 3 days in London, UK?"
347+
}
348+
],
349+
tools=tools,
350+
tool_choice="auto",
351+
)
352+
353+
print(response.choices[0].message.tool_calls[0].function.arguments)
354+
355+
```
356+
357+
<Tip>
358+
359+
Please refer to the providers' documentation to verify which models are supported by them for Function/Tool Calling.
360+
361+
</Tip>
362+
363+
## Structured Outputs & JSON Mode
364+
365+
InferenceClient supports JSON mode for syntactically valid JSON responses and Structured Outputs for schema-enforced responses. JSON mode provides machine-readable data without strict structure, while Structured Outputs guarantee both valid JSON and adherence to a predefined schema for reliable downstream processing.
366+
367+
We follow the OpenAI API specs for both JSON mode and Structured Outputs. You can enable them via the `response_format` argument. Here is an example of Structured Outputs using [Cerebras](https://www.cerebras.ai/) as the inference provider:
368+
369+
```python
370+
from huggingface_hub import InferenceClient
371+
372+
json_schema = {
373+
"name": "book",
374+
"schema": {
375+
"properties": {
376+
"name": {
377+
"title": "Name",
378+
"type": "string",
379+
},
380+
"authors": {
381+
"items": {"type": "string"},
382+
"title": "Authors",
383+
"type": "array",
384+
},
385+
},
386+
"required": ["name", "authors"],
387+
"title": "Book",
388+
"type": "object",
389+
},
390+
"strict": True,
391+
}
392+
393+
client = InferenceClient(provider="cerebras")
394+
395+
396+
completion = client.chat.completions.create(
397+
model="Qwen/Qwen3-32B",
398+
messages=[
399+
{"role": "system", "content": "Extract the books information."},
400+
{"role": "user", "content": "I recently read 'The Great Gatsby' by F. Scott Fitzgerald."},
401+
],
402+
response_format={
403+
"type": "json_schema",
404+
"json_schema": json_schema,
405+
},
406+
)
407+
408+
print(completion.choices[0].message)
409+
```
410+
<Tip>
411+
412+
Please refer to the providers' documentation to verify which models are supported by them for Structured Outputs and JSON Mode.
413+
414+
</Tip>
311415

312416
## Async client
313417

docs/source/en/package_reference/inference_types.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,18 @@ This part of the lib is still under development and will be improved in future r
5757

5858
[[autodoc]] huggingface_hub.ChatCompletionInputFunctionName
5959

60-
[[autodoc]] huggingface_hub.ChatCompletionInputGrammarType
60+
[[autodoc]] huggingface_hub.ChatCompletionInputJSONSchema
6161

6262
[[autodoc]] huggingface_hub.ChatCompletionInputMessage
6363

6464
[[autodoc]] huggingface_hub.ChatCompletionInputMessageChunk
6565

66+
[[autodoc]] huggingface_hub.ChatCompletionInputResponseFormatJSONObject
67+
68+
[[autodoc]] huggingface_hub.ChatCompletionInputResponseFormatJSONSchema
69+
70+
[[autodoc]] huggingface_hub.ChatCompletionInputResponseFormatText
71+
6672
[[autodoc]] huggingface_hub.ChatCompletionInputStreamOptions
6773

6874
[[autodoc]] huggingface_hub.ChatCompletionInputTool

docs/source/ko/package_reference/inference_types.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,18 @@ rendered properly in your Markdown viewer.
5656

5757
[[autodoc]] huggingface_hub.ChatCompletionInputFunctionName
5858

59-
[[autodoc]] huggingface_hub.ChatCompletionInputGrammarType
59+
[[autodoc]] huggingface_hub.ChatCompletionInputJSONSchema
6060

6161
[[autodoc]] huggingface_hub.ChatCompletionInputMessage
6262

6363
[[autodoc]] huggingface_hub.ChatCompletionInputMessageChunk
6464

65+
[[autodoc]] huggingface_hub.ChatCompletionInputResponseFormatJSONObject
66+
67+
[[autodoc]] huggingface_hub.ChatCompletionInputResponseFormatJSONSchema
68+
69+
[[autodoc]] huggingface_hub.ChatCompletionInputResponseFormatText
70+
6571
[[autodoc]] huggingface_hub.ChatCompletionInputStreamOptions
6672

6773
[[autodoc]] huggingface_hub.ChatCompletionInputTool

src/huggingface_hub/__init__.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -301,10 +301,13 @@
301301
"ChatCompletionInputFunctionDefinition",
302302
"ChatCompletionInputFunctionName",
303303
"ChatCompletionInputGrammarType",
304-
"ChatCompletionInputGrammarTypeType",
304+
"ChatCompletionInputJSONSchema",
305305
"ChatCompletionInputMessage",
306306
"ChatCompletionInputMessageChunk",
307307
"ChatCompletionInputMessageChunkType",
308+
"ChatCompletionInputResponseFormatJSONObject",
309+
"ChatCompletionInputResponseFormatJSONSchema",
310+
"ChatCompletionInputResponseFormatText",
308311
"ChatCompletionInputStreamOptions",
309312
"ChatCompletionInputTool",
310313
"ChatCompletionInputToolCall",
@@ -545,10 +548,13 @@
545548
"ChatCompletionInputFunctionDefinition",
546549
"ChatCompletionInputFunctionName",
547550
"ChatCompletionInputGrammarType",
548-
"ChatCompletionInputGrammarTypeType",
551+
"ChatCompletionInputJSONSchema",
549552
"ChatCompletionInputMessage",
550553
"ChatCompletionInputMessageChunk",
551554
"ChatCompletionInputMessageChunkType",
555+
"ChatCompletionInputResponseFormatJSONObject",
556+
"ChatCompletionInputResponseFormatJSONSchema",
557+
"ChatCompletionInputResponseFormatText",
552558
"ChatCompletionInputStreamOptions",
553559
"ChatCompletionInputTool",
554560
"ChatCompletionInputToolCall",
@@ -1267,10 +1273,13 @@ def __dir__():
12671273
ChatCompletionInputFunctionDefinition, # noqa: F401
12681274
ChatCompletionInputFunctionName, # noqa: F401
12691275
ChatCompletionInputGrammarType, # noqa: F401
1270-
ChatCompletionInputGrammarTypeType, # noqa: F401
1276+
ChatCompletionInputJSONSchema, # noqa: F401
12711277
ChatCompletionInputMessage, # noqa: F401
12721278
ChatCompletionInputMessageChunk, # noqa: F401
12731279
ChatCompletionInputMessageChunkType, # noqa: F401
1280+
ChatCompletionInputResponseFormatJSONObject, # noqa: F401
1281+
ChatCompletionInputResponseFormatJSONSchema, # noqa: F401
1282+
ChatCompletionInputResponseFormatText, # noqa: F401
12741283
ChatCompletionInputStreamOptions, # noqa: F401
12751284
ChatCompletionInputTool, # noqa: F401
12761285
ChatCompletionInputToolCall, # noqa: F401

src/huggingface_hub/inference/_generated/types/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,13 @@
2424
ChatCompletionInputFunctionDefinition,
2525
ChatCompletionInputFunctionName,
2626
ChatCompletionInputGrammarType,
27-
ChatCompletionInputGrammarTypeType,
27+
ChatCompletionInputJSONSchema,
2828
ChatCompletionInputMessage,
2929
ChatCompletionInputMessageChunk,
3030
ChatCompletionInputMessageChunkType,
31+
ChatCompletionInputResponseFormatJSONObject,
32+
ChatCompletionInputResponseFormatJSONSchema,
33+
ChatCompletionInputResponseFormatText,
3134
ChatCompletionInputStreamOptions,
3235
ChatCompletionInputTool,
3336
ChatCompletionInputToolCall,

src/huggingface_hub/inference/_generated/types/chat_completion.py

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# See:
44
# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts
55
# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks.
6-
from typing import Any, List, Literal, Optional, Union
6+
from typing import Any, Dict, List, Literal, Optional, Union
77

88
from .base import BaseInferenceType, dataclass_with_extra
99

@@ -45,17 +45,51 @@ class ChatCompletionInputMessage(BaseInferenceType):
4545
tool_calls: Optional[List[ChatCompletionInputToolCall]] = None
4646

4747

48-
ChatCompletionInputGrammarTypeType = Literal["json", "regex", "json_schema"]
48+
@dataclass_with_extra
49+
class ChatCompletionInputJSONSchema(BaseInferenceType):
50+
name: str
51+
"""
52+
The name of the response format.
53+
"""
54+
description: Optional[str] = None
55+
"""
56+
A description of what the response format is for, used by the model to determine
57+
how to respond in the format.
58+
"""
59+
schema: Optional[Dict[str, object]] = None
60+
"""
61+
The schema for the response format, described as a JSON Schema object. Learn how
62+
to build JSON schemas [here](https://json-schema.org/).
63+
"""
64+
strict: Optional[bool] = None
65+
"""
66+
Whether to enable strict schema adherence when generating the output. If set to
67+
true, the model will always follow the exact schema defined in the `schema`
68+
field.
69+
"""
4970

5071

5172
@dataclass_with_extra
52-
class ChatCompletionInputGrammarType(BaseInferenceType):
53-
type: "ChatCompletionInputGrammarTypeType"
54-
value: Any
55-
"""A string that represents a [JSON Schema](https://json-schema.org/).
56-
JSON Schema is a declarative language that allows to annotate JSON documents
57-
with types and descriptions.
58-
"""
73+
class ChatCompletionInputResponseFormatText(BaseInferenceType):
74+
type: Literal["text"]
75+
76+
77+
@dataclass_with_extra
78+
class ChatCompletionInputResponseFormatJSONSchema(BaseInferenceType):
79+
type: Literal["json_schema"]
80+
json_schema: ChatCompletionInputJSONSchema
81+
82+
83+
@dataclass_with_extra
84+
class ChatCompletionInputResponseFormatJSONObject(BaseInferenceType):
85+
type: Literal["json_object"]
86+
87+
88+
ChatCompletionInputGrammarType = Union[
89+
ChatCompletionInputResponseFormatText,
90+
ChatCompletionInputResponseFormatJSONSchema,
91+
ChatCompletionInputResponseFormatJSONObject,
92+
]
5993

6094

6195
@dataclass_with_extra

src/huggingface_hub/inference/_providers/cerebras.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from huggingface_hub.inference._providers._common import BaseConversationalTask
1+
from ._common import BaseConversationalTask
22

33

44
class CerebrasConversationalTask(BaseConversationalTask):

src/huggingface_hub/inference/_providers/cohere.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
from huggingface_hub.inference._providers._common import (
2-
BaseConversationalTask,
3-
)
1+
from typing import Any, Dict, Optional
2+
3+
from huggingface_hub.hf_api import InferenceProviderMapping
4+
5+
from ._common import BaseConversationalTask
46

57

68
_PROVIDER = "cohere"
@@ -13,3 +15,18 @@ def __init__(self):
1315

1416
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
1517
return "/compatibility/v1/chat/completions"
18+
19+
def _prepare_payload_as_dict(
20+
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
21+
) -> Optional[Dict]:
22+
payload = super()._prepare_payload_as_dict(inputs, parameters, provider_mapping_info)
23+
response_format = parameters.get("response_format")
24+
if isinstance(response_format, dict) and response_format.get("type") == "json_schema":
25+
json_schema_details = response_format.get("json_schema")
26+
if isinstance(json_schema_details, dict) and "schema" in json_schema_details:
27+
payload["response_format"] = { # type: ignore [index]
28+
"type": "json_object",
29+
"schema": json_schema_details["schema"],
30+
}
31+
32+
return payload

src/huggingface_hub/inference/_providers/fireworks_ai.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
from typing import Any, Dict, Optional
2+
3+
from huggingface_hub.hf_api import InferenceProviderMapping
4+
15
from ._common import BaseConversationalTask
26

37

@@ -7,3 +11,17 @@ def __init__(self):
711

812
def _prepare_route(self, mapped_model: str, api_key: str) -> str:
913
return "/inference/v1/chat/completions"
14+
15+
def _prepare_payload_as_dict(
16+
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
17+
) -> Optional[Dict]:
18+
payload = super()._prepare_payload_as_dict(inputs, parameters, provider_mapping_info)
19+
response_format = parameters.get("response_format")
20+
if isinstance(response_format, dict) and response_format.get("type") == "json_schema":
21+
json_schema_details = response_format.get("json_schema")
22+
if isinstance(json_schema_details, dict) and "schema" in json_schema_details:
23+
payload["response_format"] = { # type: ignore [index]
24+
"type": "json_object",
25+
"schema": json_schema_details["schema"],
26+
}
27+
return payload

src/huggingface_hub/inference/_providers/hf_inference.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,20 @@ def __init__(self):
9696
def _prepare_payload_as_dict(
9797
self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping
9898
) -> Optional[Dict]:
99+
payload = filter_none(parameters)
99100
mapped_model = provider_mapping_info.provider_id
100101
payload_model = parameters.get("model") or mapped_model
101102

102103
if payload_model is None or payload_model.startswith(("http://", "https://")):
103104
payload_model = "dummy"
104105

105-
return {**filter_none(parameters), "model": payload_model, "messages": inputs}
106+
response_format = parameters.get("response_format")
107+
if isinstance(response_format, dict) and response_format.get("type") == "json_schema":
108+
payload["response_format"] = {
109+
"type": "json_object",
110+
"value": response_format["json_schema"]["schema"],
111+
}
112+
return {**payload, "model": payload_model, "messages": inputs}
106113

107114
def _prepare_url(self, api_key: str, mapped_model: str) -> str:
108115
base_url = (

0 commit comments

Comments
 (0)