|
44 | 44 | import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
45 | 45 | import org.junit.jupiter.params.ParameterizedTest;
|
46 | 46 | import org.junit.jupiter.params.provider.ValueSource;
|
| 47 | +import org.springframework.ai.model.SimpleApiKey; |
47 | 48 | import org.testcontainers.elasticsearch.ElasticsearchContainer;
|
48 | 49 | import org.testcontainers.junit.jupiter.Container;
|
49 | 50 | import org.testcontainers.junit.jupiter.Testcontainers;
|
|
54 | 55 | import org.springframework.ai.openai.OpenAiEmbeddingModel;
|
55 | 56 | import org.springframework.ai.openai.api.OpenAiApi;
|
56 | 57 | import org.springframework.ai.vectorstore.SearchRequest;
|
57 |
| -import org.springframework.ai.vectorstore.filter.Filter; |
58 | 58 | import org.springframework.ai.vectorstore.filter.Filter.Expression;
|
59 | 59 | import org.springframework.ai.vectorstore.filter.Filter.ExpressionType;
|
60 | 60 | import org.springframework.ai.vectorstore.filter.Filter.Key;
|
@@ -117,10 +117,11 @@ void cleanDatabase() {
|
117 | 117 | });
|
118 | 118 | }
|
119 | 119 |
|
120 |
| - @Test |
121 |
| - public void addAndDeleteDocumentsTest() { |
| 120 | + @ParameterizedTest(name = "{0} : {displayName} ") |
| 121 | + @ValueSource(strings = { "cosine", "custom_embedding_field" }) |
| 122 | + public void addAndDeleteDocumentsTest(String vectorStoreBeanName) { |
122 | 123 | getContextRunner().run(context -> {
|
123 |
| - ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_cosine", |
| 124 | + ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + vectorStoreBeanName, |
124 | 125 | ElasticsearchVectorStore.class);
|
125 | 126 | ElasticsearchClient elasticsearchClient = context.getBean(ElasticsearchClient.class);
|
126 | 127 |
|
@@ -149,10 +150,11 @@ public void addAndDeleteDocumentsTest() {
|
149 | 150 | });
|
150 | 151 | }
|
151 | 152 |
|
152 |
| - @Test |
153 |
| - public void deleteDocumentsByFilterExpressionTest() { |
| 153 | + @ParameterizedTest(name = "{0} : {displayName} ") |
| 154 | + @ValueSource(strings = { "cosine", "custom_embedding_field" }) |
| 155 | + public void deleteDocumentsByFilterExpressionTest(String vectorStoreBeanName) { |
154 | 156 | getContextRunner().run(context -> {
|
155 |
| - ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_cosine", |
| 157 | + ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + vectorStoreBeanName, |
156 | 158 | ElasticsearchVectorStore.class);
|
157 | 159 | ElasticsearchClient elasticsearchClient = context.getBean(ElasticsearchClient.class);
|
158 | 160 |
|
@@ -202,10 +204,11 @@ public void deleteDocumentsByFilterExpressionTest() {
|
202 | 204 | });
|
203 | 205 | }
|
204 | 206 |
|
205 |
| - @Test |
206 |
| - public void deleteWithStringFilterExpressionTest() { |
| 207 | + @ParameterizedTest(name = "{0} : {displayName} ") |
| 208 | + @ValueSource(strings = { "cosine", "custom_embedding_field" }) |
| 209 | + public void deleteWithStringFilterExpressionTest(String vectorStoreBeanName) { |
207 | 210 | getContextRunner().run(context -> {
|
208 |
| - ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_cosine", |
| 211 | + ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + vectorStoreBeanName, |
209 | 212 | ElasticsearchVectorStore.class);
|
210 | 213 | ElasticsearchClient elasticsearchClient = context.getBean(ElasticsearchClient.class);
|
211 | 214 |
|
@@ -234,12 +237,12 @@ public void deleteWithStringFilterExpressionTest() {
|
234 | 237 | }
|
235 | 238 |
|
236 | 239 | @ParameterizedTest(name = "{0} : {displayName} ")
|
237 |
| - @ValueSource(strings = { "cosine", "l2_norm", "dot_product" }) |
238 |
| - public void addAndSearchTest(String similarityFunction) { |
| 240 | + @ValueSource(strings = { "cosine", "l2_norm", "dot_product", "custom_embedding_field" }) |
| 241 | + public void addAndSearchTest(String vectorStoreBeanName) { |
239 | 242 |
|
240 | 243 | getContextRunner().run(context -> {
|
241 | 244 |
|
242 |
| - ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + similarityFunction, |
| 245 | + ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + vectorStoreBeanName, |
243 | 246 | ElasticsearchVectorStore.class);
|
244 | 247 |
|
245 | 248 | vectorStore.add(this.documents);
|
@@ -271,11 +274,11 @@ public void addAndSearchTest(String similarityFunction) {
|
271 | 274 | }
|
272 | 275 |
|
273 | 276 | @ParameterizedTest(name = "{0} : {displayName} ")
|
274 |
| - @ValueSource(strings = { "cosine", "l2_norm", "dot_product" }) |
275 |
| - public void searchWithFilters(String similarityFunction) { |
| 277 | + @ValueSource(strings = { "cosine", "l2_norm", "dot_product", "custom_embedding_field" }) |
| 278 | + public void searchWithFilters(String vectorStoreBeanName) { |
276 | 279 |
|
277 | 280 | getContextRunner().run(context -> {
|
278 |
| - ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + similarityFunction, |
| 281 | + ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + vectorStoreBeanName, |
279 | 282 | ElasticsearchVectorStore.class);
|
280 | 283 |
|
281 | 284 | var bgDocument = new Document("1", "The World is Big and Salvation Lurks Around the Corner",
|
@@ -385,11 +388,11 @@ public void searchWithFilters(String similarityFunction) {
|
385 | 388 | }
|
386 | 389 |
|
387 | 390 | @ParameterizedTest(name = "{0} : {displayName} ")
|
388 |
| - @ValueSource(strings = { "cosine", "l2_norm", "dot_product" }) |
389 |
| - public void documentUpdateTest(String similarityFunction) { |
| 391 | + @ValueSource(strings = { "cosine", "l2_norm", "dot_product", "custom_embedding_field" }) |
| 392 | + public void documentUpdateTest(String vectorStoreBeanName) { |
390 | 393 |
|
391 | 394 | getContextRunner().run(context -> {
|
392 |
| - ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + similarityFunction, |
| 395 | + ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + vectorStoreBeanName, |
393 | 396 | ElasticsearchVectorStore.class);
|
394 | 397 |
|
395 | 398 | Document document = new Document(UUID.randomUUID().toString(), "Spring AI rocks!!",
|
@@ -443,10 +446,10 @@ public void documentUpdateTest(String similarityFunction) {
|
443 | 446 | }
|
444 | 447 |
|
445 | 448 | @ParameterizedTest(name = "{0} : {displayName} ")
|
446 |
| - @ValueSource(strings = { "cosine", "l2_norm", "dot_product" }) |
447 |
| - public void searchThresholdTest(String similarityFunction) { |
| 449 | + @ValueSource(strings = { "cosine", "l2_norm", "dot_product", "custom_embedding_field" }) |
| 450 | + public void searchThresholdTest(String vectorStoreBeanName) { |
448 | 451 | getContextRunner().run(context -> {
|
449 |
| - ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + similarityFunction, |
| 452 | + ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + vectorStoreBeanName, |
450 | 453 | ElasticsearchVectorStore.class);
|
451 | 454 |
|
452 | 455 | vectorStore.add(this.documents);
|
@@ -581,9 +584,20 @@ public ElasticsearchVectorStore vectorStoreDotProduct(EmbeddingModel embeddingMo
|
581 | 584 | .build();
|
582 | 585 | }
|
583 | 586 |
|
| 587 | + @Bean("vectorStore_custom_embedding_field") |
| 588 | + public ElasticsearchVectorStore vectorStoreCustomField(EmbeddingModel embeddingModel, RestClient restClient) { |
| 589 | + ElasticsearchVectorStoreOptions options = new ElasticsearchVectorStoreOptions(); |
| 590 | + options.setEmbeddingFieldName("custom_embedding_field"); |
| 591 | + return ElasticsearchVectorStore.builder(restClient, embeddingModel) |
| 592 | + .initializeSchema(true) |
| 593 | + .options(options) |
| 594 | + .build(); |
| 595 | + } |
| 596 | + |
584 | 597 | @Bean
|
585 | 598 | public EmbeddingModel embeddingModel() {
|
586 |
| - return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY"))); |
| 599 | + return new OpenAiEmbeddingModel( |
| 600 | + OpenAiApi.builder().apiKey(new SimpleApiKey(System.getenv("OPENAI_API_KEY"))).build()); |
587 | 601 | }
|
588 | 602 |
|
589 | 603 | @Bean
|
|
0 commit comments