Skip to content

Commit 2765eee

Browse files
authored
Add support for prompt caching (#1051)
1 parent 1ac816a commit 2765eee

File tree

4 files changed

+179
-47
lines changed

4 files changed

+179
-47
lines changed

packages/cdk/lambda/utils/models.ts

Lines changed: 86 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ import {
2626
ContentBlock,
2727
} from '@aws-sdk/client-bedrock-runtime';
2828
import { modelFeatureFlags } from '@generative-ai-use-cases/common';
29+
import {
30+
applyAutoCacheToMessages,
31+
applyAutoCacheToSystem,
32+
} from './promptCache';
2933

3034
// Default Models
3135

@@ -121,72 +125,104 @@ const RINNA_PROMPT: PromptTemplate = {
121125
// Model Params
122126

123127
const CLAUDE_3_5_DEFAULT_PARAMS: ConverseInferenceParams = {
124-
maxTokens: 8192,
125-
temperature: 0.6,
126-
topP: 0.8,
128+
inferenceConfig: {
129+
maxTokens: 8192,
130+
temperature: 0.6,
131+
topP: 0.8,
132+
},
127133
};
128134

129135
const CLAUDE_DEFAULT_PARAMS: ConverseInferenceParams = {
130-
maxTokens: 4096,
131-
temperature: 0.6,
132-
topP: 0.8,
136+
inferenceConfig: {
137+
maxTokens: 4096,
138+
temperature: 0.6,
139+
topP: 0.8,
140+
},
133141
};
134142

135143
const TITAN_TEXT_DEFAULT_PARAMS: ConverseInferenceParams = {
136144
// Converse API only accepts 3000, instead of 3072, which is described in the doc.
137145
// If 3072 is accepted, revert to 3072.
138146
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html
139-
maxTokens: 3000,
140-
temperature: 0.7,
141-
topP: 1.0,
147+
inferenceConfig: {
148+
maxTokens: 3000,
149+
temperature: 0.7,
150+
topP: 1.0,
151+
},
142152
};
143153

144154
const LLAMA_DEFAULT_PARAMS: ConverseInferenceParams = {
145-
maxTokens: 2048,
146-
temperature: 0.5,
147-
topP: 0.9,
148-
stopSequences: ['<|eot_id|>'],
155+
inferenceConfig: {
156+
maxTokens: 2048,
157+
temperature: 0.5,
158+
topP: 0.9,
159+
stopSequences: ['<|eot_id|>'],
160+
},
149161
};
150162

151163
const MISTRAL_DEFAULT_PARAMS: ConverseInferenceParams = {
152-
maxTokens: 8192,
153-
temperature: 0.6,
154-
topP: 0.99,
164+
inferenceConfig: {
165+
maxTokens: 8192,
166+
temperature: 0.6,
167+
topP: 0.99,
168+
},
155169
};
156170

157171
const MIXTRAL_DEFAULT_PARAMS: ConverseInferenceParams = {
158-
maxTokens: 4096,
159-
temperature: 0.6,
160-
topP: 0.99,
172+
inferenceConfig: {
173+
maxTokens: 4096,
174+
temperature: 0.6,
175+
topP: 0.99,
176+
},
161177
};
162178

163179
const COMMANDR_DEFAULT_PARAMS: ConverseInferenceParams = {
164-
maxTokens: 4000,
165-
temperature: 0.3,
166-
topP: 0.75,
180+
inferenceConfig: {
181+
maxTokens: 4000,
182+
temperature: 0.3,
183+
topP: 0.75,
184+
},
167185
};
168186

169187
const NOVA_DEFAULT_PARAMS: ConverseInferenceParams = {
170-
maxTokens: 5120,
171-
temperature: 0.7,
172-
topP: 0.9,
188+
inferenceConfig: {
189+
maxTokens: 5120,
190+
temperature: 0.7,
191+
topP: 0.9,
192+
},
173193
};
174194

175195
const DEEPSEEK_DEFAULT_PARAMS: ConverseInferenceParams = {
176-
maxTokens: 32768,
177-
temperature: 0.6,
178-
topP: 0.95,
196+
inferenceConfig: {
197+
maxTokens: 32768,
198+
temperature: 0.6,
199+
topP: 0.95,
200+
},
179201
};
180202

181203
const PALMYRA_DEFAULT_PARAMS: ConverseInferenceParams = {
182-
maxTokens: 8192,
183-
temperature: 1,
184-
topP: 0.9,
204+
inferenceConfig: {
205+
maxTokens: 8192,
206+
temperature: 1,
207+
topP: 0.9,
208+
},
185209
};
186210

187211
const USECASE_DEFAULT_PARAMS: UsecaseConverseInferenceParams = {
212+
'/chat': {
213+
promptCachingConfig: {
214+
autoCacheFields: ['system', 'messages'],
215+
},
216+
},
188217
'/rag': {
189-
temperature: 0.0,
218+
inferenceConfig: {
219+
temperature: 0.0,
220+
},
221+
},
222+
'/diagram': {
223+
promptCachingConfig: {
224+
autoCacheFields: ['system'],
225+
},
190226
},
191227
};
192228

@@ -313,32 +349,40 @@ const createConverseCommandInput = (
313349
};
314350
});
315351

