Skip to content

Commit a7eb28a

Browse files
Grogdunntzolov
authored andcommitted
Add real Function Calling Streaming support
- Add Java reflection merge utilities that can access Azure private constructors and fields. - Azure merging, creation of flux windows. - Function call grouping for function processing. - Do not perform greedy operation on Flux. - Use "real" streaming on all client on function response. - Gerimi: fix missing method impl. - Mistral AI, OpenAI: fix missing stream flag in doCreateToolResponseRequest. - Fix code formatting. No wildcard imports. - Add Grogdunn to the javadoc authors. - Anthropic 3 API does not support streaming funciton calling yet.
1 parent b0799ba commit a7eb28a

File tree

11 files changed

+610
-126
lines changed

11 files changed

+610
-126
lines changed

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatClient.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,4 +450,11 @@ protected boolean isToolFunctionCall(ResponseEntity<ChatCompletion> response) {
450450
return response.getBody().content().stream().anyMatch(content -> content.type() == MediaContent.Type.TOOL_USE);
451451
}
452452

453+
@Override
454+
protected Flux<ResponseEntity<ChatCompletion>> doChatCompletionStream(ChatCompletionRequest request) {
455+
// https://docs.anthropic.com/en/docs/tool-use
456+
throw new UnsupportedOperationException(
457+
"Streaming (stream=true) is not yet supported. We plan to add streaming support in a future beta version.");
458+
}
459+
453460
}

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import org.springframework.ai.model.ModelOptionsUtils;
3030
import org.springframework.ai.retry.RetryUtils;
3131
import org.springframework.http.HttpHeaders;
32+
import org.springframework.http.HttpStatusCode;
3233
import org.springframework.http.MediaType;
3334
import org.springframework.http.ResponseEntity;
3435
import org.springframework.util.Assert;
@@ -100,7 +101,13 @@ public AnthropicApi(String baseUrl, String anthropicApiKey, String anthropicVers
100101
.defaultStatusHandler(responseErrorHandler)
101102
.build();
102103

103-
this.webClient = WebClient.builder().baseUrl(baseUrl).defaultHeaders(jsonContentHeaders).build();
104+
this.webClient = WebClient.builder()
105+
.baseUrl(baseUrl)
106+
.defaultHeaders(jsonContentHeaders)
107+
.defaultStatusHandler(HttpStatusCode::isError,
108+
resp -> Mono.just(new RuntimeException("Response exception, Status: [" + resp.statusCode()
109+
+ "], Body:[" + resp.bodyToMono(java.lang.String.class) + "]")))
110+
.build();
104111
}
105112

106113
/**

models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatClientIT.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,15 +205,15 @@ void functionCallTest() {
205205
.withModel(AnthropicApi.ChatModel.CLAUDE_3_OPUS.getValue())
206206
.withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService())
207207
.withName("getCurrentWeather")
208-
.withDescription("Get the weather in location")
208+
.withDescription("Get the weather in location. Return temperature in 36°F or 36°C format.")
209209
.build()))
210210
.build();
211211

212212
ChatResponse response = chatClient.call(new Prompt(messages, promptOptions));
213213

214214
logger.info("Response: {}", response);
215215

216-
Generation generation = response.getResults().get(0);
216+
Generation generation = response.getResult();
217217
assertThat(generation.getOutput().getContent()).containsAnyOf("30.0", "30");
218218
assertThat(generation.getOutput().getContent()).containsAnyOf("10.0", "10");
219219
assertThat(generation.getOutput().getContent()).containsAnyOf("15.0", "15");

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatClient.java

Lines changed: 56 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,6 @@
1515
*/
1616
package org.springframework.ai.azure.openai;
1717

18-
import java.util.Collections;
19-
import java.util.HashSet;
20-
import java.util.List;
21-
import java.util.Set;
22-
2318
import com.azure.ai.openai.OpenAIClient;
2419
import com.azure.ai.openai.models.ChatChoice;
2520
import com.azure.ai.openai.models.ChatCompletions;
@@ -33,15 +28,14 @@
3328
import com.azure.ai.openai.models.ChatRequestSystemMessage;
3429
import com.azure.ai.openai.models.ChatRequestToolMessage;
3530
import com.azure.ai.openai.models.ChatRequestUserMessage;
36-
import com.azure.ai.openai.models.ChatResponseMessage;
3731
import com.azure.ai.openai.models.CompletionsFinishReason;
3832
import com.azure.ai.openai.models.ContentFilterResultsForPrompt;
33+
import com.azure.ai.openai.models.FunctionCall;
3934
import com.azure.ai.openai.models.FunctionDefinition;
4035
import com.azure.core.util.BinaryData;
4136
import com.azure.core.util.IterableStream;
4237
import org.slf4j.Logger;
4338
import org.slf4j.LoggerFactory;
44-
import reactor.core.publisher.Flux;
4539

4640
import org.springframework.ai.azure.openai.metadata.AzureOpenAiChatResponseMetadata;
4741
import org.springframework.ai.chat.ChatClient;
@@ -59,6 +53,14 @@
5953
import org.springframework.ai.model.function.FunctionCallbackContext;
6054
import org.springframework.util.Assert;
6155
import org.springframework.util.CollectionUtils;
56+
import reactor.core.publisher.Flux;
57+
58+
import java.util.Collections;
59+
import java.util.HashSet;
60+
import java.util.List;
61+
import java.util.Optional;
62+
import java.util.Set;
63+
import java.util.concurrent.atomic.AtomicBoolean;
6264

