diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SearchRequest.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SearchRequest.java index b430793cdd0..a725537e6ab 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SearchRequest.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SearchRequest.java @@ -52,6 +52,22 @@ public class SearchRequest { private Filter.Expression filterExpression; + /** + * Default value for search request is to use the vector search. + */ + private boolean vectorSearch = true; + + /** + * Enables the full text search mode. If combined with the vector search, the hybrid + * search is done. + */ + private boolean fullTextSearch = false; + + /** + * Enables the reranking of the results. + */ + private boolean reRank = false; + private SearchRequest(String query) { this.query = query; } @@ -230,6 +246,36 @@ public SearchRequest withFilterExpression(String textExpression) { return this; } + /** + * Set the vector search mode. + * @param vectorSearch + * @return this.builder + */ + public SearchRequest withVectorSearch(boolean vectorSearch) { + this.vectorSearch = vectorSearch; + return this; + } + + /** + * Set the full text search mode. + * @param fullTextSearch + * @return this.builder + */ + public SearchRequest withFullTextSearch(boolean fullTextSearch) { + this.fullTextSearch = fullTextSearch; + return this; + } + + /** + * Set the rerank mode. + * @param rerank + * @return this.builder + */ + public SearchRequest withRerank(boolean rerank) { + this.reRank = rerank; + return this; + } + public String getQuery() { return query; } @@ -250,10 +296,23 @@ public boolean hasFilterExpression() { return this.filterExpression != null; } + public boolean isVectorSearch() { + return this.vectorSearch; + } + + public boolean isFullTextSearch() { + return this.fullTextSearch; + } + + public boolean isReRank() { + return this.reRank; + } + @Override public String toString() { return "SearchRequest{" + "query='" + query + '\'' + ", topK=" + topK + ", similarityThreshold=" - + similarityThreshold + ", filterExpression=" + filterExpression + '}'; + + similarityThreshold + ", filterExpression=" + filterExpression + ", isVectorSearch=" + vectorSearch + + ", isFullTextSearch=" + fullTextSearch + ", isRerank=" + reRank + '}'; } @Override @@ -264,12 +323,13 @@ public boolean equals(Object o) { return false; SearchRequest that = (SearchRequest) o; return topK == that.topK && Double.compare(that.similarityThreshold, similarityThreshold) == 0 - && Objects.equals(query, that.query) && Objects.equals(filterExpression, that.filterExpression); + && Objects.equals(query, that.query) && Objects.equals(filterExpression, that.filterExpression) + && vectorSearch == that.vectorSearch && fullTextSearch == that.fullTextSearch && reRank == that.reRank; } @Override public int hashCode() { - return Objects.hash(query, topK, similarityThreshold, filterExpression); + return Objects.hash(query, topK, similarityThreshold, filterExpression, vectorSearch, fullTextSearch); } } \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/VectorStore.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/VectorStore.java index 8c4ff0c2a9c..5f0cb1f24cd 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/VectorStore.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/VectorStore.java @@ -70,4 +70,71 @@ default List similaritySearch(String query) { return this.similaritySearch(SearchRequest.query(query)); } + /** + * Retrieves documents by query full text content and metadata filters to retrieve + * exactly the number of nearest-neighbor results that match the request criteria. + * @param request Search request for set search parameters, such as the query text, + * topK, similarity threshold and metadata filter expressions. + * @return a list of {@link Document} objects representing the retrieved documents + * that match the search criteria. + * @throws UnsupportedOperationException if the method is not supported by the current + * implementation. Subclasses should override this method to provide a specific + * implementation. + */ + default List fullTextSearch(SearchRequest request) { + throw new UnsupportedOperationException("The [" + this.getClass() + "] doesn't support full text search!"); + } + + /** + * Retrieves documents by query full text content using the default + * {@link SearchRequest}'s' search criteria. + * @param query Text to use for full text search. + * @return a list of {@link Document} objects representing the retrieved documents + * that match the search criteria. + */ + default List fullTextSearch(String query) { + return this.fullTextSearch(SearchRequest.query(query)); + } + + /** + * Performs a hybrid search by combining semantic and keyword-based search techniques + * to retrieve a list of relevant documents based on the provided + * {@link SearchRequest}. + *

+ * This method is intended to retrieve documents that match the query both + * semantically (using vector embeddings) and via keyword matching. The hybrid + * approach aims to enhance retrieval accuracy by leveraging the strengths of both + * search methods. + *

+ * @param request the {@link SearchRequest} object containing the query and search + * parameters. + * @return a list of {@link Document} objects representing the retrieved documents + * that match the search criteria. + * @throws UnsupportedOperationException if the method is not supported by the current + * implementation. Subclasses should override this method to provide a specific + * implementation. + */ + default List hybridSearch(SearchRequest request) { + throw new UnsupportedOperationException( + "The [" + this.getClass() + "] doesn't support hybrid (vector + text) search!"); + } + + /** + * Performs a hybrid search by combining semantic and keyword-based search techniques + * to retrieve a list of relevant documents based on the provided + * {@link SearchRequest}. + *

+ * This method is intended to retrieve documents that match the query both + * semantically (using vector embeddings) and via keyword matching. The hybrid + * approach aims to enhance retrieval accuracy by leveraging the strengths of both + * search methods. + *

+ * @param query Text to use for embedding similarity comparison. + * @return a list of {@link Document} objects representing the retrieved documents + * that match the search criteria. + */ + default List hybridSearch(String query) { + return this.hybridSearch(SearchRequest.query(query)); + } + } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/SearchRequestTests.java b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/SearchRequestTests.java index 5535766b793..4f22de8647e 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/SearchRequestTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/SearchRequestTests.java @@ -56,6 +56,7 @@ public void createFrom() { assertThat(newRequest.getTopK()).isEqualTo(originalRequest.getTopK()); assertThat(newRequest.getFilterExpression()).isEqualTo(originalRequest.getFilterExpression()); assertThat(newRequest.getSimilarityThreshold()).isEqualTo(originalRequest.getSimilarityThreshold()); + assertThat(newRequest.isVectorSearch() == originalRequest.isVectorSearch()); } @Test @@ -135,10 +136,20 @@ public void withFilterExpression() { } + @Test() + public void withHybridSearchWithRerank() { + + var request = SearchRequest.query("Test").withFullTextSearch(true).withRerank(true); + assertThat(request.isVectorSearch()).isTrue(); + assertThat(request.isFullTextSearch()).isTrue(); + assertThat(request.isReRank()).isTrue(); + } + private void checkDefaults(SearchRequest request) { assertThat(request.getFilterExpression()).isNull(); assertThat(request.getSimilarityThreshold()).isEqualTo(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL); assertThat(request.getTopK()).isEqualTo(SearchRequest.DEFAULT_TOP_K); + assertThat(request.isVectorSearch()).isTrue(); } } diff --git a/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java b/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java index b51f292fc6c..9655e7f3bb9 100644 --- a/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java +++ b/vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java @@ -29,11 +29,7 @@ import com.azure.search.documents.indexes.models.VectorSearch; import com.azure.search.documents.indexes.models.VectorSearchAlgorithmMetric; import com.azure.search.documents.indexes.models.VectorSearchProfile; -import com.azure.search.documents.models.IndexDocumentsResult; -import com.azure.search.documents.models.IndexingResult; -import com.azure.search.documents.models.SearchOptions; -import com.azure.search.documents.models.VectorSearchOptions; -import com.azure.search.documents.models.VectorizedQuery; +import com.azure.search.documents.models.*; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.document.Document; @@ -64,6 +60,7 @@ * @author Xiangyang Yu * @author Christian Tzolov * @author Josh Long + * @author Alessio Bertazzo */ public class AzureVectorStore implements VectorStore, InitializingBean { @@ -91,6 +88,8 @@ public class AzureVectorStore implements VectorStore, InitializingBean { private static final String METADATA_FIELD_PREFIX = "meta_"; + private static final String SEMANTIC_SEARCH_CONFIG_NAME = "default"; + private final SearchIndexClient searchIndexClient; private final EmbeddingModel embeddingModel; @@ -281,24 +280,84 @@ public List similaritySearch(String query) { public List similaritySearch(SearchRequest request) { Assert.notNull(request, "The search request must not be null."); + Assert.isTrue(request.isVectorSearch() && !request.isFullTextSearch(), + "The search request must be a vector search."); + + return this.search(request); + } + + @Override + public List fullTextSearch(String query) { + return this.search(SearchRequest.query(query) + .withVectorSearch(false) + .withFullTextSearch(true) + .withTopK(this.defaultTopK) + .withSimilarityThreshold(this.defaultSimilarityThreshold)); + } + + @Override + public List fullTextSearch(SearchRequest request) { + + Assert.notNull(request, "The search request must not be null."); + Assert.isTrue(!request.isVectorSearch() && request.isFullTextSearch(), + "The search request must be a full text search."); + + return this.search(request); + } + + @Override + public List hybridSearch(String query) { + return this.hybridSearch(SearchRequest.query(query) + .withVectorSearch(true) + .withFullTextSearch(true) + .withTopK(this.defaultTopK) + .withSimilarityThreshold(this.defaultSimilarityThreshold)); + } + + @Override + public List hybridSearch(SearchRequest request) { + + Assert.notNull(request, "The search request must not be null."); + Assert.isTrue(request.isVectorSearch() && request.isFullTextSearch(), + "The search request must be a hybrid (vector + full text) search."); + + return this.search(request); + } + + private List search(SearchRequest request) { + + var searchOptions = new SearchOptions().setTop(request.getTopK()); + + if (request.isVectorSearch()) { + var searchEmbedding = embeddingModel.embed(request.getQuery()); - var searchEmbedding = embeddingModel.embed(request.getQuery()); + final var vectorQuery = new VectorizedQuery(EmbeddingUtils.toList(searchEmbedding)) + .setKNearestNeighborsCount(request.getTopK()) + // Set the fields to compare the vector against. This is a comma-delimited + // list of field names. + .setFields(EMBEDDING_FIELD_NAME); - final var vectorQuery = new VectorizedQuery(EmbeddingUtils.toList(searchEmbedding)) - .setKNearestNeighborsCount(request.getTopK()) - // Set the fields to compare the vector against. This is a comma-delimited - // list of field names. - .setFields(EMBEDDING_FIELD_NAME); + searchOptions.setVectorSearchOptions(new VectorSearchOptions().setQueries(vectorQuery)); + } + + String searchText = null; + if (request.isFullTextSearch()) { + searchText = request.getQuery(); - var searchOptions = new SearchOptions() - .setVectorSearchOptions(new VectorSearchOptions().setQueries(vectorQuery)); + if (request.isReRank()) { + searchOptions + .setSemanticSearchOptions( + new SemanticSearchOptions().setSemanticConfigurationName(SEMANTIC_SEARCH_CONFIG_NAME)) + .setQueryType(QueryType.SEMANTIC); + } + } if (request.hasFilterExpression()) { String oDataFilter = this.filterExpressionConverter.convertExpression(request.getFilterExpression()); searchOptions.setFilter(oDataFilter); } - final var searchResults = searchClient.search(null, searchOptions, Context.NONE); + final var searchResults = searchClient.search(searchText, searchOptions, Context.NONE); return searchResults.stream() .filter(result -> result.getScore() >= request.getSimilarityThreshold()) diff --git a/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreIT.java b/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreIT.java index ce58a246c3b..a3558ea5fbe 100644 --- a/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreIT.java +++ b/vector-stores/spring-ai-azure-store/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreIT.java @@ -72,7 +72,7 @@ public static void beforeAll() { } @Test - public void addAndSearchTest() { + public void addAndVectorSearchTest() { contextRunner.run(context -> { @@ -103,6 +103,82 @@ public void addAndSearchTest() { }); } + @Test + public void addAndFullTextSearchTest() { + + contextRunner.run(context -> { + + VectorStore vectorStore = context.getBean(VectorStore.class); + + vectorStore.add(documents); + + Awaitility.await() + .until(() -> vectorStore.fullTextSearch(SearchRequest.query("Great Depression") + .withTopK(1) + .withVectorSearch(false) + .withFullTextSearch(true)), hasSize(1)); + + List results = vectorStore.fullTextSearch(SearchRequest.query("Great Depression") + .withTopK(1) + .withVectorSearch(false) + .withFullTextSearch(true)); + + assertThat(results).hasSize(1); + Document resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); + assertThat(resultDoc.getMetadata()).hasSize(2); + assertThat(resultDoc.getMetadata()).containsKey("meta2"); + assertThat(resultDoc.getMetadata()).containsKey("distance"); + + // Remove all documents from the store + vectorStore.delete(documents.stream().map(Document::getId).toList()); + + Awaitility.await() + .until(() -> vectorStore.fullTextSearch( + SearchRequest.query("Hello").withTopK(1).withVectorSearch(false).withFullTextSearch(true)), + hasSize(0)); + }); + } + + @Test + public void addAndHybridSearchTest() { + + contextRunner.run(context -> { + + VectorStore vectorStore = context.getBean(VectorStore.class); + + vectorStore.add(documents); + + Awaitility.await() + .until(() -> vectorStore.hybridSearch(SearchRequest.query("Great Depression") + .withTopK(1) + .withVectorSearch(true) + .withFullTextSearch(true)), hasSize(1)); + + List results = vectorStore.hybridSearch(SearchRequest.query("Great Depression") + .withTopK(1) + .withVectorSearch(true) + .withFullTextSearch(true)); + + assertThat(results).hasSize(1); + Document resultDoc = results.get(0); + assertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId()); + assertThat(resultDoc.getContent()).contains("The Great Depression (1929–1939) was an economic shock"); + assertThat(resultDoc.getMetadata()).hasSize(2); + assertThat(resultDoc.getMetadata()).containsKey("meta2"); + assertThat(resultDoc.getMetadata()).containsKey("distance"); + + // Remove all documents from the store + vectorStore.delete(documents.stream().map(Document::getId).toList()); + + Awaitility.await() + .until(() -> vectorStore.hybridSearch( + SearchRequest.query("Hello").withTopK(1).withVectorSearch(true).withFullTextSearch(true)), + hasSize(0)); + }); + } + @Test public void searchWithFilters() throws InterruptedException {