Skip to content

Customize SearchRequest sent to VectorStore in QuestionAnswerAdvisor #1264

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ public AdvisedRequest adviseRequest(AdvisedRequest request, Map<String, Object>
.withQuery(request.userText())
.withFilterExpression(doGetFilterExpression(context));

searchRequestToUse = customizeSearchRequest(searchRequestToUse, request, context);

// 2. Search for similar documents in the vector store.
List<Document> documents = this.vectorStore.similaritySearch(searchRequestToUse);

Expand All @@ -117,12 +119,7 @@ public AdvisedRequest adviseRequest(AdvisedRequest request, Map<String, Object>
Map<String, Object> advisedUserParams = new HashMap<>(request.userParams());
advisedUserParams.put("question_answer_context", documentContext);

AdvisedRequest advisedRequest = AdvisedRequest.from(request)
.withUserText(advisedUserText)
.withUserParams(advisedUserParams)
.build();

return advisedRequest;
return AdvisedRequest.from(request).withUserText(advisedUserText).withUserParams(advisedUserParams).build();
}

@Override
Expand Down Expand Up @@ -151,4 +148,18 @@ protected Filter.Expression doGetFilterExpression(Map<String, Object> context) {

}

/**
* Customize {@link SearchRequest} for each request. The returned
* {@link SearchRequest} will be used in
* {@link VectorStore#similaritySearch(SearchRequest)} to find similar documents.
* @param searchRequest the original {@link SearchRequest}
* @param request the {@link AdvisedRequest} representing the current request
* @param context the shared data between advisors in the chain
* @return the customized {@link SearchRequest}
*/
protected SearchRequest customizeSearchRequest(SearchRequest searchRequest, AdvisedRequest request,
Map<String, Object> context) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it necessary to inherit and reimplement the implementation to return a SearchRequest method, even though QuestionAndAnswerAdvisor is already receiving a SearchRequest object through the constructor?

return searchRequest;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,17 @@
import static org.mockito.Mockito.when;

import java.util.List;

import java.util.Map;
import java.util.Objects;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.ai.chat.client.AdvisedRequest;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.model.ChatModel;
Expand Down Expand Up @@ -61,7 +64,9 @@ public class QuestionAnswerAdvisorTests {
public void qaAdvisorWithDynamicFilterExpressions() {

when(chatModel.call(promptCaptor.capture()))
.thenReturn(new ChatResponse(List.of(new Generation("Your answer is ZXY"))));
.thenReturn(new ChatResponse(
List.of(new Generation(new AssistantMessage(
"Your answer is ZXY")))));

when(vectorStore.similaritySearch(vectorSearchCaptor.capture()))
.thenReturn(List.of(new Document("doc1"), new Document("doc2")));
Expand All @@ -74,20 +79,16 @@ public void qaAdvisorWithDynamicFilterExpressions() {
.defaultAdvisors(qaAdvisor)
.build();

// @formatter:off
var content = chatClient.prompt()
.user("Please answer my question XYZ")
.advisors(a -> a.param(QuestionAnswerAdvisor.FILTER_EXPRESSION, "type == 'Spring'"))
.call()
.content();
//formatter:on

assertThat(content).isEqualTo("Your answer is ZXY");

Message systemMessage = promptCaptor.getValue().getInstructions().get(0);

System.out.println(systemMessage.getContent());

assertThat(systemMessage.getContent()).isEqualToIgnoringWhitespace("""
Default system text.
""");
Expand All @@ -111,4 +112,44 @@ public void qaAdvisorWithDynamicFilterExpressions() {
assertThat(vectorSearchCaptor.getValue().getSimilarityThreshold()).isEqualTo(0.99d);
assertThat(vectorSearchCaptor.getValue().getTopK()).isEqualTo(6);
}

@Test
public void qaAdvisorWithCustomizedSearchRequest() {
when(chatModel.call(promptCaptor.capture()))
.thenReturn(new ChatResponse(List.of(new Generation(new AssistantMessage(
"Your answer is XYZ")))));

when(vectorStore.similaritySearch(vectorSearchCaptor.capture()))
.thenReturn(List.of(new Document("doc1"), new Document("doc2")));

var qaAdvisor = new RewriteQueryQuestionAnswerAdvisor(vectorStore);
var chatClient = ChatClient.builder(chatModel)
.defaultAdvisors(qaAdvisor)
.build();
var updatedUserQuery = "Please answer my question 123";
var content = chatClient.prompt()
.user("Please answer my question XYZ")
.advisors(a -> a.param("qa_updated_user_query", updatedUserQuery))
.call()
.content();

assertThat(content).isEqualTo("Your answer is XYZ");
assertThat(vectorSearchCaptor.getValue().getQuery()).isEqualTo(updatedUserQuery);
}

private static class RewriteQueryQuestionAnswerAdvisor extends QuestionAnswerAdvisor {

public RewriteQueryQuestionAnswerAdvisor(VectorStore vectorStore) {
super(vectorStore);
}

@Override
protected SearchRequest customizeSearchRequest(SearchRequest searchRequest, AdvisedRequest request,
Map<String, Object> context) {
return SearchRequest.from(searchRequest)
.withQuery(Objects.toString(context.getOrDefault("qa_updated_user_query", "")));
}

}

}