Skip to content

Commit 5edafc9

Browse files
committed
✨ avoid second roundtrip in function callbacks
1 parent 14d620e commit 5edafc9

File tree

12 files changed

+216
-40
lines changed

12 files changed

+216
-40
lines changed

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatClient.java

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.util.HashSet;
2121
import java.util.List;
2222
import java.util.Map;
23+
import java.util.Optional;
2324
import java.util.Set;
2425
import java.util.concurrent.atomic.AtomicReference;
2526
import java.util.stream.Collectors;
@@ -48,6 +49,7 @@
4849
import org.springframework.ai.chat.prompt.Prompt;
4950
import org.springframework.ai.model.ModelOptionsUtils;
5051
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
52+
import org.springframework.ai.model.function.FunctionCallback;
5153
import org.springframework.ai.model.function.FunctionCallbackContext;
5254
import org.springframework.ai.retry.RetryUtils;
5355
import org.springframework.http.ResponseEntity;
@@ -391,6 +393,16 @@ public ChatCompletion build() {
391393

392394
}
393395

396+
@Override
397+
protected boolean hasReturningFunction(RequestMessage responseMessage) {
398+
return responseMessage.content()
399+
.stream()
400+
.filter(c -> c.type() == MediaContent.Type.TOOL_USE)
401+
.map(MediaContent::name)
402+
.map(functionName -> Optional.ofNullable(this.functionCallbackRegister.get(functionName)))
403+
.anyMatch(functionCallback -> functionCallback.map(FunctionCallback::returningFunction).orElse(false));
404+
}
405+
394406
@Override
395407
protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionRequest previousRequest,
396408
RequestMessage responseMessage, List<RequestMessage> conversationHistory) {
@@ -414,8 +426,9 @@ protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionReques
414426

415427
String functionResponse = this.functionCallbackRegister.get(functionName)
416428
.call(ModelOptionsUtils.toJsonString(functionArguments));
417-
418-
toolResults.add(new MediaContent(Type.TOOL_RESULT, functionCallId, functionResponse));
429+
if (functionResponse != null) {
430+
toolResults.add(new MediaContent(Type.TOOL_RESULT, functionCallId, functionResponse));
431+
}
419432
}
420433

421434
// Add the function response to the conversation.

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatClient.java

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import com.azure.ai.openai.models.ChatRequestSystemMessage;
2929
import com.azure.ai.openai.models.ChatRequestToolMessage;
3030
import com.azure.ai.openai.models.ChatRequestUserMessage;
31+
import com.azure.ai.openai.models.ChatResponseMessage;
3132
import com.azure.ai.openai.models.CompletionsFinishReason;
3233
import com.azure.ai.openai.models.ContentFilterResultsForPrompt;
3334
import com.azure.ai.openai.models.FunctionCall;
@@ -50,6 +51,7 @@
5051
import org.springframework.ai.chat.prompt.Prompt;
5152
import org.springframework.ai.model.ModelOptionsUtils;
5253
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
54+
import org.springframework.ai.model.function.FunctionCallback;
5355
import org.springframework.ai.model.function.FunctionCallbackContext;
5456
import org.springframework.util.Assert;
5557
import org.springframework.util.CollectionUtils;
@@ -513,6 +515,15 @@ private ChatCompletionsOptions copy(ChatCompletionsOptions fromOptions) {
513515
return copyOptions;
514516
}
515517

518+
@Override
519+
protected boolean hasReturningFunction(ChatRequestMessage responseMessage) {
520+
return ((ChatRequestAssistantMessage) responseMessage).getToolCalls()
521+
.stream()
522+
.map(toolCall -> ((ChatCompletionsFunctionToolCall) toolCall).getFunction().getName())
523+
.map(functionName -> Optional.ofNullable(this.functionCallbackRegister.get(functionName)))
524+
.anyMatch(functionCallback -> functionCallback.map(FunctionCallback::returningFunction).orElse(false));
525+
}
526+
516527
@Override
517528
protected ChatCompletionsOptions doCreateToolResponseRequest(ChatCompletionsOptions previousRequest,
518529
ChatRequestMessage responseMessage, List<ChatRequestMessage> conversationHistory) {
@@ -530,8 +541,10 @@ protected ChatCompletionsOptions doCreateToolResponseRequest(ChatCompletionsOpti
530541

531542
String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments);
532543

533-
// Add the function response to the conversation.
534-
conversationHistory.add(new ChatRequestToolMessage(functionResponse, toolCall.getId()));
544+
if (functionResponse != null) {
545+
// Add the function response to the conversation.
546+
conversationHistory.add(new ChatRequestToolMessage(functionResponse, toolCall.getId()));
547+
}
535548
}
536549

537550
// Recursively call chatCompletionWithTools until the model doesn't call a

models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatClientFunctionCallIT.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,29 @@ void streamFunctionCallTest() {
118118
assertThat(content).containsAnyOf("15.0", "15");
119119
}
120120

