Skip to content

Commit 54e5c07

Browse files
tzolovilayaperumalg
authored andcommitted
refactor: Move MessageAggregator to spring-ai-model module
Improve separation of concerns by keeping model-related functionality in the model module while maintaining client-specific functionality in the client module. - Move MessageAggregator class from spring-ai-client-chat to spring-ai-model module - Create new ChatClientMessageAggregator in spring-ai-client-chat module to handle client-specific aggregation - Extract client-specific aggregation logic from MessageAggregator to ChatClientMessageAggregator - Update references in advisor classes to use the new ChatClientMessageAggregator Signed-off-by: Christian Tzolov <christian.tzolov@broadcom.com>
1 parent a03e7cb commit 54e5c07

File tree

4 files changed

+66
-24
lines changed

4 files changed

+66
-24
lines changed
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Copyright 2023-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.chat.client;
18+
19+
import java.util.HashMap;
20+
import java.util.Map;
21+
import java.util.concurrent.atomic.AtomicReference;
22+
import java.util.function.Consumer;
23+
24+
import org.slf4j.Logger;
25+
import org.slf4j.LoggerFactory;
26+
import reactor.core.publisher.Flux;
27+
28+
import org.springframework.ai.chat.model.MessageAggregator;
29+
30+
/**
31+
* Helper that for streaming chat responses, aggregate the chat response messages into a
32+
* single AssistantMessage. Job is performed in parallel to the chat response processing.
33+
*
34+
* @author Christian Tzolov
35+
* @author Alexandros Pappas
36+
* @author Thomas Vitale
37+
* @since 1.0.0
38+
*/
39+
public class ChatClientMessageAggregator {
40+
41+
private static final Logger logger = LoggerFactory.getLogger(ChatClientMessageAggregator.class);
42+
43+
public Flux<ChatClientResponse> aggregateChatClientResponse(Flux<ChatClientResponse> chatClientResponses,
44+
Consumer<ChatClientResponse> aggregationHandler) {
45+
46+
AtomicReference<Map<String, Object>> context = new AtomicReference<>(new HashMap<>());
47+
48+
return new MessageAggregator().aggregate(chatClientResponses.mapNotNull(chatClientResponse -> {
49+
context.get().putAll(chatClientResponse.context());
50+
return chatClientResponse.chatResponse();
51+
}), aggregatedChatResponse -> {
52+
ChatClientResponse aggregatedChatClientResponse = ChatClientResponse.builder()
53+
.chatResponse(aggregatedChatResponse)
54+
.context(context.get())
55+
.build();
56+
aggregationHandler.accept(aggregatedChatClientResponse);
57+
}).map(chatResponse -> ChatClientResponse.builder().chatResponse(chatResponse).context(context.get()).build());
58+
}
59+
60+
}

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import reactor.core.scheduler.Scheduler;
2929
import reactor.core.scheduler.Schedulers;
3030

31+
import org.springframework.ai.chat.client.ChatClientMessageAggregator;
3132
import org.springframework.ai.chat.client.ChatClientRequest;
3233
import org.springframework.ai.chat.client.ChatClientResponse;
3334
import org.springframework.ai.chat.client.advisor.api.Advisor;
@@ -40,7 +41,6 @@
4041
import org.springframework.ai.chat.messages.MessageType;
4142
import org.springframework.ai.chat.messages.SystemMessage;
4243
import org.springframework.ai.chat.messages.UserMessage;
43-
import org.springframework.ai.chat.model.MessageAggregator;
4444
import org.springframework.ai.chat.prompt.PromptTemplate;
4545

4646
/**
@@ -172,7 +172,7 @@ public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest
172172
.publishOn(scheduler)
173173
.map(request -> this.before(request, streamAdvisorChain))
174174
.flatMapMany(streamAdvisorChain::nextStream)
175-
.transform(flux -> new MessageAggregator().aggregateChatClientResponse(flux,
175+
.transform(flux -> new ChatClientMessageAggregator().aggregateChatClientResponse(flux,
176176
response -> this.after(response, streamAdvisorChain)));
177177
}
178178

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/SimpleLoggerAdvisor.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,18 @@
1818

1919
import java.util.function.Function;
2020

21-
import reactor.core.publisher.Flux;
22-
2321
import org.slf4j.Logger;
2422
import org.slf4j.LoggerFactory;
23+
import reactor.core.publisher.Flux;
24+
25+
import org.springframework.ai.chat.client.ChatClientMessageAggregator;
2526
import org.springframework.ai.chat.client.ChatClientRequest;
2627
import org.springframework.ai.chat.client.ChatClientResponse;
2728
import org.springframework.ai.chat.client.advisor.api.CallAdvisor;
2829
import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain;
2930
import org.springframework.ai.chat.client.advisor.api.StreamAdvisor;
3031
import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;
3132
import org.springframework.ai.chat.model.ChatResponse;
32-
import org.springframework.ai.chat.model.MessageAggregator;
3333
import org.springframework.ai.model.ModelOptionsUtils;
3434
import org.springframework.lang.Nullable;
3535

@@ -85,7 +85,7 @@ public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest
8585

8686
Flux<ChatClientResponse> chatClientResponses = streamAdvisorChain.nextStream(chatClientRequest);
8787

88-
return new MessageAggregator().aggregateChatClientResponse(chatClientResponses, this::logResponse);
88+
return new ChatClientMessageAggregator().aggregateChatClientResponse(chatClientResponses, this::logResponse);
8989
}
9090

9191
private void logRequest(ChatClientRequest request) {

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java renamed to spring-ai-model/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424

2525
import org.slf4j.Logger;
2626
import org.slf4j.LoggerFactory;
27-
import org.springframework.ai.chat.client.ChatClientResponse;
2827
import reactor.core.publisher.Flux;
2928

3029
import org.springframework.ai.chat.messages.AssistantMessage;
@@ -49,23 +48,6 @@ public class MessageAggregator {
4948

5049
private static final Logger logger = LoggerFactory.getLogger(MessageAggregator.class);
5150

52-
public Flux<ChatClientResponse> aggregateChatClientResponse(Flux<ChatClientResponse> chatClientResponses,
53-
Consumer<ChatClientResponse> aggregationHandler) {
54-
55-
AtomicReference<Map<String, Object>> context = new AtomicReference<>(new HashMap<>());
56-
57-
return new MessageAggregator().aggregate(chatClientResponses.mapNotNull(chatClientResponse -> {
58-
context.get().putAll(chatClientResponse.context());
59-
return chatClientResponse.chatResponse();
60-
}), aggregatedChatResponse -> {
61-
ChatClientResponse aggregatedChatClientResponse = ChatClientResponse.builder()
62-
.chatResponse(aggregatedChatResponse)
63-
.context(context.get())
64-
.build();
65-
aggregationHandler.accept(aggregatedChatClientResponse);
66-
}).map(chatResponse -> ChatClientResponse.builder().chatResponse(chatResponse).context(context.get()).build());
67-
}
68-
6951
public Flux<ChatResponse> aggregate(Flux<ChatResponse> fluxChatResponse,
7052
Consumer<ChatResponse> onAggregationComplete) {
7153

0 commit comments

Comments
 (0)