316-
const usecaseParams = usecaseConverseInferenceParams[normalizeId(id)];
317-
const inferenceConfig = usecaseParams
318-
? { ...defaultConverseInferenceParams, ...usecaseParams }
319-
: defaultConverseInferenceParams;
352+
// Merge model's default params with use-case specific ones
353+
const usecaseParams = usecaseConverseInferenceParams[normalizeId(id)] || {};
354+
const params = { ...defaultConverseInferenceParams, ...usecaseParams };
355+
356+
// Apply prompt caching
357+
const autoCacheFields = params.promptCachingConfig?.autoCacheFields || [];
358+
const conversationWithCache = autoCacheFields.includes('messages')
359+
? applyAutoCacheToMessages(conversation, model.modelId)
360+
: conversation;
361+
const systemContextWithCache = autoCacheFields.includes('system')
362+
? applyAutoCacheToSystem(systemContext, model.modelId)
363+
: systemContext;
320364

321365
const guardrailConfig = createGuardrailConfig();
322366

323367
const converseCommandInput: ConverseCommandInput = {
324368
modelId: model.modelId,
325-
messages: conversation,
326-
system: systemContext,
327-
inferenceConfig: inferenceConfig,
328-
guardrailConfig: guardrailConfig,
369+
messages: conversationWithCache,
370+
system: systemContextWithCache,
371+
inferenceConfig: params.inferenceConfig,
372+
guardrailConfig,
329373
};
330374

