|
54 | 54 | import org.springframework.ai.openai.OpenAiEmbeddingModel;
|
55 | 55 | import org.springframework.ai.openai.api.OpenAiApi;
|
56 | 56 | import org.springframework.ai.vectorstore.SearchRequest;
|
57 |
| -import org.springframework.ai.vectorstore.filter.Filter; |
58 | 57 | import org.springframework.ai.vectorstore.filter.Filter.Expression;
|
59 | 58 | import org.springframework.ai.vectorstore.filter.Filter.ExpressionType;
|
60 | 59 | import org.springframework.ai.vectorstore.filter.Filter.Key;
|
@@ -117,10 +116,11 @@ void cleanDatabase() {
|
117 | 116 | });
|
118 | 117 | }
|
119 | 118 |
|
120 |
| - @Test |
121 |
| - public void addAndDeleteDocumentsTest() { |
| 119 | + @ParameterizedTest(name = "{0} : {displayName} ") |
| 120 | + @ValueSource(strings = { "cosine", "custom_field" }) |
| 121 | + public void addAndDeleteDocumentsTest(String vectorStoreBeanName) { |
122 | 122 | getContextRunner().run(context -> {
|
123 |
| - ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_cosine", |
| 123 | + ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + vectorStoreBeanName, |
124 | 124 | ElasticsearchVectorStore.class);
|
125 | 125 | ElasticsearchClient elasticsearchClient = context.getBean(ElasticsearchClient.class);
|
126 | 126 |
|
@@ -149,10 +149,11 @@ public void addAndDeleteDocumentsTest() {
|
149 | 149 | });
|
150 | 150 | }
|
151 | 151 |
|
152 |
| - @Test |
153 |
| - public void deleteDocumentsByFilterExpressionTest() { |
| 152 | + @ParameterizedTest(name = "{0} : {displayName} ") |
| 153 | + @ValueSource(strings = { "cosine", "custom_field" }) |
| 154 | + public void deleteDocumentsByFilterExpressionTest(String vectorStoreBeanName) { |
154 | 155 | getContextRunner().run(context -> {
|
155 |
| - ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_cosine", |
| 156 | + ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + vectorStoreBeanName, |
156 | 157 | ElasticsearchVectorStore.class);
|
157 | 158 | ElasticsearchClient elasticsearchClient = context.getBean(ElasticsearchClient.class);
|
158 | 159 |
|
@@ -202,10 +203,11 @@ public void deleteDocumentsByFilterExpressionTest() {
|
202 | 203 | });
|
203 | 204 | }
|
204 | 205 |
|
205 |
| - @Test |
206 |
| - public void deleteWithStringFilterExpressionTest() { |
| 206 | + @ParameterizedTest(name = "{0} : {displayName} ") |
| 207 | + @ValueSource(strings = { "cosine", "custom_field" }) |
| 208 | + public void deleteWithStringFilterExpressionTest(String vectorStoreBeanName) { |
207 | 209 | getContextRunner().run(context -> {
|
208 |
| - ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_cosine", |
| 210 | + ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + vectorStoreBeanName, |
209 | 211 | ElasticsearchVectorStore.class);
|
210 | 212 | ElasticsearchClient elasticsearchClient = context.getBean(ElasticsearchClient.class);
|
211 | 213 |
|
@@ -234,12 +236,12 @@ public void deleteWithStringFilterExpressionTest() {
|
234 | 236 | }
|
235 | 237 |
|
236 | 238 | @ParameterizedTest(name = "{0} : {displayName} ")
|
237 |
| - @ValueSource(strings = { "cosine", "l2_norm", "dot_product" }) |
238 |
| - public void addAndSearchTest(String similarityFunction) { |
| 239 | + @ValueSource(strings = { "cosine", "l2_norm", "dot_product", "custom_field" }) |
| 240 | + public void addAndSearchTest(String vectorStoreBeanName) { |
239 | 241 |
|
240 | 242 | getContextRunner().run(context -> {
|
241 | 243 |
|
242 |
| - ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + similarityFunction, |
| 244 | + ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + vectorStoreBeanName, |
243 | 245 | ElasticsearchVectorStore.class);
|
244 | 246 |
|
245 | 247 | vectorStore.add(this.documents);
|
@@ -271,11 +273,11 @@ public void addAndSearchTest(String similarityFunction) {
|
271 | 273 | }
|
272 | 274 |
|
273 | 275 | @ParameterizedTest(name = "{0} : {displayName} ")
|
274 |
| - @ValueSource(strings = { "cosine", "l2_norm", "dot_product" }) |
275 |
| - public void searchWithFilters(String similarityFunction) { |
| 276 | + @ValueSource(strings = { "cosine", "l2_norm", "dot_product", "custom_field" }) |
| 277 | + public void searchWithFilters(String vectorStoreBeanName) { |
276 | 278 |
|
277 | 279 | getContextRunner().run(context -> {
|
278 |
| - ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + similarityFunction, |
| 280 | + ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + vectorStoreBeanName, |
279 | 281 | ElasticsearchVectorStore.class);
|
280 | 282 |
|
281 | 283 | var bgDocument = new Document("1", "The World is Big and Salvation Lurks Around the Corner",
|
@@ -385,11 +387,11 @@ public void searchWithFilters(String similarityFunction) {
|
385 | 387 | }
|
386 | 388 |
|
387 | 389 | @ParameterizedTest(name = "{0} : {displayName} ")
|
388 |
| - @ValueSource(strings = { "cosine", "l2_norm", "dot_product" }) |
389 |
| - public void documentUpdateTest(String similarityFunction) { |
| 390 | + @ValueSource(strings = { "cosine", "l2_norm", "dot_product", "custom_field" }) |
| 391 | + public void documentUpdateTest(String vectorStoreBeanName) { |
390 | 392 |
|
391 | 393 | getContextRunner().run(context -> {
|
392 |
| - ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + similarityFunction, |
| 394 | + ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + vectorStoreBeanName, |
393 | 395 | ElasticsearchVectorStore.class);
|
394 | 396 |
|
395 | 397 | Document document = new Document(UUID.randomUUID().toString(), "Spring AI rocks!!",
|
@@ -443,10 +445,10 @@ public void documentUpdateTest(String similarityFunction) {
|
443 | 445 | }
|
444 | 446 |
|
445 | 447 | @ParameterizedTest(name = "{0} : {displayName} ")
|
446 |
| - @ValueSource(strings = { "cosine", "l2_norm", "dot_product" }) |
447 |
| - public void searchThresholdTest(String similarityFunction) { |
| 448 | + @ValueSource(strings = { "cosine", "l2_norm", "dot_product", "custom_field" }) |
| 449 | + public void searchThresholdTest(String vectorStoreBeanName) { |
448 | 450 | getContextRunner().run(context -> {
|
449 |
| - ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + similarityFunction, |
| 451 | + ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + vectorStoreBeanName, |
450 | 452 | ElasticsearchVectorStore.class);
|
451 | 453 |
|
452 | 454 | vectorStore.add(this.documents);
|
@@ -581,6 +583,16 @@ public ElasticsearchVectorStore vectorStoreDotProduct(EmbeddingModel embeddingMo
|
581 | 583 | .build();
|
582 | 584 | }
|
583 | 585 |
|
| 586 | + @Bean("vectorStore_custom_field") |
| 587 | + public ElasticsearchVectorStore vectorStoreCustomField(EmbeddingModel embeddingModel, RestClient restClient) { |
| 588 | + ElasticsearchVectorStoreOptions options = new ElasticsearchVectorStoreOptions(); |
| 589 | + options.setEmbeddingFieldName("custom_field"); |
| 590 | + return ElasticsearchVectorStore.builder(restClient, embeddingModel) |
| 591 | + .initializeSchema(true) |
| 592 | + .options(options) |
| 593 | + .build(); |
| 594 | + } |
| 595 | + |
584 | 596 | @Bean
|
585 | 597 | public EmbeddingModel embeddingModel() {
|
586 | 598 | return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
|
|
0 commit comments