Skip to content

Commit 2f5d53f

Browse files
markpollackilayaperumalg
authored andcommitted
Fix VectorStoreChatMemoryAdvisor streaming bug
- Override adviseStream method in VectorStoreChatMemoryAdvisor to properly handle streaming responses - Add tests to verify the fix works with both normal and problematic streaming scenarios Fixes #3152 Signed-off-by: Mark Pollack <mark.pollack@broadcom.com>
1 parent fa8f246 commit 2f5d53f

File tree

4 files changed

+288
-0
lines changed

4 files changed

+288
-0
lines changed

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/advisor/AbstractChatMemoryAdvisorIT.java

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
2323
import org.slf4j.Logger;
2424
import org.slf4j.LoggerFactory;
25+
import reactor.core.publisher.Flux;
2526

2627
import org.springframework.ai.chat.client.ChatClient;
2728
import org.springframework.ai.chat.client.advisor.api.BaseChatMemoryAdvisor;
@@ -412,4 +413,74 @@ protected void testHandleMultipleMessagesInReactiveMode() {
412413
assertThat(memoryMessages.get(6).getText()).isEqualTo("What is my name and where do I live?");
413414
}
414415

416+
/**
417+
* Tests that the advisor correctly handles streaming responses and updates the
418+
* memory. This verifies that the adviseStream method in chat memory advisors is
419+
* working correctly.
420+
*/
421+
protected void testStreamingWithChatMemory() {
422+
// Arrange
423+
String conversationId = "streaming-conversation-" + System.currentTimeMillis();
424+
ChatMemory chatMemory = MessageWindowChatMemory.builder()
425+
.chatMemoryRepository(new InMemoryChatMemoryRepository())
426+
.build();
427+
428+
// Create advisor with the conversation ID
429+
var advisor = createAdvisor(chatMemory);
430+
431+
ChatClient chatClient = ChatClient.builder(this.chatModel).defaultAdvisors(advisor).build();
432+
433+
// Act - Send a message using streaming
434+
String initialQuestion = "My name is David and I live in Seattle.";
435+
436+
// Collect all streaming chunks
437+
List<String> streamingChunks = new ArrayList<>();
438+
Flux<String> responseStream = chatClient.prompt()
439+
.user(initialQuestion)
440+
.advisors(a -> a.param(ChatMemory.CONVERSATION_ID, conversationId))
441+
.stream()
442+
.content();
443+
444+
// Block and collect all streaming chunks
445+
responseStream.doOnNext(streamingChunks::add).blockLast();
446+
447+
// Join all chunks to get the complete response
448+
String completeResponse = String.join("", streamingChunks);
449+
450+
logger.info("Streaming response: {}", completeResponse);
451+
452+
// Verify memory contains the initial question and the response
453+
List<Message> memoryMessages = chatMemory.get(conversationId);
454+
assertThat(memoryMessages).hasSize(2); // 1 user message + 1 assistant response
455+
assertThat(memoryMessages.get(0).getText()).isEqualTo(initialQuestion);
456+
457+
// Send a follow-up question using streaming
458+
String followUpQuestion = "Where do I live?";
459+
460+
// Collect all streaming chunks for the follow-up
461+
List<String> followUpStreamingChunks = new ArrayList<>();
462+
Flux<String> followUpResponseStream = chatClient.prompt()
463+
.user(followUpQuestion)
464+
.advisors(a -> a.param(ChatMemory.CONVERSATION_ID, conversationId))
465+
.stream()
466+
.content();
467+
468+
// Block and collect all streaming chunks
469+
followUpResponseStream.doOnNext(followUpStreamingChunks::add).blockLast();
470+
471+
// Join all chunks to get the complete follow-up response
472+
String followUpCompleteResponse = String.join("", followUpStreamingChunks);
473+
474+
logger.info("Follow-up streaming response: {}", followUpCompleteResponse);
475+
476+
// Verify the follow-up response contains the location
477+
assertThat(followUpCompleteResponse).containsIgnoringCase("Seattle");
478+
479+
// Verify memory now contains all messages
480+
memoryMessages = chatMemory.get(conversationId);
481+
assertThat(memoryMessages).hasSize(4); // 2 user messages + 2 assistant responses
482+
assertThat(memoryMessages.get(0).getText()).isEqualTo(initialQuestion);
483+
assertThat(memoryMessages.get(2).getText()).isEqualTo(followUpQuestion);
484+
}
485+
415486
}

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/advisor/MessageChatMemoryAdvisorIT.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,4 +190,9 @@ void shouldStoreCompleteContentInStreamingMode() {
190190
logger.info("Assistant response stored in memory: {}", assistantMessage.getText());
191191
}
192192

