Skip to content

Commit 5a70e32

Browse files
committed
✨ avoid second round trip for function call, let developer decide
1 parent d7dad6e commit 5a70e32

File tree

13 files changed

+188
-40
lines changed

13 files changed

+188
-40
lines changed

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatClient.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,12 @@ public AnthropicChatClient(AnthropicApi anthropicApi, AnthropicChatOptions defau
147147
public ChatResponse call(Prompt prompt) {
148148

149149
ChatCompletionRequest request = createRequest(prompt, false);
150-
151150
return this.retryTemplate.execute(ctx -> {
152-
ResponseEntity<ChatCompletion> completionEntity = this.callWithFunctionSupport(request);
151+
boolean completeRoundTrip = false;
152+
if (prompt.getOptions() instanceof AnthropicChatOptions anthropicChatOptions) {
153+
completeRoundTrip = anthropicChatOptions.isCompleteRoundTrip();
154+
}
155+
ResponseEntity<ChatCompletion> completionEntity = this.callWithFunctionSupport(request, completeRoundTrip);
153156
return toChatResponse(completionEntity.getBody());
154157
});
155158
}

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ public class AnthropicChatOptions implements ChatOptions, FunctionCallingOptions
7575
@NestedConfigurationProperty
7676
@JsonIgnore
7777
private Set<String> functions = new HashSet<>();
78+
79+
@JsonIgnore
80+
private boolean completeRoundTrip = false;
7881
// @formatter:on
7982

8083
public static Builder builder() {
@@ -137,6 +140,11 @@ public Builder withFunction(String functionName) {
137140
return this;
138141
}
139142

143+
public Builder withCompleteRoundTrip(boolean completeRoundTrip) {
144+
this.options.completeRoundTrip = completeRoundTrip;
145+
return this;
146+
}
147+
140148
public AnthropicChatOptions build() {
141149
return this.options;
142150
}
@@ -223,4 +231,14 @@ public void setFunctions(Set<String> functions) {
223231
this.functions = functions;
224232
}
225233

234+
@Override
235+
public boolean isCompleteRoundTrip() {
236+
return completeRoundTrip;
237+
}
238+
239+
@Override
240+
public void setCompleteRoundTrip(boolean completeRoundTrip) {
241+
this.completeRoundTrip = completeRoundTrip;
242+
}
243+
226244
}

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatClient.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,11 @@ public ChatResponse call(Prompt prompt) {
134134
options.setStream(false);
135135

136136
logger.trace("Azure ChatCompletionsOptions: {}", options);
137-
ChatCompletions chatCompletions = this.callWithFunctionSupport(options);
137+
boolean completeRoundTrip = false;
138+
if (prompt.getOptions() instanceof AzureOpenAiChatOptions azureOpenAiChatOptions) {
139+
completeRoundTrip = azureOpenAiChatOptions.isCompleteRoundTrip();
140+
}
141+
ChatCompletions chatCompletions = this.callWithFunctionSupport(options, completeRoundTrip);
138142
logger.trace("Azure ChatCompletions: {}", chatCompletions);
139143

140144
List<Generation> generations = chatCompletions.getChoices()
@@ -549,4 +553,4 @@ protected boolean isToolFunctionCall(ChatCompletions chatCompletions) {
549553
return choice.getFinishReason() == CompletionsFinishReason.TOOL_CALLS;
550554
}
551555

552-
}
556+
}

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,9 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptio
152152
@JsonIgnore
153153
private Set<String> functions = new HashSet<>();
154154

155+
@JsonIgnore
156+
private boolean completeRoundTrip = false;
157+
155158
public static Builder builder() {
156159
return new Builder();
157160
}
@@ -239,6 +242,11 @@ public Builder withFunction(String functionName) {
239242
return this;
240243
}
241244

245+
public Builder withCompleteRoundTrip(boolean completeRoundTrip) {
246+
this.options.completeRoundTrip = completeRoundTrip;
247+
return this;
248+
}
249+
242250
public AzureOpenAiChatOptions build() {
243251
return this.options;
244252
}
@@ -356,4 +364,14 @@ public void setFunctions(Set<String> functions) {
356364
this.functions = functions;
357365
}
358366

