From c405c077f943c2f54980980fccb738a4dea912a3 Mon Sep 17 00:00:00 2001 From: Fu Cheng Date: Thu, 22 Aug 2024 12:23:38 +0800 Subject: [PATCH] Customize `SearchRequest` sent to `VectorStore` in `QuestionAnswerAdvisor` The new `customizeSearchRequest` method provides an extension point to further customize the `SearchRequest` used by `VectorStore` to search for similar documents. --- .../client/advisor/QuestionAnswerAdvisor.java | 23 +++++--- .../advisor/QuestionAnswerAdvisorTests.java | 53 ++++++++++++++++--- 2 files changed, 64 insertions(+), 12 deletions(-) diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java index dc8747fff5c..1de0e2d5ffd 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java @@ -103,6 +103,8 @@ public AdvisedRequest adviseRequest(AdvisedRequest request, Map .withQuery(request.userText()) .withFilterExpression(doGetFilterExpression(context)); + searchRequestToUse = customizeSearchRequest(searchRequestToUse, request, context); + // 2. Search for similar documents in the vector store. List documents = this.vectorStore.similaritySearch(searchRequestToUse); @@ -117,12 +119,7 @@ public AdvisedRequest adviseRequest(AdvisedRequest request, Map Map advisedUserParams = new HashMap<>(request.userParams()); advisedUserParams.put("question_answer_context", documentContext); - AdvisedRequest advisedRequest = AdvisedRequest.from(request) - .withUserText(advisedUserText) - .withUserParams(advisedUserParams) - .build(); - - return advisedRequest; + return AdvisedRequest.from(request).withUserText(advisedUserText).withUserParams(advisedUserParams).build(); } @Override @@ -151,4 +148,18 @@ protected Filter.Expression doGetFilterExpression(Map context) { } + /** + * Customize {@link SearchRequest} for each request. The returned + * {@link SearchRequest} will be used in + * {@link VectorStore#similaritySearch(SearchRequest)} to find similar documents. + * @param searchRequest the original {@link SearchRequest} + * @param request the {@link AdvisedRequest} representing the current request + * @param context the shared data between advisors in the chain + * @return the customized {@link SearchRequest} + */ + protected SearchRequest customizeSearchRequest(SearchRequest searchRequest, AdvisedRequest request, + Map context) { + return searchRequest; + } + } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisorTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisorTests.java index 76851330780..ac40656e4e1 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisorTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisorTests.java @@ -20,14 +20,17 @@ import static org.mockito.Mockito.when; import java.util.List; - +import java.util.Map; +import java.util.Objects; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.ai.chat.client.AdvisedRequest; import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.model.ChatModel; @@ -61,7 +64,9 @@ public class QuestionAnswerAdvisorTests { public void qaAdvisorWithDynamicFilterExpressions() { when(chatModel.call(promptCaptor.capture())) - .thenReturn(new ChatResponse(List.of(new Generation("Your answer is ZXY")))); + .thenReturn(new ChatResponse( + List.of(new Generation(new AssistantMessage( + "Your answer is ZXY"))))); when(vectorStore.similaritySearch(vectorSearchCaptor.capture())) .thenReturn(List.of(new Document("doc1"), new Document("doc2"))); @@ -74,20 +79,16 @@ public void qaAdvisorWithDynamicFilterExpressions() { .defaultAdvisors(qaAdvisor) .build(); - // @formatter:off var content = chatClient.prompt() .user("Please answer my question XYZ") .advisors(a -> a.param(QuestionAnswerAdvisor.FILTER_EXPRESSION, "type == 'Spring'")) .call() .content(); - //formatter:on assertThat(content).isEqualTo("Your answer is ZXY"); Message systemMessage = promptCaptor.getValue().getInstructions().get(0); - System.out.println(systemMessage.getContent()); - assertThat(systemMessage.getContent()).isEqualToIgnoringWhitespace(""" Default system text. """); @@ -111,4 +112,44 @@ public void qaAdvisorWithDynamicFilterExpressions() { assertThat(vectorSearchCaptor.getValue().getSimilarityThreshold()).isEqualTo(0.99d); assertThat(vectorSearchCaptor.getValue().getTopK()).isEqualTo(6); } + + @Test + public void qaAdvisorWithCustomizedSearchRequest() { + when(chatModel.call(promptCaptor.capture())) + .thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage( + "Your answer is XYZ"))))); + + when(vectorStore.similaritySearch(vectorSearchCaptor.capture())) + .thenReturn(List.of(new Document("doc1"), new Document("doc2"))); + + var qaAdvisor = new RewriteQueryQuestionAnswerAdvisor(vectorStore); + var chatClient = ChatClient.builder(chatModel) + .defaultAdvisors(qaAdvisor) + .build(); + var updatedUserQuery = "Please answer my question 123"; + var content = chatClient.prompt() + .user("Please answer my question XYZ") + .advisors(a -> a.param("qa_updated_user_query", updatedUserQuery)) + .call() + .content(); + + assertThat(content).isEqualTo("Your answer is XYZ"); + assertThat(vectorSearchCaptor.getValue().getQuery()).isEqualTo(updatedUserQuery); + } + + private static class RewriteQueryQuestionAnswerAdvisor extends QuestionAnswerAdvisor { + + public RewriteQueryQuestionAnswerAdvisor(VectorStore vectorStore) { + super(vectorStore); + } + + @Override + protected SearchRequest customizeSearchRequest(SearchRequest searchRequest, AdvisedRequest request, + Map context) { + return SearchRequest.from(searchRequest) + .withQuery(Objects.toString(context.getOrDefault("qa_updated_user_query", ""))); + } + + } + }