6365
/**
6466
* {@link ChatClient} implementation for {@literal Microsoft Azure AI} backed by
@@ -68,6 +70,7 @@
6870
* @author Ueibin Kim
6971
* @author John Blum
7072
* @author Christian Tzolov
73+
* @author Grogdunn
7174
* @see ChatClient
7275
* @see com.azure.ai.openai.OpenAIClient
7376
*/
@@ -158,17 +161,42 @@ public Flux<ChatResponse> stream(Prompt prompt) {
158161
IterableStream<ChatCompletions> chatCompletionsStream = this.openAIClient
159162
.getChatCompletionsStream(options.getModel(), options);
160163

161-
return Flux.fromStream(chatCompletionsStream.stream()
164+
Flux<ChatCompletions> chatCompletionsFlux = Flux.fromIterable(chatCompletionsStream);
165+
166+
final var isFunctionCall = new AtomicBoolean(false);
167+
final var accessibleChatCompletionsFlux = chatCompletionsFlux
162168
// Note: the first chat completions can be ignored when using Azure OpenAI
163169
// service which is a known service bug.
164170
.skip(1)
165-
.map(ChatCompletions::getChoices)
166-
.flatMap(List::stream)
171+
.map(chatCompletions -> {
172+
final var toolCalls = chatCompletions.getChoices().get(0).getDelta().getToolCalls();
173+
isFunctionCall.set(toolCalls != null && !toolCalls.isEmpty());
174+
return chatCompletions;
175+
})
176+
.windowUntil(chatCompletions -> {
177+
if (isFunctionCall.get() && chatCompletions.getChoices()
178+
.get(0)
179+
.getFinishReason() == CompletionsFinishReason.TOOL_CALLS) {
180+
isFunctionCall.set(false);
181+
return true;
182+
}
183+
return false;
184+
}, false)
185+
.concatMapIterable(window -> {
186+
final var reduce = window.reduce(MergeUtils.emptyChatCompletions(), MergeUtils::mergeChatCompletions);
187+
return List.of(reduce);
188+
})
189+
.flatMap(mono -> mono);
190+
return accessibleChatCompletionsFlux
191+
.switchMap(accessibleChatCompletions -> handleFunctionCallOrReturnStream(options,
192+
Flux.just(accessibleChatCompletions)))
193+
.flatMapIterable(ChatCompletions::getChoices)
167194
.map(choice -> {
168-
var content = (choice.getDelta() != null) ? choice.getDelta().getContent() : null;
195+
var content = Optional.ofNullable(choice.getMessage()).orElse(choice.getDelta()).getContent();
169196
var generation = new Generation(content).withGenerationMetadata(generateChoiceMetadata(choice));
170197
return new ChatResponse(List.of(generation));
171-
}));
198+
});
199+
172200
}
173201

174202
/**
@@ -522,9 +550,17 @@ protected List<ChatRequestMessage> doGetUserMessages(ChatCompletionsOptions requ
522550

523551
@Override
524552
protected ChatRequestMessage doGetToolResponseMessage(ChatCompletions response) {
525-
ChatResponseMessage responseMessage = response.getChoices().get(0).getMessage();
553+
final var accessibleChatChoice = response.getChoices().get(0);
554+
var responseMessage = Optional.ofNullable(accessibleChatChoice.getMessage())
555+
.orElse(accessibleChatChoice.getDelta());
526556
ChatRequestAssistantMessage assistantMessage = new ChatRequestAssistantMessage("");
527-
assistantMessage.setToolCalls(responseMessage.getToolCalls());
557+
final var toolCalls = responseMessage.getToolCalls();
558+
assistantMessage.setToolCalls(toolCalls.stream().map(tc -> {
559+
final var tc1 = (ChatCompletionsFunctionToolCall) tc;
560+
var toDowncast = new ChatCompletionsFunctionToolCall(tc.getId(),
561+
new FunctionCall(tc1.getFunction().getName(), tc1.getFunction().getArguments()));
562+
return ((ChatCompletionsToolCall) toDowncast);
563+
}).toList());
528564
return assistantMessage;
529565
}
530566

@@ -533,6 +569,11 @@ protected ChatCompletions doChatCompletion(ChatCompletionsOptions request) {
533569
return this.openAIClient.getChatCompletions(request.getModel(), request);
534570
}
535571

572+
@Override
573+
protected Flux<ChatCompletions> doChatCompletionStream(ChatCompletionsOptions request) {
574+
return Flux.fromIterable(this.openAIClient.getChatCompletionsStream(request.getModel(), request));
575+
}
576+
536577
@Override
537578
protected boolean isToolFunctionCall(ChatCompletions chatCompletions) {
538579

@@ -549,4 +590,4 @@ protected boolean isToolFunctionCall(ChatCompletions chatCompletions) {
549590
return choice.getFinishReason() == CompletionsFinishReason.TOOL_CALLS;
550591
}
551592

552-
}
593+
}

0 commit comments

Comments
 (0)