Skip to content

Commit c979d1c

Browse files
committed
CosmosDB vector store auto configuration changes
- Configurable BatchingStrategy via auto configuraiton - Minor code cleanup
1 parent 745e718 commit c979d1c

File tree

3 files changed

+23
-11
lines changed

3 files changed

+23
-11
lines changed

spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/cosmosdb/CosmosDBVectorStoreAutoConfiguration.java

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
package org.springframework.ai.autoconfigure.vectorstore.cosmosdb;
1818

1919
import com.azure.cosmos.CosmosClientBuilder;
20+
21+
import org.springframework.ai.embedding.BatchingStrategy;
2022
import org.springframework.ai.embedding.EmbeddingModel;
23+
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
2124
import org.springframework.ai.vectorstore.CosmosDBVectorStore;
2225
import org.springframework.ai.vectorstore.CosmosDBVectorStoreConfig;
2326
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
@@ -32,9 +35,9 @@
3235

3336
/**
3437
* @author Theo van Kraay
38+
* @author Soby Chacko
3539
* @since 1.0.0
3640
*/
37-
3841
@AutoConfiguration
3942
@ConditionalOnClass({ CosmosDBVectorStore.class, EmbeddingModel.class, CosmosAsyncClient.class })
4043
@EnableConfigurationProperties(CosmosDBVectorStoreProperties.class)
@@ -53,12 +56,18 @@ public CosmosAsyncClient cosmosClient(CosmosDBVectorStoreProperties properties)
5356
.buildAsyncClient();
5457
}
5558

59+
@Bean
60+
@ConditionalOnMissingBean(BatchingStrategy.class)
61+
BatchingStrategy batchingStrategy() {
62+
return new TokenCountBatchingStrategy();
63+
}
64+
5665
@Bean
5766
@ConditionalOnMissingBean
5867
public CosmosDBVectorStore cosmosDBVectorStore(ObservationRegistry observationRegistry,
5968
ObjectProvider<VectorStoreObservationConvention> customObservationConvention,
6069
CosmosDBVectorStoreProperties properties, CosmosAsyncClient cosmosAsyncClient,
61-
EmbeddingModel embeddingModel) {
70+
EmbeddingModel embeddingModel, BatchingStrategy batchingStrategy) {
6271

6372
CosmosDBVectorStoreConfig config = new CosmosDBVectorStoreConfig();
6473
config.setDatabaseName(properties.getDatabaseName());
@@ -67,7 +76,7 @@ public CosmosDBVectorStore cosmosDBVectorStore(ObservationRegistry observationRe
6776
config.setVectorStoreThoughput(properties.getVectorStoreThoughput());
6877
config.setVectorDimensions(properties.getVectorDimensions());
6978
return new CosmosDBVectorStore(observationRegistry, customObservationConvention.getIfAvailable(),
70-
cosmosAsyncClient, config, embeddingModel);
79+
cosmosAsyncClient, config, embeddingModel, batchingStrategy);
7180
}
7281

7382
}

vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBVectorStore.java

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@
4545

4646
/**
4747
* @author Theo van Kraay
48+
* @author Soby Chacko
4849
* @since 1.0.0
4950
*/
50-
5151
public class CosmosDBVectorStore extends AbstractObservationVectorStore implements AutoCloseable {
5252

5353
private static final Logger logger = LoggerFactory.getLogger(CosmosDBVectorStore.class);
@@ -65,18 +65,24 @@ public class CosmosDBVectorStore extends AbstractObservationVectorStore implemen
6565
public CosmosDBVectorStore(ObservationRegistry observationRegistry,
6666
VectorStoreObservationConvention customObservationConvention, CosmosAsyncClient cosmosClient,
6767
CosmosDBVectorStoreConfig properties, EmbeddingModel embeddingModel) {
68+
this(observationRegistry, customObservationConvention, cosmosClient, properties, embeddingModel,
69+
new TokenCountBatchingStrategy());
70+
}
71+
72+
public CosmosDBVectorStore(ObservationRegistry observationRegistry,
73+
VectorStoreObservationConvention customObservationConvention, CosmosAsyncClient cosmosClient,
74+
CosmosDBVectorStoreConfig properties, EmbeddingModel embeddingModel, BatchingStrategy batchingStrategy) {
6875
super(observationRegistry, customObservationConvention);
6976
this.cosmosClient = cosmosClient;
7077
this.properties = properties;
71-
this.batchingStrategy = new TokenCountBatchingStrategy();
78+
this.batchingStrategy = batchingStrategy;
7279
cosmosClient.createDatabaseIfNotExists(properties.getDatabaseName()).block();
7380

7481
initializeContainer(properties.getContainerName(), properties.getDatabaseName(),
7582
properties.getVectorStoreThoughput(), properties.getVectorDimensions(),
7683
properties.getPartitionKeyPath());
7784

7885
this.embeddingModel = embeddingModel;
79-
8086
}
8187

8288
private void initializeContainer(String containerName, String databaseName, int vectorStoreThoughput,
@@ -94,9 +100,7 @@ private void initializeContainer(String containerName, String databaseName, int
94100
PartitionKeyDefinition subpartitionKeyDefinition = new PartitionKeyDefinition();
95101
List<String> pathsfromCommaSeparatedList = new ArrayList<String>();
96102
String[] subpartitionKeyPaths = partitionKeyPath.split(",");
97-
for (String path : subpartitionKeyPaths) {
98-
pathsfromCommaSeparatedList.add(path);
99-
}
103+
Collections.addAll(pathsfromCommaSeparatedList, subpartitionKeyPaths);
100104
if (subpartitionKeyPaths.length > 1) {
101105
subpartitionKeyDefinition.setPaths(pathsfromCommaSeparatedList);
102106
subpartitionKeyDefinition.setKind(PartitionKind.MULTI_HASH);
@@ -180,7 +184,7 @@ public void doAdd(List<Document> documents) {
180184
.getCreateItemOperation(mapCosmosDocument(doc, doc.getEmbedding()), new PartitionKey(doc.getId()));
181185
return new ImmutablePair<>(doc.getId(), operation); // Pair the document ID
182186
// with the operation
183-
}).collect(Collectors.toList());
187+
}).toList();
184188

185189
try {
186190
// Extract just the CosmosItemOperations from the pairs

vector-stores/spring-ai-azure-cosmos-db-store/src/test/java/org/springframework/ai/vectorstore/CosmosDBVectorStoreIT.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
* @author Theo van Kraay
4343
* @since 1.0.0
4444
*/
45-
4645
@EnabledIfEnvironmentVariable(named = "AZURE_COSMOSDB_ENDPOINT", matches = ".+")
4746
@EnabledIfEnvironmentVariable(named = "AZURE_COSMOSDB_KEY", matches = ".+")
4847
public class CosmosDBVectorStoreIT {

0 commit comments

Comments
 (0)