Skip to content

Commit 5017749

Browse files
tzolovmarkpollack
authored andcommitted
Add proxy tool calls option to chat models
This commit introduces a new proxyToolCalls option for various chat models in the Spring AI project. When enabled, it allows the client to handle function calls externally instead of being processed internally by Spring AI. The change affects multiple chat model implementations, including: AnthropicChatModel AzureOpenAiChatModel MiniMaxChatModel MistralAiChatModel MoonshotChatModel OllamaChatModel OpenAiChatModel VertexAiGeminiChatModel ZhiPuAiChatModel The proxyToolCalls option is added to the respective chat options classes and integrated into the AbstractToolCallSupport class for consistent handling across different implementations. The proxyToolCalls option can be set either programmatically via the <ModelName>ChatOptions.builder().withProxyToolCalls() method or the spring.ai.<model-name>.chat.options.proxy-tool-calls application property. Documentation for the new option is also updated in the relevant Antora pages. Added ITs for proxy tool calls Remove ChatClientPromptRequestSpec and all ChatClient.prompt() overloads can how take advantage of the full fluent API. Docs updated Resolves #1367
1 parent acb31e7 commit 5017749

File tree

42 files changed

+1023
-245
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1023
-245
lines changed

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,8 @@ public ChatResponse call(Prompt prompt) {
225225
return chatResponse;
226226
});
227227

