Skip to content

Commit 6bd059a

Browse files
muthuisheremarkpollack
authored andcommitted
Fix IN/NOT IN filters for PgVector JSON queries
PgVectorFilterExpressionConverter was generating incorrect SQL for IN and NOT IN filters with PostgreSQL JSON data types. This caused BadSqlGrammarException errors when executing queries. This change modifies the converter to generate correct SQL syntax for these operations, ensuring compatibility with PostgreSQL's JSON handling capabilities. Why: - Improves query reliability for PgVector stores - Enables more complex filtering operations on JSON data - Eliminates unexpected errors in query execution Fixes #1179
1 parent bc3f9ac commit 6bd059a

File tree

3 files changed

+93
-10
lines changed

3 files changed

+93
-10
lines changed

vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorFilterExpressionConverter.java

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,64 @@
1515
*/
1616
package org.springframework.ai.vectorstore;
1717

18+
import org.springframework.ai.vectorstore.filter.Filter;
1819
import org.springframework.ai.vectorstore.filter.Filter.Expression;
1920
import org.springframework.ai.vectorstore.filter.Filter.Group;
2021
import org.springframework.ai.vectorstore.filter.Filter.Key;
2122
import org.springframework.ai.vectorstore.filter.converter.AbstractFilterExpressionConverter;
23+
import java.util.List;
2224

2325
/**
2426
* Converts {@link Expression} into PgVector metadata filter expression format.
2527
* (https://www.postgresql.org/docs/current/functions-json.html)
2628
*
29+
* @author Muthukumaran Navaneethakrishnan
2730
* @author Christian Tzolov
2831
*/
2932
public class PgVectorFilterExpressionConverter extends AbstractFilterExpressionConverter {
3033

3134
@Override
3235
protected void doExpression(Expression expression, StringBuilder context) {
33-
this.convertOperand(expression.left(), context);
34-
context.append(getOperationSymbol(expression));
35-
this.convertOperand(expression.right(), context);
36+
if (expression.type() == Filter.ExpressionType.IN) {
37+
handleIn(expression, context);
38+
}
39+
else if (expression.type() == Filter.ExpressionType.NIN) {
40+
handleNotIn(expression, context);
41+
}
42+
else {
43+
this.convertOperand(expression.left(), context);
44+
context.append(getOperationSymbol(expression));
45+
this.convertOperand(expression.right(), context);
46+
}
47+
}
48+
49+
private void handleIn(Expression expression, StringBuilder context) {
50+
context.append("(");
51+
convertToConditions(expression, context);
52+
context.append(")");
53+
}
54+
55+
private void convertToConditions(Expression expression, StringBuilder context) {
56+
Filter.Value right = (Filter.Value) expression.right();
57+
Object value = right.value();
58+
if (!(value instanceof List)) {
59+
throw new IllegalArgumentException("Expected a List, but got: " + value.getClass().getSimpleName());
60+
}
61+
List<Object> values = (List) value;
62+
for (int i = 0; i < values.size(); i++) {
63+
this.convertOperand(expression.left(), context);
64+
context.append(" == ");
65+
this.doSingleValue(values.get(i), context);
66+
if (i < values.size() - 1) {
67+
context.append(" || ");
68+
}
69+
}
70+
}
71+
72+
private void handleNotIn(Expression expression, StringBuilder context) {
73+
context.append("!(");
74+
convertToConditions(expression, context);
75+
context.append(")");
3676
}
3777

3878
private String getOperationSymbol(Expression exp) {
@@ -53,10 +93,6 @@ private String getOperationSymbol(Expression exp) {
5393
return " > ";
5494
case GTE:
5595
return " >= ";
56-
case IN:
57-
return " in ";
58-
case NIN:
59-
return " nin ";
6096
default:
6197
throw new RuntimeException("Not supported expression type: " + exp.type());
6298
}

vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorFilterExpressionConverterTests.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.springframework.ai.vectorstore.filter.FilterExpressionConverter;
3535

3636
/**
37+
* @author Muthukumaran Navaneethakrishnan
3738
* @author Christian Tzolov
3839
*/
3940
public class PgVectorFilterExpressionConverterTests {
@@ -61,7 +62,8 @@ public void tesIn() {
6162
// genre in ["comedy", "documentary", "drama"]
6263
String vectorExpr = converter.convertExpression(
6364
new Expression(IN, new Key("genre"), new Value(List.of("comedy", "documentary", "drama"))));
64-
assertThat(vectorExpr).isEqualTo("$.genre in [\"comedy\",\"documentary\",\"drama\"]");
65+
assertThat(vectorExpr)
66+
.isEqualTo("($.genre == \"comedy\" || $.genre == \"documentary\" || $.genre == \"drama\")");
6567
}
6668

6769
@Test
@@ -82,7 +84,7 @@ public void testGroup() {
8284
new Expression(EQ, new Key("country"), new Value("BG")))),
8385
new Expression(NIN, new Key("city"), new Value(List.of("Sofia", "Plovdiv")))));
8486
assertThat(vectorExpr)
85-
.isEqualTo("($.year >= 2020 || $.country == \"BG\") && $.city nin [\"Sofia\",\"Plovdiv\"]");
87+
.isEqualTo("($.year >= 2020 || $.country == \"BG\") && !($.city == \"Sofia\" || $.city == \"Plovdiv\")");
8688
}
8789

