Skip to content

Commit 5948013

Browse files
committed
🧑‍💻 better support for non-returning functions
1 parent 7ce90d0 commit 5948013

File tree

9 files changed

+86
-59
lines changed

9 files changed

+86
-59
lines changed

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatClient.java

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import org.springframework.ai.chat.prompt.Prompt;
3838
import org.springframework.ai.model.ModelOptionsUtils;
3939
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
40+
import org.springframework.ai.model.function.FunctionCallback;
4041
import org.springframework.ai.model.function.FunctionCallbackContext;
4142
import org.springframework.ai.retry.RetryUtils;
4243
import org.springframework.http.ResponseEntity;
@@ -50,6 +51,7 @@
5051
import java.util.HashSet;
5152
import java.util.List;
5253
import java.util.Map;
54+
import java.util.Optional;
5355
import java.util.Set;
5456
import java.util.concurrent.atomic.AtomicReference;
5557
import java.util.stream.Collectors;
@@ -390,10 +392,18 @@ public ChatCompletion build() {
390392
}
391393

392394
@Override
393-
protected CompleteRoundTripBox<ChatCompletionRequest> doCreateToolResponseRequest(
394-
ChatCompletionRequest previousRequest, RequestMessage responseMessage,
395-
List<RequestMessage> conversationHistory) {
396-
boolean needCompleteRoundTrip = false;
395+
protected boolean hasReturningFunction(RequestMessage responseMessage) {
396+
return responseMessage.content()
397+
.stream()
398+
.filter(c -> c.type() == MediaContent.Type.TOOL_USE)
399+
.map(MediaContent::name)
400+
.map(functionName -> Optional.ofNullable(this.functionCallbackRegister.get(functionName)))
401+
.anyMatch(functionCallback -> functionCallback.map(FunctionCallback::returningFunction).orElse(false));
402+
}
403+
404+
@Override
405+
protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionRequest previousRequest,
406+
RequestMessage responseMessage, List<RequestMessage> conversationHistory) {
397407
List<MediaContent> toolToUseList = responseMessage.content()
398408
.stream()
399409
.filter(c -> c.type() == MediaContent.Type.TOOL_USE)
@@ -414,7 +424,6 @@ protected CompleteRoundTripBox<ChatCompletionRequest> doCreateToolResponseReques
414424
String functionResponse = this.functionCallbackRegister.get(functionName)
415425
.call(ModelOptionsUtils.toJsonString(functionArguments));
416426
if (functionResponse != null) {
417-
needCompleteRoundTrip = true;
418427
toolResults.add(new MediaContent(Type.TOOL_RESULT, functionCallId, functionResponse));
419428
}
420429
}
@@ -425,7 +434,7 @@ protected CompleteRoundTripBox<ChatCompletionRequest> doCreateToolResponseReques
425434
// Recursively call chatCompletionWithTools until the model doesn't call a
426435
// functions anymore.
427436
final var build = ChatCompletionRequest.from(previousRequest).withMessages(conversationHistory).build();
428-
return new CompleteRoundTripBox<>(needCompleteRoundTrip, build);
437+
return build;
429438
}
430439

431440
@Override

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatClient.java

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
import org.springframework.ai.chat.prompt.Prompt;
5050
import org.springframework.ai.model.ModelOptionsUtils;
5151
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
52+
import org.springframework.ai.model.function.FunctionCallback;
5253
import org.springframework.ai.model.function.FunctionCallbackContext;
5354
import org.springframework.util.Assert;
5455
import org.springframework.util.CollectionUtils;
@@ -57,6 +58,7 @@
5758
import java.util.Collections;
5859
import java.util.HashSet;
5960
import java.util.List;
61+
import java.util.Optional;
6062
import java.util.Set;
6163

