Skip to content

Commit 202148d

Browse files
sobychackomarkpollack
authored andcommitted
Prevent timeouts with configurable batching for PgVectorStore inserts
Resolves #1199 - Implement configurable maxDocumentBatchSize to prevent insert timeouts when adding large numbers of documents - Update PgVectorStore to process document inserts in controlled batches - Add maxDocumentBatchSize property to PgVectorStoreProperties - Update PgVectorStoreAutoConfiguration to use the new batching property - Add tests to verify batching behavior and performance This change addresses the issue of PgVectorStore inserts timing out due to large document volumes. By introducing configurable batching, users can now control the insert process to avoid timeouts while maintaining performance and reducing memory overhead for large-scale document additions.
1 parent 42dcb45 commit 202148d

File tree

4 files changed

+132
-52
lines changed

4 files changed

+132
-52
lines changed

spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfiguration.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ public PgVectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embed
7171
.withObservationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
7272
.withSearchObservationConvention(customObservationConvention.getIfAvailable(() -> null))
7373
.withBatchingStrategy(batchingStrategy)
74+
.withMaxDocumentBatchSize(properties.getMaxDocumentBatchSize())
7475
.build();
7576
}
7677

spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreProperties.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
/**
2525
* @author Christian Tzolov
2626
* @author Muthukumaran Navaneethakrishnan
27+
* @author Soby Chacko
2728
*/
2829
@ConfigurationProperties(PgVectorStoreProperties.CONFIG_PREFIX)
2930
public class PgVectorStoreProperties extends CommonVectorStoreProperties {
@@ -45,6 +46,8 @@ public class PgVectorStoreProperties extends CommonVectorStoreProperties {
4546

4647
private boolean schemaValidation = PgVectorStore.DEFAULT_SCHEMA_VALIDATION;
4748

49+
private int maxDocumentBatchSize = PgVectorStore.MAX_DOCUMENT_BATCH_SIZE;
50+
4851
public int getDimensions() {
4952
return dimensions;
5053
}
@@ -101,4 +104,12 @@ public void setSchemaValidation(boolean schemaValidation) {
101104
this.schemaValidation = schemaValidation;
102105
}
103106

107+
public int getMaxDocumentBatchSize() {
108+
return this.maxDocumentBatchSize;
109+
}
110+
111+
public void setMaxDocumentBatchSize(int maxDocumentBatchSize) {
112+
this.maxDocumentBatchSize = maxDocumentBatchSize;
113+
}
114+
104115
}

vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java

Lines changed: 71 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,10 @@
1515
*/
1616
package org.springframework.ai.vectorstore;
1717

18-
import java.sql.PreparedStatement;
19-
import java.sql.ResultSet;
20-
import java.sql.SQLException;
21-
import java.util.List;
22-
import java.util.Map;
23-
import java.util.Optional;
24-
import java.util.UUID;
25-
18+
import com.fasterxml.jackson.core.JsonProcessingException;
19+
import com.fasterxml.jackson.databind.ObjectMapper;
20+
import com.pgvector.PGvector;
21+
import io.micrometer.observation.ObservationRegistry;
2622
import org.postgresql.util.PGobject;
2723
import org.slf4j.Logger;
2824
import org.slf4j.LoggerFactory;
@@ -46,11 +42,14 @@
4642
import org.springframework.lang.Nullable;
4743
import org.springframework.util.StringUtils;
4844

49-
import com.fasterxml.jackson.core.JsonProcessingException;
50-
import com.fasterxml.jackson.databind.ObjectMapper;
51-
import com.pgvector.PGvector;
52-
53-
import io.micrometer.observation.ObservationRegistry;
45+
import java.sql.PreparedStatement;
46+
import java.sql.ResultSet;
47+
import java.sql.SQLException;
48+
import java.util.ArrayList;
49+
import java.util.List;
50+
import java.util.Map;
51+
import java.util.Optional;
52+
import java.util.UUID;
5453

5554
/**
5655
* Uses the "vector_store" table to store the Spring AI vector data. The table and the
@@ -81,6 +80,8 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini
8180

8281
public final FilterExpressionConverter filterExpressionConverter = new PgVectorFilterExpressionConverter();
8382

83+
public static final int MAX_DOCUMENT_BATCH_SIZE = 10_000;
84+
8485
private final String vectorTableName;
8586

8687
private final String vectorIndexName;
@@ -109,6 +110,8 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini
109110

110111
private final BatchingStrategy batchingStrategy;
111112

113+
private final int maxDocumentBatchSize;
114+
112115
public PgVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) {
113116
this(jdbcTemplate, embeddingModel, INVALID_EMBEDDING_DIMENSION, PgDistanceType.COSINE_DISTANCE, false,
114117
PgIndexType.NONE, false);
@@ -132,7 +135,6 @@ public PgVectorStore(String vectorTableName, JdbcTemplate jdbcTemplate, Embeddin
132135

133136
this(DEFAULT_SCHEMA_NAME, vectorTableName, DEFAULT_SCHEMA_VALIDATION, jdbcTemplate, embeddingModel, dimensions,
134137
distanceType, removeExistingVectorStoreTable, createIndexMethod, initializeSchema);
135-
136138
}
137139

138140
private PgVectorStore(String schemaName, String vectorTableName, boolean vectorTableValidationsEnabled,
@@ -141,14 +143,14 @@ private PgVectorStore(String schemaName, String vectorTableName, boolean vectorT
141143

142144
this(schemaName, vectorTableName, vectorTableValidationsEnabled, jdbcTemplate, embeddingModel, dimensions,
143145
distanceType, removeExistingVectorStoreTable, createIndexMethod, initializeSchema,
144-
ObservationRegistry.NOOP, null, new TokenCountBatchingStrategy());
146+
ObservationRegistry.NOOP, null, new TokenCountBatchingStrategy(), MAX_DOCUMENT_BATCH_SIZE);
145147
}
146148

147149
private PgVectorStore(String schemaName, String vectorTableName, boolean vectorTableValidationsEnabled,
148150
JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, int dimensions, PgDistanceType distanceType,
149151
boolean removeExistingVectorStoreTable, PgIndexType createIndexMethod, boolean initializeSchema,
150152
ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention,
151-
BatchingStrategy batchingStrategy) {
153+
BatchingStrategy batchingStrategy, int maxDocumentBatchSize) {
152154

153155
super(observationRegistry, customObservationConvention);
154156

@@ -172,6 +174,7 @@ private PgVectorStore(String schemaName, String vectorTableName, boolean vectorT
172174
this.initializeSchema = initializeSchema;
173175
this.schemaValidator = new PgVectorSchemaValidator(jdbcTemplate);
174176
this.batchingStrategy = batchingStrategy;
177+
this.maxDocumentBatchSize = maxDocumentBatchSize;
175178
}
176179

177180
public PgDistanceType getDistanceType() {
@@ -180,40 +183,50 @@ public PgDistanceType getDistanceType() {
180183

181184
@Override
182185
public void doAdd(List<Document> documents) {
186+
this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);
183187

184-
int size = documents.size();
188+
List<List<Document>> batchedDocuments = batchDocuments(documents);
189+
batchedDocuments.forEach(this::insertOrUpdateBatch);
190+
}
185191

186-
this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);
192+
private List<List<Document>> batchDocuments(List<Document> documents) {
193+
List<List<Document>> batches = new ArrayList<>();
194+
for (int i = 0; i < documents.size(); i += this.maxDocumentBatchSize) {
195+
batches.add(documents.subList(i, Math.min(i + this.maxDocumentBatchSize, documents.size())));
196+
}
197+
return batches;
198+
}
187199

188-
this.jdbcTemplate.batchUpdate(
189-
"INSERT INTO " + getFullyQualifiedTableName()
190-
+ " (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) " + "ON CONFLICT (id) DO "
191-
+ "UPDATE SET content = ? , metadata = ?::jsonb , embedding = ? ",
192-
new BatchPreparedStatementSetter() {
193-
@Override
194-
public void setValues(PreparedStatement ps, int i) throws SQLException {
195-
196-
var document = documents.get(i);
197-
var content = document.getContent();
198-
var json = toJson(document.getMetadata());
199-
var embedding = document.getEmbedding();
200-
var pGvector = new PGvector(embedding);
201-
202-
StatementCreatorUtils.setParameterValue(ps, 1, SqlTypeValue.TYPE_UNKNOWN,
203-
UUID.fromString(document.getId()));
204-
StatementCreatorUtils.setParameterValue(ps, 2, SqlTypeValue.TYPE_UNKNOWN, content);
205-
StatementCreatorUtils.setParameterValue(ps, 3, SqlTypeValue.TYPE_UNKNOWN, json);
206-
StatementCreatorUtils.setParameterValue(ps, 4, SqlTypeValue.TYPE_UNKNOWN, pGvector);
207-
StatementCreatorUtils.setParameterValue(ps, 5, SqlTypeValue.TYPE_UNKNOWN, content);
208-
StatementCreatorUtils.setParameterValue(ps, 6, SqlTypeValue.TYPE_UNKNOWN, json);
209-
StatementCreatorUtils.setParameterValue(ps, 7, SqlTypeValue.TYPE_UNKNOWN, pGvector);
210-
}
200+
private void insertOrUpdateBatch(List<Document> batch) {
201+
String sql = "INSERT INTO " + getFullyQualifiedTableName()
202+
+ " (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) " + "ON CONFLICT (id) DO "
203+
+ "UPDATE SET content = ? , metadata = ?::jsonb , embedding = ? ";
204+
205+
this.jdbcTemplate.batchUpdate(sql, new BatchPreparedStatementSetter() {
206+
@Override
207+
public void setValues(PreparedStatement ps, int i) throws SQLException {
208+
209+
var document = batch.get(i);
210+
var content = document.getContent();
211+
var json = toJson(document.getMetadata());
212+
var embedding = document.getEmbedding();
213+
var pGvector = new PGvector(embedding);
214+
215+
StatementCreatorUtils.setParameterValue(ps, 1, SqlTypeValue.TYPE_UNKNOWN,
216+
UUID.fromString(document.getId()));
217+
StatementCreatorUtils.setParameterValue(ps, 2, SqlTypeValue.TYPE_UNKNOWN, content);
218+
StatementCreatorUtils.setParameterValue(ps, 3, SqlTypeValue.TYPE_UNKNOWN, json);
219+
StatementCreatorUtils.setParameterValue(ps, 4, SqlTypeValue.TYPE_UNKNOWN, pGvector);
220+
StatementCreatorUtils.setParameterValue(ps, 5, SqlTypeValue.TYPE_UNKNOWN, content);
221+
StatementCreatorUtils.setParameterValue(ps, 6, SqlTypeValue.TYPE_UNKNOWN, json);
222+
StatementCreatorUtils.setParameterValue(ps, 7, SqlTypeValue.TYPE_UNKNOWN, pGvector);
223+
}
211224

212-
@Override
213-
public int getBatchSize() {
214-
return size;
215-
}
216-
});
225+
@Override
226+
public int getBatchSize() {
227+
return batch.size();
228+
}
229+
});
217230
}
218231

219232
private String toJson(Map<String, Object> map) {
@@ -285,7 +298,7 @@ private String comparisonOperator() {
285298
// Initialize
286299
// ---------------------------------------------------------------------------------
287300
@Override
288-
public void afterPropertiesSet() throws Exception {
301+
public void afterPropertiesSet() {
289302

290303
logger.info("Initializing PGVectorStore schema for table: {} in schema: {}", this.getVectorTableName(),
291304
this.getSchemaName());
@@ -390,7 +403,7 @@ public enum PgIndexType {
390403
* speed-recall tradeoff). There’s no training step like IVFFlat, so the index can
391404
* be created without any data in the table.
392405
*/
393-
HNSW;
406+
HNSW
394407

395408
}
396409

