Skip to content

Commit 5c9891c

Browse files
improve schema initialization logic and update deprecated code for Milvus vectorstore (#3705)
Fixes #3705 Auto-cherry-pick to 1.0.x Signed-off-by: jonghoon park <dev@jonghoonpark.com>
1 parent 8d45caf commit 5c9891c

File tree

2 files changed

+65
-17
lines changed

2 files changed

+65
-17
lines changed

vector-stores/spring-ai-milvus-store/src/main/java/org/springframework/ai/vectorstore/milvus/MilvusVectorStore.java

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import io.milvus.param.R;
3939
import io.milvus.param.R.Status;
4040
import io.milvus.param.RpcStatus;
41+
import io.milvus.param.collection.CollectionSchemaParam;
4142
import io.milvus.param.collection.CreateCollectionParam;
4243
import io.milvus.param.collection.DropCollectionParam;
4344
import io.milvus.param.collection.FieldType;
@@ -443,6 +444,8 @@ void createCollection() {
443444
if (!isDatabaseCollectionExists()) {
444445
createCollection(this.databaseName, this.collectionName, this.idFieldName, this.isAutoId,
445446
this.contentFieldName, this.metadataFieldName, this.embeddingFieldName);
447+
createIndex(this.databaseName, this.collectionName, this.embeddingFieldName, this.indexType,
448+
this.metricType, this.indexParameters);
446449
}
447450

448451
R<DescribeIndexResponse> indexDescriptionResponse = this.milvusClient
@@ -452,19 +455,8 @@ void createCollection() {
452455
.build());
453456

454457
if (indexDescriptionResponse.getData() == null) {
455-
R<RpcStatus> indexStatus = this.milvusClient.createIndex(CreateIndexParam.newBuilder()
456-
.withDatabaseName(this.databaseName)
457-
.withCollectionName(this.collectionName)
458-
.withFieldName(this.embeddingFieldName)
459-
.withIndexType(this.indexType)
460-
.withMetricType(this.metricType)
461-
.withExtraParam(this.indexParameters)
462-
.withSyncMode(Boolean.FALSE)
463-
.build());
464-
465-
if (indexStatus.getException() != null) {
466-
throw new RuntimeException("Failed to create Index", indexStatus.getException());
467-
}
458+
createIndex(this.databaseName, this.collectionName, this.embeddingFieldName, this.indexType,
459+
this.metricType, this.indexParameters);
468460
}
469461

470462
R<RpcStatus> loadCollectionStatus = this.milvusClient.loadCollection(LoadCollectionParam.newBuilder()
@@ -507,10 +499,12 @@ void createCollection(String databaseName, String collectionName, String idField
507499
.withDescription("Spring AI Vector Store")
508500
.withConsistencyLevel(ConsistencyLevelEnum.STRONG)
509501
.withShardsNum(2)
510-
.addFieldType(docIdFieldType)
511-
.addFieldType(contentFieldType)
512-
.addFieldType(metadataFieldType)
513-
.addFieldType(embeddingFieldType)
502+
.withSchema(CollectionSchemaParam.newBuilder()
503+
.addFieldType(docIdFieldType)
504+
.addFieldType(contentFieldType)
505+
.addFieldType(metadataFieldType)
506+
.addFieldType(embeddingFieldType)
507+
.build())
514508
.build();
515509

516510
R<RpcStatus> collectionStatus = this.milvusClient.createCollection(createCollectionReq);
@@ -520,6 +514,23 @@ void createCollection(String databaseName, String collectionName, String idField
520514

521515
}
522516

517+
void createIndex(String databaseName, String collectionName, String embeddingFieldName, IndexType indexType,
518+
MetricType metricType, String indexParameters) {
519+
R<RpcStatus> indexStatus = this.milvusClient.createIndex(CreateIndexParam.newBuilder()
520+
.withDatabaseName(databaseName)
521+
.withCollectionName(collectionName)
522+
.withFieldName(embeddingFieldName)
523+
.withIndexType(indexType)
524+
.withMetricType(metricType)
525+
.withExtraParam(indexParameters)
526+
.withSyncMode(Boolean.FALSE)
527+
.build());
528+
529+
if (indexStatus.getException() != null) {
530+
throw new RuntimeException("Failed to create Index", indexStatus.getException());
531+
}
532+
}
533+
523534
int embeddingDimensions() {
524535
if (this.embeddingDimension != INVALID_EMBEDDING_DIMENSION) {
525536
return this.embeddingDimension;

vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/milvus/MilvusVectorStoreIT.java

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import java.io.IOException;
2020
import java.nio.charset.StandardCharsets;
21+
import java.util.ArrayList;
2122
import java.util.Collections;
2223
import java.util.List;
2324
import java.util.Map;
@@ -26,6 +27,10 @@
2627
import java.util.function.Consumer;
2728
import java.util.stream.Collectors;
2829

30+
import ch.qos.logback.classic.Logger;
31+
import ch.qos.logback.classic.spi.ILoggingEvent;
32+
import ch.qos.logback.core.AppenderBase;
33+
import io.milvus.client.AbstractMilvusGrpcClient;
2934
import io.milvus.client.MilvusServiceClient;
3035
import io.milvus.param.ConnectParam;
3136
import io.milvus.param.IndexType;
@@ -34,6 +39,7 @@
3439
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
3540
import org.junit.jupiter.params.ParameterizedTest;
3641
import org.junit.jupiter.params.provider.ValueSource;
42+
import org.slf4j.LoggerFactory;
3743
import org.testcontainers.junit.jupiter.Container;
3844
import org.testcontainers.junit.jupiter.Testcontainers;
3945
import org.testcontainers.milvus.MilvusContainer;
@@ -323,6 +329,37 @@ public void deleteWithComplexFilterExpression() {
323329
});
324330
}
325331

332+
@Test
333+
void initializeSchema() {
334+
this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.milvus.metricType=COSINE").run(context -> {
335+
VectorStore vectorStore = context.getBean(VectorStore.class);
336+
337+
Logger logger = (Logger) LoggerFactory.getLogger(AbstractMilvusGrpcClient.class);
338+
LogAppender logAppender = new LogAppender();
339+
logger.addAppender(logAppender);
340+
logAppender.start();
341+
342+
resetCollection(vectorStore);
343+
344+
assertThat(logAppender.capturedLogs).isEmpty();
345+
});
346+
}
347+
348+
static class LogAppender extends AppenderBase<ILoggingEvent> {
349+
350+
private final List<String> capturedLogs = new ArrayList<>();
351+
352+
@Override
353+
protected void append(ILoggingEvent eventObject) {
354+
capturedLogs.add(eventObject.getFormattedMessage());
355+
}
356+
357+
public List<String> getCapturedLogs() {
358+
return capturedLogs;
359+
}
360+
361+
}
362+
326363
@Test
327364
void getNativeClientTest() {
328365
this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.milvus.metricType=COSINE").run(context -> {

0 commit comments

Comments
 (0)