Skip to content

Commit a719d7b

Browse files
authored
Merge pull request #1070 from narengogi/fix/support-provider-specific-fields-in-snake-case
Fix/support provider specific fields for bedrock in snake case
2 parents cb94415 + 50744bc commit a719d7b

File tree

2 files changed

+49
-5
lines changed

2 files changed

+49
-5
lines changed

src/providers/bedrock/chatComplete.ts

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,18 @@ import {
3232

3333
export interface BedrockChatCompletionsParams extends Params {
3434
additionalModelRequestFields?: Record<string, any>;
35+
additional_model_request_fields?: Record<string, any>;
3536
additionalModelResponseFieldPaths?: string[];
3637
guardrailConfig?: {
3738
guardrailIdentifier: string;
3839
guardrailVersion: string;
3940
trace?: string;
4041
};
42+
guardrail_config?: {
43+
guardrailIdentifier: string;
44+
guardrailVersion: string;
45+
trace?: string;
46+
};
4147
anthropic_version?: string;
4248
countPenalty?: number;
4349
}
@@ -312,10 +318,18 @@ export const BedrockConverseChatCompleteConfig: ProviderConfig = {
312318
param: 'guardrailConfig',
313319
required: false,
314320
},
321+
guardrail_config: {
322+
param: 'guardrailConfig',
323+
required: false,
324+
},
315325
additionalModelResponseFieldPaths: {
316326
param: 'additionalModelResponseFieldPaths',
317327
required: false,
318328
},
329+
additional_model_response_field_paths: {
330+
param: 'additionalModelResponseFieldPaths',
331+
required: false,
332+
},
319333
max_tokens: {
320334
param: 'inferenceConfig',
321335
transform: (params: BedrockChatCompletionsParams) =>
@@ -346,6 +360,11 @@ export const BedrockConverseChatCompleteConfig: ProviderConfig = {
346360
transform: (params: BedrockChatCompletionsParams) =>
347361
transformAdditionalModelRequestFields(params),
348362
},
363+
additional_model_request_fields: {
364+
param: 'additionalModelRequestFields',
365+
transform: (params: BedrockChatCompletionsParams) =>
366+
transformAdditionalModelRequestFields(params),
367+
},
349368
top_k: {
350369
param: 'additionalModelRequestFields',
351370
transform: (params: BedrockChatCompletionsParams) =>
@@ -701,6 +720,11 @@ export const BedrockConverseAnthropicChatCompleteConfig: ProviderConfig = {
701720
transform: (params: BedrockConverseAnthropicChatCompletionsParams) =>
702721
transformAnthropicAdditionalModelRequestFields(params),
703722
},
723+
additional_model_request_fields: {
724+
param: 'additionalModelRequestFields',
725+
transform: (params: BedrockConverseAnthropicChatCompletionsParams) =>
726+
transformAnthropicAdditionalModelRequestFields(params),
727+
},
704728
top_k: {
705729
param: 'additionalModelRequestFields',
706730
transform: (params: BedrockConverseAnthropicChatCompletionsParams) =>
@@ -722,7 +746,9 @@ export const BedrockConverseAnthropicChatCompleteConfig: ProviderConfig = {
722746
transformAnthropicAdditionalModelRequestFields(params),
723747
},
724748
anthropic_beta: {
725-
param: 'anthropic_beta',
749+
param: 'additionalModelRequestFields',
750+
transform: (params: BedrockConverseAnthropicChatCompletionsParams) =>
751+
transformAnthropicAdditionalModelRequestFields(params),
726752
},
727753
};
728754

@@ -733,6 +759,11 @@ export const BedrockConverseCohereChatCompleteConfig: ProviderConfig = {
733759
transform: (params: BedrockConverseCohereChatCompletionsParams) =>
734760
transformCohereAdditionalModelRequestFields(params),
735761
},
762+
additional_model_request_fields: {
763+
param: 'additionalModelRequestFields',
764+
transform: (params: BedrockConverseCohereChatCompletionsParams) =>
765+
transformCohereAdditionalModelRequestFields(params),
766+
},
736767
top_k: {
737768
param: 'additionalModelRequestFields',
738769
transform: (params: BedrockConverseCohereChatCompletionsParams) =>
@@ -767,6 +798,11 @@ export const BedrockConverseAI21ChatCompleteConfig: ProviderConfig = {
767798
transform: (params: BedrockConverseAI21ChatCompletionsParams) =>
768799
transformAI21AdditionalModelRequestFields(params),
769800
},
801+
additional_model_request_fields: {
802+
param: 'additionalModelRequestFields',
803+
transform: (params: BedrockConverseAI21ChatCompletionsParams) =>
804+
transformAI21AdditionalModelRequestFields(params),
805+
},
770806
top_k: {
771807
param: 'additionalModelRequestFields',
772808
transform: (params: BedrockConverseAI21ChatCompletionsParams) =>

src/providers/bedrock/utils.ts

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,9 @@ export const transformAdditionalModelRequestFields = (
9393
params: BedrockChatCompletionsParams
9494
) => {
9595
const additionalModelRequestFields: Record<string, any> =
96-
params.additionalModelRequestFields || {};
96+
params.additionalModelRequestFields ||
97+
params.additional_model_request_fields ||
98+
{};
9799
if (params['top_k']) {
98100
additionalModelRequestFields['top_k'] = params['top_k'];
99101
}
@@ -107,7 +109,9 @@ export const transformAnthropicAdditionalModelRequestFields = (
107109
params: BedrockConverseAnthropicChatCompletionsParams
108110
) => {
109111
const additionalModelRequestFields: Record<string, any> =
110-
params.additionalModelRequestFields || {};
112+
params.additionalModelRequestFields ||
113+
params.additional_model_request_fields ||
114+
{};
111115
if (params['top_k']) {
112116
additionalModelRequestFields['top_k'] = params['top_k'];
113117
}
@@ -130,7 +134,9 @@ export const transformCohereAdditionalModelRequestFields = (
130134
params: BedrockConverseCohereChatCompletionsParams
131135
) => {
132136
const additionalModelRequestFields: Record<string, any> =
133-
params.additionalModelRequestFields || {};
137+
params.additionalModelRequestFields ||
138+
params.additional_model_request_fields ||
139+
{};
134140
if (params['top_k']) {
135141
additionalModelRequestFields['top_k'] = params['top_k'];
136142
}
@@ -155,7 +161,9 @@ export const transformAI21AdditionalModelRequestFields = (
155161
params: BedrockConverseAI21ChatCompletionsParams
156162
) => {
157163
const additionalModelRequestFields: Record<string, any> =
158-
params.additionalModelRequestFields || {};
164+
params.additionalModelRequestFields ||
165+
params.additional_model_request_fields ||
166+
{};
159167
if (params['top_k']) {
160168
additionalModelRequestFields['top_k'] = params['top_k'];
161169
}

0 commit comments

Comments
 (0)