8890
@Test
@@ -93,7 +95,8 @@ public void tesBoolean() {
9395
new Expression(GTE, new Key("year"), new Value(2020))),
9496
new Expression(IN, new Key("country"), new Value(List.of("BG", "NL", "US")))));
9597

96-
assertThat(vectorExpr).isEqualTo("$.isOpen == true && $.year >= 2020 && $.country in [\"BG\",\"NL\",\"US\"]");
98+
assertThat(vectorExpr).isEqualTo(
99+
"$.isOpen == true && $.year >= 2020 && ($.country == \"BG\" || $.country == \"NL\" || $.country == \"US\")");
97100
}
98101

99102
@Test

vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreIT.java

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,15 @@
2424
import java.util.List;
2525
import java.util.Map;
2626
import java.util.UUID;
27+
import java.util.stream.Stream;
2728

2829
import javax.sql.DataSource;
2930

3031
import org.junit.Assert;
3132
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
3233
import org.junit.jupiter.params.ParameterizedTest;
34+
import org.junit.jupiter.params.provider.Arguments;
35+
import org.junit.jupiter.params.provider.MethodSource;
3336
import org.junit.jupiter.params.provider.ValueSource;
3437
import org.springframework.ai.document.Document;
3538
import org.springframework.ai.embedding.EmbeddingModel;
@@ -57,6 +60,7 @@
5760
import com.zaxxer.hikari.HikariDataSource;
5861

5962
/**
63+
* @author Muthukumaran Navaneethakrishnan
6064
* @author Christian Tzolov
6165
*/
6266
@Testcontainers
@@ -128,6 +132,46 @@ public void addAndSearch(String distanceType) {
128132
});
129133
}
130134

135+
static Stream<Arguments> provideFilters() {
136+
return Stream.of(Arguments.of("country in ['BG','NL']", 3), // String Filters In
137+
Arguments.of("year in [2020]", 1), // Numeric Filters In
138+
Arguments.of("country not in ['BG']", 1), // String Filter Not In
139+
Arguments.of("year not in [2020]", 2) // Numeric Filter Not In
140+
);
141+
}
142+
143+
@ParameterizedTest(name = "Filter expression {0} should return {1} records ")
144+
@MethodSource("provideFilters")
145+
public void searchWithInFilter(String expression, Integer expectedRecords) {
146+
147+
contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=COSINE_DISTANCE")
148+
.run(context -> {
149+
150+
VectorStore vectorStore = context.getBean(VectorStore.class);
151+
152+
var bgDocument = new Document("The World is Big and Salvation Lurks Around the Corner",
153+
Map.of("country", "BG", "year", 2020, "foo bar 1", "bar.foo"));
154+
var nlDocument = new Document("The World is Big and Salvation Lurks Around the Corner",
155+
Map.of("country", "NL"));
156+
var bgDocument2 = new Document("The World is Big and Salvation Lurks Around the Corner",
157+
Map.of("country", "BG", "year", 2023));
158+
159+
vectorStore.add(List.of(bgDocument, nlDocument, bgDocument2));
160+
161+
SearchRequest searchRequest = SearchRequest.query("The World")
162+
.withFilterExpression(expression)
163+
.withTopK(5)
164+
.withSimilarityThresholdAll();
165+
166+
List<Document> results = vectorStore.similaritySearch(searchRequest);
167+
168+
assertThat(results).hasSize(expectedRecords);
169+
170+
// Remove all documents from the store
171+
dropTable(context);
172+
});
173+
}
174+
131175
@ParameterizedTest(name = "{0} : {displayName} ")
132176
@ValueSource(strings = { "COSINE_DISTANCE", "EUCLIDEAN_DISTANCE", "NEGATIVE_INNER_PRODUCT" })
133177
public void searchWithFilters(String distanceType) {

0 commit comments

Comments
 (0)