193+
@Test
194+
void shouldHandleStreamingWithChatMemory() {
195+
testStreamingWithChatMemory();
196+
}
197+
193198
}

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/advisor/PromptChatMemoryAdvisorIT.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,4 +135,9 @@ void shouldHandleMultipleUserMessagesInPrompt() {
135135
testMultipleUserMessagesInPrompt();
136136
}
137137

138+
@Test
139+
void shouldHandleStreamingWithChatMemory() {
140+
testStreamingWithChatMemory();
141+
}
142+
138143
}

vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/pgvector/PgVectorStoreWithChatMemoryAdvisorIT.java

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.testcontainers.containers.PostgreSQLContainer;
3131
import org.testcontainers.junit.jupiter.Container;
3232
import org.testcontainers.junit.jupiter.Testcontainers;
33+
import reactor.core.publisher.Flux;
3334

3435
import org.springframework.ai.chat.client.ChatClient;
3536
import org.springframework.ai.chat.client.advisor.vectorstore.VectorStoreChatMemoryAdvisor;
@@ -42,6 +43,7 @@
4243
import org.springframework.ai.chat.prompt.Prompt;
4344
import org.springframework.ai.document.Document;
4445
import org.springframework.ai.embedding.EmbeddingModel;
46+
import org.springframework.ai.vectorstore.SearchRequest;
4547
import org.springframework.jdbc.core.JdbcTemplate;
4648

4749
import static org.assertj.core.api.Assertions.assertThat;
@@ -117,6 +119,78 @@ private static void verifyRequestHasBeenAdvisedWithMessagesFromVectorStore(ChatM
117119
""");
118120
}
119121

122+
/**
123+
* Create a mock ChatModel that supports streaming responses for testing.
124+
* @return A mock ChatModel that returns a predefined streaming response
125+
*/
126+
private static @NotNull ChatModel chatModelWithStreamingSupport() {
127+
ChatModel chatModel = mock(ChatModel.class);
128+
129+
// Mock the regular call method
130+
ArgumentCaptor<Prompt> argumentCaptor = ArgumentCaptor.forClass(Prompt.class);
131+
ChatResponse chatResponse = new ChatResponse(List.of(new Generation(new AssistantMessage("""
132+
Why don't scientists trust atoms?
133+
Because they make up everything!
134+
"""))));
135+
given(chatModel.call(argumentCaptor.capture())).willReturn(chatResponse);
136+
137+
// Mock the streaming method
138+
ArgumentCaptor<Prompt> streamArgumentCaptor = ArgumentCaptor.forClass(Prompt.class);
139+
Flux<ChatResponse> streamingResponse = Flux.just(
140+
new ChatResponse(List.of(new Generation(new AssistantMessage("Why")))),
141+
new ChatResponse(List.of(new Generation(new AssistantMessage(" don't")))),
142+
new ChatResponse(List.of(new Generation(new AssistantMessage(" scientists")))),
143+
new ChatResponse(List.of(new Generation(new AssistantMessage(" trust")))),
144+
new ChatResponse(List.of(new Generation(new AssistantMessage(" atoms?")))),
145+
new ChatResponse(List.of(new Generation(new AssistantMessage("\nBecause")))),
146+
new ChatResponse(List.of(new Generation(new AssistantMessage(" they")))),
147+
new ChatResponse(List.of(new Generation(new AssistantMessage(" make")))),
148+
new ChatResponse(List.of(new Generation(new AssistantMessage(" up")))),
149+
new ChatResponse(List.of(new Generation(new AssistantMessage(" everything!")))));
150+
given(chatModel.stream(streamArgumentCaptor.capture())).willReturn(streamingResponse);
151+
152+
return chatModel;
153+
}
154+
155+
/**
156+
* Create a mock ChatModel that simulates the problematic streaming behavior. This
157+
* mock includes a final empty message that triggers the bug in
158+
* VectorStoreChatMemoryAdvisor.
159+
* @return A mock ChatModel that returns a problematic streaming response
160+
*/
161+
private static @NotNull ChatModel chatModelWithProblematicStreamingBehavior() {
162+
ChatModel chatModel = mock(ChatModel.class);
163+
164+
// Mock the regular call method
165+
ArgumentCaptor<Prompt> argumentCaptor = ArgumentCaptor.forClass(Prompt.class);
166+
ChatResponse chatResponse = new ChatResponse(List.of(new Generation(new AssistantMessage("""
167+
Why don't scientists trust atoms?
168+
Because they make up everything!
169+
"""))));
170+
given(chatModel.call(argumentCaptor.capture())).willReturn(chatResponse);
171+
172+
// Mock the streaming method with a problematic final message (empty content)
173+
// This simulates the real-world condition that triggers the bug
174+
ArgumentCaptor<Prompt> streamArgumentCaptor = ArgumentCaptor.forClass(Prompt.class);
175+
Flux<ChatResponse> streamingResponse = Flux.just(
176+
new ChatResponse(List.of(new Generation(new AssistantMessage("Why")))),
177+
new ChatResponse(List.of(new Generation(new AssistantMessage(" don't")))),
178+
new ChatResponse(List.of(new Generation(new AssistantMessage(" scientists")))),
179+
new ChatResponse(List.of(new Generation(new AssistantMessage(" trust")))),
180+
new ChatResponse(List.of(new Generation(new AssistantMessage(" atoms?")))),
181+
new ChatResponse(List.of(new Generation(new AssistantMessage("\nBecause")))),
182+
new ChatResponse(List.of(new Generation(new AssistantMessage(" they")))),
183+
new ChatResponse(List.of(new Generation(new AssistantMessage(" make")))),
184+
new ChatResponse(List.of(new Generation(new AssistantMessage(" up")))),
185+
new ChatResponse(List.of(new Generation(new AssistantMessage(" everything!")))),
186+
// This final empty message triggers the bug in
187+
// VectorStoreChatMemoryAdvisor
188+
new ChatResponse(List.of(new Generation(new AssistantMessage("")))));
189+
given(chatModel.stream(streamArgumentCaptor.capture())).willReturn(streamingResponse);
190+
191+
return chatModel;
192+
}
193+
120194
/**
121195
* Test that chats with {@link VectorStoreChatMemoryAdvisor} get advised with similar
122196
* messages from the (gp)vector store.
@@ -182,6 +256,139 @@ void advisedChatShouldHaveSimilarMessagesFromVectorStoreWhenSystemMessageProvide
182256
""");
183257
}
184258

