Skip to content

Commit 1805ca0

Browse files
WOONBEspring-builds
authored andcommitted
refactor : refactor MessageAggregator to include toolCalls
test: Add unit test for MessageAggregator tool call aggregation Fixes #3366 Signed-off-by: WOONBE <kepull2918@naver.com> (cherry picked from commit df90b9c)
1 parent 4f520f8 commit 1805ca0

File tree

2 files changed

+82
-4
lines changed

2 files changed

+82
-4
lines changed

spring-ai-model/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package org.springframework.ai.chat.model;
1818

19+
import java.util.ArrayList;
1920
import java.util.HashMap;
2021
import java.util.List;
2122
import java.util.Map;
@@ -24,6 +25,7 @@
2425

2526
import org.slf4j.Logger;
2627
import org.slf4j.LoggerFactory;
28+
import org.springframework.util.CollectionUtils;
2729
import reactor.core.publisher.Flux;
2830

2931
import org.springframework.ai.chat.messages.AssistantMessage;
@@ -35,13 +37,16 @@
3537
import org.springframework.ai.chat.metadata.Usage;
3638
import org.springframework.util.StringUtils;
3739

40+
import static org.springframework.ai.chat.messages.AssistantMessage.*;
41+
3842
/**
3943
* Helper that for streaming chat responses, aggregate the chat response messages into a
4044
* single AssistantMessage. Job is performed in parallel to the chat response processing.
4145
*
4246
* @author Christian Tzolov
4347
* @author Alexandros Pappas
4448
* @author Thomas Vitale
49+
* @author Heonwoo Kim
4550
* @since 1.0.0
4651
*/
4752
public class MessageAggregator {
@@ -54,6 +59,7 @@ public Flux<ChatResponse> aggregate(Flux<ChatResponse> fluxChatResponse,
5459
// Assistant Message
5560
AtomicReference<StringBuilder> messageTextContentRef = new AtomicReference<>(new StringBuilder());
5661
AtomicReference<Map<String, Object>> messageMetadataMapRef = new AtomicReference<>();
62+
AtomicReference<List<ToolCall>> toolCallsRef = new AtomicReference<>(new ArrayList<>());
5763

5864
// ChatGeneration Metadata
5965
AtomicReference<ChatGenerationMetadata> generationMetadataRef = new AtomicReference<>(
@@ -73,6 +79,7 @@ public Flux<ChatResponse> aggregate(Flux<ChatResponse> fluxChatResponse,
7379
return fluxChatResponse.doOnSubscribe(subscription -> {
7480
messageTextContentRef.set(new StringBuilder());
7581
messageMetadataMapRef.set(new HashMap<>());
82+
toolCallsRef.set(new ArrayList<>());
7683
metadataIdRef.set("");
7784
metadataModelRef.set("");
7885
metadataUsagePromptTokensRef.set(0);
@@ -94,6 +101,11 @@ public Flux<ChatResponse> aggregate(Flux<ChatResponse> fluxChatResponse,
94101
if (chatResponse.getResult().getOutput().getMetadata() != null) {
95102
messageMetadataMapRef.get().putAll(chatResponse.getResult().getOutput().getMetadata());
96103
}
104+
AssistantMessage outputMessage = chatResponse.getResult().getOutput();
105+
if (!CollectionUtils.isEmpty(outputMessage.getToolCalls())) {
106+
toolCallsRef.get().addAll(outputMessage.getToolCalls());
107+
}
108+
97109
}
98110
if (chatResponse.getMetadata() != null) {
99111
if (chatResponse.getMetadata().getUsage() != null) {
@@ -119,6 +131,13 @@ public Flux<ChatResponse> aggregate(Flux<ChatResponse> fluxChatResponse,
119131
if (StringUtils.hasText(chatResponse.getMetadata().getModel())) {
120132
metadataModelRef.set(chatResponse.getMetadata().getModel());
121133
}
134+
Object toolCallsFromMetadata = chatResponse.getMetadata().get("toolCalls");
135+
if (toolCallsFromMetadata instanceof List) {
136+
@SuppressWarnings("unchecked")
137+
List<ToolCall> toolCallsList = (List<ToolCall>) toolCallsFromMetadata;
138+
toolCallsRef.get().addAll(toolCallsList);
139+
}
140+
122141
}
123142
}).doOnComplete(() -> {
124143

@@ -133,12 +152,25 @@ public Flux<ChatResponse> aggregate(Flux<ChatResponse> fluxChatResponse,
133152
.promptMetadata(metadataPromptMetadataRef.get())
134153
.build();
135154

136-
onAggregationComplete.accept(new ChatResponse(List.of(new Generation(
137-
new AssistantMessage(messageTextContentRef.get().toString(), messageMetadataMapRef.get()),
155+
AssistantMessage finalAssistantMessage;
156+
List<ToolCall> collectedToolCalls = toolCallsRef.get();
157+
158+
if (!CollectionUtils.isEmpty(collectedToolCalls)) {
159+
160+
finalAssistantMessage = new AssistantMessage(messageTextContentRef.get().toString(),
161+
messageMetadataMapRef.get(), collectedToolCalls);
162+
}
163+
else {
164+
finalAssistantMessage = new AssistantMessage(messageTextContentRef.get().toString(),
165+
messageMetadataMapRef.get());
166+
}
167+
onAggregationComplete.accept(new ChatResponse(List.of(new Generation(finalAssistantMessage,
168+
138169
generationMetadataRef.get())), chatResponseMetadata));
139170

140171
messageTextContentRef.set(new StringBuilder());
141172
messageMetadataMapRef.set(new HashMap<>());
173+
toolCallsRef.set(new ArrayList<>());
142174
metadataIdRef.set("");
143175
metadataModelRef.set("");
144176
metadataUsagePromptTokensRef.set(0);

spring-ai-model/src/test/java/org/springframework/ai/chat/model/ChatResponseTests.java

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,32 @@
1919
import java.util.List;
2020
import java.util.Map;
2121
import java.util.Set;
22+
import java.util.concurrent.atomic.AtomicReference;
2223

2324
import org.junit.jupiter.api.Test;
2425

2526
import org.springframework.ai.chat.messages.AssistantMessage;
2627
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
28+
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
29+
import reactor.core.publisher.Flux;
2730

2831
import static org.assertj.core.api.Assertions.assertThat;
2932
import static org.assertj.core.api.Assertions.assertThatThrownBy;
33+
import static org.springframework.ai.chat.messages.AssistantMessage.*;
3034

3135
/**
3236
* Unit tests for {@link ChatResponse}.
3337
*
3438
* @author Thomas Vitale
39+
* @author Heonwoo Kim
3540
*/
3641
class ChatResponseTests {
3742

3843
@Test
3944
void whenToolCallsArePresentThenReturnTrue() {
4045
ChatResponse chatResponse = ChatResponse.builder()
41-
.generations(List.of(new Generation(new AssistantMessage("", Map.of(),
42-
List.of(new AssistantMessage.ToolCall("toolA", "function", "toolA", "{}"))))))
46+
.generations(List.of(new Generation(
47+
new AssistantMessage("", Map.of(), List.of(new ToolCall("toolA", "function", "toolA", "{}"))))))
4348
.build();
4449
assertThat(chatResponse.hasToolCalls()).isTrue();
4550
}
@@ -80,4 +85,45 @@ void whenFinishReasonIsNotPresent() {
8085
assertThat(chatResponse.hasFinishReasons(Set.of("completed"))).isFalse();
8186
}
8287

88+
@Test
89+
void messageAggregatorShouldCorrectlyAggregateToolCallsFromStream() {
90+
91+
MessageAggregator aggregator = new MessageAggregator();
92+
93+
ChatResponse chunk1 = new ChatResponse(
94+
List.of(new Generation(new AssistantMessage("Thinking about the weather... "))));
95+
96+
ToolCall weatherToolCall = new ToolCall("tool-id-123", "function", "getCurrentWeather",
97+
"{\"location\": \"Seoul\"}");
98+
99+
Map<String, Object> metadataWithToolCall = Map.of("toolCalls", List.of(weatherToolCall));
100+
ChatResponseMetadata responseMetadataForChunk2 = ChatResponseMetadata.builder()
101+
.metadata(metadataWithToolCall)
102+
.build();
103+
104+
ChatResponse chunk2 = new ChatResponse(List.of(new Generation(new AssistantMessage(""))),
105+
responseMetadataForChunk2);
106+
107+
Flux<ChatResponse> streamingResponse = Flux.just(chunk1, chunk2);
108+
109+
AtomicReference<ChatResponse> aggregatedResponseRef = new AtomicReference<>();
110+
111+
aggregator.aggregate(streamingResponse, aggregatedResponseRef::set).blockLast();
112+
113+
ChatResponse finalResponse = aggregatedResponseRef.get();
114+
assertThat(finalResponse).isNotNull();
115+
116+
AssistantMessage finalAssistantMessage = finalResponse.getResult().getOutput();
117+
118+
assertThat(finalAssistantMessage).isNotNull();
119+
assertThat(finalAssistantMessage.getText()).isEqualTo("Thinking about the weather... ");
120+
assertThat(finalAssistantMessage.hasToolCalls()).isTrue();
121+
assertThat(finalAssistantMessage.getToolCalls()).hasSize(1);
122+
123+
ToolCall resultToolCall = finalAssistantMessage.getToolCalls().get(0);
124+
assertThat(resultToolCall.id()).isEqualTo("tool-id-123");
125+
assertThat(resultToolCall.name()).isEqualTo("getCurrentWeather");
126+
assertThat(resultToolCall.arguments()).isEqualTo("{\"location\": \"Seoul\"}");
127+
}
128+
83129
}

0 commit comments

Comments
 (0)