16
16
17
17
package org .springframework .ai .vectorstore ;
18
18
19
- import com . azure . cosmos .* ;
20
- import com . azure . cosmos . implementation . guava25 . collect . ImmutableList ;
21
- import com . azure . cosmos . models .* ;
22
- import com . azure . cosmos . util .CosmosPagedFlux ;
23
- import com . fasterxml . jackson . databind . JsonNode ;
24
- import com . fasterxml . jackson . databind . ObjectMapper ;
25
- import com . fasterxml . jackson . databind . node . ObjectNode ;
26
- import io . micrometer . observation . ObservationRegistry ;
19
+ import java . util . ArrayList ;
20
+ import java . util . Collections ;
21
+ import java . util . HashMap ;
22
+ import java . util .List ;
23
+ import java . util . Optional ;
24
+ import java . util . stream . Collectors ;
25
+ import java . util . stream . IntStream ;
26
+
27
27
import org .apache .commons .lang3 .tuple .ImmutablePair ;
28
- import org .apache .commons .lang3 .tuple .Pair ;
29
28
import org .slf4j .Logger ;
30
29
import org .slf4j .LoggerFactory ;
31
30
import org .springframework .ai .document .Document ;
38
37
import org .springframework .ai .vectorstore .observation .AbstractObservationVectorStore ;
39
38
import org .springframework .ai .vectorstore .observation .VectorStoreObservationContext ;
40
39
import org .springframework .ai .vectorstore .observation .VectorStoreObservationConvention ;
40
+
41
+ import com .azure .cosmos .CosmosAsyncClient ;
42
+ import com .azure .cosmos .CosmosAsyncContainer ;
43
+ import com .azure .cosmos .CosmosAsyncDatabase ;
44
+ import com .azure .cosmos .implementation .guava25 .collect .ImmutableList ;
45
+ import com .azure .cosmos .models .CosmosBulkOperations ;
46
+ import com .azure .cosmos .models .CosmosContainerProperties ;
47
+ import com .azure .cosmos .models .CosmosItemOperation ;
48
+ import com .azure .cosmos .models .CosmosQueryRequestOptions ;
49
+ import com .azure .cosmos .models .CosmosVectorDataType ;
50
+ import com .azure .cosmos .models .CosmosVectorDistanceFunction ;
51
+ import com .azure .cosmos .models .CosmosVectorEmbedding ;
52
+ import com .azure .cosmos .models .CosmosVectorEmbeddingPolicy ;
53
+ import com .azure .cosmos .models .CosmosVectorIndexSpec ;
54
+ import com .azure .cosmos .models .CosmosVectorIndexType ;
55
+ import com .azure .cosmos .models .ExcludedPath ;
56
+ import com .azure .cosmos .models .IncludedPath ;
57
+ import com .azure .cosmos .models .IndexingMode ;
58
+ import com .azure .cosmos .models .IndexingPolicy ;
59
+ import com .azure .cosmos .models .PartitionKey ;
60
+ import com .azure .cosmos .models .PartitionKeyDefinition ;
61
+ import com .azure .cosmos .models .PartitionKind ;
62
+ import com .azure .cosmos .models .SqlParameter ;
63
+ import com .azure .cosmos .models .SqlQuerySpec ;
64
+ import com .azure .cosmos .models .ThroughputProperties ;
65
+ import com .azure .cosmos .util .CosmosPagedFlux ;
66
+ import com .fasterxml .jackson .databind .JsonNode ;
67
+ import com .fasterxml .jackson .databind .ObjectMapper ;
68
+ import com .fasterxml .jackson .databind .node .ObjectNode ;
69
+
70
+ import io .micrometer .observation .ObservationRegistry ;
41
71
import reactor .core .publisher .Flux ;
42
- import java .util .*;
43
- import java .util .stream .Collectors ;
44
- import java .util .stream .IntStream ;
45
72
46
73
/**
47
74
* @author Theo van Kraay
@@ -79,38 +106,38 @@ public CosmosDBVectorStore(ObservationRegistry observationRegistry,
79
106
cosmosClient .createDatabaseIfNotExists (properties .getDatabaseName ()).block ();
80
107
81
108
initializeContainer (properties .getContainerName (), properties .getDatabaseName (),
82
- properties .getVectorStoreThoughput (), properties .getVectorDimensions (),
109
+ properties .getVectorStoreThroughput (), properties .getVectorDimensions (),
83
110
properties .getPartitionKeyPath ());
84
111
85
112
this .embeddingModel = embeddingModel ;
86
113
}
87
114
88
- private void initializeContainer (String containerName , String databaseName , int vectorStoreThoughput ,
115
+ private void initializeContainer (String containerName , String databaseName , int vectorStoreThroughput ,
89
116
long vectorDimensions , String partitionKeyPath ) {
90
117
91
118
// Set defaults if not provided
92
- if (vectorStoreThoughput == 0 ) {
93
- vectorStoreThoughput = 400 ;
119
+ if (vectorStoreThroughput == 0 ) {
120
+ vectorStoreThroughput = 400 ;
94
121
}
95
122
if (partitionKeyPath == null ) {
96
123
partitionKeyPath = "/id" ;
97
124
}
98
125
99
126
// handle hierarchical partition key
100
- PartitionKeyDefinition subpartitionKeyDefinition = new PartitionKeyDefinition ();
101
- List <String > pathsfromCommaSeparatedList = new ArrayList <String >();
102
- String [] subpartitionKeyPaths = partitionKeyPath .split ("," );
103
- Collections .addAll (pathsfromCommaSeparatedList , subpartitionKeyPaths );
104
- if (subpartitionKeyPaths .length > 1 ) {
105
- subpartitionKeyDefinition .setPaths (pathsfromCommaSeparatedList );
106
- subpartitionKeyDefinition .setKind (PartitionKind .MULTI_HASH );
127
+ PartitionKeyDefinition subPartitionKeyDefinition = new PartitionKeyDefinition ();
128
+ List <String > pathsFromCommaSeparatedList = new ArrayList <String >();
129
+ String [] subPartitionKeyPaths = partitionKeyPath .split ("," );
130
+ Collections .addAll (pathsFromCommaSeparatedList , subPartitionKeyPaths );
131
+ if (subPartitionKeyPaths .length > 1 ) {
132
+ subPartitionKeyDefinition .setPaths (pathsFromCommaSeparatedList );
133
+ subPartitionKeyDefinition .setKind (PartitionKind .MULTI_HASH );
107
134
}
108
135
else {
109
- subpartitionKeyDefinition .setPaths (Collections .singletonList (partitionKeyPath ));
110
- subpartitionKeyDefinition .setKind (PartitionKind .HASH );
136
+ subPartitionKeyDefinition .setPaths (Collections .singletonList (partitionKeyPath ));
137
+ subPartitionKeyDefinition .setKind (PartitionKind .HASH );
111
138
}
112
139
CosmosContainerProperties collectionDefinition = new CosmosContainerProperties (containerName ,
113
- subpartitionKeyDefinition );
140
+ subPartitionKeyDefinition );
114
141
// Set vector embedding policy
115
142
CosmosVectorEmbeddingPolicy embeddingPolicy = new CosmosVectorEmbeddingPolicy ();
116
143
CosmosVectorEmbedding embedding = new CosmosVectorEmbedding ();
@@ -135,16 +162,16 @@ private void initializeContainer(String containerName, String databaseName, int
135
162
indexingPolicy .setVectorIndexes (List .of (cosmosVectorIndexSpec ));
136
163
collectionDefinition .setIndexingPolicy (indexingPolicy );
137
164
138
- ThroughputProperties throughputProperties = ThroughputProperties .createManualThroughput (vectorStoreThoughput );
139
- CosmosAsyncDatabase cosmosAsyncDatabase = cosmosClient .getDatabase (databaseName );
165
+ ThroughputProperties throughputProperties = ThroughputProperties .createManualThroughput (vectorStoreThroughput );
166
+ CosmosAsyncDatabase cosmosAsyncDatabase = this . cosmosClient .getDatabase (databaseName );
140
167
cosmosAsyncDatabase .createContainerIfNotExists (collectionDefinition , throughputProperties ).block ();
141
168
this .container = cosmosAsyncDatabase .getContainer (containerName );
142
169
}
143
170
144
171
@ Override
145
172
public void close () {
146
- if (cosmosClient != null ) {
147
- cosmosClient .close ();
173
+ if (this . cosmosClient != null ) {
174
+ this . cosmosClient .close ();
148
175
logger .info ("Cosmos DB client closed successfully." );
149
176
}
150
177
}
@@ -192,7 +219,7 @@ public void doAdd(List<Document> documents) {
192
219
.map (ImmutablePair ::getValue )
193
220
.collect (Collectors .toList ());
194
221
195
- container .executeBulkOperations (Flux .fromIterable (itemOperations )).doOnNext (response -> {
222
+ this . container .executeBulkOperations (Flux .fromIterable (itemOperations )).doOnNext (response -> {
196
223
if (response != null && response .getResponse () != null ) {
197
224
int statusCode = response .getResponse ().getStatusCode ();
198
225
if (statusCode == 409 ) {
@@ -236,7 +263,7 @@ public Optional<Boolean> doDelete(List<String> idList) {
236
263
237
264
// Execute bulk delete operations synchronously by using blockLast() on the
238
265
// Flux
239
- container .executeBulkOperations (Flux .fromIterable (itemOperations ))
266
+ this . container .executeBulkOperations (Flux .fromIterable (itemOperations ))
240
267
.doOnNext (response -> logger .info ("Document deleted with status: {}" ,
241
268
response .getResponse ().getStatusCode ()))
242
269
.doOnError (error -> logger .error ("Error deleting document: {}" , error .getMessage ()))
@@ -279,9 +306,11 @@ public List<Document> doSimilaritySearch(SearchRequest request) {
279
306
Filter .Expression filterExpression = request .getFilterExpression ();
280
307
if (filterExpression != null ) {
281
308
CosmosDBFilterExpressionConverter filterExpressionConverter = new CosmosDBFilterExpressionConverter (
282
- properties .getMetadataFieldsList ()); // Use the expression directly as
283
- // it handles the "metadata"
284
- // fields internally
309
+ this .properties .getMetadataFieldsList ()); // Use the expression
310
+ // directly as
311
+ // it handles the
312
+ // "metadata"
313
+ // fields internally
285
314
String filterQuery = filterExpressionConverter .convertExpression (filterExpression );
286
315
queryBuilder .append (" AND " ).append (filterQuery );
287
316
}
@@ -297,7 +326,7 @@ public List<Document> doSimilaritySearch(SearchRequest request) {
297
326
SqlQuerySpec sqlQuerySpec = new SqlQuerySpec (query , parameters );
298
327
CosmosQueryRequestOptions options = new CosmosQueryRequestOptions ();
299
328
300
- CosmosPagedFlux <JsonNode > pagedFlux = container .queryItems (sqlQuerySpec , options , JsonNode .class );
329
+ CosmosPagedFlux <JsonNode > pagedFlux = this . container .queryItems (sqlQuerySpec , options , JsonNode .class );
301
330
302
331
logger .info ("Executing similarity search query: {}" , query );
303
332
try {
@@ -322,9 +351,9 @@ public List<Document> doSimilaritySearch(SearchRequest request) {
322
351
@ Override
323
352
public VectorStoreObservationContext .Builder createObservationContextBuilder (String operationName ) {
324
353
return VectorStoreObservationContext .builder (VectorStoreProvider .COSMOSDB .value (), operationName )
325
- .withCollectionName (container .getId ())
354
+ .withCollectionName (this . container .getId ())
326
355
.withDimensions (this .embeddingModel .dimensions ())
327
- .withNamespace (container .getDatabase ().getId ())
356
+ .withNamespace (this . container .getDatabase ().getId ())
328
357
.withSimilarityMetric ("cosine" );
329
358
}
330
359
0 commit comments