259+
/**
260+
* Test that streaming chats with {@link VectorStoreChatMemoryAdvisor} get advised
261+
* with similar messages from the vector store and properly handle streaming
262+
* responses.
263+
*
264+
* This test verifies that the fix for the bug reported in
265+
* https://github.com/spring-projects/spring-ai/issues/3152 works correctly. The
266+
* VectorStoreChatMemoryAdvisor now properly handles streaming responses and saves the
267+
* assistant's messages to the vector store.
268+
*/
269+
@Test
270+
void advisedStreamingChatShouldHaveSimilarMessagesFromVectorStore() throws Exception {
271+
// Create a ChatModel with streaming support
272+
ChatModel chatModel = chatModelWithStreamingSupport();
273+
274+
// Create the embedding model
275+
EmbeddingModel embeddingModel = embeddingNModelShouldAlwaysReturnFakedEmbed();
276+
277+
// Create and initialize the vector store
278+
PgVectorStore store = createPgVectorStoreUsingTestcontainer(embeddingModel);
279+
String conversationId = UUID.randomUUID().toString();
280+
initStore(store, conversationId);
281+
282+
// Create a chat client with the VectorStoreChatMemoryAdvisor
283+
ChatClient chatClient = ChatClient.builder(chatModel).build();
284+
285+
// Execute a streaming chat request
286+
Flux<String> responseStream = chatClient.prompt()
287+
.user("joke")
288+
.advisors(a -> a.advisors(VectorStoreChatMemoryAdvisor.builder(store).build())
289+
.param(ChatMemory.CONVERSATION_ID, conversationId))
290+
.stream()
291+
.content();
292+
293+
// Collect all streaming chunks
294+
List<String> streamingChunks = responseStream.collectList().block();
295+
296+
// Verify the streaming response
297+
assertThat(streamingChunks).isNotNull();
298+
String completeResponse = String.join("", streamingChunks);
299+
assertThat(completeResponse).contains("scientists", "atoms", "everything");
300+
301+
// Verify the request was properly advised with vector store content
302+
ArgumentCaptor<Prompt> promptCaptor = ArgumentCaptor.forClass(Prompt.class);
303+
verify(chatModel).stream(promptCaptor.capture());
304+
Prompt capturedPrompt = promptCaptor.getValue();
305+
assertThat(capturedPrompt.getInstructions().get(0)).isInstanceOf(SystemMessage.class);
306+
assertThat(capturedPrompt.getInstructions().get(0).getText()).isEqualToIgnoringWhitespace("""
307+
308+
Use the long term conversation memory from the LONG_TERM_MEMORY section to provide accurate answers.
309+
310+
---------------------
311+
LONG_TERM_MEMORY:
312+
Tell me a good joke
313+
Tell me a bad joke
314+
---------------------
315+
""");
316+
317+
// Verify that the assistant's response was properly added to the vector store
318+
// after
319+
// streaming completed
320+
// This verifies that the fix for the adviseStream implementation works correctly
321+
String filter = "conversationId=='" + conversationId + "' && messageType=='ASSISTANT'";
322+
var searchRequest = SearchRequest.builder().query("atoms").filterExpression(filter).build();
323+
324+
List<Document> assistantDocuments = store.similaritySearch(searchRequest);
325+
326+
// With our fix, the assistant's response should be saved to the vector store
327+
assertThat(assistantDocuments).isNotEmpty();
328+
assertThat(assistantDocuments.get(0).getText()).contains("scientists", "atoms", "everything");
329+
}
330+
331+
/**
332+
* Test that verifies the fix for the bug reported in
333+
* https://github.com/spring-projects/spring-ai/issues/3152. The
334+
* VectorStoreChatMemoryAdvisor now properly handles streaming responses with empty
335+
* messages by using ChatClientMessageAggregator to aggregate messages before calling
336+
* the after method.
337+
*/
338+
@Test
339+
void vectorStoreChatMemoryAdvisorShouldHandleEmptyMessagesInStream() throws Exception {
340+
// Create a ChatModel with problematic streaming behavior
341+
ChatModel chatModel = chatModelWithProblematicStreamingBehavior();
342+
343+
// Create the embedding model
344+
EmbeddingModel embeddingModel = embeddingNModelShouldAlwaysReturnFakedEmbed();
345+
346+
// Create and initialize the vector store
347+
PgVectorStore store = createPgVectorStoreUsingTestcontainer(embeddingModel);
348+
String conversationId = UUID.randomUUID().toString();
349+
initStore(store, conversationId);
350+
351+
// Create a chat client with the VectorStoreChatMemoryAdvisor
352+
ChatClient chatClient = ChatClient.builder(chatModel).build();
353+
354+
// Execute a streaming chat request
355+
// This should now succeed with our fix
356+
Flux<String> responseStream = chatClient.prompt()
357+
.user("joke")
358+
.advisors(a -> a.advisors(VectorStoreChatMemoryAdvisor.builder(store).build())
359+
.param(ChatMemory.CONVERSATION_ID, conversationId))
360+
.stream()
361+
.content();
362+
363+
// Collect all streaming chunks - this should no longer throw an exception
364+
List<String> streamingChunks = responseStream.collectList().block();
365+
366+
// Verify the streaming response
367+
assertThat(streamingChunks).isNotNull();
368+
String completeResponse = String.join("", streamingChunks);
369+
assertThat(completeResponse).contains("scientists", "atoms", "everything");
370+
371+
// Verify that the assistant's response was properly added to the vector store
372+
// This verifies that our fix works correctly
373+
String filter = "conversationId=='" + conversationId + "' && messageType=='ASSISTANT'";
374+
var searchRequest = SearchRequest.builder().query("atoms").filterExpression(filter).build();
375+
376+
List<Document> assistantDocuments = store.similaritySearch(searchRequest);
377+
assertThat(assistantDocuments).isNotEmpty();
378+
assertThat(assistantDocuments.get(0).getText()).contains("scientists", "atoms", "everything");
379+
}
380+
381+
/**
382+
* Helper method to get the root cause of an exception
383+
*/
384+
private Throwable getRootCause(Throwable throwable) {
385+
Throwable cause = throwable;
386+
while (cause.getCause() != null && cause.getCause() != cause) {
387+
cause = cause.getCause();
388+
}
389+
return cause;
390+
}
391+
185392
@SuppressWarnings("unchecked")
186393
private @NotNull EmbeddingModel embeddingNModelShouldAlwaysReturnFakedEmbed() {
187394
EmbeddingModel embeddingModel = mock(EmbeddingModel.class);

0 commit comments

Comments
 (0)