228-
if (response != null && this.isToolCall(response, Set.of("tool_use"))) {
228+
if (!isProxyToolCalls(prompt, this.defaultOptions) && response != null
229+
&& this.isToolCall(response, Set.of("tool_use"))) {
229230
var toolCallConversation = handleToolCalls(prompt, response);
230231
return this.call(new Prompt(toolCallConversation, prompt.getOptions()));
231232
}
@@ -256,7 +257,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
256257
Flux<ChatResponse> chatResponseFlux = response.switchMap(chatCompletionResponse -> {
257258
ChatResponse chatResponse = toChatResponse(chatCompletionResponse);
258259

259-
if (this.isToolCall(chatResponse, Set.of("tool_use"))) {
260+
if (!isProxyToolCalls(prompt, this.defaultOptions) && this.isToolCall(chatResponse, Set.of("tool_use"))) {
260261
var toolCallConversation = handleToolCalls(prompt, chatResponse);
261262
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
262263
}

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ public class AnthropicChatOptions implements ChatOptions, FunctionCallingOptions
7777
@NestedConfigurationProperty
7878
@JsonIgnore
7979
private Set<String> functions = new HashSet<>();
80+
81+
@JsonIgnore
82+
private Boolean proxyToolCalls;
8083
// @formatter:on
8184

8285
public static Builder builder() {
@@ -144,6 +147,11 @@ public Builder withFunction(String functionName) {
144147
return this;
145148
}
146149

150+
public Builder withProxyToolCalls(Boolean proxyToolCalls) {
151+
this.options.proxyToolCalls = proxyToolCalls;
152+
return this;
153+
}
154+
147155
public AnthropicChatOptions build() {
148156
return this.options;
149157
}
@@ -246,6 +254,15 @@ public Double getPresencePenalty() {
246254
return null;
247255
}
248256

257+
@Override
258+
public Boolean getProxyToolCalls() {
259+
return this.proxyToolCalls;
260+
}
261+
262+
public void setProxyToolCalls(Boolean proxyToolCalls) {
263+
this.proxyToolCalls = proxyToolCalls;
264+
}
265+
249266
@Override
250267
public AnthropicChatOptions copy() {
251268
return fromOptions(this);
@@ -261,6 +278,7 @@ public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions)
261278
.withTopK(fromOptions.getTopK())
262279
.withFunctionCallbacks(fromOptions.getFunctionCallbacks())
263280
.withFunctions(fromOptions.getFunctions())
281+
.withProxyToolCalls(fromOptions.getProxyToolCalls())
264282
.build();
265283
}
266284

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,8 @@ public ChatResponse call(Prompt prompt) {
151151

152152
ChatResponse chatResponse = toChatResponse(chatCompletions);
153153

154-
if (isToolCall(chatResponse, Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS).toLowerCase()))) {
154+
if (!isProxyToolCalls(prompt, this.defaultOptions)
155+
&& isToolCall(chatResponse, Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS).toLowerCase()))) {
155156
var toolCallConversation = handleToolCalls(prompt, chatResponse);
156157
// Recursively call the call method with the tool call message
157158
// conversation that contains the call responses.
@@ -199,7 +200,8 @@ public Flux<ChatResponse> stream(Prompt prompt) {
199200

200201
ChatResponse chatResponse = toChatResponse(chatCompletions);
201202

202-
if (isToolCall(chatResponse, Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS).toLowerCase()))) {
203+
if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(chatResponse,
204+
Set.of(String.valueOf(CompletionsFinishReason.TOOL_CALLS).toLowerCase()))) {
203205
var toolCallConversation = handleToolCalls(prompt, chatResponse);
204206
// Recursively call the call method with the tool call message
205207
// conversation that contains the call responses.

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.springframework.ai.model.function.FunctionCallingOptions;
3232
import org.springframework.boot.context.properties.NestedConfigurationProperty;
3333
import org.springframework.util.Assert;
34+
import org.stringtemplate.v4.compiler.CodeGenerator.primary_return;
3435

3536
/**
3637
* The configuration information for a chat completions request. Completions support a
@@ -161,6 +162,9 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptio
161162
@JsonIgnore
162163
private Set<String> functions = new HashSet<>();
163164

165+
@JsonIgnore
166+
private Boolean proxyToolCalls;
167+
164168
public static Builder builder() {
165169
return new Builder();
166170
}
@@ -250,6 +254,11 @@ public Builder withResponseFormat(AzureOpenAiResponseFormat responseFormat) {
250254
return this;
251255
}
252256

257+
public Builder withProxyToolCalls(Boolean proxyToolCalls) {
258+
this.options.proxyToolCalls = proxyToolCalls;
259+
return this;
260+
}
261+
253262
public AzureOpenAiChatOptions build() {
254263
return this.options;
255264
}
@@ -395,6 +404,15 @@ public Integer getTopK() {
395404
return null;
396405
}
397406

407+
@Override
408+
public Boolean getProxyToolCalls() {
409+
return this.proxyToolCalls;
410+
}
411+
412+
public void setProxyToolCalls(Boolean proxyToolCalls) {
413+
this.proxyToolCalls = proxyToolCalls;
414+
}
415+
398416
@Override
399417
public AzureOpenAiChatOptions copy() {
400418
return fromOptions(this);
@@ -413,6 +431,7 @@ public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOpti
413431
.withUser(fromOptions.getUser())
414432
.withFunctionCallbacks(fromOptions.getFunctionCallbacks())
415433
.withFunctions(fromOptions.getFunctions())
434+
.withResponseFormat(fromOptions.getResponseFormat())
416435
.build();
417436
}
418437

models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ public ChatResponse call(Prompt prompt) {
190190

191191
ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody()));
192192

193-
if (isToolCall(chatResponse,
193+
if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(chatResponse,
194194
Set.of(ChatCompletionFinishReason.TOOL_CALLS.name(), ChatCompletionFinishReason.STOP.name()))) {
195195
var toolCallConversation = handleToolCalls(prompt, chatResponse);
196196
// Recursively call the call method with the tool call message
@@ -254,7 +254,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
254254

255255
return chatResponse.flatMap(response -> {
256256

257-
if (isToolCall(response,
257+
if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(response,
258258
Set.of(ChatCompletionFinishReason.TOOL_CALLS.name(), ChatCompletionFinishReason.STOP.name()))) {
259259
var toolCallConversation = handleToolCalls(prompt, response);
260260
// Recursively call the stream method with the tool call message

models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,9 @@ public class MiniMaxChatOptions implements FunctionCallingOptions, ChatOptions {
142142
@NestedConfigurationProperty
143143
@JsonIgnore
144144
private Set<String> functions = new HashSet<>();
145+
146+
@JsonIgnore
147+
private Boolean proxyToolCalls;
145148
// @formatter:on
146149

147150
public static Builder builder() {
@@ -242,6 +245,11 @@ public Builder withFunction(String functionName) {
242245
return this;
243246
}
244247

248+
public Builder withProxyToolCalls(Boolean proxyToolCalls) {
249+
this.options.proxyToolCalls = proxyToolCalls;
250+
return this;
251+
}
252+
245253
public MiniMaxChatOptions build() {
246254
return this.options;
247255
}
@@ -394,6 +402,15 @@ public Integer getTopK() {
394402
return null;
395403
}
396404

405+
@Override
406+
public Boolean getProxyToolCalls() {
407+
return this.proxyToolCalls;
408+
}
409+
410+
public void setProxyToolCalls(Boolean proxyToolCalls) {
411+
this.proxyToolCalls = proxyToolCalls;
412+
}
413+
397414
@Override
398415
public int hashCode() {
399416
final int prime = 31;
@@ -411,6 +428,7 @@ public int hashCode() {
411428
result = prime * result + ((maskSensitiveInfo == null) ? 0 : maskSensitiveInfo.hashCode());
412429
result = prime * result + ((tools == null) ? 0 : tools.hashCode());
413430
result = prime * result + ((toolChoice == null) ? 0 : toolChoice.hashCode());
431+
result = prime * result + ((proxyToolCalls == null) ? 0 : proxyToolCalls.hashCode());
414432
return result;
415433
}
416434

@@ -501,6 +519,12 @@ else if (!tools.equals(other.tools))
501519
}
502520
else if (!toolChoice.equals(other.toolChoice))
503521
return false;
522+
if (this.proxyToolCalls == null) {
523+
if (other.proxyToolCalls != null)
524+
return false;
525+
}
526+
else if (!proxyToolCalls.equals(other.proxyToolCalls))
527+
return false;
504528
return true;
505529
}
506530

@@ -525,6 +549,7 @@ public static MiniMaxChatOptions fromOptions(MiniMaxChatOptions fromOptions) {
525549
.withToolChoice(fromOptions.getToolChoice())
526550
.withFunctionCallbacks(fromOptions.getFunctionCallbacks())
527551
.withFunctions(fromOptions.getFunctions())
552+
.withProxyToolCalls(fromOptions.getProxyToolCalls())
528553
.build();
529554
}
530555

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,9 @@ public ChatResponse call(Prompt prompt) {
183183
return chatResponse;
184184
});
185185

186-
if (response != null && isToolCall(response, Set.of(MistralAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(),
187-
MistralAiApi.ChatCompletionFinishReason.STOP.name()))) {
186+
if (!isProxyToolCalls(prompt, this.defaultOptions) && response != null
187+
&& isToolCall(response, Set.of(MistralAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(),
188+
MistralAiApi.ChatCompletionFinishReason.STOP.name()))) {
188189
var toolCallConversation = handleToolCalls(prompt, response);
189190
// Recursively call the call method with the tool call message
190191
// conversation that contains the call responses.
@@ -255,7 +256,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
255256

256257
// @formatter:off
257258
Flux<ChatResponse> chatResponseFlux = chatResponse.flatMap(response -> {
258-
if (isToolCall(response, Set.of(MistralAiApi.ChatCompletionFinishReason.TOOL_CALLS.name()))) {
259+
if (!isProxyToolCalls(prompt, this.defaultOptions) && isToolCall(response, Set.of(MistralAiApi.ChatCompletionFinishReason.TOOL_CALLS.name()))) {
259260
var toolCallConversation = handleToolCalls(prompt, response);
260261
// Recursively call the stream method with the tool call message
261262
// conversation that contains the call responses.

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,9 @@ public class MistralAiChatOptions implements FunctionCallingOptions, ChatOptions
135135
@JsonIgnore
136136
private Set<String> functions = new HashSet<>();
137137

138+
@JsonIgnore
139+
private Boolean proxyToolCalls;
140+
138141
public static Builder builder() {
139142
return new Builder();
140143
}
@@ -215,6 +218,11 @@ public Builder withFunction(String functionName) {
215218
return this;
216219
}
217220

221+
public Builder withProxyToolCalls(Boolean proxyToolCalls) {
222+
this.options.proxyToolCalls = proxyToolCalls;
223+
return this;
224+
}
225+
218226
public MistralAiChatOptions build() {
219227
return this.options;
220228
}
@@ -356,6 +364,15 @@ public Integer getTopK() {
356364
return null;
357365
}
358366

367+
@Override
368+
public Boolean getProxyToolCalls() {
369+
return this.proxyToolCalls;
370+
}
371+
372+
public void setProxyToolCalls(Boolean proxyToolCalls) {
373+
this.proxyToolCalls = proxyToolCalls;
374+
}
375+
359376
@Override
360377
public MistralAiChatOptions copy() {
361378
return fromOptions(this);
@@ -374,7 +391,114 @@ public static MistralAiChatOptions fromOptions(MistralAiChatOptions fromOptions)
374391
.withToolChoice(fromOptions.getToolChoice())
375392
.withFunctionCallbacks(fromOptions.getFunctionCallbacks())
376393
.withFunctions(fromOptions.getFunctions())
394+
.withProxyToolCalls(fromOptions.getProxyToolCalls())
377395
.build();
378396
}
379397

398+
@Override
399+
public int hashCode() {
400+
final int prime = 31;
401+
int result = 1;
402+
result = prime * result + ((model == null) ? 0 : model.hashCode());
403+
result = prime * result + ((temperature == null) ? 0 : temperature.hashCode());
404+
result = prime * result + ((topP == null) ? 0 : topP.hashCode());
405+
result = prime * result + ((maxTokens == null) ? 0 : maxTokens.hashCode());
406+
result = prime * result + ((safePrompt == null) ? 0 : safePrompt.hashCode());
407+
result = prime * result + ((randomSeed == null) ? 0 : randomSeed.hashCode());
408+
result = prime * result + ((responseFormat == null) ? 0 : responseFormat.hashCode());
409+
result = prime * result + ((stop == null) ? 0 : stop.hashCode());
410+
result = prime * result + ((tools == null) ? 0 : tools.hashCode());
411+
result = prime * result + ((toolChoice == null) ? 0 : toolChoice.hashCode());
412+
result = prime * result + ((functionCallbacks == null) ? 0 : functionCallbacks.hashCode());
413+
result = prime * result + ((functions == null) ? 0 : functions.hashCode());
414+
result = prime * result + ((proxyToolCalls == null) ? 0 : proxyToolCalls.hashCode());
415+
return result;
416+
}
417+
418+
@Override
419+
public boolean equals(Object obj) {
420+
if (this == obj)
421+
return true;
422+
if (obj == null)
423+
return false;
424+
if (getClass() != obj.getClass())
425+
return false;
426+
MistralAiChatOptions other = (MistralAiChatOptions) obj;
427+
if (model == null) {
428+
if (other.model != null)
429+
return false;
430+
}
431+
else if (!model.equals(other.model))
432+
return false;
433+
if (temperature == null) {
434+
if (other.temperature != null)
435+
return false;
436+
}
437+
else if (!temperature.equals(other.temperature))
438+
return false;
439+
if (topP == null) {
440+
if (other.topP != null)
441+
return false;
442+
}
443+
else if (!topP.equals(other.topP))
444+
return false;
445+
if (maxTokens == null) {
446+
if (other.maxTokens != null)
447+
return false;
448+
}
449+
else if (!maxTokens.equals(other.maxTokens))
450+
return false;
451+
if (safePrompt == null) {
452+
if (other.safePrompt != null)
453+
return false;
454+
}
455+
else if (!safePrompt.equals(other.safePrompt))
456+
return false;
457+
if (randomSeed == null) {
458+
if (other.randomSeed != null)
459+
return false;
460+
}
461+
else if (!randomSeed.equals(other.randomSeed))
462+
return false;
463+
if (responseFormat == null) {
464+
if (other.responseFormat != null)
465+
return false;
466+
}
467+
else if (!responseFormat.equals(other.responseFormat))
468+
return false;
469+
if (stop == null) {
470+
if (other.stop != null)
471+
return false;
472+
}
473+
else if (!stop.equals(other.stop))
474+
return false;
475+
if (tools == null) {
476+
if (other.tools != null)
477+
return false;
478+
}
479+
else if (!tools.equals(other.tools))
480+
return false;
481+
if (toolChoice != other.toolChoice)
482+
return false;
483+
if (functionCallbacks == null) {
484+
if (other.functionCallbacks != null)
485+
return false;
486+
}
487+
else if (!functionCallbacks.equals(other.functionCallbacks))
488+
return false;
489+
if (functions == null) {
490+
if (other.functions != null)
491+
return false;
492+
}
493+
else if (!functions.equals(other.functions))
494+
return false;
495+
if (proxyToolCalls == null) {
496+
if (other.proxyToolCalls != null)
497+
return false;
498+
}
499+
else if (!proxyToolCalls.equals(other.proxyToolCalls))
500+
return false;
501+
return true;
502+
}
503+
380504
}

0 commit comments

Comments
 (0)