|
30 | 30 | import org.testcontainers.containers.PostgreSQLContainer;
|
31 | 31 | import org.testcontainers.junit.jupiter.Container;
|
32 | 32 | import org.testcontainers.junit.jupiter.Testcontainers;
|
| 33 | +import reactor.core.publisher.Flux; |
33 | 34 |
|
34 | 35 | import org.springframework.ai.chat.client.ChatClient;
|
35 | 36 | import org.springframework.ai.chat.client.advisor.vectorstore.VectorStoreChatMemoryAdvisor;
|
|
42 | 43 | import org.springframework.ai.chat.prompt.Prompt;
|
43 | 44 | import org.springframework.ai.document.Document;
|
44 | 45 | import org.springframework.ai.embedding.EmbeddingModel;
|
| 46 | +import org.springframework.ai.vectorstore.SearchRequest; |
45 | 47 | import org.springframework.jdbc.core.JdbcTemplate;
|
46 | 48 |
|
47 | 49 | import static org.assertj.core.api.Assertions.assertThat;
|
@@ -117,6 +119,78 @@ private static void verifyRequestHasBeenAdvisedWithMessagesFromVectorStore(ChatM
|
117 | 119 | """);
|
118 | 120 | }
|
119 | 121 |
|
| 122 | + /** |
| 123 | + * Create a mock ChatModel that supports streaming responses for testing. |
| 124 | + * @return A mock ChatModel that returns a predefined streaming response |
| 125 | + */ |
| 126 | + private static @NotNull ChatModel chatModelWithStreamingSupport() { |
| 127 | + ChatModel chatModel = mock(ChatModel.class); |
| 128 | + |
| 129 | + // Mock the regular call method |
| 130 | + ArgumentCaptor<Prompt> argumentCaptor = ArgumentCaptor.forClass(Prompt.class); |
| 131 | + ChatResponse chatResponse = new ChatResponse(List.of(new Generation(new AssistantMessage(""" |
| 132 | + Why don't scientists trust atoms? |
| 133 | + Because they make up everything! |
| 134 | + """)))); |
| 135 | + given(chatModel.call(argumentCaptor.capture())).willReturn(chatResponse); |
| 136 | + |
| 137 | + // Mock the streaming method |
| 138 | + ArgumentCaptor<Prompt> streamArgumentCaptor = ArgumentCaptor.forClass(Prompt.class); |
| 139 | + Flux<ChatResponse> streamingResponse = Flux.just( |
| 140 | + new ChatResponse(List.of(new Generation(new AssistantMessage("Why")))), |
| 141 | + new ChatResponse(List.of(new Generation(new AssistantMessage(" don't")))), |
| 142 | + new ChatResponse(List.of(new Generation(new AssistantMessage(" scientists")))), |
| 143 | + new ChatResponse(List.of(new Generation(new AssistantMessage(" trust")))), |
| 144 | + new ChatResponse(List.of(new Generation(new AssistantMessage(" atoms?")))), |
| 145 | + new ChatResponse(List.of(new Generation(new AssistantMessage("\nBecause")))), |
| 146 | + new ChatResponse(List.of(new Generation(new AssistantMessage(" they")))), |
| 147 | + new ChatResponse(List.of(new Generation(new AssistantMessage(" make")))), |
| 148 | + new ChatResponse(List.of(new Generation(new AssistantMessage(" up")))), |
| 149 | + new ChatResponse(List.of(new Generation(new AssistantMessage(" everything!"))))); |
| 150 | + given(chatModel.stream(streamArgumentCaptor.capture())).willReturn(streamingResponse); |
| 151 | + |
| 152 | + return chatModel; |
| 153 | + } |
| 154 | + |
| 155 | + /** |
| 156 | + * Create a mock ChatModel that simulates the problematic streaming behavior. This |
| 157 | + * mock includes a final empty message that triggers the bug in |
| 158 | + * VectorStoreChatMemoryAdvisor. |
| 159 | + * @return A mock ChatModel that returns a problematic streaming response |
| 160 | + */ |
| 161 | + private static @NotNull ChatModel chatModelWithProblematicStreamingBehavior() { |
| 162 | + ChatModel chatModel = mock(ChatModel.class); |
| 163 | + |
| 164 | + // Mock the regular call method |
| 165 | + ArgumentCaptor<Prompt> argumentCaptor = ArgumentCaptor.forClass(Prompt.class); |
| 166 | + ChatResponse chatResponse = new ChatResponse(List.of(new Generation(new AssistantMessage(""" |
| 167 | + Why don't scientists trust atoms? |
| 168 | + Because they make up everything! |
| 169 | + """)))); |
| 170 | + given(chatModel.call(argumentCaptor.capture())).willReturn(chatResponse); |
| 171 | + |
| 172 | + // Mock the streaming method with a problematic final message (empty content) |
| 173 | + // This simulates the real-world condition that triggers the bug |
| 174 | + ArgumentCaptor<Prompt> streamArgumentCaptor = ArgumentCaptor.forClass(Prompt.class); |
| 175 | + Flux<ChatResponse> streamingResponse = Flux.just( |
| 176 | + new ChatResponse(List.of(new Generation(new AssistantMessage("Why")))), |
| 177 | + new ChatResponse(List.of(new Generation(new AssistantMessage(" don't")))), |
| 178 | + new ChatResponse(List.of(new Generation(new AssistantMessage(" scientists")))), |
| 179 | + new ChatResponse(List.of(new Generation(new AssistantMessage(" trust")))), |
| 180 | + new ChatResponse(List.of(new Generation(new AssistantMessage(" atoms?")))), |
| 181 | + new ChatResponse(List.of(new Generation(new AssistantMessage("\nBecause")))), |
| 182 | + new ChatResponse(List.of(new Generation(new AssistantMessage(" they")))), |
| 183 | + new ChatResponse(List.of(new Generation(new AssistantMessage(" make")))), |
| 184 | + new ChatResponse(List.of(new Generation(new AssistantMessage(" up")))), |
| 185 | + new ChatResponse(List.of(new Generation(new AssistantMessage(" everything!")))), |
| 186 | + // This final empty message triggers the bug in |
| 187 | + // VectorStoreChatMemoryAdvisor |
| 188 | + new ChatResponse(List.of(new Generation(new AssistantMessage(""))))); |
| 189 | + given(chatModel.stream(streamArgumentCaptor.capture())).willReturn(streamingResponse); |
| 190 | + |
| 191 | + return chatModel; |
| 192 | + } |
| 193 | + |
120 | 194 | /**
|
121 | 195 | * Test that chats with {@link VectorStoreChatMemoryAdvisor} get advised with similar
|
122 | 196 | * messages from the (gp)vector store.
|
@@ -182,6 +256,139 @@ void advisedChatShouldHaveSimilarMessagesFromVectorStoreWhenSystemMessageProvide
|
182 | 256 | """);
|
183 | 257 | }
|
184 | 258 |
|
| 259 | + /** |
| 260 | + * Test that streaming chats with {@link VectorStoreChatMemoryAdvisor} get advised |
| 261 | + * with similar messages from the vector store and properly handle streaming |
| 262 | + * responses. |
| 263 | + * |
| 264 | + * This test verifies that the fix for the bug reported in |
| 265 | + * https://github.com/spring-projects/spring-ai/issues/3152 works correctly. The |
| 266 | + * VectorStoreChatMemoryAdvisor now properly handles streaming responses and saves the |
| 267 | + * assistant's messages to the vector store. |
| 268 | + */ |
| 269 | + @Test |
| 270 | + void advisedStreamingChatShouldHaveSimilarMessagesFromVectorStore() throws Exception { |
| 271 | + // Create a ChatModel with streaming support |
| 272 | + ChatModel chatModel = chatModelWithStreamingSupport(); |
| 273 | + |
| 274 | + // Create the embedding model |
| 275 | + EmbeddingModel embeddingModel = embeddingNModelShouldAlwaysReturnFakedEmbed(); |
| 276 | + |
| 277 | + // Create and initialize the vector store |
| 278 | + PgVectorStore store = createPgVectorStoreUsingTestcontainer(embeddingModel); |
| 279 | + String conversationId = UUID.randomUUID().toString(); |
| 280 | + initStore(store, conversationId); |
| 281 | + |
| 282 | + // Create a chat client with the VectorStoreChatMemoryAdvisor |
| 283 | + ChatClient chatClient = ChatClient.builder(chatModel).build(); |
| 284 | + |
| 285 | + // Execute a streaming chat request |
| 286 | + Flux<String> responseStream = chatClient.prompt() |
| 287 | + .user("joke") |
| 288 | + .advisors(a -> a.advisors(VectorStoreChatMemoryAdvisor.builder(store).build()) |
| 289 | + .param(ChatMemory.CONVERSATION_ID, conversationId)) |
| 290 | + .stream() |
| 291 | + .content(); |
| 292 | + |
| 293 | + // Collect all streaming chunks |
| 294 | + List<String> streamingChunks = responseStream.collectList().block(); |
| 295 | + |
| 296 | + // Verify the streaming response |
| 297 | + assertThat(streamingChunks).isNotNull(); |
| 298 | + String completeResponse = String.join("", streamingChunks); |
| 299 | + assertThat(completeResponse).contains("scientists", "atoms", "everything"); |
| 300 | + |
| 301 | + // Verify the request was properly advised with vector store content |
| 302 | + ArgumentCaptor<Prompt> promptCaptor = ArgumentCaptor.forClass(Prompt.class); |
| 303 | + verify(chatModel).stream(promptCaptor.capture()); |
| 304 | + Prompt capturedPrompt = promptCaptor.getValue(); |
| 305 | + assertThat(capturedPrompt.getInstructions().get(0)).isInstanceOf(SystemMessage.class); |
| 306 | + assertThat(capturedPrompt.getInstructions().get(0).getText()).isEqualToIgnoringWhitespace(""" |
| 307 | +
|
| 308 | + Use the long term conversation memory from the LONG_TERM_MEMORY section to provide accurate answers. |
| 309 | +
|
| 310 | + --------------------- |
| 311 | + LONG_TERM_MEMORY: |
| 312 | + Tell me a good joke |
| 313 | + Tell me a bad joke |
| 314 | + --------------------- |
| 315 | + """); |
| 316 | + |
| 317 | + // Verify that the assistant's response was properly added to the vector store |
| 318 | + // after |
| 319 | + // streaming completed |
| 320 | + // This verifies that the fix for the adviseStream implementation works correctly |
| 321 | + String filter = "conversationId=='" + conversationId + "' && messageType=='ASSISTANT'"; |
| 322 | + var searchRequest = SearchRequest.builder().query("atoms").filterExpression(filter).build(); |
| 323 | + |
| 324 | + List<Document> assistantDocuments = store.similaritySearch(searchRequest); |
| 325 | + |
| 326 | + // With our fix, the assistant's response should be saved to the vector store |
| 327 | + assertThat(assistantDocuments).isNotEmpty(); |
| 328 | + assertThat(assistantDocuments.get(0).getText()).contains("scientists", "atoms", "everything"); |
| 329 | + } |
| 330 | + |
| 331 | + /** |
| 332 | + * Test that verifies the fix for the bug reported in |
| 333 | + * https://github.com/spring-projects/spring-ai/issues/3152. The |
| 334 | + * VectorStoreChatMemoryAdvisor now properly handles streaming responses with empty |
| 335 | + * messages by using ChatClientMessageAggregator to aggregate messages before calling |
| 336 | + * the after method. |
| 337 | + */ |
| 338 | + @Test |
| 339 | + void vectorStoreChatMemoryAdvisorShouldHandleEmptyMessagesInStream() throws Exception { |
| 340 | + // Create a ChatModel with problematic streaming behavior |
| 341 | + ChatModel chatModel = chatModelWithProblematicStreamingBehavior(); |
| 342 | + |
| 343 | + // Create the embedding model |
| 344 | + EmbeddingModel embeddingModel = embeddingNModelShouldAlwaysReturnFakedEmbed(); |
| 345 | + |
| 346 | + // Create and initialize the vector store |
| 347 | + PgVectorStore store = createPgVectorStoreUsingTestcontainer(embeddingModel); |
| 348 | + String conversationId = UUID.randomUUID().toString(); |
| 349 | + initStore(store, conversationId); |
| 350 | + |
| 351 | + // Create a chat client with the VectorStoreChatMemoryAdvisor |
| 352 | + ChatClient chatClient = ChatClient.builder(chatModel).build(); |
| 353 | + |
| 354 | + // Execute a streaming chat request |
| 355 | + // This should now succeed with our fix |
| 356 | + Flux<String> responseStream = chatClient.prompt() |
| 357 | + .user("joke") |
| 358 | + .advisors(a -> a.advisors(VectorStoreChatMemoryAdvisor.builder(store).build()) |
| 359 | + .param(ChatMemory.CONVERSATION_ID, conversationId)) |
| 360 | + .stream() |
| 361 | + .content(); |
| 362 | + |
| 363 | + // Collect all streaming chunks - this should no longer throw an exception |
| 364 | + List<String> streamingChunks = responseStream.collectList().block(); |
| 365 | + |
| 366 | + // Verify the streaming response |
| 367 | + assertThat(streamingChunks).isNotNull(); |
| 368 | + String completeResponse = String.join("", streamingChunks); |
| 369 | + assertThat(completeResponse).contains("scientists", "atoms", "everything"); |
| 370 | + |
| 371 | + // Verify that the assistant's response was properly added to the vector store |
| 372 | + // This verifies that our fix works correctly |
| 373 | + String filter = "conversationId=='" + conversationId + "' && messageType=='ASSISTANT'"; |
| 374 | + var searchRequest = SearchRequest.builder().query("atoms").filterExpression(filter).build(); |
| 375 | + |
| 376 | + List<Document> assistantDocuments = store.similaritySearch(searchRequest); |
| 377 | + assertThat(assistantDocuments).isNotEmpty(); |
| 378 | + assertThat(assistantDocuments.get(0).getText()).contains("scientists", "atoms", "everything"); |
| 379 | + } |
| 380 | + |
| 381 | + /** |
| 382 | + * Helper method to get the root cause of an exception |
| 383 | + */ |
| 384 | + private Throwable getRootCause(Throwable throwable) { |
| 385 | + Throwable cause = throwable; |
| 386 | + while (cause.getCause() != null && cause.getCause() != cause) { |
| 387 | + cause = cause.getCause(); |
| 388 | + } |
| 389 | + return cause; |
| 390 | + } |
| 391 | + |
185 | 392 | @SuppressWarnings("unchecked")
|
186 | 393 | private @NotNull EmbeddingModel embeddingNModelShouldAlwaysReturnFakedEmbed() {
|
187 | 394 | EmbeddingModel embeddingModel = mock(EmbeddingModel.class);
|
|
0 commit comments