Skip to content

Commit 97f7639

Browse files
committed
🧑‍💻 better support for non-returning functions
1 parent b94b5b8 commit 97f7639

File tree

9 files changed

+109
-74
lines changed

9 files changed

+109
-74
lines changed

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatClient.java

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,8 @@
1515
*/
1616
package org.springframework.ai.anthropic;
1717

18-
import java.util.ArrayList;
19-
import java.util.Base64;
20-
import java.util.HashSet;
21-
import java.util.List;
22-
import java.util.Map;
23-
import java.util.Set;
24-
import java.util.concurrent.atomic.AtomicReference;
25-
import java.util.stream.Collectors;
26-
2718
import org.slf4j.Logger;
2819
import org.slf4j.LoggerFactory;
29-
import reactor.core.publisher.Flux;
30-
3120
import org.springframework.ai.anthropic.api.AnthropicApi;
3221
import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletion;
3322
import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest;
@@ -48,12 +37,24 @@
4837
import org.springframework.ai.chat.prompt.Prompt;
4938
import org.springframework.ai.model.ModelOptionsUtils;
5039
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
40+
import org.springframework.ai.model.function.FunctionCallback;
5141
import org.springframework.ai.model.function.FunctionCallbackContext;
5242
import org.springframework.ai.retry.RetryUtils;
5343
import org.springframework.http.ResponseEntity;
5444
import org.springframework.retry.support.RetryTemplate;
5545
import org.springframework.util.Assert;
5646
import org.springframework.util.CollectionUtils;
47+
import reactor.core.publisher.Flux;
48+
49+
import java.util.ArrayList;
50+
import java.util.Base64;
51+
import java.util.HashSet;
52+
import java.util.List;
53+
import java.util.Map;
54+
import java.util.Optional;
55+
import java.util.Set;
56+
import java.util.concurrent.atomic.AtomicReference;
57+
import java.util.stream.Collectors;
5758

5859
/**
5960
* The {@link ChatClient} implementation for the Anthropic service.
@@ -391,10 +392,18 @@ public ChatCompletion build() {
391392
}
392393

393394
@Override
394-
protected CompleteRoundTripBox<ChatCompletionRequest> doCreateToolResponseRequest(
395-
ChatCompletionRequest previousRequest, RequestMessage responseMessage,
396-
List<RequestMessage> conversationHistory) {
397-
boolean needCompleteRoundTrip = false;
395+
protected boolean hasReturningFunction(RequestMessage responseMessage) {
396+
return responseMessage.content()
397+
.stream()
398+
.filter(c -> c.type() == MediaContent.Type.TOOL_USE)
399+
.map(MediaContent::name)
400+
.map(functionName -> Optional.ofNullable(this.functionCallbackRegister.get(functionName)))
401+
.anyMatch(functionCallback -> functionCallback.map(FunctionCallback::returningFunction).orElse(false));
402+
}
403+
404+
@Override
405+
protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionRequest previousRequest,
406+
RequestMessage responseMessage, List<RequestMessage> conversationHistory) {
398407
List<MediaContent> toolToUseList = responseMessage.content()
399408
.stream()
400409
.filter(c -> c.type() == MediaContent.Type.TOOL_USE)
@@ -415,7 +424,6 @@ protected CompleteRoundTripBox<ChatCompletionRequest> doCreateToolResponseReques
415424
String functionResponse = this.functionCallbackRegister.get(functionName)
416425
.call(ModelOptionsUtils.toJsonString(functionArguments));
417426
if (functionResponse != null) {
418-
needCompleteRoundTrip = true;
419427
toolResults.add(new MediaContent(Type.TOOL_RESULT, functionCallId, functionResponse));
420428
}
421429
}
@@ -426,7 +434,7 @@ protected CompleteRoundTripBox<ChatCompletionRequest> doCreateToolResponseReques
426434
// Recursively call chatCompletionWithTools until the model doesn't call a
427435
// functions anymore.
428436
final var build = ChatCompletionRequest.from(previousRequest).withMessages(conversationHistory).build();
429-
return new CompleteRoundTripBox<>(needCompleteRoundTrip, build);
437+
return build;
430438
}
431439

432440
@Override

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatClient.java

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
*/
1616
package org.springframework.ai.azure.openai;
1717