@@ -443,7 +456,7 @@ private static class DocumentRowMapper implements RowMapper<Document> {
443456

444457
private static final String COLUMN_DISTANCE = "distance";
445458

446-
private ObjectMapper objectMapper;
459+
private final ObjectMapper objectMapper;
447460

448461
public DocumentRowMapper(ObjectMapper objectMapper) {
449462
this.objectMapper = objectMapper;
@@ -509,6 +522,8 @@ public static class Builder {
509522

510523
private BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy();
511524

525+
private int maxDocumentBatchSize = MAX_DOCUMENT_BATCH_SIZE;
526+
512527
@Nullable
513528
private VectorStoreObservationConvention searchObservationConvention;
514529

@@ -576,11 +591,17 @@ public Builder withBatchingStrategy(BatchingStrategy batchingStrategy) {
576591
return this;
577592
}
578593

594+
public Builder withMaxDocumentBatchSize(int maxDocumentBatchSize) {
595+
this.maxDocumentBatchSize = maxDocumentBatchSize;
596+
return this;
597+
}
598+
579599
public PgVectorStore build() {
580600
return new PgVectorStore(this.schemaName, this.vectorTableName, this.vectorTableValidationsEnabled,
581601
this.jdbcTemplate, this.embeddingModel, this.dimensions, this.distanceType,
582602
this.removeExistingVectorStoreTable, this.indexType, this.initializeSchema,
583-
this.observationRegistry, this.searchObservationConvention, this.batchingStrategy);
603+
this.observationRegistry, this.searchObservationConvention, this.batchingStrategy,
604+
this.maxDocumentBatchSize);
584605
}
585606

586607
}

vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreTests.java

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,31 @@
1515
*/
1616
package org.springframework.ai.vectorstore;
1717

18+
import org.junit.jupiter.api.Test;
1819
import org.junit.jupiter.params.ParameterizedTest;
1920
import org.junit.jupiter.params.provider.CsvSource;
21+
import org.mockito.ArgumentCaptor;
2022

2123
import static org.assertj.core.api.Assertions.assertThat;
24+
import static org.mockito.ArgumentMatchers.any;
25+
import static org.mockito.ArgumentMatchers.anyString;
26+
import static org.mockito.ArgumentMatchers.eq;
27+
import static org.mockito.Mockito.mock;
28+
import static org.mockito.Mockito.only;
29+
import static org.mockito.Mockito.times;
30+
import static org.mockito.Mockito.verify;
31+
32+
import java.util.Collections;
33+
34+
import org.springframework.ai.document.Document;
35+
import org.springframework.ai.embedding.EmbeddingModel;
36+
import org.springframework.jdbc.core.BatchPreparedStatementSetter;
37+
import org.springframework.jdbc.core.JdbcTemplate;
2238

2339
/**
2440
* @author Muthukumaran Navaneethakrishnan
41+
* @author Soby Chacko
2542
*/
26-
2743
public class PgVectorStoreTests {
2844

2945
@ParameterizedTest(name = "{0} - Verifies valid Table name")
@@ -53,8 +69,39 @@ public class PgVectorStoreTests {
5369
// 64
5470
// characters
5571
})
56-
public void isValidTable(String tableName, Boolean expected) {
72+
void isValidTable(String tableName, Boolean expected) {
5773
assertThat(PgVectorSchemaValidator.isValidNameForDatabaseObject(tableName)).isEqualTo(expected);
5874
}
5975

76+
@Test
77+
void shouldAddDocumentsInBatchesAndEmbedOnce() {
78+
// Given
79+
var jdbcTemplate = mock(JdbcTemplate.class);
80+
var embeddingModel = mock(EmbeddingModel.class);
81+
var pgVectorStore = new PgVectorStore.Builder(jdbcTemplate, embeddingModel).withMaxDocumentBatchSize(1000)
82+
.build();
83+
84+
// Testing with 9989 documents
85+
var documents = Collections.nCopies(9989, new Document("foo"));
86+
87+
// When
88+
pgVectorStore.doAdd(documents);
89+
90+
// Then
91+
verify(embeddingModel, only()).embed(eq(documents), any(), any());
92+
93+
var batchUpdateCaptor = ArgumentCaptor.forClass(BatchPreparedStatementSetter.class);
94+
verify(jdbcTemplate, times(10)).batchUpdate(anyString(), batchUpdateCaptor.capture());
95+
96+
assertThat(batchUpdateCaptor.getAllValues()).hasSize(10)
97+
.allSatisfy(BatchPreparedStatementSetter::getBatchSize)
98+
.satisfies(batches -> {
99+
for (int i = 0; i < 9; i++) {
100+
assertThat(batches.get(i).getBatchSize()).as("Batch at index %d should have size 10", i)
101+
.isEqualTo(1000);
102+
}
103+
assertThat(batches.get(9).getBatchSize()).as("Last batch should have size 989").isEqualTo(989);
104+
});
105+
}
106+
60107
}

0 commit comments

Comments
 (0)