Skip to content

Commit f5f0212

Browse files
committed
Fix MessageChatMemoryAdvisor in streaming case.
- The fix overrides the adviseStream method to use ChatClientMessageAggregator to properly aggregate streaming chunks before storing them in memory, similar to how PromptChatMemoryAdvisor handles streaming responses - Added test Signed-off-by: Mark Pollack <mark.pollack@broadcom.com>
1 parent 348adee commit f5f0212

File tree

2 files changed

+71
-0
lines changed

2 files changed

+71
-0
lines changed

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/advisor/MessageChatMemoryAdvisorIT.java

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,4 +138,56 @@ void shouldHandleNonExistentConversation() {
138138
testHandleNonExistentConversation();
139139
}
140140

141+
@Test
142+
void shouldStoreCompleteContentInStreamingMode() {
143+
// Arrange
144+
String conversationId = "streaming-test-" + System.currentTimeMillis();
145+
ChatMemory chatMemory = MessageWindowChatMemory.builder()
146+
.chatMemoryRepository(new InMemoryChatMemoryRepository())
147+
.build();
148+
149+
// Create MessageChatMemoryAdvisor with the conversation ID
150+
MessageChatMemoryAdvisor advisor = MessageChatMemoryAdvisor.builder(chatMemory)
151+
.conversationId(conversationId)
152+
.build();
153+
154+
ChatClient chatClient = ChatClient.builder(chatModel).defaultAdvisors(advisor).build();
155+
156+
// Act - Use streaming API
157+
String userInput = "Tell me a short joke about programming";
158+
159+
// Collect the streaming responses
160+
List<String> streamedResponses = new ArrayList<>();
161+
chatClient.prompt()
162+
.user(userInput)
163+
.advisors(a -> a.param(ChatMemory.CONVERSATION_ID, conversationId))
164+
.stream()
165+
.content()
166+
.collectList()
167+
.block();
168+
169+
// Wait a moment to ensure all processing is complete
170+
try {
171+
Thread.sleep(500);
172+
}
173+
catch (InterruptedException e) {
174+
Thread.currentThread().interrupt();
175+
}
176+
177+
// Assert - Check that the memory contains the complete content
178+
List<Message> memoryMessages = chatMemory.get(conversationId);
179+
180+
// Should have at least 2 messages (user + assistant)
181+
assertThat(memoryMessages).hasSizeGreaterThanOrEqualTo(2);
182+
183+
// First message should be the user message
184+
assertThat(memoryMessages.get(0).getText()).isEqualTo(userInput);
185+
186+
// Last message should be the assistant's response and should have content
187+
Message assistantMessage = memoryMessages.get(memoryMessages.size() - 1);
188+
assertThat(assistantMessage.getText()).isNotEmpty();
189+
190+
logger.info("Assistant response stored in memory: {}", assistantMessage.getText());
191+
}
192+
141193
}

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,18 @@
2020
import java.util.List;
2121

2222
import org.springframework.util.Assert;
23+
import reactor.core.publisher.Flux;
24+
import reactor.core.publisher.Mono;
2325
import reactor.core.scheduler.Scheduler;
2426

27+
import org.springframework.ai.chat.client.ChatClientMessageAggregator;
2528
import org.springframework.ai.chat.client.ChatClientRequest;
2629
import org.springframework.ai.chat.client.ChatClientResponse;
2730
import org.springframework.ai.chat.client.advisor.api.Advisor;
2831
import org.springframework.ai.chat.client.advisor.api.AdvisorChain;
2932
import org.springframework.ai.chat.client.advisor.api.BaseAdvisor;
3033
import org.springframework.ai.chat.client.advisor.api.BaseChatMemoryAdvisor;
34+
import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;
3135
import org.springframework.ai.chat.memory.ChatMemory;
3236
import org.springframework.ai.chat.messages.Message;
3337
import org.springframework.ai.chat.messages.UserMessage;
@@ -109,6 +113,21 @@ public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorCh
109113
return chatClientResponse;
110114
}
111115

116+
@Override
117+
public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest,
118+
StreamAdvisorChain streamAdvisorChain) {
119+
// Get the scheduler from BaseAdvisor
120+
Scheduler scheduler = this.getScheduler();
121+
122+
// Process the request with the before method
123+
return Mono.just(chatClientRequest)
124+
.publishOn(scheduler)
125+
.map(request -> this.before(request, streamAdvisorChain))
126+
.flatMapMany(streamAdvisorChain::nextStream)
127+
.transform(flux -> new ChatClientMessageAggregator().aggregateChatClientResponse(flux,
128+
response -> this.after(response, streamAdvisorChain)));
129+
}
130+
112131
public static Builder builder(ChatMemory chatMemory) {
113132
return new Builder(chatMemory);
114133
}

0 commit comments

Comments
 (0)