@@ -26,6 +26,10 @@ import {
26
26
ContentBlock ,
27
27
} from '@aws-sdk/client-bedrock-runtime' ;
28
28
import { modelFeatureFlags } from '@generative-ai-use-cases/common' ;
29
+ import {
30
+ applyAutoCacheToMessages ,
31
+ applyAutoCacheToSystem ,
32
+ } from './promptCache' ;
29
33
30
34
// Default Models
31
35
@@ -121,72 +125,104 @@ const RINNA_PROMPT: PromptTemplate = {
121
125
// Model Params
122
126
123
127
const CLAUDE_3_5_DEFAULT_PARAMS : ConverseInferenceParams = {
124
- maxTokens : 8192 ,
125
- temperature : 0.6 ,
126
- topP : 0.8 ,
128
+ inferenceConfig : {
129
+ maxTokens : 8192 ,
130
+ temperature : 0.6 ,
131
+ topP : 0.8 ,
132
+ } ,
127
133
} ;
128
134
129
135
const CLAUDE_DEFAULT_PARAMS : ConverseInferenceParams = {
130
- maxTokens : 4096 ,
131
- temperature : 0.6 ,
132
- topP : 0.8 ,
136
+ inferenceConfig : {
137
+ maxTokens : 4096 ,
138
+ temperature : 0.6 ,
139
+ topP : 0.8 ,
140
+ } ,
133
141
} ;
134
142
135
143
const TITAN_TEXT_DEFAULT_PARAMS : ConverseInferenceParams = {
136
144
// Converse API only accepts 3000, instead of 3072, which is described in the doc.
137
145
// If 3072 is accepted, revert to 3072.
138
146
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-text.html
139
- maxTokens : 3000 ,
140
- temperature : 0.7 ,
141
- topP : 1.0 ,
147
+ inferenceConfig : {
148
+ maxTokens : 3000 ,
149
+ temperature : 0.7 ,
150
+ topP : 1.0 ,
151
+ } ,
142
152
} ;
143
153
144
154
const LLAMA_DEFAULT_PARAMS : ConverseInferenceParams = {
145
- maxTokens : 2048 ,
146
- temperature : 0.5 ,
147
- topP : 0.9 ,
148
- stopSequences : [ '<|eot_id|>' ] ,
155
+ inferenceConfig : {
156
+ maxTokens : 2048 ,
157
+ temperature : 0.5 ,
158
+ topP : 0.9 ,
159
+ stopSequences : [ '<|eot_id|>' ] ,
160
+ } ,
149
161
} ;
150
162
151
163
const MISTRAL_DEFAULT_PARAMS : ConverseInferenceParams = {
152
- maxTokens : 8192 ,
153
- temperature : 0.6 ,
154
- topP : 0.99 ,
164
+ inferenceConfig : {
165
+ maxTokens : 8192 ,
166
+ temperature : 0.6 ,
167
+ topP : 0.99 ,
168
+ } ,
155
169
} ;
156
170
157
171
const MIXTRAL_DEFAULT_PARAMS : ConverseInferenceParams = {
158
- maxTokens : 4096 ,
159
- temperature : 0.6 ,
160
- topP : 0.99 ,
172
+ inferenceConfig : {
173
+ maxTokens : 4096 ,
174
+ temperature : 0.6 ,
175
+ topP : 0.99 ,
176
+ } ,
161
177
} ;
162
178
163
179
const COMMANDR_DEFAULT_PARAMS : ConverseInferenceParams = {
164
- maxTokens : 4000 ,
165
- temperature : 0.3 ,
166
- topP : 0.75 ,
180
+ inferenceConfig : {
181
+ maxTokens : 4000 ,
182
+ temperature : 0.3 ,
183
+ topP : 0.75 ,
184
+ } ,
167
185
} ;
168
186
169
187
const NOVA_DEFAULT_PARAMS : ConverseInferenceParams = {
170
- maxTokens : 5120 ,
171
- temperature : 0.7 ,
172
- topP : 0.9 ,
188
+ inferenceConfig : {
189
+ maxTokens : 5120 ,
190
+ temperature : 0.7 ,
191
+ topP : 0.9 ,
192
+ } ,
173
193
} ;
174
194
175
195
const DEEPSEEK_DEFAULT_PARAMS : ConverseInferenceParams = {
176
- maxTokens : 32768 ,
177
- temperature : 0.6 ,
178
- topP : 0.95 ,
196
+ inferenceConfig : {
197
+ maxTokens : 32768 ,
198
+ temperature : 0.6 ,
199
+ topP : 0.95 ,
200
+ } ,
179
201
} ;
180
202
181
203
const PALMYRA_DEFAULT_PARAMS : ConverseInferenceParams = {
182
- maxTokens : 8192 ,
183
- temperature : 1 ,
184
- topP : 0.9 ,
204
+ inferenceConfig : {
205
+ maxTokens : 8192 ,
206
+ temperature : 1 ,
207
+ topP : 0.9 ,
208
+ } ,
185
209
} ;
186
210
187
211
const USECASE_DEFAULT_PARAMS : UsecaseConverseInferenceParams = {
212
+ '/chat' : {
213
+ promptCachingConfig : {
214
+ autoCacheFields : [ 'system' , 'messages' ] ,
215
+ } ,
216
+ } ,
188
217
'/rag' : {
189
- temperature : 0.0 ,
218
+ inferenceConfig : {
219
+ temperature : 0.0 ,
220
+ } ,
221
+ } ,
222
+ '/diagram' : {
223
+ promptCachingConfig : {
224
+ autoCacheFields : [ 'system' ] ,
225
+ } ,
190
226
} ,
191
227
} ;
192
228
@@ -313,32 +349,40 @@ const createConverseCommandInput = (
313
349
} ;
314
350
} ) ;
315
351
316
- const usecaseParams = usecaseConverseInferenceParams [ normalizeId ( id ) ] ;
317
- const inferenceConfig = usecaseParams
318
- ? { ...defaultConverseInferenceParams , ...usecaseParams }
319
- : defaultConverseInferenceParams ;
352
+ // Merge model's default params with use-case specific ones
353
+ const usecaseParams = usecaseConverseInferenceParams [ normalizeId ( id ) ] || { } ;
354
+ const params = { ...defaultConverseInferenceParams , ...usecaseParams } ;
355
+
356
+ // Apply prompt caching
357
+ const autoCacheFields = params . promptCachingConfig ?. autoCacheFields || [ ] ;
358
+ const conversationWithCache = autoCacheFields . includes ( 'messages' )
359
+ ? applyAutoCacheToMessages ( conversation , model . modelId )
360
+ : conversation ;
361
+ const systemContextWithCache = autoCacheFields . includes ( 'system' )
362
+ ? applyAutoCacheToSystem ( systemContext , model . modelId )
363
+ : systemContext ;
320
364
321
365
const guardrailConfig = createGuardrailConfig ( ) ;
322
366
323
367
const converseCommandInput : ConverseCommandInput = {
324
368
modelId : model . modelId ,
325
- messages : conversation ,
326
- system : systemContext ,
327
- inferenceConfig : inferenceConfig ,
328
- guardrailConfig : guardrailConfig ,
369
+ messages : conversationWithCache ,
370
+ system : systemContextWithCache ,
371
+ inferenceConfig : params . inferenceConfig ,
372
+ guardrailConfig,
329
373
} ;
330
374
331
375
if (
332
376
modelFeatureFlags [ model . modelId ] . reasoning &&
333
377
model . modelParameters ?. reasoningConfig ?. type === 'enabled'
334
378
) {
335
379
converseCommandInput . inferenceConfig = {
336
- ...inferenceConfig ,
380
+ ...( params . inferenceConfig || { } ) ,
337
381
temperature : 1 , // reasoning requires temperature to be 1
338
382
topP : undefined , // reasoning does not require topP
339
383
maxTokens :
340
384
( model . modelParameters ?. reasoningConfig ?. budgetTokens || 0 ) +
341
- ( inferenceConfig ?. maxTokens || 0 ) ,
385
+ ( params . inferenceConfig ?. maxTokens || 0 ) ,
342
386
} ;
343
387
converseCommandInput . additionalModelRequestFields = {
344
388
reasoning_config : {
0 commit comments