15
15
*/
16
16
package org .springframework .ai .vectorstore ;
17
17
18
- import java .sql .PreparedStatement ;
19
- import java .sql .ResultSet ;
20
- import java .sql .SQLException ;
21
- import java .util .List ;
22
- import java .util .Map ;
23
- import java .util .Optional ;
24
- import java .util .UUID ;
25
-
18
+ import com .fasterxml .jackson .core .JsonProcessingException ;
19
+ import com .fasterxml .jackson .databind .ObjectMapper ;
20
+ import com .pgvector .PGvector ;
21
+ import io .micrometer .observation .ObservationRegistry ;
26
22
import org .postgresql .util .PGobject ;
27
23
import org .slf4j .Logger ;
28
24
import org .slf4j .LoggerFactory ;
46
42
import org .springframework .lang .Nullable ;
47
43
import org .springframework .util .StringUtils ;
48
44
49
- import com .fasterxml .jackson .core .JsonProcessingException ;
50
- import com .fasterxml .jackson .databind .ObjectMapper ;
51
- import com .pgvector .PGvector ;
52
-
53
- import io .micrometer .observation .ObservationRegistry ;
45
+ import java .sql .PreparedStatement ;
46
+ import java .sql .ResultSet ;
47
+ import java .sql .SQLException ;
48
+ import java .util .ArrayList ;
49
+ import java .util .List ;
50
+ import java .util .Map ;
51
+ import java .util .Optional ;
52
+ import java .util .UUID ;
54
53
55
54
/**
56
55
* Uses the "vector_store" table to store the Spring AI vector data. The table and the
@@ -81,6 +80,8 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini
81
80
82
81
public final FilterExpressionConverter filterExpressionConverter = new PgVectorFilterExpressionConverter ();
83
82
83
+ public static final int MAX_DOCUMENT_BATCH_SIZE = 10_000 ;
84
+
84
85
private final String vectorTableName ;
85
86
86
87
private final String vectorIndexName ;
@@ -109,6 +110,8 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini
109
110
110
111
private final BatchingStrategy batchingStrategy ;
111
112
113
+ private final int maxDocumentBatchSize ;
114
+
112
115
public PgVectorStore (JdbcTemplate jdbcTemplate , EmbeddingModel embeddingModel ) {
113
116
this (jdbcTemplate , embeddingModel , INVALID_EMBEDDING_DIMENSION , PgDistanceType .COSINE_DISTANCE , false ,
114
117
PgIndexType .NONE , false );
@@ -132,7 +135,6 @@ public PgVectorStore(String vectorTableName, JdbcTemplate jdbcTemplate, Embeddin
132
135
133
136
this (DEFAULT_SCHEMA_NAME , vectorTableName , DEFAULT_SCHEMA_VALIDATION , jdbcTemplate , embeddingModel , dimensions ,
134
137
distanceType , removeExistingVectorStoreTable , createIndexMethod , initializeSchema );
135
-
136
138
}
137
139
138
140
private PgVectorStore (String schemaName , String vectorTableName , boolean vectorTableValidationsEnabled ,
@@ -141,14 +143,14 @@ private PgVectorStore(String schemaName, String vectorTableName, boolean vectorT
141
143
142
144
this (schemaName , vectorTableName , vectorTableValidationsEnabled , jdbcTemplate , embeddingModel , dimensions ,
143
145
distanceType , removeExistingVectorStoreTable , createIndexMethod , initializeSchema ,
144
- ObservationRegistry .NOOP , null , new TokenCountBatchingStrategy ());
146
+ ObservationRegistry .NOOP , null , new TokenCountBatchingStrategy (), MAX_DOCUMENT_BATCH_SIZE );
145
147
}
146
148
147
149
private PgVectorStore (String schemaName , String vectorTableName , boolean vectorTableValidationsEnabled ,
148
150
JdbcTemplate jdbcTemplate , EmbeddingModel embeddingModel , int dimensions , PgDistanceType distanceType ,
149
151
boolean removeExistingVectorStoreTable , PgIndexType createIndexMethod , boolean initializeSchema ,
150
152
ObservationRegistry observationRegistry , VectorStoreObservationConvention customObservationConvention ,
151
- BatchingStrategy batchingStrategy ) {
153
+ BatchingStrategy batchingStrategy , int maxDocumentBatchSize ) {
152
154
153
155
super (observationRegistry , customObservationConvention );
154
156
@@ -172,6 +174,7 @@ private PgVectorStore(String schemaName, String vectorTableName, boolean vectorT
172
174
this .initializeSchema = initializeSchema ;
173
175
this .schemaValidator = new PgVectorSchemaValidator (jdbcTemplate );
174
176
this .batchingStrategy = batchingStrategy ;
177
+ this .maxDocumentBatchSize = maxDocumentBatchSize ;
175
178
}
176
179
177
180
public PgDistanceType getDistanceType () {
@@ -180,40 +183,50 @@ public PgDistanceType getDistanceType() {
180
183
181
184
@ Override
182
185
public void doAdd (List <Document > documents ) {
186
+ this .embeddingModel .embed (documents , EmbeddingOptionsBuilder .builder ().build (), this .batchingStrategy );
183
187
184
- int size = documents .size ();
188
+ List <List <Document >> batchedDocuments = batchDocuments (documents );
189
+ batchedDocuments .forEach (this ::insertOrUpdateBatch );
190
+ }
185
191
186
- this .embeddingModel .embed (documents , EmbeddingOptionsBuilder .builder ().build (), this .batchingStrategy );
192
+ private List <List <Document >> batchDocuments (List <Document > documents ) {
193
+ List <List <Document >> batches = new ArrayList <>();
194
+ for (int i = 0 ; i < documents .size (); i += this .maxDocumentBatchSize ) {
195
+ batches .add (documents .subList (i , Math .min (i + this .maxDocumentBatchSize , documents .size ())));
196
+ }
197
+ return batches ;
198
+ }
187
199
188
- this .jdbcTemplate .batchUpdate (
189
- "INSERT INTO " + getFullyQualifiedTableName ()
190
- + " (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) " + "ON CONFLICT (id) DO "
191
- + "UPDATE SET content = ? , metadata = ?::jsonb , embedding = ? " ,
192
- new BatchPreparedStatementSetter () {
193
- @ Override
194
- public void setValues (PreparedStatement ps , int i ) throws SQLException {
195
-
196
- var document = documents .get (i );
197
- var content = document .getContent ();
198
- var json = toJson (document .getMetadata ());
199
- var embedding = document .getEmbedding ();
200
- var pGvector = new PGvector (embedding );
201
-
202
- StatementCreatorUtils .setParameterValue (ps , 1 , SqlTypeValue .TYPE_UNKNOWN ,
203
- UUID .fromString (document .getId ()));
204
- StatementCreatorUtils .setParameterValue (ps , 2 , SqlTypeValue .TYPE_UNKNOWN , content );
205
- StatementCreatorUtils .setParameterValue (ps , 3 , SqlTypeValue .TYPE_UNKNOWN , json );
206
- StatementCreatorUtils .setParameterValue (ps , 4 , SqlTypeValue .TYPE_UNKNOWN , pGvector );
207
- StatementCreatorUtils .setParameterValue (ps , 5 , SqlTypeValue .TYPE_UNKNOWN , content );
208
- StatementCreatorUtils .setParameterValue (ps , 6 , SqlTypeValue .TYPE_UNKNOWN , json );
209
- StatementCreatorUtils .setParameterValue (ps , 7 , SqlTypeValue .TYPE_UNKNOWN , pGvector );
210
- }
200
+ private void insertOrUpdateBatch (List <Document > batch ) {
201
+ String sql = "INSERT INTO " + getFullyQualifiedTableName ()
202
+ + " (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) " + "ON CONFLICT (id) DO "
203
+ + "UPDATE SET content = ? , metadata = ?::jsonb , embedding = ? " ;
204
+
205
+ this .jdbcTemplate .batchUpdate (sql , new BatchPreparedStatementSetter () {
206
+ @ Override
207
+ public void setValues (PreparedStatement ps , int i ) throws SQLException {
208
+
209
+ var document = batch .get (i );
210
+ var content = document .getContent ();
211
+ var json = toJson (document .getMetadata ());
212
+ var embedding = document .getEmbedding ();
213
+ var pGvector = new PGvector (embedding );
214
+
215
+ StatementCreatorUtils .setParameterValue (ps , 1 , SqlTypeValue .TYPE_UNKNOWN ,
216
+ UUID .fromString (document .getId ()));
217
+ StatementCreatorUtils .setParameterValue (ps , 2 , SqlTypeValue .TYPE_UNKNOWN , content );
218
+ StatementCreatorUtils .setParameterValue (ps , 3 , SqlTypeValue .TYPE_UNKNOWN , json );
219
+ StatementCreatorUtils .setParameterValue (ps , 4 , SqlTypeValue .TYPE_UNKNOWN , pGvector );
220
+ StatementCreatorUtils .setParameterValue (ps , 5 , SqlTypeValue .TYPE_UNKNOWN , content );
221
+ StatementCreatorUtils .setParameterValue (ps , 6 , SqlTypeValue .TYPE_UNKNOWN , json );
222
+ StatementCreatorUtils .setParameterValue (ps , 7 , SqlTypeValue .TYPE_UNKNOWN , pGvector );
223
+ }
211
224
212
- @ Override
213
- public int getBatchSize () {
214
- return size ;
215
- }
216
- });
225
+ @ Override
226
+ public int getBatchSize () {
227
+ return batch . size () ;
228
+ }
229
+ });
217
230
}
218
231
219
232
private String toJson (Map <String , Object > map ) {
@@ -285,7 +298,7 @@ private String comparisonOperator() {
285
298
// Initialize
286
299
// ---------------------------------------------------------------------------------
287
300
@ Override
288
- public void afterPropertiesSet () throws Exception {
301
+ public void afterPropertiesSet () {
289
302
290
303
logger .info ("Initializing PGVectorStore schema for table: {} in schema: {}" , this .getVectorTableName (),
291
304
this .getSchemaName ());
@@ -390,7 +403,7 @@ public enum PgIndexType {
390
403
* speed-recall tradeoff). There’s no training step like IVFFlat, so the index can
391
404
* be created without any data in the table.
392
405
*/
393
- HNSW ;
406
+ HNSW
394
407
395
408
}
396
409
@@ -443,7 +456,7 @@ private static class DocumentRowMapper implements RowMapper<Document> {
443
456
444
457
private static final String COLUMN_DISTANCE = "distance" ;
445
458
446
- private ObjectMapper objectMapper ;
459
+ private final ObjectMapper objectMapper ;
447
460
448
461
public DocumentRowMapper (ObjectMapper objectMapper ) {
449
462
this .objectMapper = objectMapper ;
@@ -509,6 +522,8 @@ public static class Builder {
509
522
510
523
private BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy ();
511
524
525
+ private int maxDocumentBatchSize = MAX_DOCUMENT_BATCH_SIZE ;
526
+
512
527
@ Nullable
513
528
private VectorStoreObservationConvention searchObservationConvention ;
514
529
@@ -576,11 +591,17 @@ public Builder withBatchingStrategy(BatchingStrategy batchingStrategy) {
576
591
return this ;
577
592
}
578
593
594
+ public Builder withMaxDocumentBatchSize (int maxDocumentBatchSize ) {
595
+ this .maxDocumentBatchSize = maxDocumentBatchSize ;
596
+ return this ;
597
+ }
598
+
579
599
public PgVectorStore build () {
580
600
return new PgVectorStore (this .schemaName , this .vectorTableName , this .vectorTableValidationsEnabled ,
581
601
this .jdbcTemplate , this .embeddingModel , this .dimensions , this .distanceType ,
582
602
this .removeExistingVectorStoreTable , this .indexType , this .initializeSchema ,
583
- this .observationRegistry , this .searchObservationConvention , this .batchingStrategy );
603
+ this .observationRegistry , this .searchObservationConvention , this .batchingStrategy ,
604
+ this .maxDocumentBatchSize );
584
605
}
585
606
586
607
}
0 commit comments