367+
@Override
368+
public boolean isCompleteRoundTrip() {
369+
return completeRoundTrip;
370+
}
371+
372+
@Override
373+
public void setCompleteRoundTrip(boolean completeRoundTrip) {
374+
this.completeRoundTrip = completeRoundTrip;
375+
}
376+
359377
}

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatClient.java

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,11 @@ public ChatResponse call(Prompt prompt) {
103103
var request = createRequest(prompt, false);
104104

105105
return retryTemplate.execute(ctx -> {
106-
107-
ResponseEntity<ChatCompletion> completionEntity = this.callWithFunctionSupport(request);
106+
boolean completeRoundTrip = false;
107+
if (prompt.getOptions() instanceof MistralAiChatOptions mistralAiChatOptions) {
108+
completeRoundTrip = mistralAiChatOptions.isCompleteRoundTrip();
109+
}
110+
ResponseEntity<ChatCompletion> completionEntity = this.callWithFunctionSupport(request, completeRoundTrip);
108111

109112
var chatCompletion = completionEntity.getBody();
110113
if (chatCompletion == null) {
@@ -149,8 +152,12 @@ public Flux<ChatResponse> stream(Prompt prompt) {
149152
ConcurrentHashMap<String, String> roleMap = new ConcurrentHashMap<>();
150153

151154
return completionChunks.map(chunk -> toChatCompletion(chunk)).map(chatCompletion -> {
152-
153-
chatCompletion = handleFunctionCallOrReturn(request, ResponseEntity.of(Optional.of(chatCompletion)))
155+
boolean completeRoundTrip = false;
156+
if (prompt.getOptions() instanceof MistralAiChatOptions mistralAiChatOptions) {
157+
completeRoundTrip = mistralAiChatOptions.isCompleteRoundTrip();
158+
}
159+
chatCompletion = handleFunctionCallOrReturn(request, ResponseEntity.of(Optional.of(chatCompletion)),
160+
completeRoundTrip)
154161
.getBody();
155162

156163
@SuppressWarnings("null")

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,9 @@ public class MistralAiChatOptions implements FunctionCallingOptions, ChatOptions
126126
@JsonIgnore
127127
private Set<String> functions = new HashSet<>();
128128

129+
@JsonIgnore
130+
private boolean completeRoundTrip = false;
131+
129132
public static Builder builder() {
130133
return new Builder();
131134
}
@@ -196,6 +199,11 @@ public Builder withFunction(String functionName) {
196199
return this;
197200
}
198201

202+
public Builder withCompleteRoundTrip(boolean completeRoundTrip) {
203+
this.options.completeRoundTrip = completeRoundTrip;
204+
return this;
205+
}
206+
199207
public MistralAiChatOptions build() {
200208
return this.options;
201209
}
@@ -309,4 +317,14 @@ public void setFunctions(Set<String> functions) {
309317
this.functions = functions;
310318
}
311319

320+
@Override
321+
public boolean isCompleteRoundTrip() {
322+
return completeRoundTrip;
323+
}
324+
325+
@Override
326+
public void setCompleteRoundTrip(boolean completeRoundTrip) {
327+
this.completeRoundTrip = completeRoundTrip;
328+
}
329+
312330
}

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,11 @@ public ChatResponse call(Prompt prompt) {
139139
ChatCompletionRequest request = createRequest(prompt, false);
140140

141141
return this.retryTemplate.execute(ctx -> {
142-
143-
ResponseEntity<ChatCompletion> completionEntity = this.callWithFunctionSupport(request);
142+
boolean completeRoundTrip = false;
143+
if (prompt.getOptions() instanceof OpenAiChatOptions openAiChatOptions) {
144+
completeRoundTrip = openAiChatOptions.isCompleteRoundTrip();
145+
}
146+
ResponseEntity<ChatCompletion> completionEntity = this.callWithFunctionSupport(request, completeRoundTrip);
144147

145148
var chatCompletion = completionEntity.getBody();
146149
if (chatCompletion == null) {
@@ -191,7 +194,13 @@ public Flux<ChatResponse> stream(Prompt prompt) {
191194
// the function call handling logic.
192195
return completionChunks.map(chunk -> chunkToChatCompletion(chunk)).map(chatCompletion -> {
193196
try {
194-
chatCompletion = handleFunctionCallOrReturn(request, ResponseEntity.of(Optional.of(chatCompletion)))
197+
198+
boolean completeRoundTrip = false;
199+
if (prompt.getOptions() instanceof OpenAiChatOptions openAiChatOptions) {
200+
completeRoundTrip = openAiChatOptions.isCompleteRoundTrip();
201+
}
202+
chatCompletion = handleFunctionCallOrReturn(request, ResponseEntity.of(Optional.of(chatCompletion)),
203+
completeRoundTrip)
195204
.getBody();
196205

197206
@SuppressWarnings("null")

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@ public class OpenAiChatOptions implements FunctionCallingOptions, ChatOptions {
155155
@NestedConfigurationProperty
156156
@JsonIgnore
157157
private Set<String> functions = new HashSet<>();
158+
@JsonIgnore
159+
private boolean completeRoundTrip;
158160
// @formatter:on
159161

160162
public static Builder builder() {
@@ -270,6 +272,11 @@ public Builder withFunction(String functionName) {
270272
return this;
271273
}
272274

275+
public Builder withCompleteRoundTrip(boolean completeRoundTrip) {
276+
this.options.completeRoundTrip = completeRoundTrip;
277+
return this;
278+
}
279+
273280
public OpenAiChatOptions build() {
274281
return this.options;
275282
}
@@ -425,6 +432,27 @@ public void setFunctions(Set<String> functionNames) {
425432
this.functions = functionNames;
426433
}
427434

435+
@Override
436+
@JsonIgnore
437+
public Integer getTopK() {
438+
throw new UnsupportedOperationException("Unimplemented method 'getTopK'");
439+
}
440+
441+
@JsonIgnore
442+
public void setTopK(Integer topK) {
443+
throw new UnsupportedOperationException("Unimplemented method 'setTopK'");
444+
}
445+
446+
@Override
447+
public boolean isCompleteRoundTrip() {
448+
return completeRoundTrip;
449+
}
450+
451+
@Override
452+
public void setCompleteRoundTrip(boolean completeRoundTrip) {
453+
this.completeRoundTrip = completeRoundTrip;
454+
}
455+
428456
@Override
429457
public int hashCode() {
430458
final int prime = 31;
@@ -556,15 +584,4 @@ else if (!this.user.equals(other.user))
556584
return true;
557585
}
558586

559-
@Override
560-
@JsonIgnore
561-
public Integer getTopK() {
562-
throw new UnsupportedOperationException("Unimplemented method 'getTopK'");
563-
}
564-
565-
@JsonIgnore
566-
public void setTopK(Integer topK) {
567-
throw new UnsupportedOperationException("Unimplemented method 'setTopK'");
568-
}
569-
570587
}

models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatClient.java

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,11 @@ public VertexAiGeminiChatClient(VertexAI vertexAI, VertexAiGeminiChatOptions opt
144144
public ChatResponse call(Prompt prompt) {
145145

146146
var geminiRequest = createGeminiRequest(prompt);
147-
148-
GenerateContentResponse response = this.callWithFunctionSupport(geminiRequest);
147+
boolean completeRoundTrip = false;
148+
if (prompt.getOptions() instanceof VertexAiGeminiChatOptions vertexAiGeminiChatOptions) {
149+
completeRoundTrip = vertexAiGeminiChatOptions.isCompleteRoundTrip();
150+
}
151+
GenerateContentResponse response = this.callWithFunctionSupport(geminiRequest, completeRoundTrip);
149152

150153
List<Generation> generations = response.getCandidatesList()
151154
.stream()
@@ -168,7 +171,11 @@ public Flux<ChatResponse> stream(Prompt prompt) {
168171
.generateContentStream(request.contents);
169172

170173
return Flux.fromStream(responseStream.stream()).map(response -> {
171-
response = handleFunctionCallOrReturn(request, response);
174+
boolean completeRoundTrip = false;
175+
if (prompt.getOptions() instanceof VertexAiGeminiChatOptions vertexAiGeminiChatOptions) {
176+
completeRoundTrip = vertexAiGeminiChatOptions.isCompleteRoundTrip();
177+
}
178+
response = handleFunctionCallOrReturn(request, response, completeRoundTrip);
172179
List<Generation> generations = response.getCandidatesList()
173180
.stream()
174181
.map(candidate -> candidate.getContent().getPartsList())

models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ public enum TransportType {
9898
@NestedConfigurationProperty
9999
@JsonIgnore
100100
private Set<String> functions = new HashSet<>();
101-
101+
@JsonIgnore
102+
private boolean completeRoundTrip = false;
102103
// @formatter:on
103104

104105
public static Builder builder() {
@@ -161,6 +162,11 @@ public Builder withFunction(String functionName) {
161162
return this;
162163
}
163164

165+
public Builder withCompleteRoundTrip(boolean completeRoundTrip) {
166+
this.options.completeRoundTrip = completeRoundTrip;
167+
return this;
168+
}
169+
164170
public VertexAiGeminiChatOptions build() {
165171
return this.options;
166172
}
@@ -248,6 +254,16 @@ public void setFunctions(Set<String> functions) {
248254
this.functions = functions;
249255
}
250256

257+
@Override
258+
public boolean isCompleteRoundTrip() {
259+
return completeRoundTrip;
260+
}
261+
262+
@Override
263+
public void setCompleteRoundTrip(boolean completeRoundTrip) {
264+
this.completeRoundTrip = completeRoundTrip;
265+
}
266+
251267
@Override
252268
public int hashCode() {
253269
final int prime = 31;

spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallSupport.java

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -57,20 +57,17 @@ protected Set<String> handleFunctionCallbackConfigurations(FunctionCallingOption
5757

5858
if (options != null) {
5959
if (!CollectionUtils.isEmpty(options.getFunctionCallbacks())) {
60-
options.getFunctionCallbacks().stream().forEach(functionCallback -> {
61-
60+
options.getFunctionCallbacks().forEach(functionCallback -> {
6261
// Register the tool callback.
6362
if (isRuntimeCall) {
6463
this.functionCallbackRegister.put(functionCallback.getName(), functionCallback);
64+
// Automatically enable the function, usually from prompt
65+
// callback.
66+
functionToCall.add(functionCallback.getName());
6567
}
6668
else {
6769
this.functionCallbackRegister.putIfAbsent(functionCallback.getName(), functionCallback);
6870
}
69-
70-
// Automatically enable the function, usually from prompt callback.
71-
if (isRuntimeCall) {
72-
functionToCall.add(functionCallback.getName());
73-
}
7471
});
7572
}
7673

@@ -120,12 +117,12 @@ protected List<FunctionCallback> resolveFunctionCallbacks(Set<String> functionNa
120117
}
121118

122119
///
123-
protected Resp callWithFunctionSupport(Req request) {
120+
protected Resp callWithFunctionSupport(Req request, boolean completeRoundTrip) {
124121
Resp response = this.doChatCompletion(request);
125-
return this.handleFunctionCallOrReturn(request, response);
122+
return this.handleFunctionCallOrReturn(request, response, completeRoundTrip);
126123
}
127124

128-
protected Resp handleFunctionCallOrReturn(Req request, Resp response) {
125+
protected Resp handleFunctionCallOrReturn(Req request, Resp response, boolean completeRoundTrip) {
129126

130127
if (!this.isToolFunctionCall(response)) {
131128
return response;
@@ -143,8 +140,10 @@ protected Resp handleFunctionCallOrReturn(Req request, Resp response) {
143140
conversationHistory.add(responseMessage);
144141

145142
Req newRequest = this.doCreateToolResponseRequest(request, responseMessage, conversationHistory);
146-
147-
return this.callWithFunctionSupport(newRequest);
143+
if (!completeRoundTrip) {
144+
return response;
145+
}
146+
return this.callWithFunctionSupport(newRequest, completeRoundTrip);
148147
}
149148

150149
abstract protected Req doCreateToolResponseRequest(Req previousRequest, Msg responseMessage,

0 commit comments

Comments
 (0)