Skip to content

ToolCallingChatOptions support internalToolExecutionMaxIterations, to limit the maximum number of tool calls and prevent infinite recursive calls to LLM in special cases #3380

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 33 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
72bceab
ToolCallingChatOptions support internalToolExecutionMaxIterations
lambochen May 29, 2025
33847aa
rename internalToolExecutionMaxAttempts
lambochen May 29, 2025
e243d42
ToolExecutionEligibilityChecker add logical for attempts
lambochen May 29, 2025
670d691
OpenAiChatModel support internalToolExecutionMaxAttempts
lambochen May 29, 2025
092eb6c
internalToolExecutionEnabled set default value is Integer.MAX_VALUE
lambochen May 29, 2025
2d128ba
OpenAiChatModel support InternalToolExecutionMaxAttempts
lambochen May 29, 2025
253d19e
AnthropicChatModel support InternalToolExecutionMaxAttempts
lambochen May 29, 2025
93efce5
AzureOpenAiChatModel support InternalToolExecutionMaxAttempts
lambochen May 29, 2025
d660152
BedrockProxyChatModel support internalToolExecutionMaxAttempts
lambochen May 29, 2025
2c7fc3d
DeepSeekChatModel support internalToolExecutionMaxAttempts
lambochen May 29, 2025
1a11eca
MiniMaxChatModel support InternalToolExecutionMaxAttempts
lambochen May 29, 2025
69d0d5e
MistralAiChatModel support InternalToolExecutionMaxAttempts
lambochen May 29, 2025
f808731
OllamaChatModel support InternalToolExecutionMaxAttempts
lambochen May 29, 2025
bf9df71
VertexAiGeminiChatModel support InternalToolExecutionMaxAttempts
lambochen May 29, 2025
8dd816d
ZhiPuAiChatModel support InternalToolExecutionMaxAttempts
lambochen May 29, 2025
0ec9490
fix: api compatability
lambochen May 29, 2025
8aaf132
UT for IsToolExecutionRequiredWithAttempts
lambochen May 30, 2025
4269f37
UT for ToolExecutionEligibilityChecker attempts
lambochen May 30, 2025
8e1f251
merge main
lambochen May 31, 2025
fcf1a76
UT config for openai
lambochen May 31, 2025
ae8d1ab
any UT for internalTooolExecutionMaxAttempts
lambochen May 31, 2025
3b1162e
add UT for internalToolCallingExecutionMaxAttempts in some options
lambochen May 31, 2025
e894e67
code format by spring-javaformat plugin
lambochen May 31, 2025
c63341b
fix attempts check logical
lambochen May 31, 2025
56fba7c
ToolCallingChatOptions: rename attempts to iterations for tool execution
lambochen Jun 1, 2025
80c2e26
rename iterations to toolExecutionIterations
lambochen Jun 2, 2025
0b364b3
merge origin main to refresh code
lambochen Jun 18, 2025
ee79d7b
merge main to fix conflict
lambochen Jun 23, 2025
86ba656
ChatModel: throw ToolExecutionLimitExceededException when over limit …
lambochen Jun 24, 2025
b737af9
fix ut for org.springframework.ai.model.tool.ToolExecutionEligibility…
lambochen Jun 24, 2025
3d4b4be
merge main and fix conflict
lambochen Jul 3, 2025
28abaa6
ChatOptions: internalToolExecutionMaxIterations rename to toolExecuti…
lambochen Jul 3, 2025
5dc08fa
fix: MiniMaxChatModel call requestPrompt
lambochen Jul 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.model.tool.ToolExecutionLimitExceededException;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;
Expand Down Expand Up @@ -91,6 +92,7 @@
* @author Alexandros Pappas
* @author Jonghoon Park
* @author Soby Chacko
* @author lambochen
* @since 1.0.0
*/
public class AnthropicChatModel implements ChatModel {
Expand Down Expand Up @@ -175,6 +177,10 @@ public ChatResponse call(Prompt prompt) {
}

public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) {
return this.internalCall(prompt, previousChatResponse, 1);
}

