|
45 | 45 | import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
46 | 46 | import org.junit.jupiter.params.ParameterizedTest;
|
47 | 47 | import org.junit.jupiter.params.provider.ValueSource;
|
| 48 | +import org.springframework.ai.model.SimpleApiKey; |
48 | 49 | import org.testcontainers.elasticsearch.ElasticsearchContainer;
|
49 | 50 | import org.testcontainers.junit.jupiter.Container;
|
50 | 51 | import org.testcontainers.junit.jupiter.Testcontainers;
|
|
57 | 58 | import org.springframework.ai.test.vectorstore.BaseVectorStoreTests;
|
58 | 59 | import org.springframework.ai.vectorstore.SearchRequest;
|
59 | 60 | import org.springframework.ai.vectorstore.VectorStore;
|
60 |
| -import org.springframework.ai.vectorstore.filter.Filter.Expression; |
61 |
| -import org.springframework.ai.vectorstore.filter.Filter.ExpressionType; |
62 |
| -import org.springframework.ai.vectorstore.filter.Filter.Key; |
63 |
| -import org.springframework.ai.vectorstore.filter.Filter.Value; |
64 | 61 | import org.springframework.boot.SpringBootConfiguration;
|
65 | 62 | import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
|
66 | 63 | import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration;
|
@@ -127,10 +124,11 @@ protected void executeTest(Consumer<VectorStore> testFunction) {
|
127 | 124 | });
|
128 | 125 | }
|
129 | 126 |
|
130 |
| - @Test |
131 |
| - public void addAndDeleteDocumentsTest() { |
| 127 | + @ParameterizedTest(name = "{0} : {displayName} ") |
| 128 | + @ValueSource(strings = { "cosine", "custom_embedding_field" }) |
| 129 | + public void addAndDeleteDocumentsTest(String vectorStoreBeanName) { |
132 | 130 | getContextRunner().run(context -> {
|
133 |
| - ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_cosine", |
| 131 | + ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + vectorStoreBeanName, |
134 | 132 | ElasticsearchVectorStore.class);
|
135 | 133 | ElasticsearchClient elasticsearchClient = context.getBean(ElasticsearchClient.class);
|
136 | 134 |
|
@@ -160,12 +158,12 @@ public void addAndDeleteDocumentsTest() {
|
160 | 158 | }
|
161 | 159 |
|
162 | 160 | @ParameterizedTest(name = "{0} : {displayName} ")
|
163 |
| - @ValueSource(strings = { "cosine", "l2_norm", "dot_product" }) |
164 |
| - public void addAndSearchTest(String similarityFunction) { |
| 161 | + @ValueSource(strings = { "cosine", "l2_norm", "dot_product", "custom_embedding_field" }) |
| 162 | + public void addAndSearchTest(String vectorStoreBeanName) { |
165 | 163 |
|
166 | 164 | getContextRunner().run(context -> {
|
167 | 165 |
|
168 |
| - ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + similarityFunction, |
| 166 | + ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + vectorStoreBeanName, |
169 | 167 | ElasticsearchVectorStore.class);
|
170 | 168 |
|
171 | 169 | vectorStore.add(this.documents);
|
@@ -197,11 +195,11 @@ public void addAndSearchTest(String similarityFunction) {
|
197 | 195 | }
|
198 | 196 |
|
199 | 197 | @ParameterizedTest(name = "{0} : {displayName} ")
|
200 |
| - @ValueSource(strings = { "cosine", "l2_norm", "dot_product" }) |
201 |
| - public void searchWithFilters(String similarityFunction) { |
| 198 | + @ValueSource(strings = { "cosine", "l2_norm", "dot_product", "custom_embedding_field" }) |
| 199 | + public void searchWithFilters(String vectorStoreBeanName) { |
202 | 200 |
|
203 | 201 | getContextRunner().run(context -> {
|
204 |
| - ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + similarityFunction, |
| 202 | + ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + vectorStoreBeanName, |
205 | 203 | ElasticsearchVectorStore.class);
|
206 | 204 |
|
207 | 205 | var bgDocument = new Document("1", "The World is Big and Salvation Lurks Around the Corner",
|
@@ -311,11 +309,11 @@ public void searchWithFilters(String similarityFunction) {
|
311 | 309 | }
|
312 | 310 |
|
313 | 311 | @ParameterizedTest(name = "{0} : {displayName} ")
|
314 |
| - @ValueSource(strings = { "cosine", "l2_norm", "dot_product" }) |
315 |
| - public void documentUpdateTest(String similarityFunction) { |
| 312 | + @ValueSource(strings = { "cosine", "l2_norm", "dot_product", "custom_embedding_field" }) |
| 313 | + public void documentUpdateTest(String vectorStoreBeanName) { |
316 | 314 |
|
317 | 315 | getContextRunner().run(context -> {
|
318 |
| - ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + similarityFunction, |
| 316 | + ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + vectorStoreBeanName, |
319 | 317 | ElasticsearchVectorStore.class);
|
320 | 318 |
|
321 | 319 | Document document = new Document(UUID.randomUUID().toString(), "Spring AI rocks!!",
|
@@ -369,10 +367,10 @@ public void documentUpdateTest(String similarityFunction) {
|
369 | 367 | }
|
370 | 368 |
|
371 | 369 | @ParameterizedTest(name = "{0} : {displayName} ")
|
372 |
| - @ValueSource(strings = { "cosine", "l2_norm", "dot_product" }) |
373 |
| - public void searchThresholdTest(String similarityFunction) { |
| 370 | + @ValueSource(strings = { "cosine", "l2_norm", "dot_product", "custom_embedding_field" }) |
| 371 | + public void searchThresholdTest(String vectorStoreBeanName) { |
374 | 372 | getContextRunner().run(context -> {
|
375 |
| - ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + similarityFunction, |
| 373 | + ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + vectorStoreBeanName, |
376 | 374 | ElasticsearchVectorStore.class);
|
377 | 375 |
|
378 | 376 | vectorStore.add(this.documents);
|
@@ -507,9 +505,20 @@ public ElasticsearchVectorStore vectorStoreDotProduct(EmbeddingModel embeddingMo
|
507 | 505 | .build();
|
508 | 506 | }
|
509 | 507 |
|
| 508 | + @Bean("vectorStore_custom_embedding_field") |
| 509 | + public ElasticsearchVectorStore vectorStoreCustomField(EmbeddingModel embeddingModel, RestClient restClient) { |
| 510 | + ElasticsearchVectorStoreOptions options = new ElasticsearchVectorStoreOptions(); |
| 511 | + options.setEmbeddingFieldName("custom_embedding_field"); |
| 512 | + return ElasticsearchVectorStore.builder(restClient, embeddingModel) |
| 513 | + .initializeSchema(true) |
| 514 | + .options(options) |
| 515 | + .build(); |
| 516 | + } |
| 517 | + |
510 | 518 | @Bean
|
511 | 519 | public EmbeddingModel embeddingModel() {
|
512 |
| - return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY"))); |
| 520 | + return new OpenAiEmbeddingModel( |
| 521 | + OpenAiApi.builder().apiKey(new SimpleApiKey(System.getenv("OPENAI_API_KEY"))).build()); |
513 | 522 | }
|
514 | 523 |
|
515 | 524 | @Bean
|
|
0 commit comments