Skip to content

Commit 78e24fd

Browse files
committed
✨ support and identify Functions<I, Void> or Consumer<I> to avoid second round trip
1 parent 493e2ea commit 78e24fd

File tree

19 files changed

+175
-233
lines changed

19 files changed

+175
-233
lines changed

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatClient.java

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -148,11 +148,7 @@ public ChatResponse call(Prompt prompt) {
148148

149149
ChatCompletionRequest request = createRequest(prompt, false);
150150
return this.retryTemplate.execute(ctx -> {
151-
boolean completeRoundTrip = true;
152-
if (prompt.getOptions() instanceof AnthropicChatOptions anthropicChatOptions) {
153-
completeRoundTrip = anthropicChatOptions.isCompleteRoundTrip();
154-
}
155-
ResponseEntity<ChatCompletion> completionEntity = this.callWithFunctionSupport(request, completeRoundTrip);
151+
ResponseEntity<ChatCompletion> completionEntity = this.callWithFunctionSupport(request);
156152
return toChatResponse(completionEntity.getBody());
157153
});
158154
}
@@ -395,9 +391,10 @@ public ChatCompletion build() {
395391
}
396392

397393
@Override
398-
protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionRequest previousRequest,
399-
RequestMessage responseMessage, List<RequestMessage> conversationHistory) {
400-
394+
protected CompleteRoundTripBox<ChatCompletionRequest> doCreateToolResponseRequest(
395+
ChatCompletionRequest previousRequest, RequestMessage responseMessage,
396+
List<RequestMessage> conversationHistory) {
397+
boolean needCompleteRoundTrip = false;
401398
List<MediaContent> toolToUseList = responseMessage.content()
402399
.stream()
403400
.filter(c -> c.type() == MediaContent.Type.TOOL_USE)
@@ -417,16 +414,19 @@ protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionReques
417414

418415
String functionResponse = this.functionCallbackRegister.get(functionName)
419416
.call(ModelOptionsUtils.toJsonString(functionArguments));
420-
421-
toolResults.add(new MediaContent(Type.TOOL_RESULT, functionCallId, functionResponse));
417+
if (functionResponse != null) {
418+
needCompleteRoundTrip = true;
419+
toolResults.add(new MediaContent(Type.TOOL_RESULT, functionCallId, functionResponse));
420+
}
422421
}
423422

424423
// Add the function response to the conversation.
425424
conversationHistory.add(new RequestMessage(toolResults, Role.USER));
426425

427426
// Recursively call chatCompletionWithTools until the model doesn't call a
428427
// functions anymore.
429-
return ChatCompletionRequest.from(previousRequest).withMessages(conversationHistory).build();
428+
final var build = ChatCompletionRequest.from(previousRequest).withMessages(conversationHistory).build();
429+
return new CompleteRoundTripBox<>(needCompleteRoundTrip, build);
430430
}
431431

432432
@Override

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,6 @@ public class AnthropicChatOptions implements ChatOptions, FunctionCallingOptions
7676
@JsonIgnore
7777
private Set<String> functions = new HashSet<>();
7878

79-
@JsonIgnore
80-
private boolean completeRoundTrip = true;
8179
// @formatter:on
8280

8381
public static Builder builder() {
@@ -140,11 +138,6 @@ public Builder withFunction(String functionName) {
140138
return this;
141139
}
142140

143-
public Builder withCompleteRoundTrip(boolean completeRoundTrip) {
144-
this.options.completeRoundTrip = completeRoundTrip;
145-
return this;
146-
}
147-
148141
public AnthropicChatOptions build() {
149142
return this.options;
150143
}
@@ -231,14 +224,4 @@ public void setFunctions(Set<String> functions) {
231224
this.functions = functions;
232225
}
233226

234-
@Override
235-
public boolean isCompleteRoundTrip() {
236-
return completeRoundTrip;
237-
}
238-
239-
@Override
240-
public void setCompleteRoundTrip(boolean completeRoundTrip) {
241-
this.completeRoundTrip = completeRoundTrip;
242-
}
243-
244227
}

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatClient.java

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,7 @@
1515
*/
1616
package org.springframework.ai.azure.openai;
1717

