diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SelfQueryAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SelfQueryAdvisor.java new file mode 100644 index 00000000000..bf69df851e2 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/SelfQueryAdvisor.java @@ -0,0 +1,197 @@ +package org.springframework.ai.chat.client.advisor; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.client.AdvisedRequest; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.VectorStore; + +import java.util.List; +import java.util.Map; + +/** + * Advisor that that generates and executes structured queries over its own data source. + * Inspired by the langchain SelfQueryRetriever. + * + * @author Florin Duroiu + */ +public class SelfQueryAdvisor extends QuestionAnswerAdvisor { + + private static final Logger logger = LoggerFactory.getLogger(SelfQueryAdvisor.class); + + private static final String STRUCTURED_REQUEST_PROMPT = """ + Your goal is to structure the user's query to match the request schema provided below.\s + + {schema} + << Structured Request Schema >> + When responding, please don't use markup or back-tics. Instead use a directly parsable JSON object formatted in the following schema: + + {{"query": "string", "filter": "string"}} + + The response JSON object should have the fields "query" and "filter"."query" is a text string to compare to document contents. "filter" is a logical condition statement for filtering documents. Any conditions in the "filter" should not be mentioned in the "query" as well. + + A logical condition statement is composed of one or more comparison and logical operation statements. + + A comparison statement takes the form: `attr comp val`: + - `comp` ({allowed_comparators}): comparator + - `attr` (string): name of attribute to apply the comparison to + - `val` (string): is the comparison value. Enclose val with single quotes. + - if `val` is a list of values, it should be in the format `attr == 'val1' OR attr == 'val2' OR ...` + A logical operation statement takes the form `statement1 op statement2 op ...`: + - `op` ({allowed_operators}): logical operator + - `statement1`, `statement2`, ... (comparison statements or logical operation statements): one or more statements to apply the operation to + + Make sure that you only use the comparators and logical operators listed above and no others. + Make sure that filters only refer to attributes that exist in the data source. + Make sure that filters only use the attributed names with its function names if there are functions applied on them. + Make sure that filters only use format `YYYY-MM-DD` when handling date data typed values. + Make sure that filters take into account the descriptions of attributes and only make comparisons that are feasible given the type of data being stored. + Make sure that filters are only used as needed. If there are no filters that should be applied return "NO_FILTER" for the filter value. + + User query: {query} + """; + + private final String attributeInfosAsJson; + + private final SearchRequest searchRequest; + + private final ChatModel chatModel; + + private static final ObjectMapper objectMapper = new ObjectMapper(); + + private static final String allowedComparators = String.join(",", + List.of("==", "!=", ">", ">=", "<", "<=", "-", "+")); + + private static final String allowedOperators = String.join(",", List.of("AND", "OR", "IN", "NIN", "NOT")); + + public SelfQueryAdvisor(List attributeInfos, VectorStore vectorStore, SearchRequest searchRequest, + ChatModel chatModel) { + super(vectorStore, searchRequest); + try { + this.attributeInfosAsJson = objectMapper.writeValueAsString(attributeInfos); + } + catch (Exception e) { + throw new RuntimeException("Failed to serialize metadata field info", e); + } + this.searchRequest = searchRequest; + this.chatModel = chatModel; + } + + @Override + public AdvisedRequest adviseRequest(AdvisedRequest request, Map context) { + String userQuery = request.userText(); + QueryFilter queryFilter = extractQueryFilter(userQuery); + if (queryFilter.isFilterFound()) { + searchRequest.withQuery(queryFilter.getQuery()).withFilterExpression(queryFilter.getFilter()); + } + var fromAdvisedRequestWithSummaryQuery = AdvisedRequest.from(request).withUserText(queryFilter.query).build(); + return super.adviseRequest(fromAdvisedRequestWithSummaryQuery, context); + } + + private QueryFilter extractQueryFilter(String userQuery) { + String queryExtractionResult = chatModel.call(queryExtractionPrompt(userQuery)) + .getResult() + .getOutput() + .getContent(); + try { + return objectMapper.readValue(queryExtractionResult, QueryFilter.class); + } + catch (Exception e) { + logger.warn("Failed to serialize metadata field info. Returning original query with NO_FILTER. Reason:", e); + return new QueryFilter(userQuery, QueryFilter.NO_FILTER); + } + } + + private Prompt queryExtractionPrompt(String query) { + PromptTemplate promptTemplate = new PromptTemplate(STRUCTURED_REQUEST_PROMPT, + Map.of("allowed_comparators", allowedComparators, "allowed_operators", allowedOperators, "schema", + attributeInfosAsJson, "query", query)); + return new Prompt(promptTemplate.createMessage()); + } + + public static class QueryFilter { + + public static final String NO_FILTER = "NO_FILTER"; + + private String query; + + private String filter; + + public QueryFilter() { + } + + public QueryFilter(String query, String filter) { + this.query = query; + this.filter = filter; + } + + public String getQuery() { + return query; + } + + public void setQuery(String query) { + this.query = query; + } + + public String getFilter() { + return filter; + } + + public void setFilter(String filter) { + this.filter = filter; + } + + public boolean isFilterFound() { + return !NO_FILTER.equals(filter); + } + + } + + public static class AttributeInfo { + + private String name; + + private String type; + + private String description; + + public AttributeInfo() { + } + + public AttributeInfo(String name, String type, String description) { + this.name = name; + this.type = type; + this.description = description; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public String getType() { + return type; + } + + public void setType(String type) { + this.type = type; + } + + public String getDescription() { + return description; + } + + public void setDescription(String description) { + this.description = description; + } + + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/SelfQueryAdvisorTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/SelfQueryAdvisorTests.java new file mode 100644 index 00000000000..d2506d13e2f --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/SelfQueryAdvisorTests.java @@ -0,0 +1,107 @@ +package org.springframework.ai.chat.client; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.ai.chat.client.advisor.SelfQueryAdvisor; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.document.Document; +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.VectorStore; +import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder; + +import java.util.Collections; +import java.util.List; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.mockito.Mockito.when; +import static org.springframework.ai.chat.client.advisor.SelfQueryAdvisor.QueryFilter.NO_FILTER; + +@ExtendWith(MockitoExtension.class) +public class SelfQueryAdvisorTests { + + public static final List METADATA_FIELD_INFO = List.of( + new SelfQueryAdvisor.AttributeInfo("name", "string", "description1"), + new SelfQueryAdvisor.AttributeInfo("age", "integer", "description2")); + + @Mock + ChatModel chatModel; + + @Captor + ArgumentCaptor vectorSearchCaptor; + + @Captor + ArgumentCaptor promptCaptor; + + @Mock + VectorStore vectorStore; + + @Test + public void selfQueryAdvisorWithExtractedQuery() { + String query = "joyful"; + String name = "Joe"; + int age = 30; + String extractedJsonQuery = String.format(""" + {"query": "%s", "filter": "name == '%s' AND age == %d"} + """, query, name, age); + List docsResult = List.of(new Document("doc1"), new Document("doc2")); + when(vectorStore.similaritySearch(vectorSearchCaptor.capture())).thenReturn(docsResult); + when(chatModel.call(promptCaptor.capture())) + .thenReturn(new ChatResponse(List.of(new Generation(extractedJsonQuery)))); + + var selfQueryAdvisor = new SelfQueryAdvisor(METADATA_FIELD_INFO, vectorStore, SearchRequest.defaults(), + chatModel); + var chatClient = ChatClient.builder(chatModel).defaultSystem("You are a helpful assistant").build(); + + chatClient.prompt().advisors(selfQueryAdvisor).user("Look for a user named Joe aged 30").call().chatResponse(); + + FilterExpressionBuilder b = new FilterExpressionBuilder(); + assertThat(vectorSearchCaptor.getValue().getFilterExpression()) + .isEqualTo(b.and(b.eq("name", name), b.eq("age", age)).build()); + assertThat(vectorSearchCaptor.getValue().getQuery()).isEqualTo(query); + } + + @Test + public void selfQueryAdvisorWithExtractedQueryAndNoFilter() { + String query = "any query"; + String extractedJsonQuery = String.format(""" + {"query": "%s", "filter": "%s"} + """, query, NO_FILTER); + when(vectorStore.similaritySearch(vectorSearchCaptor.capture())).thenReturn(Collections.emptyList()); + when(chatModel.call(promptCaptor.capture())) + .thenReturn(new ChatResponse(List.of(new Generation(extractedJsonQuery)))); + + var selfQueryAdvisor = new SelfQueryAdvisor(METADATA_FIELD_INFO, vectorStore, SearchRequest.defaults(), + chatModel); + var chatClient = ChatClient.builder(chatModel).defaultSystem("You are a helpful assistant").build(); + + chatClient.prompt().advisors(selfQueryAdvisor).user("Look for a user named Joe aged 30").call().chatResponse(); + + assertThat(vectorSearchCaptor.getValue().getFilterExpression()).isNull(); + assertThat(vectorSearchCaptor.getValue().getQuery()).isEqualTo(query); + } + + @Test + public void selfQueryWithInvalidQueryExtracted() { + String extractedJsonQuery = "not a JSON String"; + when(chatModel.call(promptCaptor.capture())) + .thenReturn(new ChatResponse(List.of(new Generation(extractedJsonQuery)))); + + var selfQueryAdvisor = new SelfQueryAdvisor(METADATA_FIELD_INFO, vectorStore, SearchRequest.defaults(), + chatModel); + var chatClient = ChatClient.builder(chatModel).defaultSystem("You are a helpful assistant").build(); + + String userQuery = "Look for a user named Joe aged 30"; + chatClient.prompt().advisors(selfQueryAdvisor).user(userQuery).call().chatResponse(); + + // assert that vectorSearchCaptor was not called + assertThat(vectorSearchCaptor.getAllValues()).isEqualTo(Collections.emptyList()); + } + +}