18-
import java.util.*;
19-
2018
import com.azure.ai.openai.OpenAIClient;
2119
import com.azure.ai.openai.models.ChatChoice;
2220
import com.azure.ai.openai.models.ChatCompletions;
@@ -38,8 +36,6 @@
3836
import com.azure.core.util.IterableStream;
3937
import org.slf4j.Logger;
4038
import org.slf4j.LoggerFactory;
41-
import reactor.core.publisher.Flux;
42-
4339
import org.springframework.ai.azure.openai.metadata.AzureOpenAiChatResponseMetadata;
4440
import org.springframework.ai.chat.ChatClient;
4541
import org.springframework.ai.chat.ChatResponse;
@@ -53,9 +49,17 @@
5349
import org.springframework.ai.chat.prompt.Prompt;
5450
import org.springframework.ai.model.ModelOptionsUtils;
5551
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
52+
import org.springframework.ai.model.function.FunctionCallback;
5653
import org.springframework.ai.model.function.FunctionCallbackContext;
5754
import org.springframework.util.Assert;
5855
import org.springframework.util.CollectionUtils;
56+
import reactor.core.publisher.Flux;
57+
58+
import java.util.Collections;
59+
import java.util.HashSet;
60+
import java.util.List;
61+
import java.util.Optional;
62+
import java.util.Set;
5963