121+
@Test
122+
void functionCallWithoutCompleteRoundTrip() {
123+
124+
UserMessage userMessage = new UserMessage("What's the weather like in San Francisco?");
125+
126+
List<Message> messages = new ArrayList<>(List.of(userMessage));
127+
128+
final var spyingMockWeatherService = new SpyingMockWeatherService();
129+
var promptOptions = AzureOpenAiChatOptions.builder()
130+
.withDeploymentName(selectedModel)
131+
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(spyingMockWeatherService)
132+
.withName("getCurrentWeather")
133+
.withDescription("Get the current weather in a given location")
134+
.build()))
135+
.build();
136+
137+
ChatResponse response = chatClient.call(new Prompt(messages, promptOptions));
138+
139+
logger.info("Response: {}", response);
140+
final var interceptedRequest = spyingMockWeatherService.getInterceptedRequest();
141+
assertThat(interceptedRequest.location()).containsIgnoringCase("San Francisco");
142+
}
143+
121144
@SpringBootConfiguration
122145
public static class TestConfiguration {
123146

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Copyright 2023 - 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.azure.openai.function;
17+
18+
import java.util.function.Function;
19+
20+
public class SpyingMockWeatherService implements Function<MockWeatherService.Request, Void> {
21+
22+
private MockWeatherService.Request interceptedRequest = null;
23+
24+
@Override
25+
public Void apply(MockWeatherService.Request request) {
26+
interceptedRequest = request;
27+
return null;
28+
}
29+
30+
public MockWeatherService.Request getInterceptedRequest() {
31+
return interceptedRequest;
32+
}
33+
34+
}

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatClient.java

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest;
3434
import org.springframework.ai.model.ModelOptionsUtils;
3535
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
36+
import org.springframework.ai.model.function.FunctionCallback;
3637
import org.springframework.ai.model.function.FunctionCallbackContext;
3738
import org.springframework.ai.retry.RetryUtils;
3839
import org.springframework.http.ResponseEntity;
@@ -98,7 +99,6 @@ public ChatResponse call(Prompt prompt) {
9899
var request = createRequest(prompt, false);
99100

100101
return retryTemplate.execute(ctx -> {
101-
102102
ResponseEntity<ChatCompletion> completionEntity = this.callWithFunctionSupport(request);
103103

104104
var chatCompletion = completionEntity.getBody();
@@ -239,13 +239,18 @@ private List<MistralAiApi.FunctionTool> getFunctionTools(Set<String> functionNam
239239
}).toList();
240240
}
241241

242-
//
243-
// Function Calling Support
244-
//
242+
@Override
243+
protected boolean hasReturningFunction(ChatCompletionMessage responseMessage) {
244+
return responseMessage.toolCalls()
245+
.stream()
246+
.map(toolCall -> toolCall.function().name())
247+
.map(functionName -> Optional.ofNullable(this.functionCallbackRegister.get(functionName)))
248+
.anyMatch(functionCallback -> functionCallback.map(FunctionCallback::returningFunction).orElse(false));
249+
}
250+
245251
@Override
246252
protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionRequest previousRequest,
247253
ChatCompletionMessage responseMessage, List<ChatCompletionMessage> conversationHistory) {
248-
249254
// Every tool-call item requires a separate function call and a response (TOOL)
250255
// message.
251256
for (ToolCall toolCall : responseMessage.toolCalls()) {
@@ -258,10 +263,12 @@ protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionReques
258263
}
259264

260265
String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments);
266+
if (functionResponse != null) {
267+
// Add the function response to the conversation.
268+
conversationHistory.add(new ChatCompletionMessage(functionResponse, ChatCompletionMessage.Role.TOOL,
269+
functionName, null));
270+
}
261271

262-
// Add the function response to the conversation.
263-
conversationHistory
264-
.add(new ChatCompletionMessage(functionResponse, ChatCompletionMessage.Role.TOOL, functionName, null));
265272
}
266273

267274
// Recursively call chatCompletionWithTools until the model doesn't call a

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.springframework.ai.chat.prompt.Prompt;
2828
import org.springframework.ai.model.ModelOptionsUtils;
2929
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
30+
import org.springframework.ai.model.function.FunctionCallback;
3031
import org.springframework.ai.model.function.FunctionCallbackContext;
3132
import org.springframework.ai.openai.api.OpenAiApi;
3233
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion;
@@ -324,6 +325,15 @@ private List<OpenAiApi.FunctionTool> getFunctionTools(Set<String> functionNames)
324325
}).toList();
325326
}
326327

328+
@Override
329+
protected boolean hasReturningFunction(ChatCompletionMessage responseMessage) {
330+
return responseMessage.toolCalls()
331+
.stream()
332+
.map(toolCall -> toolCall.function().name())
333+
.map(functionName -> Optional.ofNullable(this.functionCallbackRegister.get(functionName)))
334+
.anyMatch(functionCallback -> functionCallback.map(FunctionCallback::returningFunction).orElse(false));
335+
}
336+
327337
@Override
328338
protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionRequest previousRequest,
329339
ChatCompletionMessage responseMessage, List<ChatCompletionMessage> conversationHistory) {
@@ -340,10 +350,11 @@ protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionReques
340350
}
341351