6264
/**
@@ -426,11 +428,18 @@ private ChatCompletionsOptions merge(ChatCompletionsOptions fromOptions, ChatCom
426428
}
427429

428430
@Override
429-
protected CompleteRoundTripBox<ChatCompletionsOptions> doCreateToolResponseRequest(
430-
ChatCompletionsOptions previousRequest, ChatRequestMessage responseMessage,
431-
List<ChatRequestMessage> conversationHistory) {
431+
protected boolean hasReturningFunction(ChatRequestMessage responseMessage) {
432+
return ((ChatRequestAssistantMessage) responseMessage).getToolCalls()
433+
.stream()
434+
.map(toolCall -> ((ChatCompletionsFunctionToolCall) toolCall).getFunction().getName())
435+
.map(functionName -> Optional.ofNullable(this.functionCallbackRegister.get(functionName)))
436+
.anyMatch(functionCallback -> functionCallback.map(FunctionCallback::returningFunction).orElse(false));
437+
}
438+
439+
@Override
440+
protected ChatCompletionsOptions doCreateToolResponseRequest(ChatCompletionsOptions previousRequest,
441+
ChatRequestMessage responseMessage, List<ChatRequestMessage> conversationHistory) {
432442

433-
boolean needCompleteRoundTrip = false;
434443
// Every tool-call item requires a separate function call and a response (TOOL)
435444
// message.
436445
for (ChatCompletionsToolCall toolCall : ((ChatRequestAssistantMessage) responseMessage).getToolCalls()) {
@@ -445,7 +454,6 @@ protected CompleteRoundTripBox<ChatCompletionsOptions> doCreateToolResponseReque
445454
String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments);
446455

447456
if (functionResponse != null) {
448-
needCompleteRoundTrip = true;
449457
// Add the function response to the conversation.
450458
conversationHistory.add(new ChatRequestToolMessage(functionResponse, toolCall.getId()));
451459
}
@@ -457,7 +465,7 @@ protected CompleteRoundTripBox<ChatCompletionsOptions> doCreateToolResponseReque
457465

458466
newRequest = merge(previousRequest, newRequest);
459467

460-
return new CompleteRoundTripBox<>(needCompleteRoundTrip, newRequest);
468+
return newRequest;
461469
}
462470

463471
@Override

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatClient.java

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest;
3434
import org.springframework.ai.model.ModelOptionsUtils;
3535
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
36+
import org.springframework.ai.model.function.FunctionCallback;
3637
import org.springframework.ai.model.function.FunctionCallbackContext;
3738
import org.springframework.ai.retry.RetryUtils;
3839
import org.springframework.http.ResponseEntity;
@@ -241,14 +242,18 @@ private List<MistralAiApi.FunctionTool> getFunctionTools(Set<String> functionNam
241242
}).toList();
242243
}
243244

244-
//
245-
// Function Calling Support
246-
//
247245
@Override
248-
protected CompleteRoundTripBox<ChatCompletionRequest> doCreateToolResponseRequest(
249-
ChatCompletionRequest previousRequest, ChatCompletionMessage responseMessage,
250-
List<ChatCompletionMessage> conversationHistory) {
251-
boolean needCompleteRoundTrip = false;
246+
protected boolean hasReturningFunction(ChatCompletionMessage responseMessage) {
247+
return responseMessage.toolCalls()
248+
.stream()
249+
.map(toolCall -> toolCall.function().name())
250+
.map(functionName -> Optional.ofNullable(this.functionCallbackRegister.get(functionName)))
251+
.anyMatch(functionCallback -> functionCallback.map(FunctionCallback::returningFunction).orElse(false));
252+
}
253+
254+
@Override
255+
protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionRequest previousRequest,
256+
ChatCompletionMessage responseMessage, List<ChatCompletionMessage> conversationHistory) {
252257
// Every tool-call item requires a separate function call and a response (TOOL)
253258
// message.
254259
for (ToolCall toolCall : responseMessage.toolCalls()) {
@@ -262,7 +267,6 @@ protected CompleteRoundTripBox<ChatCompletionRequest> doCreateToolResponseReques
262267

263268
String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments);
264269
if (functionResponse != null) {
265-
needCompleteRoundTrip = true;
266270
// Add the function response to the conversation.
267271
conversationHistory.add(new ChatCompletionMessage(functionResponse, ChatCompletionMessage.Role.TOOL,
268272
functionName, null));
@@ -275,7 +279,7 @@ protected CompleteRoundTripBox<ChatCompletionRequest> doCreateToolResponseReques
275279
ChatCompletionRequest newRequest = new ChatCompletionRequest(conversationHistory, false);
276280
newRequest = ModelOptionsUtils.merge(newRequest, previousRequest, ChatCompletionRequest.class);
277281

278-
return new CompleteRoundTripBox<>(needCompleteRoundTrip, newRequest);
282+
return newRequest;
279283
}
280284

281285
@Override

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.springframework.ai.chat.prompt.Prompt;
2828
import org.springframework.ai.model.ModelOptionsUtils;
2929
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
30+
import org.springframework.ai.model.function.FunctionCallback;
3031
import org.springframework.ai.model.function.FunctionCallbackContext;
3132
import org.springframework.ai.openai.api.OpenAiApi;
3233
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion;
@@ -323,11 +324,18 @@ private List<OpenAiApi.FunctionTool> getFunctionTools(Set<String> functionNames)
323324
}
324325

325326
@Override
326-
protected CompleteRoundTripBox<ChatCompletionRequest> doCreateToolResponseRequest(
327-
ChatCompletionRequest previousRequest, ChatCompletionMessage responseMessage,
328-
List<ChatCompletionMessage> conversationHistory) {
327+
protected boolean hasReturningFunction(ChatCompletionMessage responseMessage) {
328+
return responseMessage.toolCalls()
329+
.stream()
330+
.map(toolCall -> toolCall.function().name())
331+
.map(functionName -> Optional.ofNullable(this.functionCallbackRegister.get(functionName)))
332+
.anyMatch(functionCallback -> functionCallback.map(FunctionCallback::returningFunction).orElse(false));
333+
}
334+
335+
@Override
336+
protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionRequest previousRequest,
337+
ChatCompletionMessage responseMessage, List<ChatCompletionMessage> conversationHistory) {
329338

330-
boolean needCompleteRoundTrip = false;
331339
// Every tool-call item requires a separate function call and a response (TOOL)
332340
// message.
333341
for (ToolCall toolCall : responseMessage.toolCalls()) {
@@ -341,7 +349,6 @@ protected CompleteRoundTripBox<ChatCompletionRequest> doCreateToolResponseReques
341349

342350
String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments);
343351
if (functionResponse != null) {
344-
needCompleteRoundTrip = true;
345352
// Add the function response to the conversation.
346353
conversationHistory
347354
.add(new ChatCompletionMessage(functionResponse, Role.TOOL, functionName, toolCall.id(), null));
@@ -354,7 +361,7 @@ protected CompleteRoundTripBox<ChatCompletionRequest> doCreateToolResponseReques
354361

355362
newRequest = ModelOptionsUtils.merge(newRequest, previousRequest, ChatCompletionRequest.class);
356363

357-
return new CompleteRoundTripBox<>(needCompleteRoundTrip, newRequest);
364+
return newRequest;
358365
}
359366

360367
@Override

models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatClient.java

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import org.springframework.ai.chat.prompt.Prompt;
4545
import org.springframework.ai.model.ModelOptionsUtils;
4646
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
47+
import org.springframework.ai.model.function.FunctionCallback;
4748
import org.springframework.ai.model.function.FunctionCallbackContext;
4849
import org.springframework.ai.vertexai.gemini.metadata.VertexAiChatResponseMetadata;
4950
import org.springframework.ai.vertexai.gemini.metadata.VertexAiUsage;
@@ -57,6 +58,7 @@
5758
import java.util.ArrayList;
5859
import java.util.HashSet;
5960
import java.util.List;
61+
import java.util.Optional;
6062
import java.util.Set;
6163
import java.util.stream.Collectors;
6264

@@ -400,9 +402,16 @@ public void destroy() throws Exception {
400402
}
401403

402404
@Override
403-
protected CompleteRoundTripBox<GeminiRequest> doCreateToolResponseRequest(GeminiRequest previousRequest,
404-
Content responseMessage, List<Content> conversationHistory) {
405-
boolean needCompleteRoundTrip = false;
405+
protected boolean hasReturningFunction(Content responseMessage) {
406+
final var functionName = responseMessage.getPartsList().get(0).getFunctionCall().getName();
407+
return Optional.ofNullable(this.functionCallbackRegister.get(functionName))
408+
.map(FunctionCallback::returningFunction)
409+
.orElse(false);
410+
}
411+
412+
@Override
413+
protected GeminiRequest doCreateToolResponseRequest(GeminiRequest previousRequest, Content responseMessage,
414+
List<Content> conversationHistory) {
406415
FunctionCall functionCall = responseMessage.getPartsList().iterator().next().getFunctionCall();
407416

408417
var functionName = functionCall.getName();
@@ -414,7 +423,6 @@ protected CompleteRoundTripBox<GeminiRequest> doCreateToolResponseRequest(Gemini
414423

415424
String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments);
416425
if (functionResponse != null) {
417-
needCompleteRoundTrip = true;
418426
Content contentFnResp = Content.newBuilder()
419427
.addParts(Part.newBuilder()
420428
.setFunctionResponse(FunctionResponse.newBuilder()
@@ -428,7 +436,7 @@ protected CompleteRoundTripBox<GeminiRequest> doCreateToolResponseRequest(Gemini
428436
}
429437

430438
final var geminiRequest = new GeminiRequest(conversationHistory, previousRequest.model());
431-
return new CompleteRoundTripBox<>(needCompleteRoundTrip, geminiRequest);
439+
return geminiRequest;
432440
}
433441

434442
@Override

spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallSupport.java

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -139,15 +139,17 @@ protected Resp handleFunctionCallOrReturn(Req request, Resp response) {
139139
// Add the assistant response to the message conversation history.
140140
conversationHistory.add(responseMessage);
141141

142-
CompleteRoundTripBox<Req> needRoundTripAndResponse = this.doCreateToolResponseRequest(request, responseMessage,
143-
conversationHistory);
144-
if (!needRoundTripAndResponse.completeRoundTrip) {
142+
Req newRequest = this.doCreateToolResponseRequest(request, responseMessage, conversationHistory);
143+
144+
if (!this.hasReturningFunction(responseMessage)) {
145145
return response;
146146
}
147-
return this.callWithFunctionSupport(needRoundTripAndResponse.getResponseMessage());
147+
return this.callWithFunctionSupport(newRequest);
148148
}
149149

150-
abstract protected CompleteRoundTripBox<Req> doCreateToolResponseRequest(Req previousRequest, Msg responseMessage,
150+
abstract protected boolean hasReturningFunction(Msg responseMessage);
151+
152+
abstract protected Req doCreateToolResponseRequest(Req previousRequest, Msg responseMessage,
151153
List<Msg> conversationHistory);
152154

153155
abstract protected List<Msg> doGetUserMessages(Req request);
@@ -158,25 +160,4 @@ abstract protected CompleteRoundTripBox<Req> doCreateToolResponseRequest(Req pre
158160

159161
abstract protected boolean isToolFunctionCall(Resp response);
160162

161-
public static class CompleteRoundTripBox<Resp> {
162-
163-
private final boolean completeRoundTrip;
164-
165-
private final Resp responseMessage;
166-
167-
public CompleteRoundTripBox(boolean completeRoundTrip, Resp responseMessage) {
168-
this.completeRoundTrip = completeRoundTrip;
169-
this.responseMessage = responseMessage;
170-
}
171-
172-
public Resp getResponseMessage() {
173-
return responseMessage;
174-
}
175-
176-
public boolean isCompleteRoundTrip() {
177-
return completeRoundTrip;
178-
}
179-
180-
}
181-
182163
}

spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,11 @@ public String getInputTypeSchema() {
100100
return this.inputTypeSchema;
101101
}
102102

103+
@Override
104+
public boolean returningFunction() {
105+
return !outputType.isAssignableFrom(Void.class);
106+
}
107+
103108
@Override
104109
public String call(String functionArguments) {
105110

spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallback.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,9 @@ public interface FunctionCallback {
4949
*/
5050
public String call(String functionInput);
5151

52+
/**
53+
* @return This function return a value or not
54+
*/
55+
boolean returningFunction();
56+
5257
}

spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ public FunctionCallback getFunctionCallback(@NonNull String beanName, @Nullable
119119
.withInputType(functionInputClass)
120120
.build();
121121
}
122-
if (bean instanceof Consumer consumer) {
122+
if (bean instanceof Consumer<?> consumer) {
123123
return FunctionCallbackWrapper.builder(consumer)
124124
.withName(functionName)
125125
.withSchemaType(this.schemaType)

0 commit comments

Comments
 (0)