diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java index 2ded856a05f..4156a6ad03f 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java @@ -30,6 +30,7 @@ import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.model.tool.ToolExecutionLimitExceededException; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; @@ -91,6 +92,7 @@ * @author Alexandros Pappas * @author Jonghoon Park * @author Soby Chacko + * @author lambochen * @since 1.0.0 */ public class AnthropicChatModel implements ChatModel { @@ -175,6 +177,10 @@ public ChatResponse call(Prompt prompt) { } public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) { + return this.internalCall(prompt, previousChatResponse, 1); + } + + public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse, int iterations) { ChatCompletionRequest request = createRequest(prompt, false); ChatModelObservationContext observationContext = ChatModelObservationContext.builder() @@ -204,7 +210,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons return chatResponse; }); - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response, iterations)) { var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); if (toolExecutionResult.returnDirect()) { // Return tool execution result directly to the client. @@ -216,9 +222,12 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons else { // Send the tool execution result back to the model. return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), - response); + response, iterations + 1); } } + else if (this.toolExecutionEligibilityPredicate.isLimitExceeded(prompt.getOptions(), iterations)) { + throw new ToolExecutionLimitExceededException(iterations); + } return response; } @@ -237,6 +246,10 @@ public Flux stream(Prompt prompt) { } public Flux internalStream(Prompt prompt, ChatResponse previousChatResponse) { + return this.internalStream(prompt, previousChatResponse, 1); + } + + public Flux internalStream(Prompt prompt, ChatResponse previousChatResponse, int iterations) { return Flux.deferContextual(contextView -> { ChatCompletionRequest request = createRequest(prompt, true); @@ -261,7 +274,7 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha Usage accumulatedUsage = UsageCalculator.getCumulativeUsage(currentChatResponseUsage, previousChatResponse); ChatResponse chatResponse = toChatResponse(chatCompletionResponse, accumulatedUsage); - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse)) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse, iterations)) { if (chatResponse.hasFinishReasons(Set.of("tool_use"))) { // FIXME: bounded elastic needs to be used since tool calling @@ -288,10 +301,12 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha } }).subscribeOn(Schedulers.boundedElastic()); - } else { + } else { return Mono.empty(); } + } else if (this.toolExecutionEligibilityPredicate.isLimitExceeded(prompt.getOptions(), iterations)){ + throw new ToolExecutionLimitExceededException(iterations); } else { // If internal tool execution is not required, just return the chat response. return Mono.just(chatResponse); @@ -453,6 +468,8 @@ Prompt buildRequestPrompt(Prompt prompt) { requestOptions.setInternalToolExecutionEnabled( ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionEnabled(), this.defaultOptions.getInternalToolExecutionEnabled())); + requestOptions.setToolExecutionMaxIterations(ModelOptionsUtils.mergeOption( + runtimeOptions.getToolExecutionMaxIterations(), defaultOptions.getToolExecutionMaxIterations())); requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(), this.defaultOptions.getToolNames())); requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(), @@ -463,6 +480,7 @@ Prompt buildRequestPrompt(Prompt prompt) { else { requestOptions.setHttpHeaders(this.defaultOptions.getHttpHeaders()); requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled()); + requestOptions.setToolExecutionMaxIterations(this.defaultOptions.getToolExecutionMaxIterations()); requestOptions.setToolNames(this.defaultOptions.getToolNames()); requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks()); requestOptions.setToolContext(this.defaultOptions.getToolContext()); diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java index dbfbee561c8..4f2513570ac 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java @@ -44,6 +44,7 @@ * @author Thomas Vitale * @author Alexandros Pappas * @author Ilayaperumal Gopinathan + * @author lambochen * @since 1.0.0 */ @JsonInclude(Include.NON_NULL) @@ -79,6 +80,9 @@ public class AnthropicChatOptions implements ToolCallingChatOptions { @JsonIgnore private Boolean internalToolExecutionEnabled; + @JsonIgnore + private Integer toolExecutionMaxIterations = ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS; + @JsonIgnore private Map toolContext = new HashMap<>(); @@ -109,6 +113,7 @@ public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions) fromOptions.getToolCallbacks() != null ? new ArrayList<>(fromOptions.getToolCallbacks()) : null) .toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null) .internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()) + .toolExecutionMaxIterations(fromOptions.getToolExecutionMaxIterations()) .toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null) .httpHeaders(fromOptions.getHttpHeaders() != null ? new HashMap<>(fromOptions.getHttpHeaders()) : null) .build(); @@ -226,6 +231,16 @@ public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecut this.internalToolExecutionEnabled = internalToolExecutionEnabled; } + @Override + public Integer getToolExecutionMaxIterations() { + return this.toolExecutionMaxIterations; + } + + @Override + public void setToolExecutionMaxIterations(@Nullable Integer toolExecutionMaxIterations) { + this.toolExecutionMaxIterations = toolExecutionMaxIterations; + } + @Override @JsonIgnore public Double getFrequencyPenalty() { @@ -281,6 +296,7 @@ public boolean equals(Object o) { && Objects.equals(this.toolCallbacks, that.toolCallbacks) && Objects.equals(this.toolNames, that.toolNames) && Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled) + && Objects.equals(this.toolExecutionMaxIterations, that.toolExecutionMaxIterations) && Objects.equals(this.toolContext, that.toolContext) && Objects.equals(this.httpHeaders, that.httpHeaders); } @@ -289,7 +305,7 @@ public boolean equals(Object o) { public int hashCode() { return Objects.hash(this.model, this.maxTokens, this.metadata, this.stopSequences, this.temperature, this.topP, this.topK, this.thinking, this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled, - this.toolContext, this.httpHeaders); + this.toolExecutionMaxIterations, this.toolContext, this.httpHeaders); } public static class Builder { @@ -374,6 +390,11 @@ public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecut return this; } + public Builder toolExecutionMaxIterations(@Nullable Integer toolExecutionMaxIterations) { + this.options.setToolExecutionMaxIterations(toolExecutionMaxIterations); + return this; + } + public Builder toolContext(Map toolContext) { if (this.options.toolContext == null) { this.options.toolContext = toolContext; diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatOptionsTests.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatOptionsTests.java index 62d97b459e4..09f036272cd 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatOptionsTests.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatOptionsTests.java @@ -22,6 +22,7 @@ import org.junit.jupiter.api.Test; import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest.Metadata; +import org.springframework.ai.model.tool.ToolCallingChatOptions; import static org.assertj.core.api.Assertions.assertThat; @@ -29,6 +30,7 @@ * Tests for {@link AnthropicChatOptions}. * * @author Alexandros Pappas + * @author lambochen */ class AnthropicChatOptionsTests { @@ -42,10 +44,13 @@ void testBuilderWithAllFields() { .topP(0.8) .topK(50) .metadata(new Metadata("userId_123")) + .toolExecutionMaxIterations(3) .build(); - assertThat(options).extracting("model", "maxTokens", "stopSequences", "temperature", "topP", "topK", "metadata") - .containsExactly("test-model", 100, List.of("stop1", "stop2"), 0.7, 0.8, 50, new Metadata("userId_123")); + assertThat(options) + .extracting("model", "maxTokens", "stopSequences", "temperature", "topP", "topK", "metadata", + "toolExecutionMaxIterations") + .containsExactly("test-model", 100, List.of("stop1", "stop2"), 0.7, 0.8, 50, new Metadata("userId_123"), 3); } @Test @@ -59,6 +64,7 @@ void testCopy() { .topK(50) .metadata(new Metadata("userId_123")) .toolContext(Map.of("key1", "value1")) + .toolExecutionMaxIterations(3) .build(); AnthropicChatOptions copied = original.copy(); @@ -67,6 +73,8 @@ void testCopy() { // Ensure deep copy assertThat(copied.getStopSequences()).isNotSameAs(original.getStopSequences()); assertThat(copied.getToolContext()).isNotSameAs(original.getToolContext()); + + assertThat(copied.getToolExecutionMaxIterations()).isEqualTo(3); } @Test @@ -79,6 +87,7 @@ void testSetters() { options.setTopP(0.8); options.setStopSequences(List.of("stop1", "stop2")); options.setMetadata(new Metadata("userId_123")); + options.setToolExecutionMaxIterations(3); assertThat(options.getModel()).isEqualTo("test-model"); assertThat(options.getMaxTokens()).isEqualTo(100); @@ -87,6 +96,7 @@ void testSetters() { assertThat(options.getTopP()).isEqualTo(0.8); assertThat(options.getStopSequences()).isEqualTo(List.of("stop1", "stop2")); assertThat(options.getMetadata()).isEqualTo(new Metadata("userId_123")); + assertThat(options.getToolExecutionMaxIterations()).isEqualTo(3); } @Test @@ -99,6 +109,8 @@ void testDefaultValues() { assertThat(options.getTopP()).isNull(); assertThat(options.getStopSequences()).isNull(); assertThat(options.getMetadata()).isNull(); + assertThat(options.getToolExecutionMaxIterations()) + .isEqualTo(ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS); } } diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java index 3f659671c4d..d1b81212626 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java @@ -62,6 +62,7 @@ import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.model.tool.ToolExecutionLimitExceededException; import reactor.core.publisher.Flux; import reactor.core.scheduler.Schedulers; @@ -123,8 +124,10 @@ * @author Berjan Jonker * @author Andres da Silva Santos * @author Bart Veenstra + * @author lambochen * @see ChatModel * @see com.azure.ai.openai.OpenAIClient + * @see ToolCallingChatOptions * @since 1.0.0 */ public class AzureOpenAiChatModel implements ChatModel { @@ -252,6 +255,10 @@ public ChatResponse call(Prompt prompt) { } public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) { + return internalCall(prompt, previousChatResponse, 1); + } + + public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse, int iterations) { ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(prompt) @@ -271,7 +278,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons return chatResponse; }); - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response, iterations)) { var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); if (toolExecutionResult.returnDirect()) { // Return tool execution result directly to the client. @@ -283,9 +290,12 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons else { // Send the tool execution result back to the model. return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), - response); + response, iterations + 1); } } + else if (this.toolExecutionEligibilityPredicate.isLimitExceeded(prompt.getOptions(), iterations)) { + throw new ToolExecutionLimitExceededException(iterations); + } return response; } @@ -299,6 +309,10 @@ public Flux stream(Prompt prompt) { } public Flux internalStream(Prompt prompt, ChatResponse previousChatResponse) { + return this.internalStream(prompt, previousChatResponse, 1); + } + + public Flux internalStream(Prompt prompt, ChatResponse previousChatResponse, int iterations) { return Flux.deferContextual(contextView -> { ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt); @@ -378,7 +392,8 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha }); return chatResponseFlux.flatMap(chatResponse -> { - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse)) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse, + iterations)) { // FIXME: bounded elastic needs to be used since tool calling // is currently only synchronous return Flux.deferContextual((ctx) -> { @@ -401,10 +416,13 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha // Send the tool execution result back to the model. return this.internalStream( new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), - chatResponse); + chatResponse, iterations + 1); } }).subscribeOn(Schedulers.boundedElastic()); } + else if (this.toolExecutionEligibilityPredicate.isLimitExceeded(prompt.getOptions(), iterations)) { + throw new ToolExecutionLimitExceededException(iterations); + } Flux flux = Flux.just(chatResponse) .doOnError(observation::error) @@ -674,6 +692,9 @@ Prompt buildRequestPrompt(Prompt prompt) { requestOptions.setInternalToolExecutionEnabled( ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionEnabled(), this.defaultOptions.getInternalToolExecutionEnabled())); + runtimeOptions.setToolExecutionMaxIterations( + ModelOptionsUtils.mergeOption(runtimeOptions.getToolExecutionMaxIterations(), + this.defaultOptions.getToolExecutionMaxIterations())); requestOptions.setStreamUsage(ModelOptionsUtils.mergeOption(runtimeOptions.getStreamUsage(), this.defaultOptions.getStreamUsage())); requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(), @@ -685,6 +706,7 @@ Prompt buildRequestPrompt(Prompt prompt) { } else { requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled()); + requestOptions.setToolExecutionMaxIterations(this.defaultOptions.getToolExecutionMaxIterations()); requestOptions.setStreamUsage(this.defaultOptions.getStreamUsage()); requestOptions.setToolNames(this.defaultOptions.getToolNames()); requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks()); diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java index da442b4ad4d..40168aba8e9 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java @@ -48,6 +48,7 @@ * @author Ilayaperumal Gopinathan * @author Alexandros Pappas * @author Andres da Silva Santos + * @author lambochen */ @JsonInclude(Include.NON_NULL) public class AzureOpenAiChatOptions implements ToolCallingChatOptions { @@ -200,6 +201,9 @@ public class AzureOpenAiChatOptions implements ToolCallingChatOptions { @JsonIgnore private Boolean internalToolExecutionEnabled; + @JsonIgnore + private Integer toolExecutionMaxIterations = ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS; + /** * Whether to include token usage information in streaming chat completion responses. * Only applies to streaming responses. @@ -257,6 +261,16 @@ public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecut this.internalToolExecutionEnabled = internalToolExecutionEnabled; } + @Override + public Integer getToolExecutionMaxIterations() { + return this.toolExecutionMaxIterations; + } + + @Override + public void setToolExecutionMaxIterations(Integer toolExecutionMaxIterations) { + this.toolExecutionMaxIterations = toolExecutionMaxIterations; + } + public static Builder builder() { return new Builder(); } @@ -284,6 +298,7 @@ public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOpti .enhancements(fromOptions.getEnhancements()) .toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null) .internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()) + .toolExecutionMaxIterations(fromOptions.getToolExecutionMaxIterations()) .streamOptions(fromOptions.getStreamOptions()) .toolCallbacks( fromOptions.getToolCallbacks() != null ? new ArrayList<>(fromOptions.getToolCallbacks()) : null) @@ -504,6 +519,7 @@ public boolean equals(Object o) { && Objects.equals(this.toolCallbacks, that.toolCallbacks) && Objects.equals(this.toolNames, that.toolNames) && Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled) + && Objects.equals(this.toolExecutionMaxIterations, that.toolExecutionMaxIterations) && Objects.equals(this.logprobs, that.logprobs) && Objects.equals(this.topLogProbs, that.topLogProbs) && Objects.equals(this.enhancements, that.enhancements) && Objects.equals(this.streamOptions, that.streamOptions) @@ -518,10 +534,10 @@ public boolean equals(Object o) { @Override public int hashCode() { return Objects.hash(this.logitBias, this.user, this.n, this.stop, this.deploymentName, this.responseFormat, - this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled, this.seed, this.logprobs, - this.topLogProbs, this.enhancements, this.streamOptions, this.reasoningEffort, this.enableStreamUsage, - this.toolContext, this.maxTokens, this.frequencyPenalty, this.presencePenalty, this.temperature, - this.topP); + this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled, this.toolExecutionMaxIterations, + this.seed, this.logprobs, this.topLogProbs, this.enhancements, this.streamOptions, this.reasoningEffort, + this.enableStreamUsage, this.toolContext, this.maxTokens, this.frequencyPenalty, this.presencePenalty, + this.temperature, this.topP); } public static class Builder { @@ -664,6 +680,11 @@ public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecut return this; } + public Builder toolExecutionMaxIterations(@Nullable Integer toolExecutionMaxIterations) { + this.options.setToolExecutionMaxIterations(toolExecutionMaxIterations); + return this; + } + public AzureOpenAiChatOptions build() { return this.options; } diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptionsTests.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptionsTests.java index 789635d358e..71203f9de12 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptionsTests.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptionsTests.java @@ -24,6 +24,7 @@ import com.azure.ai.openai.models.AzureChatOCREnhancementConfiguration; import com.azure.ai.openai.models.ChatCompletionStreamOptions; import org.junit.jupiter.api.Test; +import org.springframework.ai.model.tool.ToolCallingChatOptions; import static org.assertj.core.api.Assertions.assertThat; @@ -31,6 +32,7 @@ * Tests for {@link AzureOpenAiChatOptions}. * * @author Alexandros Pappas + * @author lambochen */ class AzureOpenAiChatOptionsTests { @@ -65,15 +67,16 @@ void testBuilderWithAllFields() { .topLogprobs(5) .enhancements(enhancements) .streamOptions(streamOptions) + .toolExecutionMaxIterations(3) .build(); assertThat(options) .extracting("deploymentName", "frequencyPenalty", "logitBias", "maxTokens", "n", "presencePenalty", "stop", "temperature", "topP", "user", "responseFormat", "streamUsage", "reasoningEffort", "seed", - "logprobs", "topLogProbs", "enhancements", "streamOptions") + "logprobs", "topLogProbs", "enhancements", "streamOptions", "toolExecutionMaxIterations") .containsExactly("test-deployment", 0.5, Map.of("token1", 1, "token2", -1), 200, 2, 0.8, List.of("stop1", "stop2"), 0.7, 0.9, "test-user", responseFormat, true, "low", 12345L, true, 5, - enhancements, streamOptions); + enhancements, streamOptions, 3); } @Test @@ -107,6 +110,7 @@ void testCopy() { .topLogprobs(5) .enhancements(enhancements) .streamOptions(streamOptions) + .toolExecutionMaxIterations(3) .build(); AzureOpenAiChatOptions copiedOptions = originalOptions.copy(); @@ -115,6 +119,8 @@ void testCopy() { // Ensure deep copy assertThat(copiedOptions.getStop()).isNotSameAs(originalOptions.getStop()); assertThat(copiedOptions.getToolContext()).isNotSameAs(originalOptions.getToolContext()); + + assertThat(copiedOptions.getToolExecutionMaxIterations()).isEqualTo(3); } @Test @@ -145,6 +151,7 @@ void testSetters() { options.setTopLogProbs(5); options.setEnhancements(enhancements); options.setStreamOptions(streamOptions); + options.setToolExecutionMaxIterations(3); assertThat(options.getDeploymentName()).isEqualTo("test-deployment"); options.setModel("test-model"); @@ -168,6 +175,7 @@ void testSetters() { assertThat(options.getEnhancements()).isEqualTo(enhancements); assertThat(options.getStreamOptions()).isEqualTo(streamOptions); assertThat(options.getModel()).isEqualTo("test-model"); + assertThat(options.getToolExecutionMaxIterations()).isEqualTo(3); } @Test @@ -193,6 +201,8 @@ void testDefaultValues() { assertThat(options.getEnhancements()).isNull(); assertThat(options.getStreamOptions()).isNull(); assertThat(options.getModel()).isNull(); + assertThat(options.getToolExecutionMaxIterations()) + .isEqualTo(ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS); } } diff --git a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java index 484e979385e..f853799faff 100644 --- a/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java +++ b/models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java @@ -33,6 +33,7 @@ import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.model.tool.ToolExecutionLimitExceededException; import reactor.core.publisher.Flux; import reactor.core.publisher.Sinks; import reactor.core.publisher.Sinks.EmitFailureHandler; @@ -134,6 +135,7 @@ * @author Alexandros Pappas * @author Jihoon Kim * @author Soby Chacko + * @author lambochen * @since 1.0.0 */ public class BedrockProxyChatModel implements ChatModel { @@ -218,6 +220,10 @@ public ChatResponse call(Prompt prompt) { } private ChatResponse internalCall(Prompt prompt, ChatResponse perviousChatResponse) { + return this.internalCall(prompt, perviousChatResponse, 1); + } + + private ChatResponse internalCall(Prompt prompt, ChatResponse perviousChatResponse, int iterations) { ConverseRequest converseRequest = this.createRequest(prompt); @@ -242,8 +248,8 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse perviousChatRespon return response; }); - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse) - && chatResponse.hasFinishReasons(Set.of(StopReason.TOOL_USE.toString()))) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse, + iterations) && chatResponse.hasFinishReasons(Set.of(StopReason.TOOL_USE.toString()))) { var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse); if (toolExecutionResult.returnDirect()) { // Return tool execution result directly to the client. @@ -255,9 +261,13 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse perviousChatRespon else { // Send the tool execution result back to the model. return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), - chatResponse); + chatResponse, iterations + 1); } } + else if (this.toolExecutionEligibilityPredicate.isLimitExceeded(prompt.getOptions(), iterations)) { + throw new ToolExecutionLimitExceededException(iterations); + } + return chatResponse; } @@ -311,6 +321,9 @@ Prompt buildRequestPrompt(Prompt prompt) { .internalToolExecutionEnabled(runtimeOptions.getInternalToolExecutionEnabled() != null ? runtimeOptions.getInternalToolExecutionEnabled() : this.defaultOptions.getInternalToolExecutionEnabled()) + .toolExecutionMaxIterations( + ModelOptionsUtils.mergeOption(runtimeOptions.getToolExecutionMaxIterations(), + this.defaultOptions.getToolExecutionMaxIterations())) .build(); } @@ -645,6 +658,10 @@ public Flux stream(Prompt prompt) { } private Flux internalStream(Prompt prompt, ChatResponse perviousChatResponse) { + return this.internalStream(prompt, perviousChatResponse, 1); + } + + private Flux internalStream(Prompt prompt, ChatResponse perviousChatResponse, int iterations) { Assert.notNull(prompt, "'prompt' must not be null"); return Flux.deferContextual(contextView -> { @@ -677,8 +694,8 @@ private Flux internalStream(Prompt prompt, ChatResponse perviousCh Flux chatResponseFlux = chatResponses.switchMap(chatResponse -> { - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse) - && chatResponse.hasFinishReasons(Set.of(StopReason.TOOL_USE.toString()))) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse, + iterations) && chatResponse.hasFinishReasons(Set.of(StopReason.TOOL_USE.toString()))) { // FIXME: bounded elastic needs to be used since tool calling // is currently only synchronous @@ -703,10 +720,13 @@ private Flux internalStream(Prompt prompt, ChatResponse perviousCh // Send the tool execution result back to the model. return this.internalStream( new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), - chatResponse); + chatResponse, iterations + 1); } }).subscribeOn(Schedulers.boundedElastic()); } + else if (this.toolExecutionEligibilityPredicate.isLimitExceeded(prompt.getOptions(), iterations)) { + throw new ToolExecutionLimitExceededException(iterations); + } else { return Flux.just(chatResponse); } diff --git a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java index 6295666e07f..7a703db9cf5 100644 --- a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java +++ b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java @@ -25,6 +25,7 @@ import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.model.tool.ToolExecutionLimitExceededException; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; @@ -76,6 +77,7 @@ * backed by {@link DeepSeekApi}. * * @author Geng Rong + * @author lambochen */ public class DeepSeekChatModel implements ChatModel { @@ -152,6 +154,10 @@ public ChatResponse call(Prompt prompt) { } public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) { + return internalCall(prompt, previousChatResponse, 1); + } + + public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse, int iterations) { ChatCompletionRequest request = createRequest(prompt, false); @@ -206,7 +212,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons }); - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response, iterations)) { var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); if (toolExecutionResult.returnDirect()) { // Return tool execution result directly to the client. @@ -218,9 +224,12 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons else { // Send the tool execution result back to the model. return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), - response); + response, iterations + 1); } } + else if (this.toolExecutionEligibilityPredicate.isLimitExceeded(prompt.getOptions(), iterations)) { + throw new ToolExecutionLimitExceededException(iterations); + } return response; } @@ -232,6 +241,10 @@ public Flux stream(Prompt prompt) { } public Flux internalStream(Prompt prompt, ChatResponse previousChatResponse) { + return internalStream(prompt, previousChatResponse, 1); + } + + public Flux internalStream(Prompt prompt, ChatResponse previousChatResponse, int iterations) { return Flux.deferContextual(contextView -> { ChatCompletionRequest request = createRequest(prompt, true); @@ -286,7 +299,7 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha // @formatter:off Flux flux = chatResponse.flatMap(response -> { - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response, iterations)) { // FIXME: bounded elastic needs to be used since tool calling // is currently only synchronous return Flux.deferContextual((ctx) -> { @@ -306,9 +319,11 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha else { // Send the tool execution result back to the model. return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), - response); + response, iterations + 1); } }).subscribeOn(Schedulers.boundedElastic()); + } else if (this.toolExecutionEligibilityPredicate.isLimitExceeded(prompt.getOptions(), iterations)){ + throw new ToolExecutionLimitExceededException(iterations); } else { return Flux.just(response); @@ -404,6 +419,9 @@ Prompt buildRequestPrompt(Prompt prompt) { requestOptions.setInternalToolExecutionEnabled( ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionEnabled(), this.defaultOptions.getInternalToolExecutionEnabled())); + requestOptions.setToolExecutionMaxIterations( + ModelOptionsUtils.mergeOption(runtimeOptions.getToolExecutionMaxIterations(), + this.defaultOptions.getToolExecutionMaxIterations())); requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(), this.defaultOptions.getToolNames())); requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(), @@ -413,6 +431,7 @@ Prompt buildRequestPrompt(Prompt prompt) { } else { requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled()); + requestOptions.setToolExecutionMaxIterations(this.defaultOptions.getToolExecutionMaxIterations()); requestOptions.setToolNames(this.defaultOptions.getToolNames()); requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks()); requestOptions.setToolContext(this.defaultOptions.getToolContext()); diff --git a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatOptions.java b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatOptions.java index b9c7a3d4962..e33d607d8c7 100644 --- a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatOptions.java +++ b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatOptions.java @@ -43,6 +43,7 @@ * chat completion * * @author Geng Rong + * @author lambochen */ @JsonInclude(Include.NON_NULL) public class DeepSeekChatOptions implements ToolCallingChatOptions { @@ -122,6 +123,9 @@ public class DeepSeekChatOptions implements ToolCallingChatOptions { @JsonIgnore private Boolean internalToolExecutionEnabled; + @JsonIgnore + private Integer toolExecutionMaxIterations = ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS; + /** * Tool Function Callbacks to register with the ChatModel. * For Prompt Options the toolCallbacks are automatically enabled for the duration of the prompt execution. @@ -289,6 +293,18 @@ public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecut this.internalToolExecutionEnabled = internalToolExecutionEnabled; } + @Override + @JsonIgnore + public Integer getToolExecutionMaxIterations() { + return this.toolExecutionMaxIterations; + } + + @Override + @JsonIgnore + public void setToolExecutionMaxIterations(@Nullable Integer toolExecutionMaxIterations) { + this.toolExecutionMaxIterations = toolExecutionMaxIterations; + } + public Boolean getLogprobs() { return this.logprobs; } @@ -332,7 +348,9 @@ public int hashCode() { return Objects.hash(this.model, this.frequencyPenalty, this.logprobs, this.topLogprobs, this.maxTokens, this.presencePenalty, this.responseFormat, this.stop, this.temperature, this.topP, this.tools, this.toolChoice, - this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled, this.toolContext); + this.toolCallbacks, this.toolNames, + this.internalToolExecutionEnabled, this.toolExecutionMaxIterations, + this.toolContext); } @@ -357,7 +375,9 @@ public boolean equals(Object o) { && Objects.equals(this.toolCallbacks, other.toolCallbacks) && Objects.equals(this.toolNames, other.toolNames) && Objects.equals(this.toolContext, other.toolContext) - && Objects.equals(this.internalToolExecutionEnabled, other.internalToolExecutionEnabled); + && Objects.equals(this.internalToolExecutionEnabled, other.internalToolExecutionEnabled) + && Objects.equals(this.toolExecutionMaxIterations, other.toolExecutionMaxIterations) + ; } public static DeepSeekChatOptions fromOptions(DeepSeekChatOptions fromOptions) { @@ -378,6 +398,7 @@ public static DeepSeekChatOptions fromOptions(DeepSeekChatOptions fromOptions) { fromOptions.getToolCallbacks() != null ? new ArrayList<>(fromOptions.getToolCallbacks()) : null) .toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null) .internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()) + .toolExecutionMaxIterations(fromOptions.getToolExecutionMaxIterations()) .toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null) .build(); } @@ -487,6 +508,11 @@ public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecut return this; } + public Builder toolExecutionMaxIterations(@Nullable Integer toolExecutionMaxIterations) { + this.options.setToolExecutionMaxIterations(toolExecutionMaxIterations); + return this; + } + public Builder toolContext(Map toolContext) { if (this.options.toolContext == null) { this.options.toolContext = toolContext; diff --git a/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/DeepSeekChatOptionsTest.java b/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/DeepSeekChatOptionsTest.java new file mode 100644 index 00000000000..8805039378a --- /dev/null +++ b/models/spring-ai-deepseek/src/test/java/org/springframework/ai/deepseek/DeepSeekChatOptionsTest.java @@ -0,0 +1,38 @@ +package org.springframework.ai.deepseek; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.model.tool.ToolCallingChatOptions; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * @author lambochen + */ +class DeepSeekChatOptionsTest { + + @Test + void fromOptions() { + var original = new DeepSeekChatOptions(); + original.setToolExecutionMaxIterations(3); + + var copy = DeepSeekChatOptions.fromOptions(original); + assertNotSame(original, copy); + assertSame(original.getToolExecutionMaxIterations(), copy.getToolExecutionMaxIterations()); + } + + @Test + void optionsDefault() { + var options = new DeepSeekChatOptions(); + + assertEquals(ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS, + options.getToolExecutionMaxIterations()); + } + + @Test + void optionsBuilder() { + var options = DeepSeekChatOptions.builder().toolExecutionMaxIterations(3).build(); + + assertEquals(3, options.getToolExecutionMaxIterations()); + } + +} diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java index 19f821b7fb3..c896f3d5282 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java @@ -26,6 +26,7 @@ import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.model.tool.ToolExecutionLimitExceededException; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; @@ -80,9 +81,11 @@ * @author Geng Rong * @author Alexandros Pappas * @author Ilayaperumal Gopinathan + * @author lambochen * @see ChatModel * @see StreamingChatModel * @see MiniMaxApi + * @see ToolCallingChatOptions * @since 1.0.0 M1 */ public class MiniMaxChatModel implements ChatModel { @@ -237,6 +240,10 @@ public ChatResponse call(Prompt prompt) { // Before moving any further, build the final request Prompt, // merging runtime and default options. Prompt requestPrompt = buildRequestPrompt(prompt); + return internalCall(requestPrompt, 1); + } + + private ChatResponse internalCall(Prompt requestPrompt, int iterations) { ChatCompletionRequest request = createRequest(requestPrompt, false); ChatModelObservationContext observationContext = ChatModelObservationContext.builder() @@ -293,7 +300,8 @@ else if (!CollectionUtils.isEmpty(choice.messages())) { return chatResponse; }); - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(requestPrompt.getOptions(), response)) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(requestPrompt.getOptions(), response, + iterations)) { var toolExecutionResult = this.toolCallingManager.executeToolCalls(requestPrompt, response); if (toolExecutionResult.returnDirect()) { // Return tool execution result directly to the client. @@ -304,9 +312,14 @@ else if (!CollectionUtils.isEmpty(choice.messages())) { } else { // Send the tool execution result back to the model. - return this.call(new Prompt(toolExecutionResult.conversationHistory(), requestPrompt.getOptions())); + return this.internalCall( + new Prompt(toolExecutionResult.conversationHistory(), requestPrompt.getOptions()), + iterations + 1); } } + else if (this.toolExecutionEligibilityPredicate.isLimitExceeded(requestPrompt.getOptions(), iterations)) { + throw new ToolExecutionLimitExceededException(iterations); + } return response; } @@ -321,6 +334,10 @@ public Flux stream(Prompt prompt) { // Before moving any further, build the final request Prompt, // merging runtime and default options. Prompt requestPrompt = buildRequestPrompt(prompt); + return internalStream(requestPrompt, 1); + } + + private Flux internalStream(Prompt requestPrompt, int iterations) { return Flux.deferContextual(contextView -> { ChatCompletionRequest request = createRequest(requestPrompt, true); @@ -362,22 +379,22 @@ public Flux stream(Prompt prompt) { return buildGeneration(choice, metadata); }).toList(); return new ChatResponse(generations, from(chatCompletion2)); - } - catch (Exception e) { + } + catch (Exception e) { logger.error("Error processing chat completion", e); return new ChatResponse(List.of()); } })); Flux flux = chatResponse.flatMap(response -> { - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(requestPrompt.getOptions(), response)) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(requestPrompt.getOptions(), response, iterations)) { // FIXME: bounded elastic needs to be used since tool calling // is currently only synchronous return Flux.deferContextual((ctx) -> { ToolExecutionResult toolExecutionResult; try { ToolCallReactiveContextHolder.setContext(ctx); - toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); + toolExecutionResult = this.toolCallingManager.executeToolCalls(requestPrompt, response); } finally { ToolCallReactiveContextHolder.clearContext(); } @@ -389,10 +406,13 @@ public Flux stream(Prompt prompt) { } else { // Send the tool execution result back to the model. - return this.stream(new Prompt(toolExecutionResult.conversationHistory(), requestPrompt.getOptions())); + return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), requestPrompt.getOptions()), iterations + 1); } }).subscribeOn(Schedulers.boundedElastic()); + } else if (this.toolExecutionEligibilityPredicate.isLimitExceeded(requestPrompt.getOptions(), iterations)){ + throw new ToolExecutionLimitExceededException(iterations); } + return Flux.just(response); }) .doOnError(observation::error) @@ -479,6 +499,9 @@ Prompt buildRequestPrompt(Prompt prompt) { requestOptions.setInternalToolExecutionEnabled( ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionEnabled(), this.defaultOptions.getInternalToolExecutionEnabled())); + requestOptions.setToolExecutionMaxIterations( + ModelOptionsUtils.mergeOption(runtimeOptions.getToolExecutionMaxIterations(), + this.defaultOptions.getToolExecutionMaxIterations())); requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(), this.defaultOptions.getToolNames())); requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(), @@ -488,6 +511,7 @@ Prompt buildRequestPrompt(Prompt prompt) { } else { requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled()); + requestOptions.setToolExecutionMaxIterations(this.defaultOptions.getToolExecutionMaxIterations()); requestOptions.setToolNames(this.defaultOptions.getToolNames()); requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks()); requestOptions.setToolContext(this.defaultOptions.getToolContext()); diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java index a8f1e62e77e..6ef6b131733 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java @@ -48,6 +48,7 @@ * @author Thomas Vitale * @author Ilayaperumal Gopinathan * @author Alexandros Pappas + * @author lambochen * @since 1.0.0 M1 */ @JsonInclude(Include.NON_NULL) @@ -156,6 +157,9 @@ public class MiniMaxChatOptions implements ToolCallingChatOptions { @JsonIgnore private Boolean internalToolExecutionEnabled; + @JsonIgnore + private Integer toolExecutionMaxIterations = ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS; + // @formatter:on public static Builder builder() { @@ -179,6 +183,7 @@ public static MiniMaxChatOptions fromOptions(MiniMaxChatOptions fromOptions) { .toolCallbacks(fromOptions.getToolCallbacks()) .toolNames(fromOptions.getToolNames()) .internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()) + .toolExecutionMaxIterations(fromOptions.getToolExecutionMaxIterations()) .toolContext(fromOptions.getToolContext()) .build(); } @@ -352,6 +357,16 @@ public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecut this.internalToolExecutionEnabled = internalToolExecutionEnabled; } + @Override + public Integer getToolExecutionMaxIterations() { + return this.toolExecutionMaxIterations; + } + + @Override + public void setToolExecutionMaxIterations(Integer toolExecutionMaxIterations) { + this.toolExecutionMaxIterations = toolExecutionMaxIterations; + } + @Override public Map getToolContext() { return (this.toolContext != null) ? Collections.unmodifiableMap(this.toolContext) : null; @@ -366,7 +381,7 @@ public void setToolContext(Map toolContext) { public int hashCode() { return Objects.hash(model, frequencyPenalty, maxTokens, n, presencePenalty, responseFormat, seed, stop, temperature, topP, maskSensitiveInfo, tools, toolChoice, toolCallbacks, toolNames, toolContext, - internalToolExecutionEnabled); + internalToolExecutionEnabled, toolExecutionMaxIterations); } @Override @@ -385,7 +400,8 @@ public boolean equals(Object o) { && Objects.equals(tools, that.tools) && Objects.equals(toolChoice, that.toolChoice) && Objects.equals(toolCallbacks, that.toolCallbacks) && Objects.equals(toolNames, that.toolNames) && Objects.equals(toolContext, that.toolContext) - && Objects.equals(internalToolExecutionEnabled, that.internalToolExecutionEnabled); + && Objects.equals(internalToolExecutionEnabled, that.internalToolExecutionEnabled) + && Objects.equals(toolExecutionMaxIterations, that.toolExecutionMaxIterations); } @Override @@ -498,6 +514,11 @@ public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecut return this; } + public Builder toolExecutionMaxIterations(Integer toolExecutionMaxIterations) { + this.options.setToolExecutionMaxIterations(toolExecutionMaxIterations); + return this; + } + public Builder toolContext(Map toolContext) { if (this.options.toolContext == null) { this.options.toolContext = toolContext; diff --git a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/chat/MiniMaxChatOptionsTests.java b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/chat/MiniMaxChatOptionsTests.java index fa3a0489409..dcb69f762c3 100644 --- a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/chat/MiniMaxChatOptionsTests.java +++ b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/chat/MiniMaxChatOptionsTests.java @@ -25,6 +25,7 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.model.tool.ToolCallingChatOptions; import reactor.core.publisher.Flux; import org.springframework.ai.chat.messages.AssistantMessage; @@ -116,4 +117,26 @@ void testToolCallingStream() { assertThat(content).contains("15"); } + @Test + void testOptionsDefaultValue() { + var options = new MiniMaxChatOptions(); + + assertThat(options.getToolExecutionMaxIterations()) + .isEqualTo(ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS); + } + + @Test + void testOptionsSetter() { + var options = new MiniMaxChatOptions(); + options.setToolExecutionMaxIterations(3); + assertThat(options.getToolExecutionMaxIterations()).isEqualTo(3); + } + + @Test + void testOptionsBuilder() { + var options = MiniMaxChatOptions.builder().toolExecutionMaxIterations(3).build(); + + assertThat(options.getToolExecutionMaxIterations()).isEqualTo(3); + } + } diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java index b1449fe580a..2729752e820 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java @@ -27,6 +27,7 @@ import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.model.tool.ToolExecutionLimitExceededException; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; @@ -84,7 +85,9 @@ * @author luocongqiu * @author Ilayaperumal Gopinathan * @author Alexandros Pappas + * @author lambochen * @since 1.0.0 + * @see ToolCallingChatOptions */ public class MistralAiChatModel implements ChatModel { @@ -182,6 +185,10 @@ public ChatResponse call(Prompt prompt) { } public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) { + return internalCall(prompt, previousChatResponse, 1); + } + + public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse, int iterations) { MistralAiApi.ChatCompletionRequest request = createRequest(prompt, false); @@ -226,7 +233,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons return chatResponse; }); - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response, iterations)) { var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); if (toolExecutionResult.returnDirect()) { // Return tool execution result directly to the client. @@ -238,9 +245,12 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons else { // Send the tool execution result back to the model. return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), - response); + response, iterations + 1); } } + else if (this.toolExecutionEligibilityPredicate.isLimitExceeded(prompt.getOptions(), iterations)) { + throw new ToolExecutionLimitExceededException(iterations); + } return response; } @@ -254,6 +264,10 @@ public Flux stream(Prompt prompt) { } public Flux internalStream(Prompt prompt, ChatResponse previousChatResponse) { + return internalStream(prompt, previousChatResponse, 1); + } + + public Flux internalStream(Prompt prompt, ChatResponse previousChatResponse, int iterations) { return Flux.deferContextual(contextView -> { var request = createRequest(prompt, true); @@ -314,7 +328,7 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha // @formatter:off Flux chatResponseFlux = chatResponse.flatMap(response -> { - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response, iterations)) { // FIXME: bounded elastic needs to be used since tool calling // is currently only synchronous return Flux.deferContextual((ctx) -> { @@ -334,9 +348,11 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha else { // Send the tool execution result back to the model. return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), - response); + response, iterations + 1); } }).subscribeOn(Schedulers.boundedElastic()); + } else if (this.toolExecutionEligibilityPredicate.isLimitExceeded(prompt.getOptions(), iterations)){ + throw new ToolExecutionLimitExceededException(iterations); } else { return Flux.just(response); @@ -401,6 +417,9 @@ Prompt buildRequestPrompt(Prompt prompt) { requestOptions.setInternalToolExecutionEnabled( ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionEnabled(), this.defaultOptions.getInternalToolExecutionEnabled())); + requestOptions.setToolExecutionMaxIterations( + ModelOptionsUtils.mergeOption(runtimeOptions.getToolExecutionMaxIterations(), + this.defaultOptions.getToolExecutionMaxIterations())); requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(), this.defaultOptions.getToolNames())); requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(), @@ -410,6 +429,7 @@ Prompt buildRequestPrompt(Prompt prompt) { } else { requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled()); + requestOptions.setToolExecutionMaxIterations(this.defaultOptions.getToolExecutionMaxIterations()); requestOptions.setToolNames(this.defaultOptions.getToolNames()); requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks()); requestOptions.setToolContext(this.defaultOptions.getToolContext()); diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java index 801c35f2118..ae7fe15d853 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java @@ -45,6 +45,7 @@ * @author Christian Tzolov * @author Thomas Vitale * @author Alexandros Pappas + * @author lambochen * @since 0.8.1 */ @JsonInclude(JsonInclude.Include.NON_NULL) @@ -156,6 +157,9 @@ public class MistralAiChatOptions implements ToolCallingChatOptions { @JsonIgnore private Boolean internalToolExecutionEnabled; + @JsonIgnore + private Integer toolExecutionMaxIterations = ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS; + @JsonIgnore private Map toolContext = new HashMap<>(); @@ -181,6 +185,7 @@ public static MistralAiChatOptions fromOptions(MistralAiChatOptions fromOptions) fromOptions.getToolCallbacks() != null ? new ArrayList<>(fromOptions.getToolCallbacks()) : null) .toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null) .internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()) + .toolExecutionMaxIterations(fromOptions.getToolExecutionMaxIterations()) .toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null) .build(); } @@ -348,6 +353,16 @@ public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecut this.internalToolExecutionEnabled = internalToolExecutionEnabled; } + @Override + public Integer getToolExecutionMaxIterations() { + return this.toolExecutionMaxIterations; + } + + @Override + public void setToolExecutionMaxIterations(@Nullable Integer toolExecutionMaxIterations) { + this.toolExecutionMaxIterations = toolExecutionMaxIterations; + } + @Override @JsonIgnore public Integer getTopK() { @@ -376,7 +391,8 @@ public MistralAiChatOptions copy() { public int hashCode() { return Objects.hash(this.model, this.temperature, this.topP, this.maxTokens, this.safePrompt, this.randomSeed, this.responseFormat, this.stop, this.frequencyPenalty, this.presencePenalty, this.n, this.tools, - this.toolChoice, this.toolCallbacks, this.tools, this.internalToolExecutionEnabled, this.toolContext); + this.toolChoice, this.toolCallbacks, this.tools, this.internalToolExecutionEnabled, + this.toolExecutionMaxIterations, this.toolContext); } @Override @@ -402,6 +418,7 @@ public boolean equals(Object obj) { && Objects.equals(this.toolCallbacks, other.toolCallbacks) && Objects.equals(this.toolNames, other.toolNames) && Objects.equals(this.internalToolExecutionEnabled, other.internalToolExecutionEnabled) + && Objects.equals(this.toolExecutionMaxIterations, other.toolExecutionMaxIterations) && Objects.equals(this.toolContext, other.toolContext); } @@ -507,6 +524,11 @@ public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecut return this; } + public Builder toolExecutionMaxIterations(@Nullable Integer toolExecutionMaxIterations) { + this.options.setToolExecutionMaxIterations(toolExecutionMaxIterations); + return this; + } + public Builder toolContext(Map toolContext) { if (this.options.toolContext == null) { this.options.toolContext = toolContext; diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatCompletionRequestTest.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatCompletionRequestTest.java index e6bf2490cc0..3147eb2eae0 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatCompletionRequestTest.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatCompletionRequestTest.java @@ -35,6 +35,7 @@ * @author Ricken Bazolo * @author Alexandros Pappas * @author Thomas Vitale + * @author lambochen * @since 0.8.1 */ @SpringBootTest(classes = MistralAiTestConfiguration.class) @@ -73,6 +74,7 @@ void whenToolRuntimeOptionsThenMergeWithDefaults() { MistralAiChatOptions defaultOptions = MistralAiChatOptions.builder() .model("DEFAULT_MODEL") .internalToolExecutionEnabled(true) + .toolExecutionMaxIterations(ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS) .toolCallbacks(new TestToolCallback("tool1"), new TestToolCallback("tool2")) .toolNames("tool1", "tool2") .toolContext(Map.of("key1", "value1", "key2", "valueA")) @@ -85,6 +87,7 @@ void whenToolRuntimeOptionsThenMergeWithDefaults() { MistralAiChatOptions runtimeOptions = MistralAiChatOptions.builder() .internalToolExecutionEnabled(false) + .toolExecutionMaxIterations(3) .toolCallbacks(new TestToolCallback("tool3"), new TestToolCallback("tool4")) .toolNames("tool3") .toolContext(Map.of("key2", "valueB")) @@ -93,6 +96,7 @@ void whenToolRuntimeOptionsThenMergeWithDefaults() { assertThat(((ToolCallingChatOptions) prompt.getOptions())).isNotNull(); assertThat(((ToolCallingChatOptions) prompt.getOptions()).getInternalToolExecutionEnabled()).isFalse(); + assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolExecutionMaxIterations()).isEqualTo(3); assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolCallbacks()).hasSize(2); assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolCallbacks() .stream() diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatOptionsTests.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatOptionsTests.java index 3177e85a442..c3bb225f5a0 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatOptionsTests.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatOptionsTests.java @@ -16,12 +16,13 @@ package org.springframework.ai.mistralai; +import org.junit.jupiter.api.Test; +import org.springframework.ai.model.tool.ToolCallingChatOptions; import java.util.List; import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; -import org.junit.jupiter.api.Test; import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest.ResponseFormat; import org.springframework.ai.mistralai.api.MistralAiApi; @@ -30,6 +31,7 @@ * Tests for {@link MistralAiChatOptions}. * * @author Alexandros Pappas + * @author lambochen */ class MistralAiChatOptionsTests { @@ -124,4 +126,28 @@ void testDefaultValues() { assertThat(options.getResponseFormat()).isNull(); } + @Test + void testOptionsDefault() { + var options = new MistralAiChatOptions(); + + assertThat(options.getToolExecutionMaxIterations()) + .isEqualTo(ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS); + } + + @Test + void testOptionsCustom() { + var options = new MistralAiChatOptions(); + + options.setToolExecutionMaxIterations(3); + + assertThat(options.getToolExecutionMaxIterations()).isEqualTo(3); + } + + @Test + void testBuilder() { + var options = MistralAiChatOptions.builder().toolExecutionMaxIterations(3).build(); + + assertThat(options.getToolExecutionMaxIterations()).isEqualTo(3); + } + } diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index c6bd6c2676e..d089be03415 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -28,6 +28,7 @@ import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.model.tool.ToolExecutionLimitExceededException; import reactor.core.publisher.Flux; import reactor.core.scheduler.Schedulers; @@ -91,7 +92,9 @@ * @author Alexandros Pappas * @author Ilayaperumal Gopinathan * @author Sun Yuhan + * @author lambochen * @since 1.0.0 + * @see ToolCallingChatOptions */ public class OllamaChatModel implements ChatModel { @@ -231,10 +234,10 @@ public ChatResponse call(Prompt prompt) { // Before moving any further, build the final request Prompt, // merging runtime and default options. Prompt requestPrompt = buildRequestPrompt(prompt); - return this.internalCall(requestPrompt, null); + return this.internalCall(requestPrompt, null, 1); } - private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) { + private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse, int iterations) { OllamaApi.ChatRequest request = ollamaChatRequest(prompt, false); @@ -277,7 +280,7 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon }); - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response, iterations)) { var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); if (toolExecutionResult.returnDirect()) { // Return tool execution result directly to the client. @@ -289,9 +292,12 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon else { // Send the tool execution result back to the model. return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), - response); + response, iterations + 1); } } + else if (this.toolExecutionEligibilityPredicate.isLimitExceeded(prompt.getOptions(), iterations)) { + throw new ToolExecutionLimitExceededException(iterations); + } return response; } @@ -301,10 +307,10 @@ public Flux stream(Prompt prompt) { // Before moving any further, build the final request Prompt, // merging runtime and default options. Prompt requestPrompt = buildRequestPrompt(prompt); - return this.internalStream(requestPrompt, null); + return this.internalStream(requestPrompt, null, 1); } - private Flux internalStream(Prompt prompt, ChatResponse previousChatResponse) { + private Flux internalStream(Prompt prompt, ChatResponse previousChatResponse, int iterations) { return Flux.deferContextual(contextView -> { OllamaApi.ChatRequest request = ollamaChatRequest(prompt, true); @@ -349,7 +355,7 @@ private Flux internalStream(Prompt prompt, ChatResponse previousCh // @formatter:off Flux chatResponseFlux = chatResponse.flatMap(response -> { - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response, iterations)) { // FIXME: bounded elastic needs to be used since tool calling // is currently only synchronous return Flux.deferContextual((ctx) -> { @@ -369,9 +375,11 @@ private Flux internalStream(Prompt prompt, ChatResponse previousCh else { // Send the tool execution result back to the model. return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), - response); + response, iterations + 1); } }).subscribeOn(Schedulers.boundedElastic()); + } else if (this.toolExecutionEligibilityPredicate.isLimitExceeded(prompt.getOptions(), iterations)){ + throw new ToolExecutionLimitExceededException(iterations); } else { return Flux.just(response); @@ -411,6 +419,9 @@ Prompt buildRequestPrompt(Prompt prompt) { requestOptions.setInternalToolExecutionEnabled( ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionEnabled(), this.defaultOptions.getInternalToolExecutionEnabled())); + requestOptions.setToolExecutionMaxIterations( + ModelOptionsUtils.mergeOption(runtimeOptions.getToolExecutionMaxIterations(), + this.defaultOptions.getToolExecutionMaxIterations())); requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(), this.defaultOptions.getToolNames())); requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(), @@ -420,6 +431,7 @@ Prompt buildRequestPrompt(Prompt prompt) { } else { requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled()); + requestOptions.setToolExecutionMaxIterations(this.defaultOptions.getToolExecutionMaxIterations()); requestOptions.setToolNames(this.defaultOptions.getToolNames()); requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks()); requestOptions.setToolContext(this.defaultOptions.getToolContext()); diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java index a71be1ce2b2..04553fee8b4 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java @@ -44,6 +44,7 @@ * @author Christian Tzolov * @author Thomas Vitale * @author Ilayaperumal Gopinathan + * @author lambochen * @since 0.8.0 * @see Ollama @@ -321,6 +322,9 @@ public class OllamaOptions implements ToolCallingChatOptions, EmbeddingOptions { @JsonIgnore private Boolean internalToolExecutionEnabled; + @JsonIgnore + private Integer toolExecutionMaxIterations = ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS; + /** * Tool Function Callbacks to register with the ChatModel. * For Prompt Options the toolCallbacks are automatically enabled for the duration of the prompt execution. @@ -397,6 +401,7 @@ public static OllamaOptions fromOptions(OllamaOptions fromOptions) { .stop(fromOptions.getStop()) .toolNames(fromOptions.getToolNames()) .internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()) + .toolExecutionMaxIterations(fromOptions.getToolExecutionMaxIterations()) .toolCallbacks(fromOptions.getToolCallbacks()) .toolContext(fromOptions.getToolContext()).build(); } @@ -746,6 +751,16 @@ public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecut this.internalToolExecutionEnabled = internalToolExecutionEnabled; } + @Override + public Integer getToolExecutionMaxIterations() { + return this.toolExecutionMaxIterations; + } + + @Override + public void setToolExecutionMaxIterations(@Nullable Integer toolExecutionMaxIterations) { + this.toolExecutionMaxIterations = toolExecutionMaxIterations; + } + @Override @JsonIgnore public Integer getDimensions() { @@ -809,6 +824,7 @@ public boolean equals(Object o) { && Objects.equals(this.penalizeNewline, that.penalizeNewline) && Objects.equals(this.stop, that.stop) && Objects.equals(this.toolCallbacks, that.toolCallbacks) && Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled) + && Objects.equals(this.toolExecutionMaxIterations, that.toolExecutionMaxIterations) && Objects.equals(this.toolNames, that.toolNames) && Objects.equals(this.toolContext, that.toolContext); } @@ -820,7 +836,7 @@ public int hashCode() { this.topP, this.minP, this.tfsZ, this.typicalP, this.repeatLastN, this.temperature, this.repeatPenalty, this.presencePenalty, this.frequencyPenalty, this.mirostat, this.mirostatTau, this.mirostatEta, this.penalizeNewline, this.stop, this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled, - this.toolContext); + this.toolExecutionMaxIterations, this.toolContext); } public static class Builder { @@ -1029,6 +1045,11 @@ public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecut return this; } + public Builder toolExecutionMaxIterations(@Nullable Integer toolExecutionMaxIterations) { + this.options.setToolExecutionMaxIterations(toolExecutionMaxIterations); + return this; + } + public Builder toolContext(Map toolContext) { if (this.options.toolContext == null) { this.options.toolContext = toolContext; diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java index dbc65e1fb25..429c7e9819c 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java @@ -50,6 +50,7 @@ void whenToolRuntimeOptionsThenMergeWithDefaults() { OllamaOptions defaultOptions = OllamaOptions.builder() .model("MODEL_NAME") .internalToolExecutionEnabled(true) + .toolExecutionMaxIterations(ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS) .toolCallbacks(new TestToolCallback("tool1"), new TestToolCallback("tool2")) .toolNames("tool1", "tool2") .toolContext(Map.of("key1", "value1", "key2", "valueA")) @@ -61,6 +62,7 @@ void whenToolRuntimeOptionsThenMergeWithDefaults() { OllamaOptions runtimeOptions = OllamaOptions.builder() .internalToolExecutionEnabled(false) + .toolExecutionMaxIterations(3) .toolCallbacks(new TestToolCallback("tool3"), new TestToolCallback("tool4")) .toolNames("tool3") .toolContext(Map.of("key2", "valueB")) @@ -69,6 +71,7 @@ void whenToolRuntimeOptionsThenMergeWithDefaults() { assertThat(((ToolCallingChatOptions) prompt.getOptions())).isNotNull(); assertThat(((ToolCallingChatOptions) prompt.getOptions()).getInternalToolExecutionEnabled()).isFalse(); + assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolExecutionMaxIterations()).isEqualTo(3); assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolCallbacks()).hasSize(2); assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolCallbacks() .stream() diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index 2ad584fa82f..62af1f92d56 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -29,6 +29,7 @@ import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.model.tool.ToolExecutionLimitExceededException; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; @@ -106,9 +107,11 @@ * @author Alexandros Pappas * @author Soby Chacko * @author Jonghoon Park + * @author lambochen * @see ChatModel * @see StreamingChatModel * @see OpenAiApi + * @see ToolCallingChatOptions */ public class OpenAiChatModel implements ChatModel { @@ -183,6 +186,10 @@ public ChatResponse call(Prompt prompt) { } public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) { + return internalCall(prompt, previousChatResponse, 1); + } + + public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse, int iterations) { ChatCompletionRequest request = createRequest(prompt, false); @@ -241,7 +248,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons }); - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response, iterations)) { var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); if (toolExecutionResult.returnDirect()) { // Return tool execution result directly to the client. @@ -253,9 +260,12 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons else { // Send the tool execution result back to the model. return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), - response); + response, iterations + 1); } } + else if (this.toolExecutionEligibilityPredicate.isLimitExceeded(prompt.getOptions(), iterations)) { + throw new ToolExecutionLimitExceededException(iterations); + } return response; } @@ -269,6 +279,10 @@ public Flux stream(Prompt prompt) { } public Flux internalStream(Prompt prompt, ChatResponse previousChatResponse) { + return internalStream(prompt, previousChatResponse, 1); + } + + public Flux internalStream(Prompt prompt, ChatResponse previousChatResponse, int iterations) { return Flux.deferContextual(contextView -> { ChatCompletionRequest request = createRequest(prompt, true); @@ -363,7 +377,7 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha // @formatter:off Flux flux = chatResponse.flatMap(response -> { - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response, iterations)) { // FIXME: bounded elastic needs to be used since tool calling // is currently only synchronous return Flux.deferContextual((ctx) -> { @@ -383,9 +397,11 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha else { // Send the tool execution result back to the model. return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), - response); + response, iterations + 1); } }).subscribeOn(Schedulers.boundedElastic()); + } else if (this.toolExecutionEligibilityPredicate.isLimitExceeded(prompt.getOptions(), iterations)){ + throw new ToolExecutionLimitExceededException(iterations); } else { return Flux.just(response); @@ -527,6 +543,9 @@ Prompt buildRequestPrompt(Prompt prompt) { requestOptions.setInternalToolExecutionEnabled( ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionEnabled(), this.defaultOptions.getInternalToolExecutionEnabled())); + requestOptions.setToolExecutionMaxIterations( + ModelOptionsUtils.mergeOption(runtimeOptions.getToolExecutionMaxIterations(), + this.defaultOptions.getToolExecutionMaxIterations())); requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(), this.defaultOptions.getToolNames())); requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(), @@ -537,6 +556,7 @@ Prompt buildRequestPrompt(Prompt prompt) { else { requestOptions.setHttpHeaders(this.defaultOptions.getHttpHeaders()); requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled()); + requestOptions.setToolExecutionMaxIterations(this.defaultOptions.getToolExecutionMaxIterations()); requestOptions.setToolNames(this.defaultOptions.getToolNames()); requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks()); requestOptions.setToolContext(this.defaultOptions.getToolContext()); diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java index afbbd803ec6..3e0137b9d85 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java @@ -219,6 +219,9 @@ public class OpenAiChatOptions implements ToolCallingChatOptions { @JsonIgnore private Boolean internalToolExecutionEnabled; + @JsonIgnore + private Integer toolExecutionMaxIterations = ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS; + /** * Optional HTTP headers to be added to the chat completion request. */ @@ -263,6 +266,7 @@ public static OpenAiChatOptions fromOptions(OpenAiChatOptions fromOptions) { .toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null) .httpHeaders(fromOptions.getHttpHeaders() != null ? new HashMap<>(fromOptions.getHttpHeaders()) : null) .internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()) + .toolExecutionMaxIterations(fromOptions.getToolExecutionMaxIterations()) .toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null) .store(fromOptions.getStore()) .metadata(fromOptions.getMetadata()) @@ -506,6 +510,16 @@ public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecut this.internalToolExecutionEnabled = internalToolExecutionEnabled; } + @Override + public Integer getToolExecutionMaxIterations() { + return this.toolExecutionMaxIterations; + } + + @Override + public void setToolExecutionMaxIterations(@Nullable Integer toolExecutionMaxIterations) { + this.toolExecutionMaxIterations = toolExecutionMaxIterations; + } + public Map getHttpHeaders() { return this.httpHeaders; } @@ -575,8 +589,9 @@ public int hashCode() { this.maxTokens, this.maxCompletionTokens, this.n, this.presencePenalty, this.responseFormat, this.streamOptions, this.seed, this.stop, this.temperature, this.topP, this.tools, this.toolChoice, this.user, this.parallelToolCalls, this.toolCallbacks, this.toolNames, this.httpHeaders, - this.internalToolExecutionEnabled, this.toolContext, this.outputModalities, this.outputAudio, - this.store, this.metadata, this.reasoningEffort, this.webSearchOptions); + this.internalToolExecutionEnabled, this.toolExecutionMaxIterations, this.toolContext, + this.outputModalities, this.outputAudio, this.store, this.metadata, this.reasoningEffort, + this.webSearchOptions); } @Override @@ -605,6 +620,7 @@ public boolean equals(Object o) { && Objects.equals(this.httpHeaders, other.httpHeaders) && Objects.equals(this.toolContext, other.toolContext) && Objects.equals(this.internalToolExecutionEnabled, other.internalToolExecutionEnabled) + && Objects.equals(this.toolExecutionMaxIterations, other.toolExecutionMaxIterations) && Objects.equals(this.outputModalities, other.outputModalities) && Objects.equals(this.outputAudio, other.outputAudio) && Objects.equals(this.store, other.store) && Objects.equals(this.metadata, other.metadata) @@ -767,6 +783,11 @@ public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecut return this; } + public Builder toolExecutionMaxIterations(@Nullable Integer toolExecutionMaxIterations) { + this.options.setToolExecutionMaxIterations(toolExecutionMaxIterations); + return this; + } + public Builder httpHeaders(Map httpHeaders) { this.options.httpHeaders = httpHeaders; return this; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java index 3d7623c96f4..09e7d16dd60 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java @@ -36,6 +36,7 @@ /** * @author Christian Tzolov * @author Thomas Vitale + * @author lambochen */ class ChatCompletionRequestTests { @@ -44,6 +45,7 @@ void whenToolRuntimeOptionsThenMergeWithDefaults() { OpenAiChatOptions defaultOptions = OpenAiChatOptions.builder() .model("DEFAULT_MODEL") .internalToolExecutionEnabled(true) + .toolExecutionMaxIterations(ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS) .toolCallbacks(new TestToolCallback("tool1"), new TestToolCallback("tool2")) .toolNames("tool1", "tool2") .toolContext(Map.of("key1", "value1", "key2", "valueA")) @@ -56,6 +58,7 @@ void whenToolRuntimeOptionsThenMergeWithDefaults() { OpenAiChatOptions runtimeOptions = OpenAiChatOptions.builder() .internalToolExecutionEnabled(false) + .toolExecutionMaxIterations(10) .toolCallbacks(new TestToolCallback("tool3"), new TestToolCallback("tool4")) .toolNames("tool3") .toolContext(Map.of("key2", "valueB")) @@ -64,6 +67,7 @@ void whenToolRuntimeOptionsThenMergeWithDefaults() { assertThat(((ToolCallingChatOptions) prompt.getOptions())).isNotNull(); assertThat(((ToolCallingChatOptions) prompt.getOptions()).getInternalToolExecutionEnabled()).isFalse(); + assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolExecutionMaxIterations()).isEqualTo(10); assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolCallbacks()).hasSize(2); assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolCallbacks() .stream() diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiChatOptionsTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiChatOptionsTests.java index 70b7f1fad66..b0eba0b9781 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiChatOptionsTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiChatOptionsTests.java @@ -23,6 +23,7 @@ import org.junit.jupiter.api.Test; +import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.AudioParameters; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.StreamOptions; @@ -81,6 +82,7 @@ void testBuilderWithAllFields() { .metadata(metadata) .reasoningEffort("medium") .internalToolExecutionEnabled(false) + .toolExecutionMaxIterations(10) .httpHeaders(Map.of("header1", "value1")) .toolContext(toolContext) .build(); @@ -90,10 +92,10 @@ void testBuilderWithAllFields() { "maxCompletionTokens", "n", "outputModalities", "outputAudio", "presencePenalty", "responseFormat", "streamOptions", "seed", "stop", "temperature", "topP", "tools", "toolChoice", "user", "parallelToolCalls", "store", "metadata", "reasoningEffort", "internalToolExecutionEnabled", - "httpHeaders", "toolContext") + "toolExecutionMaxIterations", "httpHeaders", "toolContext") .containsExactly("test-model", 0.5, logitBias, true, 5, 100, 50, 2, outputModalities, outputAudio, 0.8, responseFormat, streamOptions, 12345, stopSequences, 0.7, 0.9, tools, toolChoice, "test-user", true, - false, metadata, "medium", false, Map.of("header1", "value1"), toolContext); + false, metadata, "medium", false, 10, Map.of("header1", "value1"), toolContext); assertThat(options.getStreamUsage()).isTrue(); assertThat(options.getStreamOptions()).isEqualTo(StreamOptions.INCLUDE_USAGE); @@ -140,6 +142,7 @@ void testCopy() { .metadata(metadata) .reasoningEffort("low") .internalToolExecutionEnabled(true) + .toolExecutionMaxIterations(3) .httpHeaders(Map.of("header1", "value1")) .build(); @@ -188,6 +191,7 @@ void testSetters() { options.setMetadata(metadata); options.setReasoningEffort("high"); options.setInternalToolExecutionEnabled(false); + options.setToolExecutionMaxIterations(3); options.setHttpHeaders(Map.of("header2", "value2")); assertThat(options.getModel()).isEqualTo("test-model"); @@ -215,6 +219,7 @@ void testSetters() { assertThat(options.getMetadata()).isEqualTo(metadata); assertThat(options.getReasoningEffort()).isEqualTo("high"); assertThat(options.getInternalToolExecutionEnabled()).isFalse(); + assertThat(options.getToolExecutionMaxIterations()).isEqualTo(3); assertThat(options.getHttpHeaders()).isEqualTo(Map.of("header2", "value2")); assertThat(options.getStreamUsage()).isTrue(); options.setStreamUsage(false); @@ -254,6 +259,8 @@ void testDefaultValues() { assertThat(options.getReasoningEffort()).isNull(); assertThat(options.getToolCallbacks()).isNotNull().isEmpty(); assertThat(options.getInternalToolExecutionEnabled()).isNull(); + assertThat(options.getToolExecutionMaxIterations()) + .isEqualTo(ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS); assertThat(options.getHttpHeaders()).isNotNull().isEmpty(); assertThat(options.getToolContext()).isEqualTo(new HashMap<>()); assertThat(options.getStreamUsage()).isFalse(); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java index e7401d9d81b..e10dfdd2307 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java @@ -32,22 +32,48 @@ public class OpenAiTestConfiguration { @Bean public OpenAiApi openAiApi() { - return OpenAiApi.builder().apiKey(getApiKey()).build(); + var builder = OpenAiApi.builder().apiKey(getApiKey()); + + String baseUrl = getBaseUrl(); + if (StringUtils.hasText(baseUrl)) { + builder.baseUrl(baseUrl); + } + String completionsPath = getCompletionsPath(); + if (StringUtils.hasText(completionsPath)) { + builder.completionsPath(completionsPath); + } + + return builder.build(); } @Bean public OpenAiImageApi openAiImageApi() { - return OpenAiImageApi.builder().apiKey(getApiKey()).build(); + var builder = OpenAiImageApi.builder().apiKey(getApiKey()); + String baseUrl = getBaseUrl(); + if (StringUtils.hasText(baseUrl)) { + builder.baseUrl(baseUrl); + } + return builder.build(); } @Bean public OpenAiAudioApi openAiAudioApi() { - return OpenAiAudioApi.builder().apiKey(getApiKey()).build(); + var builder = OpenAiAudioApi.builder().apiKey(getApiKey()); + String baseUrl = getBaseUrl(); + if (StringUtils.hasText(baseUrl)) { + builder.baseUrl(baseUrl); + } + return builder.build(); } @Bean public OpenAiModerationApi openAiModerationApi() { - return OpenAiModerationApi.builder().apiKey(getApiKey()).build(); + var builder = OpenAiModerationApi.builder().apiKey(getApiKey()); + String baseUrl = getBaseUrl(); + if (StringUtils.hasText(baseUrl)) { + builder.baseUrl(baseUrl); + } + return builder.build(); } private ApiKey getApiKey() { @@ -59,6 +85,22 @@ private ApiKey getApiKey() { return new SimpleApiKey(apiKey); } + private String getBaseUrl() { + String baseUrl = System.getenv("OPENAI_BASE_URL"); + if (StringUtils.hasText(baseUrl)) { + return baseUrl; + } + return null; + } + + private String getCompletionsPath() { + String path = System.getenv("OPENAI_COMPLETIONS_PATH"); + if (StringUtils.hasText(path)) { + return path; + } + return null; + } + @Bean public OpenAiChatModel openAiChatModel(OpenAiApi api) { return OpenAiChatModel.builder() diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java index 852678a1da3..61f15d71f21 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java @@ -49,6 +49,7 @@ import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.model.tool.ToolExecutionLimitExceededException; import reactor.core.publisher.Flux; import reactor.core.scheduler.Schedulers; @@ -137,10 +138,12 @@ * @author Jihoon Kim * @author Alexandros Pappas * @author Ilayaperumal Gopinathan + * @author lambochen * @since 0.8.1 * @see VertexAiGeminiChatOptions * @see ToolCallingManager * @see ChatModel + * @see ToolCallingChatOptions */ public class VertexAiGeminiChatModel implements ChatModel, DisposableBean { @@ -394,6 +397,10 @@ public ChatResponse call(Prompt prompt) { } private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) { + return this.internalCall(prompt, previousChatResponse, 1); + } + + private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse, int iterations) { ChatModelObservationContext observationContext = ChatModelObservationContext.builder() .prompt(prompt) @@ -426,7 +433,7 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon return chatResponse; })); - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response, iterations)) { var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); if (toolExecutionResult.returnDirect()) { // Return tool execution result directly to the client. @@ -438,9 +445,12 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon else { // Send the tool execution result back to the model. return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), - response); + response, iterations + 1); } } + else if (this.toolExecutionEligibilityPredicate.isLimitExceeded(prompt.getOptions(), iterations)) { + throw new ToolExecutionLimitExceededException(iterations); + } return response; @@ -470,6 +480,9 @@ Prompt buildRequestPrompt(Prompt prompt) { requestOptions.setInternalToolExecutionEnabled( ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionEnabled(), this.defaultOptions.getInternalToolExecutionEnabled())); + requestOptions.setToolExecutionMaxIterations( + ModelOptionsUtils.mergeOption(runtimeOptions.getToolExecutionMaxIterations(), + this.defaultOptions.getToolExecutionMaxIterations())); requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(), this.defaultOptions.getToolNames())); requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(), @@ -484,6 +497,7 @@ Prompt buildRequestPrompt(Prompt prompt) { } else { requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled()); + requestOptions.setToolExecutionMaxIterations(this.defaultOptions.getToolExecutionMaxIterations()); requestOptions.setToolNames(this.defaultOptions.getToolNames()); requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks()); requestOptions.setToolContext(this.defaultOptions.getToolContext()); @@ -504,6 +518,10 @@ public Flux stream(Prompt prompt) { } public Flux internalStream(Prompt prompt, ChatResponse previousChatResponse) { + return this.internalStream(prompt, previousChatResponse, 1); + } + + public Flux internalStream(Prompt prompt, ChatResponse previousChatResponse, int iterations) { return Flux.deferContextual(contextView -> { ChatModelObservationContext observationContext = ChatModelObservationContext.builder() @@ -539,7 +557,7 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha // @formatter:off Flux flux = chatResponseFlux.flatMap(response -> { - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response, iterations)) { // FIXME: bounded elastic needs to be used since tool calling // is currently only synchronous return Flux.deferContextual((ctx) -> { @@ -558,9 +576,14 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha } else { // Send the tool execution result back to the model. - return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), response); + return this.internalStream( + new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), + response, + iterations + 1); } }).subscribeOn(Schedulers.boundedElastic()); + } else if (this.toolExecutionEligibilityPredicate.isLimitExceeded(prompt.getOptions(), iterations)){ + throw new ToolExecutionLimitExceededException(iterations); } else { return Flux.just(response); diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java index 68ae24a92e2..25a624f922d 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java @@ -45,6 +45,7 @@ * @author Grogdunn * @author Ilayaperumal Gopinathan * @author Soby Chacko + * @author lambochen * @since 1.0.0 */ @JsonInclude(Include.NON_NULL) @@ -126,6 +127,9 @@ public class VertexAiGeminiChatOptions implements ToolCallingChatOptions { @JsonIgnore private Boolean internalToolExecutionEnabled; + @JsonIgnore + private Integer toolExecutionMaxIterations = ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS; + @JsonIgnore private Map toolContext = new HashMap<>(); @@ -161,6 +165,7 @@ public static VertexAiGeminiChatOptions fromOptions(VertexAiGeminiChatOptions fr options.setGoogleSearchRetrieval(fromOptions.getGoogleSearchRetrieval()); options.setSafetySettings(fromOptions.getSafetySettings()); options.setInternalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()); + options.setToolExecutionMaxIterations(fromOptions.getToolExecutionMaxIterations()); options.setToolContext(fromOptions.getToolContext()); return options; } @@ -281,6 +286,16 @@ public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecut this.internalToolExecutionEnabled = internalToolExecutionEnabled; } + @Override + public Integer getToolExecutionMaxIterations() { + return this.toolExecutionMaxIterations; + } + + @Override + public void setToolExecutionMaxIterations(@Nullable Integer toolExecutionMaxIterations) { + this.toolExecutionMaxIterations = toolExecutionMaxIterations; + } + @Override public Double getFrequencyPenalty() { return this.frequencyPenalty; @@ -346,6 +361,7 @@ public boolean equals(Object o) { && Objects.equals(this.toolNames, that.toolNames) && Objects.equals(this.safetySettings, that.safetySettings) && Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled) + && Objects.equals(this.toolExecutionMaxIterations, that.toolExecutionMaxIterations) && Objects.equals(this.toolContext, that.toolContext); } @@ -354,7 +370,7 @@ public int hashCode() { return Objects.hash(this.stopSequences, this.temperature, this.topP, this.topK, this.candidateCount, this.frequencyPenalty, this.presencePenalty, this.maxOutputTokens, this.model, this.responseMimeType, this.toolCallbacks, this.toolNames, this.googleSearchRetrieval, this.safetySettings, - this.internalToolExecutionEnabled, this.toolContext); + this.internalToolExecutionEnabled, this.toolExecutionMaxIterations, this.toolContext); } @Override @@ -478,6 +494,11 @@ public Builder internalToolExecutionEnabled(boolean internalToolExecutionEnabled return this; } + public Builder toolExecutionMaxIterations(Integer toolExecutionMaxIterations) { + this.options.toolExecutionMaxIterations = toolExecutionMaxIterations; + return this; + } + public Builder toolContext(Map toolContext) { if (this.options.toolContext == null) { this.options.toolContext = toolContext; diff --git a/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptionsTest.java b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptionsTest.java new file mode 100644 index 00000000000..9a6320f53bf --- /dev/null +++ b/models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptionsTest.java @@ -0,0 +1,43 @@ +package org.springframework.ai.vertexai.gemini; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.model.tool.ToolCallingChatOptions; + +import static org.junit.jupiter.api.Assertions.*; + +class VertexAiGeminiChatOptionsTest { + + @Test + void optionsDefault() { + var options = new VertexAiGeminiChatOptions(); + + assertEquals(ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS, + options.getToolExecutionMaxIterations()); + } + + @Test + void builderDefault() { + var options = VertexAiGeminiChatOptions.builder().build(); + + assertEquals(ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS, + options.getToolExecutionMaxIterations()); + } + + @Test + void testBuilder() { + var options = VertexAiGeminiChatOptions.builder().toolExecutionMaxIterations(3).build(); + + assertEquals(3, options.getToolExecutionMaxIterations()); + } + + @Test + void fromOptions() { + var original = new VertexAiGeminiChatOptions(); + original.setToolExecutionMaxIterations(3); + + var copied = VertexAiGeminiChatOptions.fromOptions(original); + + assertEquals(original.getToolExecutionMaxIterations(), copied.getToolExecutionMaxIterations()); + } + +} diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java index 01402acc36a..db4b3f3b56d 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java @@ -27,6 +27,7 @@ import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.model.tool.ToolExecutionLimitExceededException; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; @@ -83,9 +84,11 @@ * @author Geng Rong * @author Alexandros Pappas * @author Ilayaperumal Gopinathan + * @author lambochen * @see ChatModel * @see StreamingChatModel * @see ZhiPuAiApi + * @see ToolCallingChatOptions * @since 1.0.0 M1 */ public class ZhiPuAiChatModel implements ChatModel { @@ -238,6 +241,10 @@ public ChatResponse call(Prompt prompt) { // Before moving any further, build the final request Prompt, // merging runtime and default options. Prompt requestPrompt = buildRequestPrompt(prompt); + return internalCall(requestPrompt, 1); + } + + private ChatResponse internalCall(Prompt requestPrompt, int iterations) { ChatCompletionRequest request = createRequest(requestPrompt, false); ChatModelObservationContext observationContext = ChatModelObservationContext.builder() @@ -256,7 +263,7 @@ public ChatResponse call(Prompt prompt) { var chatCompletion = completionEntity.getBody(); if (chatCompletion == null) { - logger.warn("No chat completion returned for prompt: {}", prompt); + logger.warn("No chat completion returned for prompt: {}", requestPrompt); return new ChatResponse(List.of()); } @@ -264,12 +271,12 @@ public ChatResponse call(Prompt prompt) { List generations = choices.stream().map(choice -> { // @formatter:off - Map metadata = Map.of( - "id", chatCompletion.id(), - "role", choice.message().role() != null ? choice.message().role().name() : "", - "finishReason", choice.finishReason() != null ? choice.finishReason().name() : "" - ); - // @formatter:on + Map metadata = Map.of( + "id", chatCompletion.id(), + "role", choice.message().role() != null ? choice.message().role().name() : "", + "finishReason", choice.finishReason() != null ? choice.finishReason().name() : "" + ); + // @formatter:on return buildGeneration(choice, metadata); }).toList(); @@ -279,7 +286,8 @@ public ChatResponse call(Prompt prompt) { return chatResponse; }); - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(requestPrompt.getOptions(), response)) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(requestPrompt.getOptions(), response, + iterations)) { var toolExecutionResult = this.toolCallingManager.executeToolCalls(requestPrompt, response); if (toolExecutionResult.returnDirect()) { // Return tool execution result directly to the client. @@ -290,9 +298,15 @@ public ChatResponse call(Prompt prompt) { } else { // Send the tool execution result back to the model. - return this.call(new Prompt(toolExecutionResult.conversationHistory(), requestPrompt.getOptions())); + return this.internalCall( + new Prompt(toolExecutionResult.conversationHistory(), requestPrompt.getOptions()), + iterations + 1); } } + else if (this.toolExecutionEligibilityPredicate.isLimitExceeded(requestPrompt.getOptions(), iterations)) { + throw new ToolExecutionLimitExceededException(iterations); + } + return response; } @@ -303,6 +317,10 @@ public ChatOptions getDefaultOptions() { @Override public Flux stream(Prompt prompt) { + return internalStream(prompt, 1); + } + + private Flux internalStream(Prompt prompt, int iterations) { return Flux.deferContextual(contextView -> { // Before moving any further, build the final request Prompt, // merging runtime and default options. @@ -333,18 +351,18 @@ public Flux stream(Prompt prompt) { String id = chatCompletion2.id(); // @formatter:off - List generations = chatCompletion2.choices().stream().map(choice -> { - if (choice.message().role() != null) { - roleMap.putIfAbsent(id, choice.message().role().name()); - } - Map metadata = Map.of( - "id", chatCompletion2.id(), - "role", roleMap.getOrDefault(id, ""), - "finishReason", choice.finishReason() != null ? choice.finishReason().name() : "" - ); - return buildGeneration(choice, metadata); - }).toList(); - // @formatter:on + List generations = chatCompletion2.choices().stream().map(choice -> { + if (choice.message().role() != null) { + roleMap.putIfAbsent(id, choice.message().role().name()); + } + Map metadata = Map.of( + "id", chatCompletion2.id(), + "role", roleMap.getOrDefault(id, ""), + "finishReason", choice.finishReason() != null ? choice.finishReason().name() : "" + ); + return buildGeneration(choice, metadata); + }).toList(); + // @formatter:on return new ChatResponse(generations, from(chatCompletion2)); } @@ -357,7 +375,7 @@ public Flux stream(Prompt prompt) { // @formatter:off Flux flux = chatResponse.flatMap(response -> { - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(requestPrompt.getOptions(), response)) { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(requestPrompt.getOptions(), response, iterations)) { // FIXME: bounded elastic needs to be used since tool calling // is currently only synchronous return Flux.deferContextual((ctx) -> { @@ -376,15 +394,19 @@ public Flux stream(Prompt prompt) { } else { // Send the tool execution result back to the model. - return this.stream(new Prompt(toolExecutionResult.conversationHistory(), requestPrompt.getOptions())); + return this.internalStream( + new Prompt(toolExecutionResult.conversationHistory(), requestPrompt.getOptions()), + iterations + 1); } }).subscribeOn(Schedulers.boundedElastic()); + } else if (this.toolExecutionEligibilityPredicate.isLimitExceeded(prompt.getOptions(), iterations)){ + throw new ToolExecutionLimitExceededException(iterations); } return Flux.just(response); - }) - .doOnError(observation::error) - .doFinally(s -> observation.stop()) - .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); + }) + .doOnError(observation::error) + .doFinally(s -> observation.stop()) + .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); // @formatter:on return new MessageAggregator().aggregate(flux, observationContext::setResponse); @@ -456,6 +478,9 @@ Prompt buildRequestPrompt(Prompt prompt) { requestOptions.setInternalToolExecutionEnabled( ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionEnabled(), this.defaultOptions.getInternalToolExecutionEnabled())); + requestOptions.setToolExecutionMaxIterations( + ModelOptionsUtils.mergeOption(runtimeOptions.getToolExecutionMaxIterations(), + this.defaultOptions.getToolExecutionMaxIterations())); requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(), this.defaultOptions.getToolNames())); requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(), @@ -465,6 +490,7 @@ Prompt buildRequestPrompt(Prompt prompt) { } else { requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled()); + requestOptions.setToolExecutionMaxIterations(this.defaultOptions.getToolExecutionMaxIterations()); requestOptions.setToolNames(this.defaultOptions.getToolNames()); requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks()); requestOptions.setToolContext(this.defaultOptions.getToolContext()); diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java index c31320defe1..2269a8a45c2 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java @@ -42,6 +42,7 @@ * @author Geng Rong * @author Thomas Vitale * @author Ilayaperumal Gopinathan + * @author lambochen * @since 1.0.0 M1 */ @JsonInclude(Include.NON_NULL) @@ -122,6 +123,9 @@ public class ZhiPuAiChatOptions implements ToolCallingChatOptions { @JsonIgnore private Boolean internalToolExecutionEnabled; + @JsonIgnore + private Integer toolExecutionMaxIterations = ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS; + @JsonIgnore private Map toolContext = new HashMap<>(); // @formatter:on @@ -145,6 +149,7 @@ public static ZhiPuAiChatOptions fromOptions(ZhiPuAiChatOptions fromOptions) { .toolCallbacks(fromOptions.getToolCallbacks()) .toolNames(fromOptions.getToolNames()) .internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled()) + .toolExecutionMaxIterations(fromOptions.getToolExecutionMaxIterations()) .toolContext(fromOptions.getToolContext()) .build(); } @@ -304,6 +309,16 @@ public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecut this.internalToolExecutionEnabled = internalToolExecutionEnabled; } + @Override + public Integer getToolExecutionMaxIterations() { + return this.toolExecutionMaxIterations; + } + + @Override + public void setToolExecutionMaxIterations(@Nullable Integer toolExecutionMaxIterations) { + this.toolExecutionMaxIterations = toolExecutionMaxIterations; + } + @Override public Map getToolContext() { return this.toolContext; @@ -328,6 +343,8 @@ public int hashCode() { result = prime * result + ((this.user == null) ? 0 : this.user.hashCode()); result = prime * result + ((this.internalToolExecutionEnabled == null) ? 0 : this.internalToolExecutionEnabled.hashCode()); + result = prime * result + + ((this.toolExecutionMaxIterations == null) ? 0 : this.toolExecutionMaxIterations.hashCode()); result = prime * result + ((this.toolCallbacks == null) ? 0 : this.toolCallbacks.hashCode()); result = prime * result + ((this.toolNames == null) ? 0 : this.toolNames.hashCode()); result = prime * result + ((this.toolContext == null) ? 0 : this.toolContext.hashCode()); @@ -434,6 +451,14 @@ else if (!this.doSample.equals(other.doSample)) { else if (!this.internalToolExecutionEnabled.equals(other.internalToolExecutionEnabled)) { return false; } + if (this.toolExecutionMaxIterations == null) { + if (other.toolExecutionMaxIterations != null) { + return false; + } + } + else if (!this.toolExecutionMaxIterations.equals(other.toolExecutionMaxIterations)) { + return false; + } if (this.toolContext == null) { if (other.toolContext != null) { return false; @@ -465,6 +490,8 @@ public ToolCallingChatOptions merge(ChatOptions options) { builder.internalToolExecutionEnabled(toolCallingChatOptions.getInternalToolExecutionEnabled() != null ? (toolCallingChatOptions).getInternalToolExecutionEnabled() : this.getInternalToolExecutionEnabled()); + builder.toolExecutionMaxIterations(toolCallingChatOptions.getToolExecutionMaxIterations() != null + ? toolCallingChatOptions.getToolExecutionMaxIterations() : this.getToolExecutionMaxIterations()); Set toolNames = new HashSet<>(); if (this.toolNames != null) { @@ -495,6 +522,7 @@ public ToolCallingChatOptions merge(ChatOptions options) { } else { builder.internalToolExecutionEnabled(this.internalToolExecutionEnabled); + builder.toolExecutionMaxIterations(this.toolExecutionMaxIterations); builder.toolNames(this.toolNames != null ? new HashSet<>(this.toolNames) : null); builder.toolCallbacks(this.toolCallbacks != null ? new ArrayList<>(this.toolCallbacks) : null); builder.toolContext(this.toolContext != null ? new HashMap<>(this.toolContext) : null); @@ -600,6 +628,11 @@ public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecut return this; } + public Builder toolExecutionMaxIterations(@Nullable Integer toolExecutionMaxIterations) { + this.options.setToolExecutionMaxIterations(toolExecutionMaxIterations); + return this; + } + public Builder toolContext(Map toolContext) { if (this.options.toolContext == null) { this.options.toolContext = toolContext; diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatOptionsTests.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatOptionsTests.java new file mode 100644 index 00000000000..c60429fc6b7 --- /dev/null +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatOptionsTests.java @@ -0,0 +1,36 @@ +package org.springframework.ai.zhipuai.chat; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.model.tool.ToolCallingChatOptions; +import org.springframework.ai.zhipuai.ZhiPuAiChatOptions; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + +/** + * @author lambochen + */ +class ZhiPuAiChatOptionsTests { + + @Test + void testDefaultValue() { + var options = new ZhiPuAiChatOptions(); + + assertThat(options.getToolExecutionMaxIterations()) + .isEqualTo(ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS); + } + + @Test + void testSetter() { + var options = new ZhiPuAiChatOptions(); + options.setToolExecutionMaxIterations(3); + assertThat(options.getToolExecutionMaxIterations()).isEqualTo(3); + } + + @Test + void testBuilder() { + var options = ZhiPuAiChatOptions.builder().toolExecutionMaxIterations(3).build(); + + assertThat(options.getToolExecutionMaxIterations()).isEqualTo(3); + } + +} diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java index 870db6931b9..f3f6d626403 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptions.java @@ -33,6 +33,7 @@ * Default implementation of {@link ToolCallingChatOptions}. * * @author Thomas Vitale + * @author lambochen * @since 1.0.0 */ public class DefaultToolCallingChatOptions implements ToolCallingChatOptions { @@ -46,6 +47,9 @@ public class DefaultToolCallingChatOptions implements ToolCallingChatOptions { @Nullable private Boolean internalToolExecutionEnabled; + @Nullable + private Integer toolExecutionMaxIterations = ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS; + @Nullable private String model; @@ -118,6 +122,16 @@ public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecut this.internalToolExecutionEnabled = internalToolExecutionEnabled; } + @Override + public Integer getToolExecutionMaxIterations() { + return this.toolExecutionMaxIterations; + } + + @Override + public void setToolExecutionMaxIterations(@Nullable Integer toolExecutionMaxIterations) { + this.toolExecutionMaxIterations = toolExecutionMaxIterations; + } + @Override @Nullable public String getModel() { @@ -206,6 +220,7 @@ public T copy() { options.setToolNames(getToolNames()); options.setToolContext(getToolContext()); options.setInternalToolExecutionEnabled(getInternalToolExecutionEnabled()); + options.setToolExecutionMaxIterations(getToolExecutionMaxIterations()); options.setModel(getModel()); options.setFrequencyPenalty(getFrequencyPenalty()); options.setMaxTokens(getMaxTokens()); @@ -277,6 +292,12 @@ public ToolCallingChatOptions.Builder internalToolExecutionEnabled( return this; } + @Override + public ToolCallingChatOptions.Builder toolExecutionMaxIterations(@Nullable Integer toolExecutionMaxIterations) { + this.options.setToolExecutionMaxIterations(toolExecutionMaxIterations); + return this; + } + @Override public ToolCallingChatOptions.Builder model(@Nullable String model) { this.options.setModel(model); diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java index f06e71aa869..bb113fc6e1b 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java @@ -37,12 +37,20 @@ * * @author Thomas Vitale * @author Ilayaperumal Gopinathan + * @author lambochen * @since 1.0.0 */ public interface ToolCallingChatOptions extends ChatOptions { boolean DEFAULT_TOOL_EXECUTION_ENABLED = true; + /** + * No limit for tool execution attempts. + */ + int TOOL_EXECUTION_NO_LIMIT = Integer.MAX_VALUE; + + int DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS = TOOL_EXECUTION_NO_LIMIT; + /** * ToolCallbacks to be registered with the ChatModel. */ @@ -76,6 +84,23 @@ public interface ToolCallingChatOptions extends ChatOptions { */ void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled); + /** + * Get the maximum number of iteration for tool execution. If the number of iterations + * exceeds the limit, an {@link ToolExecutionLimitExceededException} will be thrown. + * @return the maximum number of iteration. + * @see #getInternalToolExecutionEnabled() + * @see ToolExecutionLimitExceededException + */ + @Nullable + Integer getToolExecutionMaxIterations(); + + /** + * Set the maximum number of iteration for tool execution. If the number of iterations + * exceeds the limit, an {@link ToolExecutionLimitExceededException} will be thrown. + * @param toolExecutionMaxIterations the maximum number of iteration. + */ + void setToolExecutionMaxIterations(@Nullable Integer toolExecutionMaxIterations); + /** * Get the configured tool context. * @return the tool context map. @@ -109,6 +134,21 @@ static boolean isInternalToolExecutionEnabled(ChatOptions chatOptions) { return internalToolExecutionEnabled; } + static boolean isInternalToolExecutionEnabled(ChatOptions chatOptions, int toolExecutionIterations) { + boolean isInternalToolExecutionEnabled = isInternalToolExecutionEnabled(chatOptions); + if (!isInternalToolExecutionEnabled) { + return false; + } + + if (chatOptions instanceof ToolCallingChatOptions toolCallingChatOptions + && toolCallingChatOptions.getToolExecutionMaxIterations() != null) { + int maxIterations = toolCallingChatOptions.getToolExecutionMaxIterations(); + return toolExecutionIterations <= maxIterations; + } + + return DEFAULT_TOOL_EXECUTION_ENABLED; + } + static Set mergeToolNames(Set runtimeToolNames, Set defaultToolNames) { Assert.notNull(runtimeToolNames, "runtimeToolNames cannot be null"); Assert.notNull(defaultToolNames, "defaultToolNames cannot be null"); @@ -178,6 +218,13 @@ interface Builder extends ChatOptions.Builder { */ Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled); + /** + * the maximum number of attempts for tool execution. + * @param toolExecutionMaxIterations the maximum number of iteration. + * @return the {@link ToolCallingChatOptions} Builder. + */ + Builder toolExecutionMaxIterations(@Nullable Integer toolExecutionMaxIterations); + /** * Add a {@link Map} of context values into tool context. * @param context the map representing the tool context. diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolExecutionEligibilityChecker.java b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolExecutionEligibilityChecker.java index 6ba92766929..884fa2cb227 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolExecutionEligibilityChecker.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolExecutionEligibilityChecker.java @@ -27,6 +27,7 @@ * responses. * * @author Christian Tzolov + * @author lambochen */ public interface ToolExecutionEligibilityChecker extends Function { @@ -43,6 +44,23 @@ default boolean isToolExecutionRequired(ChatOptions promptOptions, ChatResponse return this.isInternalToolExecutionEnabled(promptOptions) && this.isToolCallResponse(chatResponse); } + /** + * Determines if tool execution should be performed based on the prompt options and + * chat response and toolExecutionIterations. + * @param promptOptions The options from the prompt + * @param chatResponse The response from the chat model + * @param toolExecutionIterations The number of toolExecutionIterations to execute the + * tool + * @return true if tool execution should be performed, false otherwise + */ + default boolean isToolExecutionRequired(ChatOptions promptOptions, ChatResponse chatResponse, + int toolExecutionIterations) { + Assert.notNull(promptOptions, "promptOptions cannot be null"); + Assert.notNull(chatResponse, "chatResponse cannot be null"); + return this.isInternalToolExecutionEnabled(promptOptions, toolExecutionIterations) + && this.isToolCallResponse(chatResponse); + } + /** * Determines if the response is a tool call message response. * @param chatResponse The response from the chat model call @@ -74,4 +92,36 @@ default boolean isInternalToolExecutionEnabled(ChatOptions chatOptions) { return internalToolExecutionEnabled; } + /** + * Determines if tool execution should be performed by the Spring AI or by the client. + * @param chatOptions The options from the chat + * @param toolExecutionIterations The number of toolExecutionIterations to execute the + * tool + * @return true if tool execution should be performed by Spring AI, false if it should + * be performed by the client + */ + default boolean isInternalToolExecutionEnabled(ChatOptions chatOptions, int toolExecutionIterations) { + boolean internalToolExecutionEnabled = isInternalToolExecutionEnabled(chatOptions); + if (!internalToolExecutionEnabled) { + return false; + } + + return !isLimitExceeded(chatOptions, toolExecutionIterations); + } + + /** + * Determines if the tool execution limit has been exceeded. + * @param promptOptions The options from the prompt + * @param toolExecutionIterations The number of toolExecutionIterations + * @return true if the tool execution limit has been exceeded, false otherwise + */ + default boolean isLimitExceeded(ChatOptions promptOptions, int toolExecutionIterations) { + if (promptOptions instanceof ToolCallingChatOptions toolCallingChatOptions) { + return toolCallingChatOptions.getToolExecutionMaxIterations() != null + && toolExecutionIterations > toolCallingChatOptions.getToolExecutionMaxIterations(); + } + + return false; + } + } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolExecutionEligibilityPredicate.java b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolExecutionEligibilityPredicate.java index e3f048ebd41..0d30620a624 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolExecutionEligibilityPredicate.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolExecutionEligibilityPredicate.java @@ -43,4 +43,39 @@ default boolean isToolExecutionRequired(ChatOptions promptOptions, ChatResponse return test(promptOptions, chatResponse); } + /** + * Determines if tool execution should be performed based on the prompt options and + * chat response and the number of toolExecutionIterations. + * @param promptOptions The options from the prompt + * @param chatResponse The response from the chat model + * @param toolExecutionIterations The number of toolExecutionIterations + * @return true if tool execution should be performed, false otherwise + * @see ToolCallingChatOptions#getToolExecutionMaxIterations() + * @see #isToolExecutionRequired(ChatOptions, ChatResponse) + */ + default boolean isToolExecutionRequired(ChatOptions promptOptions, ChatResponse chatResponse, + int toolExecutionIterations) { + boolean isToolExecutionRequired = isToolExecutionRequired(promptOptions, chatResponse); + if (!isToolExecutionRequired) { + return false; + } + + return !isLimitExceeded(promptOptions, toolExecutionIterations); + } + + /** + * Determines if the tool execution limit has been exceeded. + * @param promptOptions The options from the prompt + * @param toolExecutionIterations The number of toolExecutionIterations + * @return true if the tool execution limit has been exceeded, false otherwise + */ + default boolean isLimitExceeded(ChatOptions promptOptions, int toolExecutionIterations) { + if (promptOptions instanceof ToolCallingChatOptions toolCallingChatOptions) { + return toolCallingChatOptions.getToolExecutionMaxIterations() != null + && toolExecutionIterations > toolCallingChatOptions.getToolExecutionMaxIterations(); + } + + return false; + } + } diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolExecutionLimitExceededException.java b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolExecutionLimitExceededException.java new file mode 100644 index 00000000000..042fc7230c8 --- /dev/null +++ b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolExecutionLimitExceededException.java @@ -0,0 +1,42 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.tool; + +/** + * Exception thrown when the tool execution limit is exceeded. + * + * @author lambochen + * @see ToolCallingChatOptions#getToolExecutionMaxIterations() + */ +public class ToolExecutionLimitExceededException extends RuntimeException { + + private final Integer maxIterations; + + public ToolExecutionLimitExceededException(Integer maxIterations) { + this("Tool execution limit exceeded: " + maxIterations, maxIterations); + } + + public ToolExecutionLimitExceededException(String message, Integer maxIterations) { + super(message); + this.maxIterations = maxIterations; + } + + public Integer getMaxIterations() { + return maxIterations; + } + +} diff --git a/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolExecutionRequest.java b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolExecutionRequest.java new file mode 100644 index 00000000000..5d349846f7f --- /dev/null +++ b/spring-ai-model/src/main/java/org/springframework/ai/model/tool/ToolExecutionRequest.java @@ -0,0 +1,24 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.model.tool; + +/** + * @author lambochen + */ +public interface ToolExecutionRequest { + +} diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptionsTests.java b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptionsTests.java index 45557f23a6d..5bf394f0383 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptionsTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/DefaultToolCallingChatOptionsTests.java @@ -33,6 +33,7 @@ * Unit tests for {@link DefaultToolCallingChatOptions}. * * @author Thomas Vitale + * @author lambochen */ class DefaultToolCallingChatOptionsTests { @@ -140,6 +141,7 @@ void copyShouldCreateNewInstanceWithSameValues() { original.setToolNames(Set.of("tool1")); original.setToolContext(Map.of("key", "value")); original.setInternalToolExecutionEnabled(true); + original.setToolExecutionMaxIterations(ToolCallingChatOptions.TOOL_EXECUTION_NO_LIMIT); original.setModel("gpt-4"); original.setTemperature(0.7); @@ -150,6 +152,7 @@ void copyShouldCreateNewInstanceWithSameValues() { assertThat(c.getToolNames()).isEqualTo(original.getToolNames()); assertThat(c.getToolContext()).isEqualTo(original.getToolContext()); assertThat(c.getInternalToolExecutionEnabled()).isEqualTo(original.getInternalToolExecutionEnabled()); + assertThat(c.getToolExecutionMaxIterations()).isEqualTo(original.getToolExecutionMaxIterations()); assertThat(c.getModel()).isEqualTo(original.getModel()); assertThat(c.getTemperature()).isEqualTo(original.getTemperature()); }); @@ -180,6 +183,7 @@ void builderShouldCreateOptionsWithAllProperties() { .toolNames(Set.of("tool1")) .toolContext(context) .internalToolExecutionEnabled(true) + .toolExecutionMaxIterations(3) .model("gpt-4") .temperature(0.7) .maxTokens(100) @@ -195,6 +199,7 @@ void builderShouldCreateOptionsWithAllProperties() { assertThat(o.getToolNames()).containsExactly("tool1"); assertThat(o.getToolContext()).isEqualTo(context); assertThat(o.getInternalToolExecutionEnabled()).isTrue(); + assertThat(o.getToolExecutionMaxIterations()).isEqualTo(3); assertThat(o.getModel()).isEqualTo("gpt-4"); assertThat(o.getTemperature()).isEqualTo(0.7); assertThat(o.getMaxTokens()).isEqualTo(100); @@ -233,6 +238,13 @@ void deprecatedMethodsShouldWorkCorrectly() { options.setInternalToolExecutionEnabled(true); assertThat(options.getInternalToolExecutionEnabled()).isTrue(); + + // default value check + assertThat(options.getToolExecutionMaxIterations()) + .isEqualTo(ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS); + + options.setToolExecutionMaxIterations(3); + assertThat(options.getToolExecutionMaxIterations()).isEqualTo(3); } } diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolCallingChatOptionsTests.java b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolCallingChatOptionsTests.java index 6d5d599dccd..904a852dc70 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolCallingChatOptionsTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolCallingChatOptionsTests.java @@ -33,6 +33,7 @@ * Unit tests for {@link ToolCallingChatOptions}. * * @author Thomas Vitale + * @author lambochen */ class ToolCallingChatOptionsTests { @@ -50,6 +51,20 @@ void whenToolCallingChatOptionsAndExecutionEnabledFalse() { assertThat(ToolCallingChatOptions.isInternalToolExecutionEnabled(options)).isFalse(); } + @Test + void whenToolCallingChatOptionsAndMaxIterationsOver() { + ToolCallingChatOptions options = new DefaultToolCallingChatOptions(); + options.setToolExecutionMaxIterations(1); + // 3 > 1 + assertThat(ToolCallingChatOptions.isInternalToolExecutionEnabled(options, 3)).isFalse(); + } + + @Test + void whenToolCallingChatOptionsAndMaxIterationsDefault() { + ToolCallingChatOptions options = new DefaultToolCallingChatOptions(); + assertThat(ToolCallingChatOptions.isInternalToolExecutionEnabled(options, 1)).isTrue(); + } + @Test void whenToolCallingChatOptionsAndExecutionEnabledDefault() { ToolCallingChatOptions options = new DefaultToolCallingChatOptions(); diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolExecutionEligibilityCheckerTest.java b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolExecutionEligibilityCheckerTest.java new file mode 100644 index 00000000000..c6a52807ffb --- /dev/null +++ b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolExecutionEligibilityCheckerTest.java @@ -0,0 +1,54 @@ +package org.springframework.ai.model.tool; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; + +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +class ToolExecutionEligibilityCheckerTest { + + @Test + void isToolExecutionRequired() { + ToolExecutionEligibilityChecker checker = new TestToolExecutionEligibilityChecker(); + + ToolCallingChatOptions promptOptions = ToolCallingChatOptions.builder().build(); + ChatResponse chatResponse = new ChatResponse(List.of(new Generation(new AssistantMessage("test")))); + promptOptions.setToolExecutionMaxIterations(2); + + assertThat(checker.isToolExecutionRequired(promptOptions, chatResponse, 1)).isTrue(); + assertThat(checker.isToolExecutionRequired(promptOptions, chatResponse, 2)).isTrue(); + + // attempts value is oversize + assertThat(checker.isToolExecutionRequired(promptOptions, chatResponse, 3)).isFalse(); + } + + @Test + void isInternalToolExecutionEnabled() { + + ToolExecutionEligibilityChecker checker = new TestToolExecutionEligibilityChecker(); + + ToolCallingChatOptions promptOptions = ToolCallingChatOptions.builder().build(); + promptOptions.setToolExecutionMaxIterations(2); + + assertThat(checker.isInternalToolExecutionEnabled(promptOptions, 1)).isTrue(); + assertThat(checker.isInternalToolExecutionEnabled(promptOptions, 2)).isTrue(); + + // attempts value is oversize + assertThat(checker.isInternalToolExecutionEnabled(promptOptions, 3)).isFalse(); + + } + + class TestToolExecutionEligibilityChecker implements ToolExecutionEligibilityChecker { + + @Override + public Boolean apply(ChatResponse chatResponse) { + return true; + } + + } + +} diff --git a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolExecutionEligibilityPredicateTests.java b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolExecutionEligibilityPredicateTests.java index d347f9190f1..a0cae386210 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolExecutionEligibilityPredicateTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/model/tool/ToolExecutionEligibilityPredicateTests.java @@ -45,6 +45,20 @@ void whenIsToolExecutionRequiredWithNullPromptOptions() { .hasMessageContaining("promptOptions cannot be null"); } + @Test + void whenIsToolExecutionRequiredWithAttempts() { + ToolExecutionEligibilityPredicate predicate = new TestToolExecutionEligibilityPredicate(); + ToolCallingChatOptions promptOptions = ToolCallingChatOptions.builder().build(); + ChatResponse chatResponse = new ChatResponse(List.of(new Generation(new AssistantMessage("test")))); + promptOptions.setToolExecutionMaxIterations(2); + + assertThat(predicate.isToolExecutionRequired(promptOptions, chatResponse, 1)).isTrue(); + assertThat(predicate.isToolExecutionRequired(promptOptions, chatResponse, 2)).isTrue(); + + // attempts value is oversize + assertThat(predicate.isToolExecutionRequired(promptOptions, chatResponse, 3)).isFalse(); + } + @Test void whenIsToolExecutionRequiredWithNullChatResponse() { ToolExecutionEligibilityPredicate predicate = new TestToolExecutionEligibilityPredicate();