Skip to content

Commit c0b9240

Browse files
markpollackilayaperumalg
authored andcommitted
Fix VectorStoreDocumentRetriever to handle Filter.Expression objects directly
- Updated computeRequestFilterExpression to check if the context value is already a Filter.Expression object before attempting to parse it as a string - Added docs for FILTER_EXPRESSION key that it accepts both String and Filter.Expression - Added test Fixes #3179
1 parent 368be3a commit c0b9240

File tree

2 files changed

+50
-2
lines changed

2 files changed

+50
-2
lines changed

spring-ai-rag/src/main/java/org/springframework/ai/rag/retrieval/search/VectorStoreDocumentRetriever.java

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@
4545
* List<Document> documents = retriever.retrieve(new Query("example query"));
4646
* }</pre>
4747
*
48+
* <p>
49+
* The {@link #FILTER_EXPRESSION} context key can be used to provide a filter expression
50+
* for a specific query. This key accepts either a string representation of a filter
51+
* expression or a {@link Filter.Expression} object directly.
52+
*
4853
* @author Thomas Vitale
4954
* @since 1.0.0
5055
*/
@@ -89,10 +94,27 @@ public List<Document> retrieve(Query query) {
8994
return this.vectorStore.similaritySearch(searchRequest);
9095
}
9196

97+
/**
98+
* Computes the filter expression to use for the current request.
99+
* <p>
100+
* The filter expression can be provided in the query context using the
101+
* {@link #FILTER_EXPRESSION} key. This key accepts either a string representation of
102+
* a filter expression or a {@link Filter.Expression} object directly.
103+
* <p>
104+
* If no filter expression is provided in the context, the default filter expression
105+
* configured for this retriever is used.
106+
* @param query the query containing potential context with filter expression
107+
* @return the filter expression to use for the request
108+
*/
92109
private Filter.Expression computeRequestFilterExpression(Query query) {
93110
var contextFilterExpression = query.context().get(FILTER_EXPRESSION);
94-
if (contextFilterExpression != null && StringUtils.hasText(contextFilterExpression.toString())) {
95-
return new FilterExpressionTextParser().parse(contextFilterExpression.toString());
111+
if (contextFilterExpression != null) {
112+
if (contextFilterExpression instanceof Filter.Expression) {
113+
return (Filter.Expression) contextFilterExpression;
114+
}
115+
else if (StringUtils.hasText(contextFilterExpression.toString())) {
116+
return new FilterExpressionTextParser().parse(contextFilterExpression.toString());
117+
}
96118
}
97119
return this.filterExpression.get();
98120
}

spring-ai-rag/src/test/java/org/springframework/ai/rag/retrieval/search/VectorStoreDocumentRetrieverTests.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,32 @@ void retrieveWithQueryObjectAndRequestFilterExpression() {
234234
.isEqualTo(new FilterExpressionBuilder().eq("location", "Rivendell").build());
235235
}
236236

237+
@Test
238+
void retrieveWithQueryObjectAndFilterExpressionObject() {
239+
var mockVectorStore = mock(VectorStore.class);
240+
var documentRetriever = VectorStoreDocumentRetriever.builder().vectorStore(mockVectorStore).build();
241+
242+
// Create a Filter.Expression object directly
243+
var filterExpression = new Filter.Expression(EQ, new Filter.Key("location"), new Filter.Value("Rivendell"));
244+
245+
var query = Query.builder()
246+
.text("test query")
247+
.context(Map.of(VectorStoreDocumentRetriever.FILTER_EXPRESSION, filterExpression))
248+
.build();
249+
documentRetriever.retrieve(query);
250+
251+
// Verify the mock interaction
252+
var searchRequestCaptor = ArgumentCaptor.forClass(SearchRequest.class);
253+
verify(mockVectorStore).similaritySearch(searchRequestCaptor.capture());
254+
255+
// Verify the search request
256+
var searchRequest = searchRequestCaptor.getValue();
257+
assertThat(searchRequest.getQuery()).isEqualTo("test query");
258+
assertThat(searchRequest.getSimilarityThreshold()).isEqualTo(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL);
259+
assertThat(searchRequest.getTopK()).isEqualTo(SearchRequest.DEFAULT_TOP_K);
260+
assertThat(searchRequest.getFilterExpression()).isEqualTo(filterExpression);
261+
}
262+
237263
static final class TenantContextHolder {
238264

239265
private static final ThreadLocal<String> tenantIdentifier = new ThreadLocal<>();

0 commit comments

Comments
 (0)