public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse, int iterations) {
ChatCompletionRequest request = createRequest(prompt, false);

ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
Expand Down Expand Up @@ -204,7 +210,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
return chatResponse;
});

if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response, iterations)) {
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
if (toolExecutionResult.returnDirect()) {
// Return tool execution result directly to the client.
Expand All @@ -216,9 +222,12 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
else {
// Send the tool execution result back to the model.
return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
response);
response, iterations + 1);
}
}
else if (this.toolExecutionEligibilityPredicate.isLimitExceeded(prompt.getOptions(), iterations)) {
throw new ToolExecutionLimitExceededException(iterations);
}

return response;
}
Expand All @@ -237,6 +246,10 @@ public Flux<ChatResponse> stream(Prompt prompt) {
}

public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse) {
return this.internalStream(prompt, previousChatResponse, 1);
}

public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse, int iterations) {
return Flux.deferContextual(contextView -> {
ChatCompletionRequest request = createRequest(prompt, true);

Expand All @@ -261,7 +274,7 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
Usage accumulatedUsage = UsageCalculator.getCumulativeUsage(currentChatResponseUsage, previousChatResponse);
ChatResponse chatResponse = toChatResponse(chatCompletionResponse, accumulatedUsage);

if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse)) {
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse, iterations)) {

if (chatResponse.hasFinishReasons(Set.of("tool_use"))) {
// FIXME: bounded elastic needs to be used since tool calling
Expand All @@ -288,10 +301,12 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
}
}).subscribeOn(Schedulers.boundedElastic());

} else {
} else {
return Mono.empty();
}

} else if (this.toolExecutionEligibilityPredicate.isLimitExceeded(prompt.getOptions(), iterations)){
throw new ToolExecutionLimitExceededException(iterations);
} else {
// If internal tool execution is not required, just return the chat response.
return Mono.just(chatResponse);
Expand Down Expand Up @@ -453,6 +468,8 @@ Prompt buildRequestPrompt(Prompt prompt) {
requestOptions.setInternalToolExecutionEnabled(
ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionEnabled(),
this.defaultOptions.getInternalToolExecutionEnabled()));
requestOptions.setToolExecutionMaxIterations(ModelOptionsUtils.mergeOption(
runtimeOptions.getToolExecutionMaxIterations(), defaultOptions.getToolExecutionMaxIterations()));
requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(),
this.defaultOptions.getToolNames()));
requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(),
Expand All @@ -463,6 +480,7 @@ Prompt buildRequestPrompt(Prompt prompt) {
else {
requestOptions.setHttpHeaders(this.defaultOptions.getHttpHeaders());
requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled());
requestOptions.setToolExecutionMaxIterations(this.defaultOptions.getToolExecutionMaxIterations());
requestOptions.setToolNames(this.defaultOptions.getToolNames());
requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks());
requestOptions.setToolContext(this.defaultOptions.getToolContext());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
* @author Thomas Vitale
* @author Alexandros Pappas
* @author Ilayaperumal Gopinathan
* @author lambochen
* @since 1.0.0
*/
@JsonInclude(Include.NON_NULL)
Expand Down Expand Up @@ -79,6 +80,9 @@ public class AnthropicChatOptions implements ToolCallingChatOptions {
@JsonIgnore
private Boolean internalToolExecutionEnabled;

@JsonIgnore
private Integer toolExecutionMaxIterations = ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS;

@JsonIgnore
private Map<String, Object> toolContext = new HashMap<>();

Expand Down Expand Up @@ -109,6 +113,7 @@ public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions)
fromOptions.getToolCallbacks() != null ? new ArrayList<>(fromOptions.getToolCallbacks()) : null)
.toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null)
.internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled())
.toolExecutionMaxIterations(fromOptions.getToolExecutionMaxIterations())
.toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null)
.httpHeaders(fromOptions.getHttpHeaders() != null ? new HashMap<>(fromOptions.getHttpHeaders()) : null)
.build();
Expand Down Expand Up @@ -226,6 +231,16 @@ public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecut
this.internalToolExecutionEnabled = internalToolExecutionEnabled;
}

@Override
public Integer getToolExecutionMaxIterations() {
return this.toolExecutionMaxIterations;
}

