Skip to content

Commit 268248b

Browse files
CChuYongilayaperumalg
authored andcommitted
Fix PgVectorStore doDelete function as batch
Signed-off-by: CChuYong <yeongmin1061@gmail.com>
1 parent d5203ed commit 268248b

File tree

2 files changed

+60
-6
lines changed

2 files changed

+60
-6
lines changed

vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/pgvector/PgVectorStore.java

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@
152152
* @author Soby Chacko
153153
* @author Sebastien Deleuze
154154
* @author Jihoon Kim
155+
* @author YeongMin Song
155156
* @since 1.0.0
156157
*/
157158
public class PgVectorStore extends AbstractObservationVectorStore implements InitializingBean {
@@ -319,12 +320,21 @@ private Object convertIdToPgType(String id) {
319320

320321
@Override
321322
public void doDelete(List<String> idList) {
322-
int updateCount = 0;
323-
for (String id : idList) {
324-
int count = this.jdbcTemplate.update("DELETE FROM " + getFullyQualifiedTableName() + " WHERE id = ?",
325-
UUID.fromString(id));
326-
updateCount = updateCount + count;
327-
}
323+
String sql = "DELETE FROM " + getFullyQualifiedTableName() + " WHERE id = ?";
324+
325+
this.jdbcTemplate.batchUpdate(sql, new BatchPreparedStatementSetter() {
326+
327+
@Override
328+
public void setValues(PreparedStatement ps, int i) throws SQLException {
329+
var id = idList.get(i);
330+
StatementCreatorUtils.setParameterValue(ps, 1, SqlTypeValue.TYPE_UNKNOWN, convertIdToPgType(id));
331+
}
332+
333+
@Override
334+
public int getBatchSize() {
335+
return idList.size();
336+
}
337+
});
328338
}
329339

330340
@Override

vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/pgvector/PgVectorStoreIT.java

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
* @author Christian Tzolov
7676
* @author Thomas Vitale
7777
* @author Jihoon Kim
78+
* @author YeongMin Song
7879
*/
7980
@Testcontainers
8081
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
@@ -232,6 +233,47 @@ public void testToPgTypeWithNonUuidIdType() {
232233
});
233234
}
234235

236+
@Test
237+
public void testBulkOperationWithUuidIdType() {
238+
this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=" + "COSINE_DISTANCE")
239+
.run(context -> {
240+
241+
VectorStore vectorStore = context.getBean(VectorStore.class);
242+
243+
List<Document> documents = List.of(
244+
new Document(new RandomIdGenerator().generateId(), "TEXT", new HashMap<>()),
245+
new Document(new RandomIdGenerator().generateId(), "TEXT", new HashMap<>()),
246+
new Document(new RandomIdGenerator().generateId(), "TEXT", new HashMap<>()));
247+
vectorStore.add(documents);
248+
249+
List<String> idList = documents.stream().map(Document::getId).toList();
250+
vectorStore.delete(idList);
251+
252+
dropTable(context);
253+
});
254+
}
255+
256+
@Test
257+
public void testBulkOperationWithNonUuidIdType() {
258+
this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=" + "COSINE_DISTANCE")
259+
.withPropertyValues("test.spring.ai.vectorstore.pgvector.initializeSchema=" + false)
260+
.withPropertyValues("test.spring.ai.vectorstore.pgvector.idType=" + "TEXT")
261+
.run(context -> {
262+
VectorStore vectorStore = context.getBean(VectorStore.class);
263+
initSchema(context);
264+
265+
List<Document> documents = List.of(new Document("NON_UUID_1", "TEXT", new HashMap<>()),
266+
new Document("NON_UUID_2", "TEXT", new HashMap<>()),
267+
new Document("NON_UUID_3", "TEXT", new HashMap<>()));
268+
vectorStore.add(documents);
269+
270+
List<String> idList = documents.stream().map(Document::getId).toList();
271+
vectorStore.delete(idList);
272+
273+
dropTable(context);
274+
});
275+
}
276+
235277
@ParameterizedTest(name = "Filter expression {0} should return {1} records ")
236278
@MethodSource("provideFilters")
237279
public void searchWithInFilter(String expression, Integer expectedRecords) {
@@ -436,6 +478,8 @@ void getNativeClientTest() {
436478
PgVectorStore vectorStore = context.getBean(PgVectorStore.class);
437479
Optional<JdbcTemplate> nativeClient = vectorStore.getNativeClient();
438480
assertThat(nativeClient).isPresent();
481+
482+
dropTable(context);
439483
});
440484
}
441485

0 commit comments

Comments
 (0)