26
26
import java .util .Base64 ;
27
27
import java .util .List ;
28
28
import java .util .Map ;
29
+ import java .util .Set ;
29
30
30
31
import io .micrometer .observation .Observation ;
31
32
import io .micrometer .observation .ObservationRegistry ;
57
58
import software .amazon .awssdk .services .bedrockruntime .model .InferenceConfiguration ;
58
59
import software .amazon .awssdk .services .bedrockruntime .model .Message ;
59
60
import software .amazon .awssdk .services .bedrockruntime .model .S3Location ;
61
+ import software .amazon .awssdk .services .bedrockruntime .model .StopReason ;
60
62
import software .amazon .awssdk .services .bedrockruntime .model .SystemContentBlock ;
61
63
import software .amazon .awssdk .services .bedrockruntime .model .Tool ;
62
64
import software .amazon .awssdk .services .bedrockruntime .model .ToolConfiguration ;
@@ -262,7 +264,8 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse perviousChatRespon
262
264
});
263
265
264
266
if (ToolCallingChatOptions .isInternalToolExecutionEnabled (prompt .getOptions ()) && chatResponse != null
265
- && chatResponse .hasToolCalls ()) {
267
+ && chatResponse .hasToolCalls ()
268
+ && chatResponse .hasFinishReasons (Set .of (StopReason .TOOL_USE .toString ()))) {
266
269
var toolExecutionResult = this .toolCallingManager .executeToolCalls (prompt , chatResponse );
267
270
if (toolExecutionResult .returnDirect ()) {
268
271
// Return tool execution result directly to the client.
@@ -280,22 +283,6 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse perviousChatRespon
280
283
return chatResponse ;
281
284
}
282
285
283
- // private ToolCallingChatOptions buildRequestOptions(ConverseRequest request) {
284
-
285
- // ToolCallingChatOptions toolCallbackChatOptions = ToolCallingChatOptions.builder()
286
- // .model(request.modelId())
287
- // .maxTokens(request.inferenceConfig().maxTokens())
288
- // .stopSequences(request.inferenceConfig().stopSequences())
289
- // .temperature(request.inferenceConfig().temperature() != null
290
- // ? request.inferenceConfig().temperature().doubleValue()
291
- // : null)
292
- // .topP(request.inferenceConfig().topP() != null ?
293
- // request.inferenceConfig().topP().doubleValue() : null)
294
- // .build();
295
-
296
- // return toolCallbackChatOptions;
297
- // }
298
-
299
286
@ Override
300
287
public ChatOptions getDefaultOptions () {
301
288
return this .defaultOptions ;
@@ -708,28 +695,34 @@ private Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse perviousCh
708
695
709
696
Flux <ConverseStreamOutput > response = converseStream (converseStreamRequest );
710
697
711
- // @formatter:off
712
698
Flux <ChatResponse > chatResponses = ConverseApiUtils .toChatResponse (response , perviousChatResponse );
713
699
714
700
Flux <ChatResponse > chatResponseFlux = chatResponses .switchMap (chatResponse -> {
715
701
716
- if (ToolCallingChatOptions .isInternalToolExecutionEnabled (prompt .getOptions ()) && chatResponse .hasToolCalls ()) {
702
+ if (ToolCallingChatOptions .isInternalToolExecutionEnabled (prompt .getOptions ())
703
+ && chatResponse .hasToolCalls ()
704
+ && chatResponse .hasFinishReasons (Set .of (StopReason .TOOL_USE .toString ()))) {
705
+
717
706
var toolExecutionResult = this .toolCallingManager .executeToolCalls (prompt , chatResponse );
707
+
718
708
if (toolExecutionResult .returnDirect ()) {
719
709
// Return tool execution result directly to the client.
720
- return Flux .just (ChatResponse .builder ().from (chatResponse )
721
- .generations (ToolExecutionResult .buildGenerations (toolExecutionResult ))
722
- .build ());
723
- } else {
710
+ return Flux .just (ChatResponse .builder ()
711
+ .from (chatResponse )
712
+ .generations (ToolExecutionResult .buildGenerations (toolExecutionResult ))
713
+ .build ());
714
+ }
715
+ else {
724
716
// Send the tool execution result back to the model.
725
- return this .internalStream (new Prompt (toolExecutionResult .conversationHistory (), prompt .getOptions ()),
726
- chatResponse );
717
+ return this .internalStream (
718
+ new Prompt (toolExecutionResult .conversationHistory (), prompt .getOptions ()),
719
+ chatResponse );
727
720
}
728
721
}
729
722
else {
730
723
return Flux .just (chatResponse );
731
724
}
732
- })
725
+ })// @formatter:off
733
726
.doOnError (observation ::error )
734
727
.doFinally (s -> observation .stop ())
735
728
.contextWrite (ctx -> ctx .put (ObservationThreadLocalAccessor .KEY , observation ));
0 commit comments