Skip to content

Commit c19287d

Browse files
committed
feat(bedrock): Add StopReason check for tool execution
Add StopReason.TOOL_USE validation before executing tools to ensure proper tool execution flow. Remove unused commented code and improve code formatting. Signed-off-by: Christian Tzolov <christian.tzolov@broadcom.com>
1 parent b7dcfc1 commit c19287d

File tree

1 file changed

+19
-26
lines changed

1 file changed

+19
-26
lines changed

models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import java.util.Base64;
2727
import java.util.List;
2828
import java.util.Map;
29+
import java.util.Set;
2930

3031
import io.micrometer.observation.Observation;
3132
import io.micrometer.observation.ObservationRegistry;
@@ -57,6 +58,7 @@
5758
import software.amazon.awssdk.services.bedrockruntime.model.InferenceConfiguration;
5859
import software.amazon.awssdk.services.bedrockruntime.model.Message;
5960
import software.amazon.awssdk.services.bedrockruntime.model.S3Location;
61+
import software.amazon.awssdk.services.bedrockruntime.model.StopReason;
6062
import software.amazon.awssdk.services.bedrockruntime.model.SystemContentBlock;
6163
import software.amazon.awssdk.services.bedrockruntime.model.Tool;
6264
import software.amazon.awssdk.services.bedrockruntime.model.ToolConfiguration;
@@ -262,7 +264,8 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse perviousChatRespon
262264
});
263265

264266
if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) && chatResponse != null
265-
&& chatResponse.hasToolCalls()) {
267+
&& chatResponse.hasToolCalls()
268+
&& chatResponse.hasFinishReasons(Set.of(StopReason.TOOL_USE.toString()))) {
266269
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse);
267270
if (toolExecutionResult.returnDirect()) {
268271
// Return tool execution result directly to the client.
@@ -280,22 +283,6 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse perviousChatRespon
280283
return chatResponse;
281284
}
282285

283-
// private ToolCallingChatOptions buildRequestOptions(ConverseRequest request) {
284-
285-
// ToolCallingChatOptions toolCallbackChatOptions = ToolCallingChatOptions.builder()
286-
// .model(request.modelId())
287-
// .maxTokens(request.inferenceConfig().maxTokens())
288-
// .stopSequences(request.inferenceConfig().stopSequences())
289-
// .temperature(request.inferenceConfig().temperature() != null
290-
// ? request.inferenceConfig().temperature().doubleValue()
291-
// : null)
292-
// .topP(request.inferenceConfig().topP() != null ?
293-
// request.inferenceConfig().topP().doubleValue() : null)
294-
// .build();
295-
296-
// return toolCallbackChatOptions;
297-
// }
298-
299286
@Override
300287
public ChatOptions getDefaultOptions() {
301288
return this.defaultOptions;
@@ -708,28 +695,34 @@ private Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse perviousCh
708695

709696
Flux<ConverseStreamOutput> response = converseStream(converseStreamRequest);
710697

711-
// @formatter:off
712698
Flux<ChatResponse> chatResponses = ConverseApiUtils.toChatResponse(response, perviousChatResponse);
713699

714700
Flux<ChatResponse> chatResponseFlux = chatResponses.switchMap(chatResponse -> {
715701

716-
if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions()) && chatResponse.hasToolCalls()) {
702+
if (ToolCallingChatOptions.isInternalToolExecutionEnabled(prompt.getOptions())
703+
&& chatResponse.hasToolCalls()
704+
&& chatResponse.hasFinishReasons(Set.of(StopReason.TOOL_USE.toString()))) {
705+
717706
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse);
707+
718708
if (toolExecutionResult.returnDirect()) {
719709
// Return tool execution result directly to the client.
720-
return Flux.just(ChatResponse.builder().from(chatResponse)
721-
.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
722-
.build());
723-
} else {
710+
return Flux.just(ChatResponse.builder()
711+
.from(chatResponse)
712+
.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
713+
.build());
714+
}
715+
else {
724716
// Send the tool execution result back to the model.
725-
return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
726-
chatResponse);
717+
return this.internalStream(
718+
new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
719+
chatResponse);
727720
}
728721
}
729722
else {
730723
return Flux.just(chatResponse);
731724
}
732-
})
725+
})// @formatter:off
733726
.doOnError(observation::error)
734727
.doFinally(s -> observation.stop())
735728
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));

0 commit comments

Comments
 (0)