331375
if (
332376
modelFeatureFlags[model.modelId].reasoning &&
333377
model.modelParameters?.reasoningConfig?.type === 'enabled'
334378
) {
335379
converseCommandInput.inferenceConfig = {
336-
...inferenceConfig,
380+
...(params.inferenceConfig || {}),
337381
temperature: 1, // reasoning requires temperature to be 1
338382
topP: undefined, // reasoning does not require topP
339383
maxTokens:
340384
(model.modelParameters?.reasoningConfig?.budgetTokens || 0) +
341-
(inferenceConfig?.maxTokens || 0),
385+
(params.inferenceConfig?.maxTokens || 0),
342386
};
343387
converseCommandInput.additionalModelRequestFields = {
344388
reasoning_config: {
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import {
2+
ContentBlock,
3+
Message,
4+
SystemContentBlock,
5+
} from '@aws-sdk/client-bedrock-runtime';
6+
import { SUPPORTED_CACHE_FIELDS } from '@generative-ai-use-cases/common';
7+
8+
const CACHE_POINT = {
9+
cachePoint: { type: 'default' },
10+
} as ContentBlock.CachePointMember;
11+
12+
const SYSTEM_CACHE_POINT = {
13+
cachePoint: { type: 'default' },
14+
} as SystemContentBlock.CachePointMember;
15+
16+
const getSupportedCacheFields = (modelId: string) => {
17+
// Remove CRI prefix
18+
const baseModelId = modelId.replace(/^(us|eu|apac)\./, '');
19+
return SUPPORTED_CACHE_FIELDS[baseModelId] || [];
20+
};
21+
22+
export const applyAutoCacheToMessages = (
23+
messages: Message[],
24+
modelId: string
25+
) => {
26+
const cacheFields = getSupportedCacheFields(modelId);
27+
if (!cacheFields.includes('messages') || messages.length === 0) {
28+
return messages;
29+
}
30+
31+
// Insert cachePoint into the last two user messages (for cache read and write respectively)
32+
const isToolsSupported = cacheFields.includes('tools');
33+
const cachableIndices = messages
34+
.map((message, index) => ({ message, index }))
35+
.filter(({ message }) => message.role === 'user')
36+
.filter(
37+
({ message }) =>
38+
isToolsSupported ||
39+
// For Amazon Nova, placing cachePoint after toolResult is not supported
40+
!message.content?.some((content) => content.toolResult)
41+
)
42+
.slice(-2)
43+
.map(({ index }) => index);
44+
45+
return messages.map((message, index) => {
46+
if (
47+
!cachableIndices.includes(index) ||
48+
message.content?.at(-1)?.cachePoint // Already inserted
49+
) {
50+
return message;
51+
}
52+
return {
53+
...message,
54+
content: [...(message.content || []), CACHE_POINT],
55+
};
56+
});
57+
};
58+
59+
export const applyAutoCacheToSystem = (
60+
system: SystemContentBlock[],
61+
modelId: string
62+
) => {
63+
const cacheFields = getSupportedCacheFields(modelId);
64+
if (
65+
!cacheFields.includes('system') ||
66+
system.length === 0 ||
67+
system.at(-1)?.cachePoint // Already inserted
68+
) {
69+
return system;
70+
}
71+
return [...system, SYSTEM_CACHE_POINT];
72+
};

packages/common/src/application/model.ts

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { FeatureFlags } from 'generative-ai-use-cases';
1+
import { FeatureFlags, PromptCacheField } from 'generative-ai-use-cases';
22

33
// Manage Model Feature
44
// https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference-supported-models-features.html
@@ -214,3 +214,13 @@ export const BEDROCK_RERANKING_MODELS = Object.keys(modelFeatureFlags).filter(
214214
export const BEDROCK_SPEECH_TO_SPEECH_MODELS = Object.keys(
215215
modelFeatureFlags
216216
).filter((model) => modelFeatureFlags[model].speechToSpeech);
217+
218+
// Prompt caching
219+
// https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html
220+
export const SUPPORTED_CACHE_FIELDS: Record<string, PromptCacheField[]> = {
221+
'anthropic.claude-3-7-sonnet-20250219-v1:0': ['messages', 'system', 'tools'],
222+
'anthropic.claude-3-5-haiku-20241022-v1:0': ['messages', 'system', 'tools'],
223+
'amazon.nova-pro-v1:0': ['messages', 'system'],
224+
'amazon.nova-lite-v1:0': ['messages', 'system'],
225+
'amazon.nova-micro-v1:0': ['messages', 'system'],
226+
};

packages/types/src/text.d.ts

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
// ConverseAPI
2+
3+
import { InferenceConfiguration } from '@aws-sdk/client-bedrock-runtime';
4+
5+
export type PromptCacheField = 'messages' | 'system' | 'tools';
6+
export type PromptCachingConfig = {
7+
autoCacheFields: PromptCacheField[];
8+
};
9+
210
// https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html#API_runtime_Converse_RequestSyntax
311
export type ConverseInferenceParams = {
4-
maxTokens?: number;
5-
stopSequences?: string[];
6-
temperature?: number;
7-
topP?: number;
12+
inferenceConfig?: InferenceConfiguration;
13+
promptCachingConfig?: PromptCachingConfig;
814
};
915

1016
export type UsecaseConverseInferenceParams = {

0 commit comments

Comments
 (0)