diff --git a/packages/inference/src/providers/nebius.ts b/packages/inference/src/providers/nebius.ts index a5b31c6cf6..93376f9170 100644 --- a/packages/inference/src/providers/nebius.ts +++ b/packages/inference/src/providers/nebius.ts @@ -25,6 +25,7 @@ import { type TextToImageTaskHelper, } from "./providerHelper.js"; import { InferenceClientProviderOutputError } from "../errors.js"; +import type { ChatCompletionInput } from "../../../tasks/dist/commonjs/index.js"; const NEBIUS_API_BASE_URL = "https://api.studio.nebius.ai"; @@ -44,6 +45,17 @@ export class NebiusConversationalTask extends BaseConversationalTask { constructor() { super("nebius", NEBIUS_API_BASE_URL); } + + override preparePayload(params: BodyParams): Record { + const payload = super.preparePayload(params) as Record; + + const responseFormat = params.args.response_format; + if (responseFormat?.type === "json_schema" && responseFormat.json_schema?.schema) { + payload["guided_json"] = responseFormat.json_schema.schema; + } + + return payload; + } } export class NebiusTextGenerationTask extends BaseTextGenerationTask { diff --git a/packages/inference/src/providers/sambanova.ts b/packages/inference/src/providers/sambanova.ts index d198ff53c6..7c66f81733 100644 --- a/packages/inference/src/providers/sambanova.ts +++ b/packages/inference/src/providers/sambanova.ts @@ -19,11 +19,25 @@ import type { BodyParams } from "../types.js"; import type { FeatureExtractionTaskHelper } from "./providerHelper.js"; import { BaseConversationalTask, TaskProviderHelper } from "./providerHelper.js"; import { InferenceClientProviderOutputError } from "../errors.js"; +import type { ChatCompletionInput } from "../../../tasks/dist/commonjs/index.js"; export class SambanovaConversationalTask extends BaseConversationalTask { constructor() { super("sambanova", "https://api.sambanova.ai"); } + + override preparePayload(params: BodyParams): Record { + const responseFormat = params.args.response_format; + + if (responseFormat?.type === "json_schema" && responseFormat.json_schema) { + if (responseFormat.json_schema.strict ?? true) { + responseFormat.json_schema.strict = false; + } + } + const payload = super.preparePayload(params) as Record; + + return payload; + } } export class SambanovaFeatureExtractionTask extends TaskProviderHelper implements FeatureExtractionTaskHelper { diff --git a/packages/inference/src/providers/together.ts b/packages/inference/src/providers/together.ts index 6777fdb169..bec666186c 100644 --- a/packages/inference/src/providers/together.ts +++ b/packages/inference/src/providers/together.ts @@ -24,6 +24,7 @@ import { type TextToImageTaskHelper, } from "./providerHelper.js"; import { InferenceClientProviderOutputError } from "../errors.js"; +import type { ChatCompletionInput } from "../../../tasks/dist/commonjs/index.js"; const TOGETHER_API_BASE_URL = "https://api.together.xyz"; @@ -47,6 +48,22 @@ export class TogetherConversationalTask extends BaseConversationalTask { constructor() { super("together", TOGETHER_API_BASE_URL); } + + override preparePayload(params: BodyParams): Record { + const payload = super.preparePayload(params); + const response_format = payload.response_format as + | { type: "json_schema"; json_schema: { schema: unknown } } + | undefined; + + if (response_format?.type === "json_schema" && response_format?.json_schema?.schema) { + payload.response_format = { + type: "json_schema", + schema: response_format.json_schema.schema, + }; + } + + return payload; + } } export class TogetherTextGenerationTask extends BaseTextGenerationTask {