Skip to content

Commit 0fdc6aa

Browse files
committed
Add streaming Function Calling support of OpenAI and Mistral AI
- Extends the reactor logic to to allow aggregation of the chunked tool-calls messages and leverage the exsiting fnctoin calling infrastructure. - Seamples experience for the streaming functionality. - Add Message ID and FinishReason to the returned Generations properties.
1 parent f8f38d6 commit 0fdc6aa

File tree

14 files changed

+862
-128
lines changed

14 files changed

+862
-128
lines changed

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatClient.java

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import java.util.HashSet;
1919
import java.util.List;
2020
import java.util.Map;
21+
import java.util.Optional;
2122
import java.util.Set;
2223
import java.util.concurrent.ConcurrentHashMap;
2324

@@ -34,6 +35,8 @@
3435
import org.springframework.ai.chat.prompt.Prompt;
3536
import org.springframework.ai.mistralai.api.MistralAiApi;
3637
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletion;
38+
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletion.Choice;
39+
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionChunk;
3740
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage;
3841
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage.ToolCall;
3942
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest;
@@ -131,13 +134,21 @@ public Flux<ChatResponse> stream(Prompt prompt) {
131134
// The rest of the chunks with same ID share the same role.
132135
ConcurrentHashMap<String, String> roleMap = new ConcurrentHashMap<>();
133136

134-
return completionChunks.map(chunk -> {
135-
String chunkId = chunk.id();
136-
List<Generation> generations = chunk.choices().stream().map(choice -> {
137-
if (choice.delta().role() != null) {
138-
roleMap.putIfAbsent(chunkId, choice.delta().role().name());
137+
return completionChunks.map(chunk -> toChatCompletion(chunk)).map(chatCompletion -> {
138+
139+
chatCompletion = handleFunctionCallOrReturn(request, ResponseEntity.of(Optional.of(chatCompletion)))
140+
.getBody();
141+
142+
@SuppressWarnings("null")
143+
String id = chatCompletion.id();
144+
145+
List<Generation> generations = chatCompletion.choices().stream().map(choice -> {
146+
if (choice.message().role() != null) {
147+
roleMap.putIfAbsent(id, choice.message().role().name());
139148
}
140-
var generation = new Generation(choice.delta().content(), Map.of("role", roleMap.get(chunkId)));
149+
String finish = (choice.finishReason() != null ? choice.finishReason().name() : "");
150+
var generation = new Generation(choice.message().content(),
151+
Map.of("id", id, "role", roleMap.get(id), "finishReason", finish));
141152
if (choice.finishReason() != null) {
142153
generation = generation
143154
.withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null));
@@ -149,6 +160,15 @@ public Flux<ChatResponse> stream(Prompt prompt) {
149160
});
150161
}
151162

163+
private ChatCompletion toChatCompletion(ChatCompletionChunk chunk) {
164+
List<Choice> choices = chunk.choices()
165+
.stream()
166+
.map(cc -> new Choice(cc.index(), cc.delta(), cc.finishReason()))
167+
.toList();
168+
169+
return new ChatCompletion(chunk.id(), "chat.completion", chunk.created(), chunk.model(), choices, null);
170+
}
171+
152172
/**
153173
* Accessible for testing.
154174
*/
@@ -194,10 +214,6 @@ MistralAiApi.ChatCompletionRequest createRequest(Prompt prompt, boolean stream)
194214
// Add the enabled functions definitions to the request's tools parameter.
195215
if (!CollectionUtils.isEmpty(functionsForThisRequest)) {
196216

197-
if (stream) {
198-
throw new IllegalArgumentException("Currently tool functions are not supported in streaming mode");
199-
}
200-
201217
request = ModelOptionsUtils.merge(
202218
MistralAiChatOptions.builder().withTools(this.getFunctionTools(functionsForThisRequest)).build(),
203219
request, ChatCompletionRequest.class);
@@ -241,7 +257,7 @@ protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionReques
241257

242258
// Recursively call chatCompletionWithTools until the model doesn't call a
243259
// functions anymore.
244-
ChatCompletionRequest newRequest = new ChatCompletionRequest(conversationHistory, previousRequest.stream());
260+
ChatCompletionRequest newRequest = new ChatCompletionRequest(conversationHistory, false);
245261
newRequest = ModelOptionsUtils.merge(newRequest, previousRequest, ChatCompletionRequest.class);
246262

247263
return newRequest;
@@ -252,6 +268,7 @@ protected List<ChatCompletionMessage> doGetUserMessages(ChatCompletionRequest re
252268
return request.messages();
253269
}
254270

271+
@SuppressWarnings("null")
255272
@Override
256273
protected ChatCompletionMessage doGetToolResponseMessage(ResponseEntity<ChatCompletion> chatCompletion) {
257274
return chatCompletion.getBody().choices().iterator().next().message();
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
/*
2+
* Copyright 2023 - 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.mistralai.api;
17+
18+
import java.util.ArrayList;
19+
import java.util.List;
20+
import java.util.Optional;
21+
import java.util.UUID;
22+
23+
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionChunk;
24+
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionChunk.ChunkChoice;
25+
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionFinishReason;
26+
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage;
27+
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage.ChatCompletionFunction;
28+
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage.Role;
29+
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage.ToolCall;
30+
import org.springframework.util.CollectionUtils;
31+
32+
/**
33+
* Helper class to support Streaming function calling.
34+
*
35+
* It can merge the streamed ChatCompletionChunk in case of function calling message.
36+
*
37+
* @author Christian Tzolov
38+
* @since 0.8.1
39+
*/
40+
public class MIstralAiStreamFunctionCallingHelper {
41+
42+
/**
43+
* Merge the previous and current ChatCompletionChunk into a single one.
44+
* @param previous the previous ChatCompletionChunk
45+
* @param current the current ChatCompletionChunk
46+
* @return the merged ChatCompletionChunk
47+
*/
48+
public ChatCompletionChunk merge(ChatCompletionChunk previous, ChatCompletionChunk current) {
49+
50+
if (previous == null) {
51+
return current;
52+
}
53+
54+
String id = (current.id() != null ? current.id() : previous.id());
55+
Long created = (current.created() != null ? current.created() : previous.created());
56+
String model = (current.model() != null ? current.model() : previous.model());
57+
String object = (current.object() != null ? current.object() : previous.object());
58+
59+
ChunkChoice previousChoice0 = (CollectionUtils.isEmpty(previous.choices()) ? null : previous.choices().get(0));
60+
ChunkChoice currentChoice0 = (CollectionUtils.isEmpty(current.choices()) ? null : current.choices().get(0));
61+
62+
ChunkChoice choice = merge(previousChoice0, currentChoice0);
63+
64+
return new ChatCompletionChunk(id, object, created, model, List.of(choice));
65+
}
66+
67+
private ChunkChoice merge(ChunkChoice previous, ChunkChoice current) {
68+
if (previous == null) {
69+
if (current.delta() != null && current.delta().toolCalls() != null) {
70+
Optional<String> id = current.delta()
71+
.toolCalls()
72+
.stream()
73+
.filter(tool -> tool.id() != null)
74+
.map(tool -> tool.id())
75+
.findFirst();
76+
if (!id.isPresent()) {
77+
var newId = UUID.randomUUID().toString();
78+
79+
var toolCallsWithID = current.delta()
80+
.toolCalls()
81+
.stream()
82+
.map(toolCall -> new ToolCall(newId, "function", toolCall.function()))
83+
.toList();
84+
85+
var role = current.delta().role() != null ? current.delta().role() : Role.ASSISTANT;
86+
current = new ChunkChoice(current.index(), new ChatCompletionMessage(current.delta().content(),
87+
role, current.delta().name(), toolCallsWithID), current.finishReason());
88+
}
89+
}
90+
return current;
91+
}
92+
93+
ChatCompletionFinishReason finishReason = (current.finishReason() != null ? current.finishReason()
94+
: previous.finishReason());
95+
Integer index = (current.index() != null ? current.index() : previous.index());
96+
97+
ChatCompletionMessage message = merge(previous.delta(), current.delta());
98+
99+
return new ChunkChoice(index, message, finishReason);
100+
}
101+
102+
private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompletionMessage current) {
103+
String content = (current.content() != null ? current.content()
104+
: "" + ((previous.content() != null) ? previous.content() : ""));
105+
Role role = (current.role() != null ? current.role() : previous.role());
106+
role = (role != null ? role : Role.ASSISTANT); // default to ASSISTANT (if null
107+
String name = (current.name() != null ? current.name() : previous.name());
108+
109+
List<ToolCall> toolCalls = new ArrayList<>();
110+
ToolCall lastPreviousTooCall = null;
111+
if (previous.toolCalls() != null) {
112+
lastPreviousTooCall = previous.toolCalls().get(previous.toolCalls().size() - 1);
113+
if (previous.toolCalls().size() > 1) {
114+
toolCalls.addAll(previous.toolCalls().subList(0, previous.toolCalls().size() - 1));
115+
}
116+
}
117+
if (current.toolCalls() != null) {
118+
if (current.toolCalls().size() > 1) {
119+
throw new IllegalStateException("Currently only one tool call is supported per message!");
120+
}
121+
var currentToolCall = current.toolCalls().iterator().next();
122+
if (currentToolCall.id() != null) {
123+
if (lastPreviousTooCall != null) {
124+
toolCalls.add(lastPreviousTooCall);
125+
}
126+
toolCalls.add(currentToolCall);
127+
}
128+
else {
129+
toolCalls.add(merge(lastPreviousTooCall, currentToolCall));
130+
}
131+
}
132+
else {
133+
if (lastPreviousTooCall != null) {
134+
toolCalls.add(lastPreviousTooCall);
135+
}
136+
}
137+
return new ChatCompletionMessage(content, role, name, toolCalls);
138+
}
139+
140+
private ToolCall merge(ToolCall previous, ToolCall current) {
141+
if (previous == null) {
142+
return current;
143+
}
144+
String id = (current.id() != null ? current.id() : previous.id());
145+
String type = (current.type() != null ? current.type() : previous.type());
146+
ChatCompletionFunction function = merge(previous.function(), current.function());
147+
return new ToolCall(id, type, function);
148+
}
149+
150+
private ChatCompletionFunction merge(ChatCompletionFunction previous, ChatCompletionFunction current) {
151+
if (previous == null) {
152+
return current;
153+
}
154+
String name = (current.name() != null ? current.name() : previous.name());
155+
StringBuilder arguments = new StringBuilder();
156+
if (previous.arguments() != null) {
157+
arguments.append(previous.arguments());
158+
}
159+
if (current.arguments() != null) {
160+
arguments.append(current.arguments());
161+
}
162+
return new ChatCompletionFunction(name, arguments.toString());
163+
}
164+
165+
/**
166+
* @param chatCompletion the ChatCompletionChunk to check
167+
* @return true if the ChatCompletionChunk is a streaming tool function call.
168+
*/
169+
public boolean isStreamingToolFunctionCall(ChatCompletionChunk chatCompletion) {
170+
171+
var choices = chatCompletion.choices();
172+
if (CollectionUtils.isEmpty(choices)) {
173+
return false;
174+
}
175+
176+
var choice = choices.get(0);
177+
return !CollectionUtils.isEmpty(choice.delta().toolCalls());
178+
}
179+
180+
/**
181+
* @param chatCompletion the ChatCompletionChunk to check
182+
* @return true if the ChatCompletionChunk is a streaming tool function call and it is
183+
* the last one.
184+
*/
185+
public boolean isStreamingToolFunctionCallFinish(ChatCompletionChunk chatCompletion) {
186+
187+
var choices = chatCompletion.choices();
188+
if (CollectionUtils.isEmpty(choices)) {
189+
return false;
190+
}
191+
192+
var choice = choices.get(0);
193+
return choice.finishReason() == ChatCompletionFinishReason.TOOL_CALL
194+
|| choice.finishReason() == ChatCompletionFinishReason.TOOL_CALLS;
195+
}
196+
197+
}
198+
// ---

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import java.util.List;
1919
import java.util.Map;
20+
import java.util.concurrent.atomic.AtomicBoolean;
2021
import java.util.function.Consumer;
2122
import java.util.function.Predicate;
2223

@@ -704,6 +705,8 @@ public ResponseEntity<ChatCompletion> chatCompletionEntity(ChatCompletionRequest
704705
.toEntity(ChatCompletion.class);
705706
}
706707

708+
private MIstralAiStreamFunctionCallingHelper chunkMerger = new MIstralAiStreamFunctionCallingHelper();
709+
707710
/**
708711
* Creates a streaming chat response for the given chat conversation.
709712
* @param chatRequest The chat completion request. Must have the stream property set
@@ -715,14 +718,35 @@ public Flux<ChatCompletionChunk> chatCompletionStream(ChatCompletionRequest chat
715718
Assert.notNull(chatRequest, "The request body can not be null.");
716719
Assert.isTrue(chatRequest.stream(), "Request must set the steam property to true.");
717720

721+
AtomicBoolean isInsideTool = new AtomicBoolean(false);
722+
718723
return this.webClient.post()
719724
.uri("/v1/chat/completions")
720725
.body(Mono.just(chatRequest), ChatCompletionRequest.class)
721726
.retrieve()
722727
.bodyToFlux(String.class)
723728
.takeUntil(SSE_DONE_PREDICATE)
724729
.filter(SSE_DONE_PREDICATE.negate())
725-
.map(content -> ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class));
730+
.map(content -> ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class))
731+
.map(chunk -> {
732+
if (this.chunkMerger.isStreamingToolFunctionCall(chunk)) {
733+
isInsideTool.set(true);
734+
}
735+
return chunk;
736+
})
737+
.windowUntil(chunk -> {
738+
if (isInsideTool.get() && this.chunkMerger.isStreamingToolFunctionCallFinish(chunk)) {
739+
isInsideTool.set(false);
740+
return true;
741+
}
742+
return !isInsideTool.get();
743+
})
744+
.concatMapIterable(window -> {
745+
Mono<ChatCompletionChunk> mono1 = window.reduce(new ChatCompletionChunk(null, null, null, null, null),
746+
(previous, current) -> this.chunkMerger.merge(previous, current));
747+
return List.of(mono1);
748+
})
749+
.flatMap(mono -> mono);
726750
}
727751

728752
}

0 commit comments

Comments
 (0)