Skip to content

Commit 52999b4

Browse files
tzolovmarkpollack
authored andcommitted
Align with the latest Chroma API
- Added Jackson annotations for JSON property handling in Chroma API request/response records. Use @JsonInclude(JsonInclude.Include.NON_NULL) to ignore empty fields such as null where. Use @JsonProperty(XYZ) to provide a strong contract with the Cohere API. - Updated the Docker image for running Chroma locally and in tests to version 0.5.20. - Enhanced methods with non-null assertions for improved code safety. - Added a simple search test to verify document similarity search functionality. Fixes #1749
1 parent 1abfd9a commit 52999b4

File tree

7 files changed

+105
-54
lines changed

7 files changed

+105
-54
lines changed

spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/chroma.adoc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ If all goes well, you should retrieve the document containing the text "Spring A
265265
=== Run Chroma Locally
266266

267267
```shell
268-
docker run -it --rm --name chroma -p 8000:8000 ghcr.io/chroma-core/chroma:0.4.15
268+
docker run -it --rm --name chroma -p 8000:8000 ghcr.io/chroma-core/chroma:0.5.20
269269
```
270270

271271
Starts a chroma store at <http://localhost:8000/api/v1>

spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfigurationIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
public class ChromaVectorStoreAutoConfigurationIT {
5656

5757
@Container
58-
static ChromaDBContainer chroma = new ChromaDBContainer("ghcr.io/chroma-core/chroma:0.5.0");
58+
static ChromaDBContainer chroma = new ChromaDBContainer("ghcr.io/chroma-core/chroma:0.5.20");
5959

6060
private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
6161
.withConfiguration(AutoConfigurations.of(ChromaVectorStoreAutoConfiguration.class))

spring-ai-spring-boot-testcontainers/src/test/java/org/springframework/ai/testcontainers/service/connection/chroma/ChromaImage.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
*/
2424
public final class ChromaImage {
2525

26-
public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ghcr.io/chroma-core/chroma:0.5.11");
26+
public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ghcr.io/chroma-core/chroma:0.5.20");
2727

2828
private ChromaImage() {
2929

vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/ChromaApi.java

Lines changed: 63 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import java.util.regex.Matcher;
2424
import java.util.regex.Pattern;
2525

26+
import com.fasterxml.jackson.annotation.JsonInclude;
2627
import com.fasterxml.jackson.annotation.JsonProperty;
2728
import com.fasterxml.jackson.core.JsonProcessingException;
2829
import com.fasterxml.jackson.databind.ObjectMapper;
@@ -123,17 +124,6 @@ public Collection createCollection(CreateCollectionRequest createCollectionReque
123124
.getBody();
124125
}
125126

126-
public Map<String, Object> createCollection2(CreateCollectionRequest createCollectionRequest) {
127-
128-
return this.restClient.post()
129-
.uri("/api/v1/collections")
130-
.headers(this::httpHeaders)
131-
.body(createCollectionRequest)
132-
.retrieve()
133-
.toEntity(Map.class)
134-
.getBody();
135-
}
136-
137127
/**
138128
* Delete a collection with the given name.
139129
* @param collectionName the name of the collection to delete.
@@ -281,7 +271,11 @@ private String getErrorMessage(HttpStatusCodeException e) {
281271
* @param name The name of the collection.
282272
* @param metadata Metadata associated with the collection.
283273
*/
284-
public record Collection(String id, String name, Map<String, Object> metadata) {
274+
@JsonInclude(JsonInclude.Include.NON_NULL)
275+
public record Collection(// @formatter:off
276+
@JsonProperty("id") String id,
277+
@JsonProperty("name") String name,
278+
@JsonProperty("metadata") Map<String, Object> metadata) { // @formatter:on
285279

286280
}
287281

@@ -291,7 +285,10 @@ public record Collection(String id, String name, Map<String, Object> metadata) {
291285
* @param name The name of the collection to create.
292286
* @param metadata Optional metadata to associate with the collection.
293287
*/
294-
public record CreateCollectionRequest(String name, Map<String, Object> metadata) {
288+
@JsonInclude(JsonInclude.Include.NON_NULL)
289+
public record CreateCollectionRequest(// @formatter:off
290+
@JsonProperty("name") String name,
291+
@JsonProperty("metadata") Map<String, Object> metadata) {// @formatter:on
295292

296293
public CreateCollectionRequest(String name) {
297294
this(name, new HashMap<>(Map.of("hnsw:space", "cosine")));
@@ -300,7 +297,7 @@ public CreateCollectionRequest(String name) {
300297
}
301298

302299
//
303-
// Chroma Collection API (https://docs.trychroma.com/js_reference/Collection)
300+
// Chroma Collection API (https://docs.trychroma.com/reference/js-client/Collection)
304301
//
305302

306303
/**
@@ -312,14 +309,17 @@ public CreateCollectionRequest(String name) {
312309
* can filter on this metadata.
313310
* @param documents The documents contents to associate with the embeddings.
314311
*/
315-
public record AddEmbeddingsRequest(List<String> ids, List<float[]> embeddings,
316-
@JsonProperty("metadatas") List<Map<String, Object>> metadata, List<String> documents) {
312+
@JsonInclude(JsonInclude.Include.NON_NULL)
313+
public record AddEmbeddingsRequest(// @formatter:off
314+
@JsonProperty("ids") List<String> ids,
315+
@JsonProperty("embeddings") List<float[]> embeddings,
316+
@JsonProperty("metadatas") List<Map<String, Object>> metadata,
317+
@JsonProperty("documents") List<String> documents) {// @formatter:on
317318

318319
// Convenance for adding a single embedding.
319320
public AddEmbeddingsRequest(String id, float[] embedding, Map<String, Object> metadata, String document) {
320321
this(List.of(id), List.of(embedding), List.of(metadata), List.of(document));
321322
}
322-
323323
}
324324

325325
/**
@@ -329,12 +329,14 @@ public AddEmbeddingsRequest(String id, float[] embedding, Map<String, Object> me
329329
* @param where Condition to filter items to delete based on metadata values.
330330
* (Optional)
331331
*/
332-
public record DeleteEmbeddingsRequest(List<String> ids, Map<String, Object> where) {
332+
@JsonInclude(JsonInclude.Include.NON_NULL)
333+
public record DeleteEmbeddingsRequest(// @formatter:off
334+
@JsonProperty("ids") List<String> ids,
335+
@JsonProperty("where") Map<String, Object> where) {// @formatter:on
333336

334337
public DeleteEmbeddingsRequest(List<String> ids) {
335-
this(ids, Map.of());
338+
this(ids, null);
336339
}
337-
338340
}
339341

340342
/**
@@ -348,19 +350,24 @@ public DeleteEmbeddingsRequest(List<String> ids) {
348350
* "metadatas", "documents", "distances". Ids are always included. Defaults to
349351
* [metadatas, documents, distances].
350352
*/
351-
public record GetEmbeddingsRequest(List<String> ids, Map<String, Object> where, int limit, int offset,
352-
List<Include> include) {
353+
@JsonInclude(JsonInclude.Include.NON_NULL)
354+
public record GetEmbeddingsRequest(// @formatter:off
355+
@JsonProperty("ids") List<String> ids,
356+
@JsonProperty("where") Map<String, Object> where,
357+
@JsonProperty("limit") Integer limit,
358+
@JsonProperty("offset") Integer offset,
359+
@JsonProperty("include") List<Include> include) {// @formatter:on
353360

354361
public GetEmbeddingsRequest(List<String> ids) {
355-
this(ids, Map.of(), 10, 0, Include.all);
362+
this(ids, null, 10, 0, Include.all);
356363
}
357364

358365
public GetEmbeddingsRequest(List<String> ids, Map<String, Object> where) {
359-
this(ids, where, 10, 0, Include.all);
366+
this(ids, CollectionUtils.isEmpty(where) ? null : where, 10, 0, Include.all);
360367
}
361368

362-
public GetEmbeddingsRequest(List<String> ids, Map<String, Object> where, int limit, int offset) {
363-
this(ids, where, limit, offset, Include.all);
369+
public GetEmbeddingsRequest(List<String> ids, Map<String, Object> where, Integer limit, Integer offset) {
370+
this(ids, CollectionUtils.isEmpty(where) ? null : where, limit, offset, Include.all);
364371
}
365372

366373
}
@@ -373,9 +380,12 @@ public GetEmbeddingsRequest(List<String> ids, Map<String, Object> where, int lim
373380
* @param documents List of document contents. One for each returned document.
374381
* @param metadata List of document metadata. One for each returned document.
375382
*/
376-
public record GetEmbeddingResponse(List<String> ids, List<float[]> embeddings, List<String> documents,
377-
@JsonProperty("metadatas") List<Map<String, String>> metadata) {
378-
383+
@JsonInclude(JsonInclude.Include.NON_NULL)
384+
public record GetEmbeddingResponse(// @formatter:off
385+
@JsonProperty("ids") List<String> ids,
386+
@JsonProperty("embeddings") List<float[]> embeddings,
387+
@JsonProperty("documents") List<String> documents,
388+
@JsonProperty("metadatas") List<Map<String, String>> metadata) {// @formatter:on
379389
}
380390

381391
/**
@@ -390,18 +400,22 @@ public record GetEmbeddingResponse(List<String> ids, List<float[]> embeddings, L
390400
* "metadatas", "documents", "distances". Ids are always included. Defaults to
391401
* [metadatas, documents, distances].
392402
*/
393-
public record QueryRequest(@JsonProperty("query_embeddings") List<float[]> queryEmbeddings,
394-
@JsonProperty("n_results") int nResults, Map<String, Object> where, List<Include> include) {
403+
@JsonInclude(JsonInclude.Include.NON_NULL)
404+
public record QueryRequest( // @formatter:off
405+
@JsonProperty("query_embeddings") List<float[]> queryEmbeddings,
406+
@JsonProperty("n_results") Integer nResults,
407+
@JsonProperty("where") Map<String, Object> where,
408+
@JsonProperty("include") List<Include> include) { // @formatter:on
395409

396410
/**
397411
* Convenience to query for a single embedding instead of a batch of embeddings.
398412
*/
399-
public QueryRequest(float[] queryEmbedding, int nResults) {
400-
this(List.of(queryEmbedding), nResults, Map.of(), Include.all);
413+
public QueryRequest(float[] queryEmbedding, Integer nResults) {
414+
this(List.of(queryEmbedding), nResults, null, Include.all);
401415
}
402416

403-
public QueryRequest(float[] queryEmbedding, int nResults, Map<String, Object> where) {
404-
this(List.of(queryEmbedding), nResults, where, Include.all);
417+
public QueryRequest(float[] queryEmbedding, Integer nResults, Map<String, Object> where) {
418+
this(List.of(queryEmbedding), nResults, CollectionUtils.isEmpty(where) ? null : where, Include.all);
405419
}
406420

407421
public enum Include {
@@ -434,9 +448,13 @@ public enum Include {
434448
* @param metadata List of list of document metadata. One for each returned document.
435449
* @param distances List of list of search distances. One for each returned document.
436450
*/
437-
public record QueryResponse(List<List<String>> ids, List<List<float[]>> embeddings, List<List<String>> documents,
438-
@JsonProperty("metadatas") List<List<Map<String, Object>>> metadata, List<List<Double>> distances) {
439-
451+
@JsonInclude(JsonInclude.Include.NON_NULL)
452+
public record QueryResponse(// @formatter:off
453+
@JsonProperty("ids") List<List<String>> ids,
454+
@JsonProperty("embeddings") List<List<float[]>> embeddings,
455+
@JsonProperty("documents") List<List<String>> documents,
456+
@JsonProperty("metadatas") List<List<Map<String, Object>>> metadata,
457+
@JsonProperty("distances") List<List<Double>> distances) {// @formatter:on
440458
}
441459

442460
/**
@@ -448,8 +466,13 @@ public record QueryResponse(List<List<String>> ids, List<List<float[]>> embeddin
448466
* @param metadata The metadata of the document.
449467
* @param distances The distance of the document to the query embedding.
450468
*/
451-
public record Embedding(String id, float[] embedding, String document, Map<String, Object> metadata,
452-
Double distances) {
469+
@JsonInclude(JsonInclude.Include.NON_NULL)
470+
public record Embedding(// @formatter:off
471+
@JsonProperty("id") String id,
472+
@JsonProperty("embedding") float[] embedding,
473+
@JsonProperty("document") String document,
474+
@JsonProperty("metadata") Map<String, Object> metadata,
475+
@JsonProperty("distances") Double distances) {// @formatter:on
453476

454477
}
455478

vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@
4343
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext;
4444
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
4545
import org.springframework.beans.factory.InitializingBean;
46+
import org.springframework.lang.NonNull;
4647
import org.springframework.util.Assert;
4748
import org.springframework.util.CollectionUtils;
48-
import org.springframework.util.StringUtils;
4949

5050
/**
5151
* {@link ChromaVectorStore} is a concrete implementation of the {@link VectorStore}
@@ -134,7 +134,7 @@ public void setFilterExpressionConverter(FilterExpressionConverter filterExpress
134134
}
135135

136136
@Override
137-
public void doAdd(List<Document> documents) {
137+
public void doAdd(@NonNull List<Document> documents) {
138138
Assert.notNull(documents, "Documents must not be null");
139139
if (CollectionUtils.isEmpty(documents)) {
140140
return;
@@ -160,24 +160,23 @@ public void doAdd(List<Document> documents) {
160160
}
161161

162162
@Override
163-
public Optional<Boolean> doDelete(List<String> idList) {
163+
public Optional<Boolean> doDelete(@NonNull List<String> idList) {
164164
Assert.notNull(idList, "Document id list must not be null");
165165
int status = this.chromaApi.deleteEmbeddings(this.collectionId, new DeleteEmbeddingsRequest(idList));
166166
return Optional.of(status == 200);
167167
}
168168

169169
@Override
170-
public List<Document> doSimilaritySearch(SearchRequest request) {
171-
172-
String nativeFilterExpression = (request.getFilterExpression() != null)
173-
? this.filterExpressionConverter.convertExpression(request.getFilterExpression()) : "";
170+
public @NonNull List<Document> doSimilaritySearch(@NonNull SearchRequest request) {
174171

175172
String query = request.getQuery();
176173
Assert.notNull(query, "Query string must not be null");
177174

178175
float[] embedding = this.embeddingModel.embed(query);
179-
Map<String, Object> where = (StringUtils.hasText(nativeFilterExpression)) ? jsonToMap(nativeFilterExpression)
180-
: Map.of();
176+
177+
Map<String, Object> where = (request.getFilterExpression() != null)
178+
? jsonToMap(this.filterExpressionConverter.convertExpression(request.getFilterExpression())) : null;
179+
181180
var queryRequest = new ChromaApi.QueryRequest(embedding, request.getTopK(), where);
182181
var queryResponse = this.chromaApi.queryCollection(this.collectionId, queryRequest);
183182
var embeddings = this.chromaApi.toEmbeddingResponseList(queryResponse);
@@ -241,7 +240,8 @@ public void afterPropertiesSet() throws Exception {
241240
}
242241

243242
@Override
244-
public VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName) {
243+
public @NonNull VectorStoreObservationContext.Builder createObservationContextBuilder(
244+
@NonNull String operationName) {
245245
return VectorStoreObservationContext.builder(VectorStoreProvider.CHROMA.value(), operationName)
246246
.withDimensions(this.embeddingModel.dimensions())
247247
.withCollectionName(this.collectionName + ":" + this.collectionId)

vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/ChromaImage.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
*/
2424
public final class ChromaImage {
2525

26-
public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ghcr.io/chroma-core/chroma:0.5.16");
26+
public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ghcr.io/chroma-core/chroma:0.5.20");
2727

2828
private ChromaImage() {
2929

vector-stores/spring-ai-chroma-store/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreIT.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,34 @@ public void addAndSearch() {
9292
});
9393
}
9494

95+
@Test
96+
public void simpleSearch() {
97+
this.contextRunner.run(context -> {
98+
99+
VectorStore vectorStore = context.getBean(VectorStore.class);
100+
101+
var document = Document.builder()
102+
.withId("simpleDoc")
103+
.withContent("The sky is blue because of Rayleigh scattering.")
104+
.build();
105+
106+
vectorStore.add(List.of(document));
107+
108+
List<Document> results = vectorStore.similaritySearch("Why is the sky blue?");
109+
110+
assertThat(results).hasSize(1);
111+
Document resultDoc = results.get(0);
112+
assertThat(resultDoc.getId()).isEqualTo(document.getId());
113+
assertThat(resultDoc.getContent()).isEqualTo("The sky is blue because of Rayleigh scattering.");
114+
115+
// Remove all documents from the store
116+
assertThat(vectorStore.delete(List.of(document.getId()))).isEqualTo(Optional.of(Boolean.TRUE));
117+
118+
results = vectorStore.similaritySearch(SearchRequest.query("Why is the sky blue?"));
119+
assertThat(results).hasSize(0);
120+
});
121+
}
122+
95123
@Test
96124
public void addAndSearchWithFilters() {
97125

0 commit comments

Comments
 (0)