Skip to content

Commit 8a5f9df

Browse files
michaelsembwevertzolov
authored andcommitted
Add option to CassandraVectorStore to return embeddings in documents from similarity searches
1 parent 6e40575 commit 8a5f9df

File tree

3 files changed

+89
-20
lines changed

3 files changed

+89
-20
lines changed

vector-stores/spring-ai-cassandra/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -160,11 +160,9 @@ public void add(List<Document> documents) {
160160
futures[i++] = CompletableFuture.runAsync(() -> {
161161
List<Object> primaryKeyValues = this.conf.documentIdTranslator.apply(d.getId());
162162

163-
var embedding = (null != d.getEmbedding() && !d.getEmbedding().isEmpty() ? d.getEmbedding()
164-
: this.embeddingClient.embed(d))
165-
.stream()
166-
.map(Double::floatValue)
167-
.toList();
163+
if (null == d.getEmbedding() || d.getEmbedding().isEmpty()) {
164+
d.setEmbedding(this.embeddingClient.embed(d));
165+
}
168166

169167
BoundStatementBuilder builder = prepareAddStatement(d.getMetadata().keySet()).boundStatementBuilder();
170168
for (int k = 0; k < primaryKeyValues.size(); ++k) {
@@ -173,7 +171,9 @@ public void add(List<Document> documents) {
173171
}
174172

175173
builder = builder.setString(this.conf.schema.content(), d.getContent())
176-
.setVector(this.conf.schema.embedding(), CqlVector.newInstance(embedding), Float.class);
174+
.setVector(this.conf.schema.embedding(),
175+
CqlVector.newInstance(d.getEmbedding().stream().map(Double::floatValue).toList()),
176+
Float.class);
177177

178178
for (var metadataColumn : this.conf.schema.metadataColumns()
179179
.stream()
@@ -235,8 +235,15 @@ public List<Document> similaritySearch(SearchRequest request) {
235235
docFields.put(metadata.name(), value);
236236
}
237237
}
238+
Document doc = new Document(getDocumentId(row), row.getString(this.conf.schema.content()), docFields);
238239

239-
documents.add(new Document(getDocumentId(row), row.getString(this.conf.schema.content()), docFields));
240+
if (this.conf.returnEmbeddings) {
241+
doc.setEmbedding(row.getVector(this.conf.schema.embedding(), Float.class)
242+
.stream()
243+
.map(Float::doubleValue)
244+
.toList());
245+
}
246+
documents.add(doc);
240247
}
241248
return documents;
242249
}
@@ -328,6 +335,9 @@ private String similaritySearchStatement() {
328335
for (var m : this.conf.schema.metadataColumns()) {
329336
extraSelectFields.append(',').append(m.name());
330337
}
338+
if (this.conf.returnEmbeddings) {
339+
extraSelectFields.append(',').append(this.conf.schema.embedding());
340+
}
331341

332342
// java-driver-query-builder doesn't support orderByAnnOf yet
333343
String query = String.format(QUERY_FORMAT, similarityFunction, ids.toString(), this.conf.schema.content(),

vector-stores/spring-ai-cassandra/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStoreConfig.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,8 @@ public interface PrimaryKeyTranslator extends Function<List<Object>, String> {
132132

133133
final boolean disallowSchemaChanges;
134134

135+
final boolean returnEmbeddings;
136+
135137
final DocumentIdTranslator documentIdTranslator;
136138

137139
final PrimaryKeyTranslator primaryKeyTranslator;
@@ -148,6 +150,7 @@ private CassandraVectorStoreConfig(Builder builder) {
148150
builder.contentColumnName, builder.embeddingColumnName, builder.indexName, builder.metadataColumns);
149151

150152
this.disallowSchemaChanges = builder.disallowSchemaCreation;
153+
this.returnEmbeddings = builder.returnEmbeddings;
151154
this.documentIdTranslator = builder.documentIdTranslator;
152155
this.primaryKeyTranslator = builder.primaryKeyTranslator;
153156
this.executor = Executors.newFixedThreadPool(builder.fixedThreadPoolExecutorSize);
@@ -199,6 +202,8 @@ public static class Builder {
199202

200203
private boolean disallowSchemaCreation = false;
201204

205+
private boolean returnEmbeddings = false;
206+
202207
private int fixedThreadPoolExecutorSize = DEFAULT_ADD_CONCURRENCY;
203208

204209
private DocumentIdTranslator documentIdTranslator = (String id) -> List.of(id);
@@ -308,6 +313,11 @@ public Builder disallowSchemaChanges() {
308313
return this;
309314
}
310315

316+
public Builder returnEmbeddings() {
317+
this.returnEmbeddings = true;
318+
return this;
319+
}
320+
311321
/**
312322
* Executor to use when adding documents. The hotspot is the call to the
313323
* embeddingClient. For remote transformers you probably want a higher value to

vector-stores/spring-ai-cassandra/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreIT.java

Lines changed: 62 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,14 @@ class CassandraVectorStoreIT {
6868
private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
6969
.withUserConfiguration(TestApplication.class);
7070

71-
List<Document> documents = List.of(
72-
new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")),
73-
new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()),
74-
new Document("3", getText("classpath:/test/data/great.depression.txt"),
75-
Map.of("meta2", "meta2", "something_extra", "blue")));
71+
private static List<Document> documents() {
72+
return List.of(new Document("1", getText("classpath:/test/data/spring.ai.txt"), Map.of("meta1", "meta1")),
73+
new Document("2", getText("classpath:/test/data/time.shelter.txt"), Map.of()),
74+
new Document("3", getText("classpath:/test/data/great.depression.txt"),
75+
Map.of("meta2", "meta2", "something_extra", "blue")));
76+
}
7677

77-
public static String getText(String uri) {
78+
private static String getText(String uri) {
7879
var resource = new DefaultResourceLoader().getResource(uri);
7980
try {
8081
return resource.getContentAsString(StandardCharsets.UTF_8);
@@ -99,13 +100,21 @@ void addAndSearch() {
99100
contextRunner.run(context -> {
100101
try (CassandraVectorStore store = createTestStore(context, new SchemaColumn("meta1", DataTypes.TEXT),
101102
new SchemaColumn("meta2", DataTypes.TEXT))) {
103+
104+
List<Document> documents = documents();
102105
store.add(documents);
106+
for (Document d : documents) {
107+
assertThat(d.getEmbedding()).satisfiesAnyOf(e -> assertThat(e).isNotNull(),
108+
e -> assertThat(e).isNotEmpty());
109+
}
103110

104111
List<Document> results = store.similaritySearch(SearchRequest.query("Spring").withTopK(1));
105112

106113
assertThat(results).hasSize(1);
107114
Document resultDoc = results.get(0);
108-
assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId());
115+
assertThat(resultDoc.getId()).isEqualTo(documents().get(0).getId());
116+
assertThat(resultDoc.getEmbedding()).satisfiesAnyOf(e -> assertThat(e).isNull(),
117+
e -> assertThat(e).isEmpty());
109118

110119
assertThat(resultDoc.getContent()).contains(
111120
"Spring AI provides abstractions that serve as the foundation for developing AI applications.");
@@ -114,7 +123,43 @@ void addAndSearch() {
114123
assertThat(resultDoc.getMetadata()).containsKeys("meta1", CassandraVectorStore.SIMILARITY_FIELD_NAME);
115124

116125
// Remove all documents from the store
117-
store.delete(documents.stream().map(doc -> doc.getId()).toList());
126+
store.delete(documents().stream().map(doc -> doc.getId()).toList());
127+
128+
results = store.similaritySearch(SearchRequest.query("Spring").withTopK(1));
129+
assertThat(results).isEmpty();
130+
}
131+
});
132+
}
133+
134+
@Test
135+
void addAndSearchReturnEmbeddings() {
136+
contextRunner.run(context -> {
137+
CassandraVectorStoreConfig.Builder builder = storeBuilder(context.getBean(CqlSession.class))
138+
.returnEmbeddings();
139+
140+
try (CassandraVectorStore store = createTestStore(context, builder)) {
141+
List<Document> documents = documents();
142+
store.add(documents);
143+
for (Document d : documents) {
144+
assertThat(d.getEmbedding()).satisfiesAnyOf(e -> assertThat(e).isNotNull(),
145+
e -> assertThat(e).isNotEmpty());
146+
}
147+
148+
List<Document> results = store.similaritySearch(SearchRequest.query("Spring").withTopK(1));
149+
150+
assertThat(results).hasSize(1);
151+
Document resultDoc = results.get(0);
152+
assertThat(resultDoc.getId()).isEqualTo(documents().get(0).getId());
153+
assertThat(resultDoc.getEmbedding()).isNotEmpty();
154+
155+
assertThat(resultDoc.getContent()).contains(
156+
"Spring AI provides abstractions that serve as the foundation for developing AI applications.");
157+
158+
assertThat(resultDoc.getMetadata()).hasSize(1);
159+
assertThat(resultDoc.getMetadata()).containsKey(CassandraVectorStore.SIMILARITY_FIELD_NAME);
160+
161+
// Remove all documents from the store
162+
store.delete(documents().stream().map(doc -> doc.getId()).toList());
118163

119164
results = store.similaritySearch(SearchRequest.query("Spring").withTopK(1));
120165
assertThat(results).isEmpty();
@@ -309,7 +354,7 @@ void documentUpdate() {
309354
void searchWithThreshold() {
310355
contextRunner.run(context -> {
311356
try (CassandraVectorStore store = context.getBean(CassandraVectorStore.class)) {
312-
store.add(documents);
357+
store.add(documents());
313358

314359
List<Document> fullResult = store
315360
.similaritySearch(SearchRequest.query("Spring").withTopK(5).withSimilarityThresholdAll());
@@ -327,7 +372,7 @@ void searchWithThreshold() {
327372

328373
assertThat(results).hasSize(1);
329374
Document resultDoc = results.get(0);
330-
assertThat(resultDoc.getId()).isEqualTo(documents.get(0).getId());
375+
assertThat(resultDoc.getId()).isEqualTo(documents().get(0).getId());
331376

332377
assertThat(resultDoc.getContent()).contains(
333378
"Spring AI provides abstractions that serve as the foundation for developing AI applications.");
@@ -370,17 +415,21 @@ public CqlSession cqlSession() {
370415

371416
}
372417

373-
static CassandraVectorStoreConfig.Builder storeBuilder(CqlSession cqlSession) {
418+
private static CassandraVectorStoreConfig.Builder storeBuilder(CqlSession cqlSession) {
374419
return CassandraVectorStoreConfig.builder()
375420
.withCqlSession(cqlSession)
376421
.withKeyspaceName("test_" + CassandraVectorStoreConfig.DEFAULT_KEYSPACE_NAME);
377422
}
378423

379-
private CassandraVectorStore createTestStore(ApplicationContext context, SchemaColumn... metadataFields) {
380-
424+
private static CassandraVectorStore createTestStore(ApplicationContext context, SchemaColumn... metadataFields) {
381425
CassandraVectorStoreConfig.Builder builder = storeBuilder(context.getBean(CqlSession.class))
382426
.addMetadataColumns(metadataFields);
383427

428+
return createTestStore(context, builder);
429+
}
430+
431+
private static CassandraVectorStore createTestStore(ApplicationContext context,
432+
CassandraVectorStoreConfig.Builder builder) {
384433
CassandraVectorStoreConfig conf = builder.build();
385434
conf.dropKeyspace();
386435
return new CassandraVectorStore(conf, context.getBean(EmbeddingClient.class));

0 commit comments

Comments
 (0)