18-
import java.util.Collections;
19-
import java.util.HashSet;
20-
import java.util.List;
21-
import java.util.Set;
18+
import java.util.*;
2219

2320
import com.azure.ai.openai.OpenAIClient;
2421
import com.azure.ai.openai.models.ChatChoice;
@@ -134,11 +131,7 @@ public ChatResponse call(Prompt prompt) {
134131
options.setStream(false);
135132

136133
logger.trace("Azure ChatCompletionsOptions: {}", options);
137-
boolean completeRoundTrip = true;
138-
if (prompt.getOptions() instanceof AzureOpenAiChatOptions azureOpenAiChatOptions) {
139-
completeRoundTrip = azureOpenAiChatOptions.isCompleteRoundTrip();
140-
}
141-
ChatCompletions chatCompletions = this.callWithFunctionSupport(options, completeRoundTrip);
134+
ChatCompletions chatCompletions = this.callWithFunctionSupport(options);
142135
logger.trace("Azure ChatCompletions: {}", chatCompletions);
143136

144137
List<Generation> generations = chatCompletions.getChoices()
@@ -431,9 +424,11 @@ private ChatCompletionsOptions merge(ChatCompletionsOptions fromOptions, ChatCom
431424
}
432425

433426
@Override
434-
protected ChatCompletionsOptions doCreateToolResponseRequest(ChatCompletionsOptions previousRequest,
435-
ChatRequestMessage responseMessage, List<ChatRequestMessage> conversationHistory) {
427+
protected CompleteRoundTripBox<ChatCompletionsOptions> doCreateToolResponseRequest(
428+
ChatCompletionsOptions previousRequest, ChatRequestMessage responseMessage,
429+
List<ChatRequestMessage> conversationHistory) {
436430

431+
boolean needCompleteRoundTrip = false;
437432
// Every tool-call item requires a separate function call and a response (TOOL)
438433
// message.
439434
for (ChatCompletionsToolCall toolCall : ((ChatRequestAssistantMessage) responseMessage).getToolCalls()) {
@@ -447,8 +442,11 @@ protected ChatCompletionsOptions doCreateToolResponseRequest(ChatCompletionsOpti
447442

448443
String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments);
449444

450-
// Add the function response to the conversation.
451-
conversationHistory.add(new ChatRequestToolMessage(functionResponse, toolCall.getId()));
445+
if (functionResponse != null) {
446+
needCompleteRoundTrip = true;
447+
// Add the function response to the conversation.
448+
conversationHistory.add(new ChatRequestToolMessage(functionResponse, toolCall.getId()));
449+
}
452450
}
453451

454452
// Recursively call chatCompletionWithTools until the model doesn't call a
@@ -457,7 +455,7 @@ protected ChatCompletionsOptions doCreateToolResponseRequest(ChatCompletionsOpti
457455

458456
newRequest = merge(previousRequest, newRequest);
459457

460-
return newRequest;
458+
return new CompleteRoundTripBox<>(needCompleteRoundTrip, newRequest);
461459
}
462460

463461
@Override

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,6 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptio
152152
@JsonIgnore
153153
private Set<String> functions = new HashSet<>();
154154

155-
@JsonIgnore
156-
private boolean completeRoundTrip = true;
157-
158155
public static Builder builder() {
159156
return new Builder();
160157
}
@@ -242,11 +239,6 @@ public Builder withFunction(String functionName) {
242239
return this;
243240
}
244241

245-
public Builder withCompleteRoundTrip(boolean completeRoundTrip) {
246-
this.options.completeRoundTrip = completeRoundTrip;
247-
return this;
248-
}
249-
250242
public AzureOpenAiChatOptions build() {
251243
return this.options;
252244
}
@@ -364,14 +356,4 @@ public void setFunctions(Set<String> functions) {
364356
this.functions = functions;
365357
}
366358

367-
@Override
368-
public boolean isCompleteRoundTrip() {
369-
return completeRoundTrip;
370-
}
371-
372-
@Override
373-
public void setCompleteRoundTrip(boolean completeRoundTrip) {
374-
this.completeRoundTrip = completeRoundTrip;
375-
}
376-
377359
}

models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/AzureOpenAiChatClientFunctionCallIT.java

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
import java.util.ArrayList;
1919
import java.util.List;
20+
import java.util.Optional;
21+
import java.util.function.Consumer;
2022

2123
import com.azure.ai.openai.OpenAIClient;
2224
import com.azure.ai.openai.OpenAIClientBuilder;
@@ -50,6 +52,9 @@ class AzureOpenAiChatClientFunctionCallIT {
5052
@Autowired
5153
private AzureOpenAiChatClient chatClient;
5254

55+
@Autowired
56+
private String selectedModel;
57+
5358
@Test
5459
void functionCallTest() {
5560

@@ -58,7 +63,7 @@ void functionCallTest() {
5863
List<Message> messages = new ArrayList<>(List.of(userMessage));
5964

6065
var promptOptions = AzureOpenAiChatOptions.builder()
61-
.withDeploymentName("gpt-4-0125-preview")
66+
.withDeploymentName(selectedModel)
6267
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
6368
.withName("getCurrentWeather")
6469
.withDescription("Get the current weather in a given location")
@@ -84,13 +89,11 @@ void functionCallWithoutCompleteRoundTrip() {
8489

8590
final var spyingMockWeatherService = new SpyingMockWeatherService();
8691
var promptOptions = AzureOpenAiChatOptions.builder()
87-
.withDeploymentName("gpt-4-0125-preview")
92+
.withDeploymentName(selectedModel)
8893
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(spyingMockWeatherService)
8994
.withName("getCurrentWeather")
9095
.withDescription("Get the current weather in a given location")
91-
.withResponseConverter((response) -> "" + response.temp() + response.unit())
9296
.build()))
93-
.withCompleteRoundTrip(false)
9497
.build();
9598

9699
ChatResponse response = chatClient.call(new Prompt(messages, promptOptions));
@@ -111,12 +114,14 @@ public OpenAIClient openAIClient() {
111114
}
112115

113116
@Bean
114-
public AzureOpenAiChatClient azureOpenAiChatClient(OpenAIClient openAIClient) {
117+
public AzureOpenAiChatClient azureOpenAiChatClient(OpenAIClient openAIClient, String selectedModel) {
115118
return new AzureOpenAiChatClient(openAIClient,
116-
AzureOpenAiChatOptions.builder()
117-
.withDeploymentName("gpt-4-0125-preview")
118-
.withMaxTokens(500)
119-
.build());
119+
AzureOpenAiChatOptions.builder().withDeploymentName(selectedModel).withMaxTokens(500).build());
120+
}
121+
122+
@Bean
123+
public String selectedModel() {
124+
return Optional.ofNullable(System.getenv("AZURE_OPENAI_MODEL")).orElse("gpt-4-0125-preview");
120125
}
121126

122127
}

models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/function/SpyingMockWeatherService.java

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,14 @@
1717

1818
import java.util.function.Function;
1919

20-
public class SpyingMockWeatherService implements Function<MockWeatherService.Request, MockWeatherService.Response> {
21-
22-
private final MockWeatherService inner = new MockWeatherService();
20+
public class SpyingMockWeatherService implements Function<MockWeatherService.Request, Void> {
2321

2422
private MockWeatherService.Request interceptedRequest = null;
2523

2624
@Override
27-
public MockWeatherService.Response apply(MockWeatherService.Request request) {
25+
public Void apply(MockWeatherService.Request request) {
2826
interceptedRequest = request;
29-
return inner.apply(request);
27+
return null;
3028
}
3129

3230
public MockWeatherService.Request getInterceptedRequest() {

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatClient.java

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,7 @@ public ChatResponse call(Prompt prompt) {
103103
var request = createRequest(prompt, false);
104104

105105
return retryTemplate.execute(ctx -> {
106-
boolean completeRoundTrip = true;
107-
if (prompt.getOptions() instanceof MistralAiChatOptions mistralAiChatOptions) {
108-
completeRoundTrip = mistralAiChatOptions.isCompleteRoundTrip();
109-
}
110-
ResponseEntity<ChatCompletion> completionEntity = this.callWithFunctionSupport(request, completeRoundTrip);
106+
ResponseEntity<ChatCompletion> completionEntity = this.callWithFunctionSupport(request);
111107

112108
var chatCompletion = completionEntity.getBody();
113109
if (chatCompletion == null) {
@@ -152,12 +148,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
152148
ConcurrentHashMap<String, String> roleMap = new ConcurrentHashMap<>();
153149

154150
return completionChunks.map(chunk -> toChatCompletion(chunk)).map(chatCompletion -> {
155-
boolean completeRoundTrip = true;
156-
if (prompt.getOptions() instanceof MistralAiChatOptions mistralAiChatOptions) {
157-
completeRoundTrip = mistralAiChatOptions.isCompleteRoundTrip();
158-
}
159-
chatCompletion = handleFunctionCallOrReturn(request, ResponseEntity.of(Optional.of(chatCompletion)),
160-
completeRoundTrip)
151+
chatCompletion = handleFunctionCallOrReturn(request, ResponseEntity.of(Optional.of(chatCompletion)))
161152
.getBody();
162153

163154
@SuppressWarnings("null")
@@ -255,9 +246,10 @@ private List<MistralAiApi.FunctionTool> getFunctionTools(Set<String> functionNam
255246
// Function Calling Support
256247
//
257248
@Override
258-
protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionRequest previousRequest,
259-
ChatCompletionMessage responseMessage, List<ChatCompletionMessage> conversationHistory) {
260-
249+
protected CompleteRoundTripBox<ChatCompletionRequest> doCreateToolResponseRequest(
250+
ChatCompletionRequest previousRequest, ChatCompletionMessage responseMessage,
251+
List<ChatCompletionMessage> conversationHistory) {
252+
boolean needCompleteRoundTrip = false;
261253
// Every tool-call item requires a separate function call and a response (TOOL)
262254
// message.
263255
for (ToolCall toolCall : responseMessage.toolCalls()) {
@@ -270,18 +262,21 @@ protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionReques
270262
}
271263

272264
String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments);
265+
if (functionResponse != null) {
266+
needCompleteRoundTrip = true;
267+
// Add the function response to the conversation.
268+
conversationHistory.add(new ChatCompletionMessage(functionResponse, ChatCompletionMessage.Role.TOOL,
269+
functionName, null));
270+
}
273271

274-
// Add the function response to the conversation.
275-
conversationHistory
276-
.add(new ChatCompletionMessage(functionResponse, ChatCompletionMessage.Role.TOOL, functionName, null));
277272
}
278273

279274
// Recursively call chatCompletionWithTools until the model doesn't call a
280275
// functions anymore.
281276
ChatCompletionRequest newRequest = new ChatCompletionRequest(conversationHistory, false);
282277
newRequest = ModelOptionsUtils.merge(newRequest, previousRequest, ChatCompletionRequest.class);
283278

284-
return newRequest;
279+
return new CompleteRoundTripBox<>(needCompleteRoundTrip, newRequest);
285280
}
286281

287282
@Override

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,6 @@ public class MistralAiChatOptions implements FunctionCallingOptions, ChatOptions
126126
@JsonIgnore
127127
private Set<String> functions = new HashSet<>();
128128

129-
@JsonIgnore
130-
private boolean completeRoundTrip = true;
131-
132129
public static Builder builder() {
133130
return new Builder();
134131
}
@@ -199,11 +196,6 @@ public Builder withFunction(String functionName) {
199196
return this;
200197
}
201198

202-
public Builder withCompleteRoundTrip(boolean completeRoundTrip) {
203-
this.options.completeRoundTrip = completeRoundTrip;
204-
return this;
205-
}
206-
207199
public MistralAiChatOptions build() {
208200
return this.options;
209201
}
@@ -317,14 +309,4 @@ public void setFunctions(Set<String> functions) {
317309
this.functions = functions;
318310
}
319311

320-
@Override
321-
public boolean isCompleteRoundTrip() {
322-
return completeRoundTrip;
323-
}
324-
325-
@Override
326-
public void setCompleteRoundTrip(boolean completeRoundTrip) {
327-
this.completeRoundTrip = completeRoundTrip;
328-
}
329-
330312
}

0 commit comments

Comments
 (0)