@Override
public void setToolExecutionMaxIterations(@Nullable Integer toolExecutionMaxIterations) {
this.toolExecutionMaxIterations = toolExecutionMaxIterations;
}

@Override
@JsonIgnore
public Double getFrequencyPenalty() {
Expand Down Expand Up @@ -281,6 +296,7 @@ public boolean equals(Object o) {
&& Objects.equals(this.toolCallbacks, that.toolCallbacks)
&& Objects.equals(this.toolNames, that.toolNames)
&& Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled)
&& Objects.equals(this.toolExecutionMaxIterations, that.toolExecutionMaxIterations)
&& Objects.equals(this.toolContext, that.toolContext)
&& Objects.equals(this.httpHeaders, that.httpHeaders);
}
Expand All @@ -289,7 +305,7 @@ public boolean equals(Object o) {
public int hashCode() {
return Objects.hash(this.model, this.maxTokens, this.metadata, this.stopSequences, this.temperature, this.topP,
this.topK, this.thinking, this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled,
this.toolContext, this.httpHeaders);
this.toolExecutionMaxIterations, this.toolContext, this.httpHeaders);
}

public static class Builder {
Expand Down Expand Up @@ -374,6 +390,11 @@ public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecut
return this;
}

public Builder toolExecutionMaxIterations(@Nullable Integer toolExecutionMaxIterations) {
this.options.setToolExecutionMaxIterations(toolExecutionMaxIterations);
return this;
}

public Builder toolContext(Map<String, Object> toolContext) {
if (this.options.toolContext == null) {
this.options.toolContext = toolContext;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@
import org.junit.jupiter.api.Test;

import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest.Metadata;
import org.springframework.ai.model.tool.ToolCallingChatOptions;

import static org.assertj.core.api.Assertions.assertThat;

/**
* Tests for {@link AnthropicChatOptions}.
*
* @author Alexandros Pappas
* @author lambochen
*/
class AnthropicChatOptionsTests {

Expand All @@ -42,10 +44,13 @@ void testBuilderWithAllFields() {
.topP(0.8)
.topK(50)
.metadata(new Metadata("userId_123"))
.toolExecutionMaxIterations(3)
.build();

assertThat(options).extracting("model", "maxTokens", "stopSequences", "temperature", "topP", "topK", "metadata")
.containsExactly("test-model", 100, List.of("stop1", "stop2"), 0.7, 0.8, 50, new Metadata("userId_123"));
assertThat(options)
.extracting("model", "maxTokens", "stopSequences", "temperature", "topP", "topK", "metadata",
"toolExecutionMaxIterations")
.containsExactly("test-model", 100, List.of("stop1", "stop2"), 0.7, 0.8, 50, new Metadata("userId_123"), 3);
}

@Test
Expand All @@ -59,6 +64,7 @@ void testCopy() {
.topK(50)
.metadata(new Metadata("userId_123"))
.toolContext(Map.of("key1", "value1"))
.toolExecutionMaxIterations(3)
.build();

AnthropicChatOptions copied = original.copy();
Expand All @@ -67,6 +73,8 @@ void testCopy() {
// Ensure deep copy
assertThat(copied.getStopSequences()).isNotSameAs(original.getStopSequences());
assertThat(copied.getToolContext()).isNotSameAs(original.getToolContext());

assertThat(copied.getToolExecutionMaxIterations()).isEqualTo(3);
}

@Test
Expand All @@ -79,6 +87,7 @@ void testSetters() {
options.setTopP(0.8);
options.setStopSequences(List.of("stop1", "stop2"));
options.setMetadata(new Metadata("userId_123"));
options.setToolExecutionMaxIterations(3);

assertThat(options.getModel()).isEqualTo("test-model");
assertThat(options.getMaxTokens()).isEqualTo(100);
Expand All @@ -87,6 +96,7 @@ void testSetters() {
assertThat(options.getTopP()).isEqualTo(0.8);
assertThat(options.getStopSequences()).isEqualTo(List.of("stop1", "stop2"));
assertThat(options.getMetadata()).isEqualTo(new Metadata("userId_123"));
assertThat(options.getToolExecutionMaxIterations()).isEqualTo(3);
}

@Test
Expand All @@ -99,6 +109,8 @@ void testDefaultValues() {
assertThat(options.getTopP()).isNull();
assertThat(options.getStopSequences()).isNull();
assertThat(options.getMetadata()).isNull();
assertThat(options.getToolExecutionMaxIterations())
.isEqualTo(ToolCallingChatOptions.DEFAULT_TOOL_EXECUTION_MAX_ITERATIONS);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.model.tool.ToolExecutionLimitExceededException;
import reactor.core.publisher.Flux;
import reactor.core.scheduler.Schedulers;

Expand Down Expand Up @@ -123,8 +124,10 @@
* @author Berjan Jonker
* @author Andres da Silva Santos
* @author Bart Veenstra
* @author lambochen
* @see ChatModel
* @see com.azure.ai.openai.OpenAIClient
* @see ToolCallingChatOptions
* @since 1.0.0
*/
public class AzureOpenAiChatModel implements ChatModel {
Expand Down Expand Up @@ -252,6 +255,10 @@ public ChatResponse call(Prompt prompt) {
}

public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) {
return internalCall(prompt, previousChatResponse, 1);
}

public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse, int iterations) {

ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
.prompt(prompt)
Expand All @@ -271,7 +278,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
return chatResponse;
});

if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response, iterations)) {
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
if (toolExecutionResult.returnDirect()) {
// Return tool execution result directly to the client.
Expand All @@ -283,9 +290,12 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
else {
// Send the tool execution result back to the model.
return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
response);
response, iterations + 1);
}
}
else if (this.toolExecutionEligibilityPredicate.isLimitExceeded(prompt.getOptions(), iterations)) {
throw new ToolExecutionLimitExceededException(iterations);
}

return response;
}
Expand All @@ -299,6 +309,10 @@ public Flux<ChatResponse> stream(Prompt prompt) {
}

