Skip to content

Commit 51430c9

Browse files
authored
Merge pull request #1059 from b4s36t4/feat/bedrock-cache
feat: support bedrock prompt caching
2 parents 6021da6 + e065226 commit 51430c9

File tree

3 files changed

+92
-26
lines changed

3 files changed

+92
-26
lines changed

src/providers/anthropic/chatComplete.ts

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ import {
33
Params,
44
Message,
55
ContentType,
6-
AnthropicPromptCache,
76
SYSTEM_MESSAGE_ROLES,
7+
PromptCache,
88
} from '../../types/requestBody';
99
import {
1010
ChatCompletionResponse,
@@ -19,7 +19,7 @@ import { AnthropicStreamState } from './types';
1919

2020
// TODO: this configuration does not enforce the maximum token limit for the input parameter. If you want to enforce this, you might need to add a custom validation function or a max property to the ParameterConfig interface, and then use it in the input configuration. However, this might be complex because the token count is not a simple length check, but depends on the specific tokenization method used by the model.
2121

22-
interface AnthropicTool extends AnthropicPromptCache {
22+
interface AnthropicTool extends PromptCache {
2323
name: string;
2424
description: string;
2525
input_schema: {
@@ -69,7 +69,7 @@ type AnthropicMessageContentItem =
6969
| AnthropicUrlImageContentItem
7070
| AnthropicTextContentItem;
7171

72-
interface AnthropicMessage extends Message, AnthropicPromptCache {
72+
interface AnthropicMessage extends Message, PromptCache {
7373
content: AnthropicMessageContentItem[];
7474
}
7575

@@ -180,7 +180,7 @@ export const AnthropicChatCompleteConfig: ProviderConfig = {
180180
let messages: AnthropicMessage[] = [];
181181
// Transform the chat messages into a simple prompt
182182
if (!!params.messages) {
183-
params.messages.forEach((msg: Message & AnthropicPromptCache) => {
183+
params.messages.forEach((msg: Message & PromptCache) => {
184184
if (SYSTEM_MESSAGE_ROLES.includes(msg.role)) return;
185185

186186
if (msg.role === 'assistant') {
@@ -230,7 +230,7 @@ export const AnthropicChatCompleteConfig: ProviderConfig = {
230230
let systemMessages: AnthropicMessageContentItem[] = [];
231231
// Transform the chat messages into a simple prompt
232232
if (!!params.messages) {
233-
params.messages.forEach((msg: Message & AnthropicPromptCache) => {
233+
params.messages.forEach((msg: Message & PromptCache) => {
234234
if (
235235
SYSTEM_MESSAGE_ROLES.includes(msg.role) &&
236236
msg.content &&

src/providers/bedrock/chatComplete.ts

Lines changed: 84 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -69,15 +69,30 @@ export interface BedrockConverseAI21ChatCompletionsParams
6969
countPenalty?: number;
7070
}
7171

72-
const getMessageTextContentArray = (message: Message): { text: string }[] => {
72+
const getMessageTextContentArray = (
73+
message: Message
74+
): Array<{ text: string } | { cachePoint: { type: string } }> => {
7375
if (message.content && typeof message.content === 'object') {
74-
return message.content
75-
.filter((item) => item.type === 'text')
76-
.map((item) => {
77-
return {
78-
text: item.text || '',
79-
};
76+
const filteredContentMessages = message.content.filter(
77+
(item) => item.type === 'text'
78+
);
79+
const finalContent: Array<
80+
{ text: string } | { cachePoint: { type: string } }
81+
> = [];
82+
filteredContentMessages.forEach((item) => {
83+
finalContent.push({
84+
text: item.text || '',
8085
});
86+
// push a cache point.
87+
if (item.cache_control) {
88+
finalContent.push({
89+
cachePoint: {
90+
type: 'default',
91+
},
92+
});
93+
}
94+
});
95+
return finalContent;
8196
}
8297
return [
8398
{
@@ -162,6 +177,15 @@ const getMessageContent = (message: Message) => {
162177
});
163178
}
164179
}
180+
181+
if (item.cache_control) {
182+
// if content item has `cache_control`, push the cache point to the out array
183+
out.push({
184+
cachePoint: {
185+
type: 'default',
186+
},
187+
});
188+
}
165189
});
166190
}
167191

@@ -219,7 +243,10 @@ export const BedrockConverseChatCompleteConfig: ProviderConfig = {
219243
transform: (params: BedrockChatCompletionsParams) => {
220244
if (!params.messages) return;
221245
const systemMessages = params.messages.reduce(
222-
(acc: { text: string }[], msg) => {
246+
(
247+
acc: Array<{ text: string } | { cachePoint: { type: string } }>,
248+
msg
249+
) => {
223250
if (SYSTEM_MESSAGE_ROLES.includes(msg.role))
224251
return acc.concat(...getMessageTextContentArray(msg));
225252
return acc;
@@ -234,17 +261,29 @@ export const BedrockConverseChatCompleteConfig: ProviderConfig = {
234261
tools: {
235262
param: 'toolConfig',
236263
transform: (params: BedrockChatCompletionsParams) => {
237-
const toolConfig = {
238-
tools: params.tools?.map((tool) => {
239-
if (!tool.function) return;
240-
return {
241-
toolSpec: {
242-
name: tool.function.name,
243-
description: tool.function.description,
244-
inputSchema: { json: tool.function.parameters },
264+
const canBeAmazonModel = params.model?.includes('amazon');
265+
const tools: Array<
266+
| { toolSpec: { name: string; description?: string; inputSchema: any } }
267+
| { cachePoint: { type: string } }
268+
> = [];
269+
params.tools?.forEach((tool) => {
270+
tools.push({
271+
toolSpec: {
272+
name: tool.function.name,
273+
description: tool.function.description,
274+
inputSchema: { json: tool.function.parameters },
275+
},
276+
});
277+
if (tool.cache_control && !canBeAmazonModel) {
278+
tools.push({
279+
cachePoint: {
280+
type: 'default',
245281
},
246-
};
247-
}),
282+
});
283+
}
284+
});
285+
const toolConfig = {
286+
tools: tools,
248287
};
249288
let toolChoice = undefined;
250289
if (params.tool_choice) {
@@ -341,6 +380,9 @@ type BedrockContentItem = {
341380
bytes: string;
342381
};
343382
};
383+
cachePoint?: {
384+
type: string;
385+
};
344386
};
345387

346388
interface BedrockChatCompletionResponse {
@@ -358,6 +400,10 @@ interface BedrockChatCompletionResponse {
358400
inputTokens: number;
359401
outputTokens: number;
360402
totalTokens: number;
403+
cacheReadInputTokenCount?: number;
404+
cacheReadInputTokens?: number;
405+
cacheWriteInputTokenCount?: number;
406+
cacheWriteInputTokens?: number;
361407
};
362408
}
363409

@@ -421,6 +467,10 @@ export const BedrockChatCompleteResponseTransform: (
421467
}
422468

423469
if ('output' in response) {
470+
const shouldSendCacheUsage =
471+
response.usage.cacheWriteInputTokens ||
472+
response.usage.cacheReadInputTokens;
473+
424474
let content: string = '';
425475
content = response.output.message.content
426476
.filter((item) => item.text)
@@ -453,6 +503,10 @@ export const BedrockChatCompleteResponseTransform: (
453503
prompt_tokens: response.usage.inputTokens,
454504
completion_tokens: response.usage.outputTokens,
455505
total_tokens: response.usage.totalTokens,
506+
...(shouldSendCacheUsage && {
507+
cache_read_input_tokens: response.usage.cacheReadInputTokens,
508+
cache_creation_input_tokens: response.usage.cacheWriteInputTokens,
509+
}),
456510
},
457511
};
458512
const toolCalls = response.output.message.content
@@ -503,6 +557,10 @@ export interface BedrockChatCompleteStreamChunk {
503557
inputTokens: number;
504558
outputTokens: number;
505559
totalTokens: number;
560+
cacheReadInputTokenCount?: number;
561+
cacheReadInputTokens?: number;
562+
cacheWriteInputTokenCount?: number;
563+
cacheWriteInputTokens?: number;
506564
};
507565
}
508566

@@ -534,6 +592,9 @@ export const BedrockChatCompleteStreamChunkTransform: (
534592
}
535593

536594
if (parsedChunk.usage) {
595+
const shouldSendCacheUsage =
596+
parsedChunk.usage.cacheWriteInputTokens ||
597+
parsedChunk.usage.cacheReadInputTokens;
537598
return [
538599
`data: ${JSON.stringify({
539600
id: fallbackId,
@@ -552,6 +613,11 @@ export const BedrockChatCompleteStreamChunkTransform: (
552613
prompt_tokens: parsedChunk.usage.inputTokens,
553614
completion_tokens: parsedChunk.usage.outputTokens,
554615
total_tokens: parsedChunk.usage.totalTokens,
616+
...(shouldSendCacheUsage && {
617+
cache_read_input_tokens: parsedChunk.usage.cacheReadInputTokens,
618+
cache_creation_input_tokens:
619+
parsedChunk.usage.cacheWriteInputTokens,
620+
}),
555621
},
556622
})}\n\n`,
557623
`data: [DONE]\n\n`,

src/types/requestBody.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ export interface Config {
219219
* A message content type.
220220
* @interface
221221
*/
222-
export interface ContentType {
222+
export interface ContentType extends PromptCache {
223223
type: string;
224224
text?: string;
225225
thinking?: string;
@@ -285,7 +285,7 @@ export interface Message {
285285
citationMetadata?: CitationMetadata;
286286
}
287287

288-
export interface AnthropicPromptCache {
288+
export interface PromptCache {
289289
cache_control?: { type: 'ephemeral' };
290290
}
291291

@@ -340,7 +340,7 @@ export type ToolChoice = ToolChoiceObject | 'none' | 'auto' | 'required';
340340
*
341341
* @interface
342342
*/
343-
export interface Tool extends AnthropicPromptCache {
343+
export interface Tool extends PromptCache {
344344
/** The name of the function. */
345345
type: string;
346346
/** A description of the function. */

0 commit comments

Comments
 (0)