342352
String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments);
343-
344-
// Add the function response to the conversation.
345-
conversationHistory
346-
.add(new ChatCompletionMessage(functionResponse, Role.TOOL, functionName, toolCall.id(), null));
353+
if (functionResponse != null) {
354+
// Add the function response to the conversation.
355+
conversationHistory
356+
.add(new ChatCompletionMessage(functionResponse, Role.TOOL, functionName, toolCall.id(), null));
357+
}
347358
}
348359

349360
// Recursively call chatCompletionWithTools until the model doesn't call a

models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatClient.java

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import org.springframework.ai.chat.prompt.Prompt;
4545
import org.springframework.ai.model.ModelOptionsUtils;
4646
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
47+
import org.springframework.ai.model.function.FunctionCallback;
4748
import org.springframework.ai.model.function.FunctionCallbackContext;
4849
import org.springframework.ai.vertexai.gemini.metadata.VertexAiChatResponseMetadata;
4950
import org.springframework.ai.vertexai.gemini.metadata.VertexAiUsage;
@@ -57,6 +58,7 @@
5758
import java.util.ArrayList;
5859
import java.util.HashSet;
5960
import java.util.List;
61+
import java.util.Optional;
6062
import java.util.Set;
6163
import java.util.stream.Collectors;
6264

@@ -406,6 +408,14 @@ public void destroy() throws Exception {
406408
}
407409
}
408410

411+
@Override
412+
protected boolean hasReturningFunction(Content responseMessage) {
413+
final var functionName = responseMessage.getPartsList().get(0).getFunctionCall().getName();
414+
return Optional.ofNullable(this.functionCallbackRegister.get(functionName))
415+
.map(FunctionCallback::returningFunction)
416+
.orElse(false);
417+
}
418+
409419
@Override
410420
protected GeminiRequest doCreateToolResponseRequest(GeminiRequest previousRequest, Content responseMessage,
411421
List<Content> conversationHistory) {
@@ -420,17 +430,18 @@ protected GeminiRequest doCreateToolResponseRequest(GeminiRequest previousReques
420430
}
421431

422432
String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments);
423-
424-
Content contentFnResp = Content.newBuilder()
425-
.addParts(Part.newBuilder()
426-
.setFunctionResponse(FunctionResponse.newBuilder()
427-
.setName(functionCall.getName())
428-
.setResponse(jsonToStruct(functionResponse))
433+
if (functionResponse != null) {
434+
Content contentFnResp = Content.newBuilder()
435+
.addParts(Part.newBuilder()
436+
.setFunctionResponse(FunctionResponse.newBuilder()
437+
.setName(functionCall.getName())
438+
.setResponse(jsonToStruct(functionResponse))
439+
.build())
429440
.build())
430-
.build())
431-
.build();
441+
.build();
432442

433-
conversationHistory.add(contentFnResp);
443+
conversationHistory.add(contentFnResp);
444+
}
434445

435446
return new GeminiRequest(conversationHistory, previousRequest.model());
436447
}

spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallSupport.java

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,20 +60,17 @@ protected Set<String> handleFunctionCallbackConfigurations(FunctionCallingOption
6060

6161
if (options != null) {
6262
if (!CollectionUtils.isEmpty(options.getFunctionCallbacks())) {
63-
options.getFunctionCallbacks().stream().forEach(functionCallback -> {
64-
63+
options.getFunctionCallbacks().forEach(functionCallback -> {
6564
// Register the tool callback.
6665
if (isRuntimeCall) {
6766
this.functionCallbackRegister.put(functionCallback.getName(), functionCallback);
67+
// Automatically enable the function, usually from prompt
68+
// callback.
69+
functionToCall.add(functionCallback.getName());
6870
}
6971
else {
7072
this.functionCallbackRegister.putIfAbsent(functionCallback.getName(), functionCallback);
7173
}
72-
73-
// Automatically enable the function, usually from prompt callback.
74-
if (isRuntimeCall) {
75-
functionToCall.add(functionCallback.getName());
76-
}
7774
});
7875
}
7976

@@ -147,6 +144,9 @@ protected Resp handleFunctionCallOrReturn(Req request, Resp response) {
147144

148145
Req newRequest = this.doCreateToolResponseRequest(request, responseMessage, conversationHistory);
149146

147+
if (!this.hasReturningFunction(responseMessage)) {
148+
return response;
149+
}
150150
return this.callWithFunctionSupport(newRequest);
151151
}
152152

@@ -180,6 +180,8 @@ protected Flux<Resp> handleFunctionCallOrReturnStream(Req request, Flux<Resp> re
180180

181181
}
182182

183+
abstract protected boolean hasReturningFunction(Msg responseMessage);
184+
183185
abstract protected Req doCreateToolResponseRequest(Req previousRequest, Msg responseMessage,
184186
List<Msg> conversationHistory);
185187

0 commit comments

Comments
 (0)