6064
/**
6165
* {@link ChatClient} implementation for {@literal Microsoft Azure AI} backed by
@@ -483,11 +487,18 @@ private ChatCompletionsOptions copy(ChatCompletionsOptions fromOptions) {
483487
}
484488

485489
@Override
486-
protected CompleteRoundTripBox<ChatCompletionsOptions> doCreateToolResponseRequest(
487-
ChatCompletionsOptions previousRequest, ChatRequestMessage responseMessage,
488-
List<ChatRequestMessage> conversationHistory) {
490+
protected boolean hasReturningFunction(ChatRequestMessage responseMessage) {
491+
return ((ChatRequestAssistantMessage) responseMessage).getToolCalls()
492+
.stream()
493+
.map(toolCall -> ((ChatCompletionsFunctionToolCall) toolCall).getFunction().getName())
494+
.map(functionName -> Optional.ofNullable(this.functionCallbackRegister.get(functionName)))
495+
.anyMatch(functionCallback -> functionCallback.map(FunctionCallback::returningFunction).orElse(false));
496+
}
497+
498+
@Override
499+
protected ChatCompletionsOptions doCreateToolResponseRequest(ChatCompletionsOptions previousRequest,
500+
ChatRequestMessage responseMessage, List<ChatRequestMessage> conversationHistory) {
489501

490-
boolean needCompleteRoundTrip = false;
491502
// Every tool-call item requires a separate function call and a response (TOOL)
492503
// message.
493504
for (ChatCompletionsToolCall toolCall : ((ChatRequestAssistantMessage) responseMessage).getToolCalls()) {
@@ -502,7 +513,6 @@ protected CompleteRoundTripBox<ChatCompletionsOptions> doCreateToolResponseReque
502513
String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments);
503514

504515
if (functionResponse != null) {
505-
needCompleteRoundTrip = true;
506516
// Add the function response to the conversation.
507517
conversationHistory.add(new ChatRequestToolMessage(functionResponse, toolCall.getId()));
508518
}
@@ -514,7 +524,7 @@ protected CompleteRoundTripBox<ChatCompletionsOptions> doCreateToolResponseReque
514524

515525
newRequest = merge(previousRequest, newRequest);
516526

517-
return new CompleteRoundTripBox<>(needCompleteRoundTrip, newRequest);
527+
return newRequest;
518528
}
519529

520530
@Override

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatClient.java

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest;
4444
import org.springframework.ai.model.ModelOptionsUtils;
4545
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
46+
import org.springframework.ai.model.function.FunctionCallback;
4647
import org.springframework.ai.model.function.FunctionCallbackContext;
4748
import org.springframework.ai.retry.RetryUtils;
4849
import org.springframework.http.ResponseEntity;
@@ -242,14 +243,18 @@ private List<MistralAiApi.FunctionTool> getFunctionTools(Set<String> functionNam
242243
}).toList();
243244
}
244245

245-
//
246-
// Function Calling Support
247-
//
248246
@Override
249-
protected CompleteRoundTripBox<ChatCompletionRequest> doCreateToolResponseRequest(
250-
ChatCompletionRequest previousRequest, ChatCompletionMessage responseMessage,
251-
List<ChatCompletionMessage> conversationHistory) {
252-
boolean needCompleteRoundTrip = false;
247+
protected boolean hasReturningFunction(ChatCompletionMessage responseMessage) {
248+
return responseMessage.toolCalls()
249+
.stream()
250+
.map(toolCall -> toolCall.function().name())
251+
.map(functionName -> Optional.ofNullable(this.functionCallbackRegister.get(functionName)))
252+
.anyMatch(functionCallback -> functionCallback.map(FunctionCallback::returningFunction).orElse(false));
253+
}
254+
255+
@Override
256+
protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionRequest previousRequest,
257+
ChatCompletionMessage responseMessage, List<ChatCompletionMessage> conversationHistory) {
253258
// Every tool-call item requires a separate function call and a response (TOOL)
254259
// message.
255260
for (ToolCall toolCall : responseMessage.toolCalls()) {
@@ -263,7 +268,6 @@ protected CompleteRoundTripBox<ChatCompletionRequest> doCreateToolResponseReques
263268

264269
String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments);
265270
if (functionResponse != null) {
266-
needCompleteRoundTrip = true;
267271
// Add the function response to the conversation.
268272
conversationHistory.add(new ChatCompletionMessage(functionResponse, ChatCompletionMessage.Role.TOOL,
269273
functionName, null));
@@ -276,7 +280,7 @@ protected CompleteRoundTripBox<ChatCompletionRequest> doCreateToolResponseReques
276280
ChatCompletionRequest newRequest = new ChatCompletionRequest(conversationHistory, false);
277281
newRequest = ModelOptionsUtils.merge(newRequest, previousRequest, ChatCompletionRequest.class);
278282

279-
return new CompleteRoundTripBox<>(needCompleteRoundTrip, newRequest);
283+
return newRequest;
280284
}
281285

282286
@Override

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import org.springframework.ai.chat.prompt.Prompt;
4040
import org.springframework.ai.model.ModelOptionsUtils;
4141
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
42+
import org.springframework.ai.model.function.FunctionCallback;
4243
import org.springframework.ai.model.function.FunctionCallbackContext;
4344
import org.springframework.ai.openai.api.OpenAiApi;
4445
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion;
@@ -324,11 +325,18 @@ private List<OpenAiApi.FunctionTool> getFunctionTools(Set<String> functionNames)
324325
}
325326

326327
@Override
327-
protected CompleteRoundTripBox<ChatCompletionRequest> doCreateToolResponseRequest(
328-
ChatCompletionRequest previousRequest, ChatCompletionMessage responseMessage,
329-
List<ChatCompletionMessage> conversationHistory) {
328+
protected boolean hasReturningFunction(ChatCompletionMessage responseMessage) {
329+
return responseMessage.toolCalls()
330+
.stream()
331+
.map(toolCall -> toolCall.function().name())
332+
.map(functionName -> Optional.ofNullable(this.functionCallbackRegister.get(functionName)))
333+
.anyMatch(functionCallback -> functionCallback.map(FunctionCallback::returningFunction).orElse(false));
334+
}
335+
336+
@Override
337+
protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionRequest previousRequest,
338+
ChatCompletionMessage responseMessage, List<ChatCompletionMessage> conversationHistory) {
330339

331-
boolean needCompleteRoundTrip = false;
332340
// Every tool-call item requires a separate function call and a response (TOOL)
333341
// message.
334342
for (ToolCall toolCall : responseMessage.toolCalls()) {
@@ -342,7 +350,6 @@ protected CompleteRoundTripBox<ChatCompletionRequest> doCreateToolResponseReques
342350

343351
String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments);
344352
if (functionResponse != null) {
345-
needCompleteRoundTrip = true;
346353
// Add the function response to the conversation.
347354
conversationHistory
348355
.add(new ChatCompletionMessage(functionResponse, Role.TOOL, functionName, toolCall.id(), null));
@@ -354,7 +361,7 @@ protected CompleteRoundTripBox<ChatCompletionRequest> doCreateToolResponseReques
354361
ChatCompletionRequest newRequest = new ChatCompletionRequest(conversationHistory, false);
355362
newRequest = ModelOptionsUtils.merge(newRequest, previousRequest, ChatCompletionRequest.class);
356363

357-
return new CompleteRoundTripBox<>(needCompleteRoundTrip, newRequest);
364+
return newRequest;
358365
}
359366

360367
@Override

models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatClient.java

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
import org.springframework.ai.chat.prompt.Prompt;
5353
import org.springframework.ai.model.ModelOptionsUtils;
5454
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
55+
import org.springframework.ai.model.function.FunctionCallback;
5556
import org.springframework.ai.model.function.FunctionCallbackContext;
5657
import org.springframework.ai.vertexai.gemini.metadata.VertexAiChatResponseMetadata;
5758
import org.springframework.ai.vertexai.gemini.metadata.VertexAiUsage;
@@ -60,6 +61,14 @@
6061
import org.springframework.util.Assert;
6162
import org.springframework.util.CollectionUtils;
6263
import org.springframework.util.StringUtils;
64+
import reactor.core.publisher.Flux;
65+
66+
import java.util.ArrayList;
67+
import java.util.HashSet;
68+
import java.util.List;
69+
import java.util.Optional;
70+
import java.util.Set;
71+
import java.util.stream.Collectors;
6372

6473
/**
6574
* @author Christian Tzolov
@@ -401,9 +410,16 @@ public void destroy() throws Exception {
401410
}
402411

403412
@Override
404-
protected CompleteRoundTripBox<GeminiRequest> doCreateToolResponseRequest(GeminiRequest previousRequest,
405-
Content responseMessage, List<Content> conversationHistory) {
406-
boolean needCompleteRoundTrip = false;
413+
protected boolean hasReturningFunction(Content responseMessage) {
414+
final var functionName = responseMessage.getPartsList().get(0).getFunctionCall().getName();
415+
return Optional.ofNullable(this.functionCallbackRegister.get(functionName))
416+
.map(FunctionCallback::returningFunction)
417+
.orElse(false);
418+
}
419+
420+
@Override
421+
protected GeminiRequest doCreateToolResponseRequest(GeminiRequest previousRequest, Content responseMessage,
422+
List<Content> conversationHistory) {
407423
FunctionCall functionCall = responseMessage.getPartsList().iterator().next().getFunctionCall();
408424

409425
var functionName = functionCall.getName();
@@ -415,7 +431,6 @@ protected CompleteRoundTripBox<GeminiRequest> doCreateToolResponseRequest(Gemini
415431

416432
String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments);
417433
if (functionResponse != null) {
418-
needCompleteRoundTrip = true;
419434
Content contentFnResp = Content.newBuilder()
420435
.addParts(Part.newBuilder()
421436
.setFunctionResponse(FunctionResponse.newBuilder()
@@ -429,7 +444,7 @@ protected CompleteRoundTripBox<GeminiRequest> doCreateToolResponseRequest(Gemini
429444
}
430445

431446
final var geminiRequest = new GeminiRequest(conversationHistory, previousRequest.model());
432-
return new CompleteRoundTripBox<>(needCompleteRoundTrip, geminiRequest);
447+
return geminiRequest;
433448
}
434449

435450
@Override

spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallSupport.java

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -139,15 +139,17 @@ protected Resp handleFunctionCallOrReturn(Req request, Resp response) {
139139
// Add the assistant response to the message conversation history.
140140
conversationHistory.add(responseMessage);
141141

142-
CompleteRoundTripBox<Req> needRoundTripAndResponse = this.doCreateToolResponseRequest(request, responseMessage,
143-
conversationHistory);
144-
if (!needRoundTripAndResponse.completeRoundTrip) {
142+
Req newRequest = this.doCreateToolResponseRequest(request, responseMessage, conversationHistory);
143+
144+
if (!this.hasReturningFunction(responseMessage)) {
145145
return response;
146146
}
147-
return this.callWithFunctionSupport(needRoundTripAndResponse.getResponseMessage());
147+
return this.callWithFunctionSupport(newRequest);
148148
}
149149

150-
abstract protected CompleteRoundTripBox<Req> doCreateToolResponseRequest(Req previousRequest, Msg responseMessage,
150+
abstract protected boolean hasReturningFunction(Msg responseMessage);
151+
152+
abstract protected Req doCreateToolResponseRequest(Req previousRequest, Msg responseMessage,
151153
List<Msg> conversationHistory);
152154

153155
abstract protected List<Msg> doGetUserMessages(Req request);
@@ -158,25 +160,4 @@ abstract protected CompleteRoundTripBox<Req> doCreateToolResponseRequest(Req pre
158160

159161
abstract protected boolean isToolFunctionCall(Resp response);
160162

161-
public static class CompleteRoundTripBox<Resp> {
162-
163-
private final boolean completeRoundTrip;
164-
165-
private final Resp responseMessage;
166-
167-
public CompleteRoundTripBox(boolean completeRoundTrip, Resp responseMessage) {
168-
this.completeRoundTrip = completeRoundTrip;
169-
this.responseMessage = responseMessage;
170-
}
171-
172-
public Resp getResponseMessage() {
173-
return responseMessage;
174-
}
175-
176-
public boolean isCompleteRoundTrip() {
177-
return completeRoundTrip;
178-
}
179-
180-
}
181-
182163
}

spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,11 @@ public String getInputTypeSchema() {
101101
return this.inputTypeSchema;
102102
}
103103

104+
@Override
105+
public boolean returningFunction() {
106+
return !outputType.isAssignableFrom(Void.class);
107+
}
108+
104109
@Override
105110
public String call(String functionArguments) {
106111

spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallback.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,9 @@ public interface FunctionCallback {
4949
*/
5050
public String call(String functionInput);
5151

52+
/**
53+
* @return This function return a value or not
54+
*/
55+
boolean returningFunction();
56+
5257
}

spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ public FunctionCallback getFunctionCallback(@NonNull String beanName, @Nullable
120120
.withInputType(functionInputClass)
121121
.build();
122122
}
123-
if (bean instanceof Consumer consumer) {
123+
if (bean instanceof Consumer<?> consumer) {
124124
return FunctionCallbackWrapper.builder(consumer)
125125
.withName(functionName)
126126
.withSchemaType(this.schemaType)

0 commit comments

Comments
 (0)