Skip to content

Commit 04b419c

Browse files
[Inference Providers] Fix structured output (#1579)
equivalent PR to huggingface/huggingface_hub#3082. `sambanova` and `nebius` don’t fully follow OpenAI’s spec for the `response_format` field. This PR adds internal mappings for each provider. Seems like `fireworks-ai` is now following OpenAI specs for this field (double checked their docs as well). --------- Co-authored-by: SBrandeis <simon@huggingface.co>
1 parent 5f26f9d commit 04b419c

File tree

3 files changed

+43
-0
lines changed

3 files changed

+43
-0
lines changed

packages/inference/src/providers/nebius.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import {
2525
type TextToImageTaskHelper,
2626
} from "./providerHelper.js";
2727
import { InferenceClientProviderOutputError } from "../errors.js";
28+
import type { ChatCompletionInput } from "../../../tasks/dist/commonjs/index.js";
2829

2930
const NEBIUS_API_BASE_URL = "https://api.studio.nebius.ai";
3031

@@ -50,6 +51,17 @@ export class NebiusConversationalTask extends BaseConversationalTask {
5051
constructor() {
5152
super("nebius", NEBIUS_API_BASE_URL);
5253
}
54+
55+
override preparePayload(params: BodyParams<ChatCompletionInput>): Record<string, unknown> {
56+
const payload = super.preparePayload(params) as Record<string, unknown>;
57+
58+
const responseFormat = params.args.response_format;
59+
if (responseFormat?.type === "json_schema" && responseFormat.json_schema?.schema) {
60+
payload["guided_json"] = responseFormat.json_schema.schema;
61+
}
62+
63+
return payload;
64+
}
5365
}
5466

5567
export class NebiusTextGenerationTask extends BaseTextGenerationTask {

packages/inference/src/providers/sambanova.ts

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,25 @@ import type { BodyParams } from "../types.js";
1919
import type { FeatureExtractionTaskHelper } from "./providerHelper.js";
2020
import { BaseConversationalTask, TaskProviderHelper } from "./providerHelper.js";
2121
import { InferenceClientProviderOutputError } from "../errors.js";
22+
import type { ChatCompletionInput } from "../../../tasks/dist/commonjs/index.js";
2223

2324
export class SambanovaConversationalTask extends BaseConversationalTask {
2425
constructor() {
2526
super("sambanova", "https://api.sambanova.ai");
2627
}
28+
29+
override preparePayload(params: BodyParams<ChatCompletionInput>): Record<string, unknown> {
30+
const responseFormat = params.args.response_format;
31+
32+
if (responseFormat?.type === "json_schema" && responseFormat.json_schema) {
33+
if (responseFormat.json_schema.strict ?? true) {
34+
responseFormat.json_schema.strict = false;
35+
}
36+
}
37+
const payload = super.preparePayload(params) as Record<string, unknown>;
38+
39+
return payload;
40+
}
2741
}
2842

2943
export class SambanovaFeatureExtractionTask extends TaskProviderHelper implements FeatureExtractionTaskHelper {

packages/inference/src/providers/together.ts

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import {
2424
type TextToImageTaskHelper,
2525
} from "./providerHelper.js";
2626
import { InferenceClientProviderOutputError } from "../errors.js";
27+
import type { ChatCompletionInput } from "../../../tasks/dist/commonjs/index.js";
2728

2829
const TOGETHER_API_BASE_URL = "https://api.together.xyz";
2930

@@ -47,6 +48,22 @@ export class TogetherConversationalTask extends BaseConversationalTask {
4748
constructor() {
4849
super("together", TOGETHER_API_BASE_URL);
4950
}
51+
52+
override preparePayload(params: BodyParams<ChatCompletionInput>): Record<string, unknown> {
53+
const payload = super.preparePayload(params);
54+
const response_format = payload.response_format as
55+
| { type: "json_schema"; json_schema: { schema: unknown } }
56+
| undefined;
57+
58+
if (response_format?.type === "json_schema" && response_format?.json_schema?.schema) {
59+
payload.response_format = {
60+
type: "json_schema",
61+
schema: response_format.json_schema.schema,
62+
};
63+
}
64+
65+
return payload;
66+
}
5067
}
5168

5269
export class TogetherTextGenerationTask extends BaseTextGenerationTask {

0 commit comments

Comments
 (0)