Skip to content

Refactor PgVercorStore filter template to use JSONB field access #589

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package org.springframework.ai.vectorstore.filter.converter;

import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.ai.vectorstore.filter.Filter.Expression;
import org.springframework.ai.vectorstore.filter.Filter.Group;
import org.springframework.ai.vectorstore.filter.Filter.Key;
Expand All @@ -37,11 +38,11 @@ protected void doExpression(Expression expression, StringBuilder context) {
private String getOperationSymbol(Expression exp) {
switch (exp.type()) {
case AND:
return " && ";
return " AND ";
case OR:
return " || ";
return " OR ";
case EQ:
return " == ";
return " = ";
case NE:
return " != ";
case LT:
Expand All @@ -53,17 +54,34 @@ private String getOperationSymbol(Expression exp) {
case GTE:
return " >= ";
case IN:
return " in ";
return " IN ";
case NIN:
return " nin ";
return " NOT IN ";
default:
throw new RuntimeException("Not supported expression type: " + exp.type());
}
}

@Override
protected void doKey(Key key, StringBuilder context) {
context.append("$." + key.key());
context.append("metadata::jsonb->>'");
if (hasOuterQuotes(key.key())) {
context.append(removeOuterQuotes(key.key()));
}
else {
context.append(key.key());
}
context.append('\'');
}

@Override
protected void doSingleValue(Object value, StringBuilder context) {
if (value instanceof String) {
context.append(String.format("\'%s\'", value));
}
else {
context.append(value);
}
}

@Override
Expand All @@ -76,4 +94,14 @@ protected void doEndGroup(Group group, StringBuilder context) {
context.append(")");
}

@Override
protected void doStartValueRange(Filter.Value listValue, StringBuilder context) {
context.append("(");
}

@Override
protected void doEndValueRange(Filter.Value listValue, StringBuilder context) {
context.append(")");
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public class PgVectorFilterExpressionConverterTests {
public void testEQ() {
// country == "BG"
String vectorExpr = converter.convertExpression(new Expression(EQ, new Key("country"), new Value("BG")));
assertThat(vectorExpr).isEqualTo("$.country == \"BG\"");
assertThat(vectorExpr).isEqualTo("metadata::jsonb->>'country' = 'BG'");
}

@Test
Expand All @@ -55,15 +55,15 @@ public void tesEqAndGte() {
String vectorExpr = converter
.convertExpression(new Expression(AND, new Expression(EQ, new Key("genre"), new Value("drama")),
new Expression(GTE, new Key("year"), new Value(2020))));
assertThat(vectorExpr).isEqualTo("$.genre == \"drama\" && $.year >= 2020");
assertThat(vectorExpr).isEqualTo("metadata::jsonb->>'genre' = 'drama' AND metadata::jsonb->>'year' >= 2020");
}

@Test
public void tesIn() {
// genre in ["comedy", "documentary", "drama"]
String vectorExpr = converter.convertExpression(
new Expression(IN, new Key("genre"), new Value(List.of("comedy", "documentary", "drama"))));
assertThat(vectorExpr).isEqualTo("$.genre in [\"comedy\",\"documentary\",\"drama\"]");
assertThat(vectorExpr).isEqualTo("metadata::jsonb->>'genre' IN ('comedy','documentary','drama')");
}

@Test
Expand All @@ -73,7 +73,8 @@ public void testNe() {
.convertExpression(new Expression(OR, new Expression(GTE, new Key("year"), new Value(2020)),
new Expression(AND, new Expression(EQ, new Key("country"), new Value("BG")),
new Expression(NE, new Key("city"), new Value("Sofia")))));
assertThat(vectorExpr).isEqualTo("$.year >= 2020 || $.country == \"BG\" && $.city != \"Sofia\"");
assertThat(vectorExpr).isEqualTo(
"metadata::jsonb->>'year' >= 2020 OR metadata::jsonb->>'country' = 'BG' AND metadata::jsonb->>'city' != 'Sofia'");
}

@Test
Expand All @@ -83,8 +84,8 @@ public void testGroup() {
new Group(new Expression(OR, new Expression(GTE, new Key("year"), new Value(2020)),
new Expression(EQ, new Key("country"), new Value("BG")))),
new Expression(NIN, new Key("city"), new Value(List.of("Sofia", "Plovdiv")))));
assertThat(vectorExpr)
.isEqualTo("($.year >= 2020 || $.country == \"BG\") && $.city nin [\"Sofia\",\"Plovdiv\"]");
assertThat(vectorExpr).isEqualTo(
"(metadata::jsonb->>'year' >= 2020 OR metadata::jsonb->>'country' = 'BG') AND metadata::jsonb->>'city' NOT IN ('Sofia','Plovdiv')");
}

@Test
Expand All @@ -95,7 +96,8 @@ public void tesBoolean() {
new Expression(GTE, new Key("year"), new Value(2020))),
new Expression(IN, new Key("country"), new Value(List.of("BG", "NL", "US")))));

assertThat(vectorExpr).isEqualTo("$.isOpen == true && $.year >= 2020 && $.country in [\"BG\",\"NL\",\"US\"]");
assertThat(vectorExpr).isEqualTo(
"metadata::jsonb->>'isOpen' = true AND metadata::jsonb->>'year' >= 2020 AND metadata::jsonb->>'country' IN ('BG','NL','US')");
}

@Test
Expand All @@ -105,14 +107,15 @@ public void testDecimal() {
.convertExpression(new Expression(AND, new Expression(GTE, new Key("temperature"), new Value(-15.6)),
new Expression(LTE, new Key("temperature"), new Value(20.13))));

assertThat(vectorExpr).isEqualTo("$.temperature >= -15.6 && $.temperature <= 20.13");
assertThat(vectorExpr)
.isEqualTo("metadata::jsonb->>'temperature' >= -15.6 AND metadata::jsonb->>'temperature' <= 20.13");
}

@Test
public void testComplexIdentifiers() {
String vectorExpr = converter
.convertExpression(new Expression(EQ, new Key("\"country 1 2 3\""), new Value("BG")));
assertThat(vectorExpr).isEqualTo("$.\"country 1 2 3\" == \"BG\"");
assertThat(vectorExpr).isEqualTo("metadata::jsonb->>'country 1 2 3' = 'BG'");
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ public List<Document> similaritySearch(SearchRequest request) {
String jsonPathFilter = "";

if (StringUtils.hasText(nativeFilterExpression)) {
jsonPathFilter = " AND metadata::jsonb @@ '" + nativeFilterExpression + "'::jsonpath ";
jsonPathFilter = " AND (" + nativeFilterExpression + ") ";
}

double distance = 1 - request.getSimilarityThreshold();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,11 @@ public void searchWithFilters(String distanceType) {
assertThat(results).hasSize(1);
assertThat(results.get(0).getId()).isEqualTo(nlDocument.getId());

results = vectorStore.similaritySearch(searchRequest.withFilterExpression("country in ['NL', 'SP']"));

assertThat(results).hasSize(1);
assertThat(results.get(0).getId()).isEqualTo(nlDocument.getId());

results = vectorStore.similaritySearch(searchRequest.withFilterExpression("country == 'BG'"));

assertThat(results).hasSize(2);
Expand Down