public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse) {
return this.internalStream(prompt, previousChatResponse, 1);
}

public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse, int iterations) {

return Flux.deferContextual(contextView -> {
ChatCompletionsOptions options = toAzureChatCompletionsOptions(prompt);
Expand Down Expand Up @@ -378,7 +392,8 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
});

return chatResponseFlux.flatMap(chatResponse -> {
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse)) {
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse,
iterations)) {
// FIXME: bounded elastic needs to be used since tool calling
// is currently only synchronous
return Flux.deferContextual((ctx) -> {
Expand All @@ -401,10 +416,13 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
// Send the tool execution result back to the model.
return this.internalStream(
new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
chatResponse);
chatResponse, iterations + 1);
}
}).subscribeOn(Schedulers.boundedElastic());
}
else if (this.toolExecutionEligibilityPredicate.isLimitExceeded(prompt.getOptions(), iterations)) {
throw new ToolExecutionLimitExceededException(iterations);
}

Flux<ChatResponse> flux = Flux.just(chatResponse)
.doOnError(observation::error)
Expand Down Expand Up @@ -674,6 +692,9 @@ Prompt buildRequestPrompt(Prompt prompt) {
requestOptions.setInternalToolExecutionEnabled(
ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionEnabled(),
this.defaultOptions.getInternalToolExecutionEnabled()));
runtimeOptions.setToolExecutionMaxIterations(
ModelOptionsUtils.mergeOption(runtimeOptions.getToolExecutionMaxIterations(),
this.defaultOptions.getToolExecutionMaxIterations()));
requestOptions.setStreamUsage(ModelOptionsUtils.mergeOption(runtimeOptions.getStreamUsage(),
this.defaultOptions.getStreamUsage()));
requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(),
Expand All @@ -685,6 +706,7 @@ Prompt buildRequestPrompt(Prompt prompt) {
}
else {
requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled());
requestOptions.setToolExecutionMaxIterations(this.defaultOptions.getToolExecutionMaxIterations());
requestOptions.setStreamUsage(this.defaultOptions.getStreamUsage());
requestOptions.setToolNames(this.defaultOptions.getToolNames());
requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks());
Expand Down
Loading