16
16
17
17
package org .springframework .ai .chat .model ;
18
18
19
+ import java .util .ArrayList ;
19
20
import java .util .HashMap ;
20
21
import java .util .List ;
21
22
import java .util .Map ;
24
25
25
26
import org .slf4j .Logger ;
26
27
import org .slf4j .LoggerFactory ;
28
+ import org .springframework .util .CollectionUtils ;
27
29
import reactor .core .publisher .Flux ;
28
30
29
31
import org .springframework .ai .chat .messages .AssistantMessage ;
35
37
import org .springframework .ai .chat .metadata .Usage ;
36
38
import org .springframework .util .StringUtils ;
37
39
40
+ import static org .springframework .ai .chat .messages .AssistantMessage .*;
41
+
38
42
/**
39
43
* Helper that for streaming chat responses, aggregate the chat response messages into a
40
44
* single AssistantMessage. Job is performed in parallel to the chat response processing.
41
45
*
42
46
* @author Christian Tzolov
43
47
* @author Alexandros Pappas
44
48
* @author Thomas Vitale
49
+ * @author Heonwoo Kim
45
50
* @since 1.0.0
46
51
*/
47
52
public class MessageAggregator {
@@ -54,6 +59,7 @@ public Flux<ChatResponse> aggregate(Flux<ChatResponse> fluxChatResponse,
54
59
// Assistant Message
55
60
AtomicReference <StringBuilder > messageTextContentRef = new AtomicReference <>(new StringBuilder ());
56
61
AtomicReference <Map <String , Object >> messageMetadataMapRef = new AtomicReference <>();
62
+ AtomicReference <List <ToolCall >> toolCallsRef = new AtomicReference <>(new ArrayList <>());
57
63
58
64
// ChatGeneration Metadata
59
65
AtomicReference <ChatGenerationMetadata > generationMetadataRef = new AtomicReference <>(
@@ -73,6 +79,7 @@ public Flux<ChatResponse> aggregate(Flux<ChatResponse> fluxChatResponse,
73
79
return fluxChatResponse .doOnSubscribe (subscription -> {
74
80
messageTextContentRef .set (new StringBuilder ());
75
81
messageMetadataMapRef .set (new HashMap <>());
82
+ toolCallsRef .set (new ArrayList <>());
76
83
metadataIdRef .set ("" );
77
84
metadataModelRef .set ("" );
78
85
metadataUsagePromptTokensRef .set (0 );
@@ -94,6 +101,11 @@ public Flux<ChatResponse> aggregate(Flux<ChatResponse> fluxChatResponse,
94
101
if (chatResponse .getResult ().getOutput ().getMetadata () != null ) {
95
102
messageMetadataMapRef .get ().putAll (chatResponse .getResult ().getOutput ().getMetadata ());
96
103
}
104
+ AssistantMessage outputMessage = chatResponse .getResult ().getOutput ();
105
+ if (!CollectionUtils .isEmpty (outputMessage .getToolCalls ())) {
106
+ toolCallsRef .get ().addAll (outputMessage .getToolCalls ());
107
+ }
108
+
97
109
}
98
110
if (chatResponse .getMetadata () != null ) {
99
111
if (chatResponse .getMetadata ().getUsage () != null ) {
@@ -119,6 +131,13 @@ public Flux<ChatResponse> aggregate(Flux<ChatResponse> fluxChatResponse,
119
131
if (StringUtils .hasText (chatResponse .getMetadata ().getModel ())) {
120
132
metadataModelRef .set (chatResponse .getMetadata ().getModel ());
121
133
}
134
+ Object toolCallsFromMetadata = chatResponse .getMetadata ().get ("toolCalls" );
135
+ if (toolCallsFromMetadata instanceof List ) {
136
+ @ SuppressWarnings ("unchecked" )
137
+ List <ToolCall > toolCallsList = (List <ToolCall >) toolCallsFromMetadata ;
138
+ toolCallsRef .get ().addAll (toolCallsList );
139
+ }
140
+
122
141
}
123
142
}).doOnComplete (() -> {
124
143
@@ -133,12 +152,25 @@ public Flux<ChatResponse> aggregate(Flux<ChatResponse> fluxChatResponse,
133
152
.promptMetadata (metadataPromptMetadataRef .get ())
134
153
.build ();
135
154
136
- onAggregationComplete .accept (new ChatResponse (List .of (new Generation (
137
- new AssistantMessage (messageTextContentRef .get ().toString (), messageMetadataMapRef .get ()),
155
+ AssistantMessage finalAssistantMessage ;
156
+ List <ToolCall > collectedToolCalls = toolCallsRef .get ();
157
+
158
+ if (!CollectionUtils .isEmpty (collectedToolCalls )) {
159
+
160
+ finalAssistantMessage = new AssistantMessage (messageTextContentRef .get ().toString (),
161
+ messageMetadataMapRef .get (), collectedToolCalls );
162
+ }
163
+ else {
164
+ finalAssistantMessage = new AssistantMessage (messageTextContentRef .get ().toString (),
165
+ messageMetadataMapRef .get ());
166
+ }
167
+ onAggregationComplete .accept (new ChatResponse (List .of (new Generation (finalAssistantMessage ,
168
+
138
169
generationMetadataRef .get ())), chatResponseMetadata ));
139
170
140
171
messageTextContentRef .set (new StringBuilder ());
141
172
messageMetadataMapRef .set (new HashMap <>());
173
+ toolCallsRef .set (new ArrayList <>());
142
174
metadataIdRef .set ("" );
143
175
metadataModelRef .set ("" );
144
176
metadataUsagePromptTokensRef .set (0 );
0 commit comments