Skip to content

Commit 9e578a1

Browse files
christophstroblmp911de
authored andcommitted
Introduce @EnableIfVectorSearchAvailable to wait and conditionally skip tests.
We now wait until a search index becomes available. If the search index doesn't come alive within 60 seconds, we skip that test (or test class). Closes: #5013 Original pull request: #5014
1 parent b4f4176 commit 9e578a1

File tree

10 files changed

+233
-47
lines changed

10 files changed

+233
-47
lines changed

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,8 @@ static void initIndexes() {
167167
template.searchIndexOps(WithVectorFields.class).createIndex(rawIndex);
168168
template.searchIndexOps(WithVectorFields.class).createIndex(wrapperIndex);
169169

170-
template.awaitIndexCreation(WithVectorFields.class, rawIndex.getName());
171-
template.awaitIndexCreation(WithVectorFields.class, wrapperIndex.getName());
170+
template.awaitSearchIndexCreation(WithVectorFields.class, rawIndex.getName());
171+
template.awaitSearchIndexCreation(WithVectorFields.class, wrapperIndex.getName());
172172
}
173173

174174
private static void assertScoreIsDecreasing(Iterable<Document> documents) {

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/VectorIndexIntegrationTests.java

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,26 @@
1515
*/
1616
package org.springframework.data.mongodb.core.index;
1717

18-
import static org.assertj.core.api.Assertions.*;
19-
import static org.awaitility.Awaitility.*;
18+
import static org.assertj.core.api.Assertions.assertThatRuntimeException;
19+
import static org.awaitility.Awaitility.await;
2020
import static org.springframework.data.mongodb.test.util.Assertions.assertThat;
2121

22+
import java.time.Duration;
2223
import java.util.List;
2324

2425
import org.bson.Document;
2526
import org.junit.jupiter.api.AfterEach;
2627
import org.junit.jupiter.api.BeforeEach;
2728
import org.junit.jupiter.api.Test;
29+
import org.junit.jupiter.api.extension.ExtendWith;
2830
import org.junit.jupiter.params.ParameterizedTest;
2931
import org.junit.jupiter.params.provider.ValueSource;
30-
3132
import org.springframework.data.annotation.Id;
3233
import org.springframework.data.mongodb.core.index.VectorIndex.SimilarityFunction;
3334
import org.springframework.data.mongodb.core.mapping.Field;
3435
import org.springframework.data.mongodb.test.util.AtlasContainer;
36+
import org.springframework.data.mongodb.test.util.EnableIfVectorSearchAvailable;
37+
import org.springframework.data.mongodb.test.util.MongoServerCondition;
3538
import org.springframework.data.mongodb.test.util.MongoTestTemplate;
3639
import org.springframework.data.mongodb.test.util.MongoTestUtils;
3740
import org.springframework.lang.Nullable;
@@ -48,6 +51,7 @@
4851
* @author Christoph Strobl
4952
* @author Mark Paluch
5053
*/
54+
@ExtendWith(MongoServerCondition.class)
5155
@Testcontainers(disabledWithoutDocker = true)
5256
class VectorIndexIntegrationTests {
5357

@@ -66,19 +70,22 @@ class VectorIndexIntegrationTests {
6670

6771
@BeforeEach
6872
void init() {
69-
template.createCollection(Movie.class);
73+
74+
template.createCollectionIfNotExists(Movie.class);
7075
indexOps = template.searchIndexOps(Movie.class);
7176
}
7277

7378
@AfterEach
7479
void cleanup() {
7580

81+
template.flush(Movie.class);
7682
template.searchIndexOps(Movie.class).dropAllIndexes();
77-
template.dropCollection(Movie.class);
83+
template.awaitNoSearchIndexAvailable(Movie.class, Duration.ofSeconds(30));
7884
}
7985

8086
@ParameterizedTest // GH-4706
8187
@ValueSource(strings = { "euclidean", "cosine", "dotProduct" })
88+
@EnableIfVectorSearchAvailable(collection = Movie.class)
8289
void createsSimpleVectorIndex(String similarityFunction) {
8390

8491
VectorIndex idx = new VectorIndex("vector_index").addVector("plotEmbedding",
@@ -98,21 +105,23 @@ void createsSimpleVectorIndex(String similarityFunction) {
98105
}
99106

100107
@Test // GH-4706
108+
@EnableIfVectorSearchAvailable(collection = Movie.class)
101109
void dropIndex() {
102110

103111
VectorIndex idx = new VectorIndex("vector_index").addVector("plotEmbedding",
104112
builder -> builder.dimensions(1536).similarity("cosine"));
105113

106114
indexOps.createIndex(idx);
107115

108-
template.awaitIndexCreation(Movie.class, idx.getName());
116+
template.awaitSearchIndexCreation(Movie.class, idx.getName());
109117

110118
indexOps.dropIndex(idx.getName());
111119

112120
assertThat(readRawIndexInfo(idx.getName())).isNull();
113121
}
114122

115123
@Test // GH-4706
124+
@EnableIfVectorSearchAvailable(collection = Movie.class)
116125
void statusChanges() throws InterruptedException {
117126

118127
String indexName = "vector_index";
@@ -131,6 +140,7 @@ void statusChanges() throws InterruptedException {
131140
}
132141

133142
@Test // GH-4706
143+
@EnableIfVectorSearchAvailable(collection = Movie.class)
134144
void exists() throws InterruptedException {
135145

136146
String indexName = "vector_index";
@@ -148,6 +158,7 @@ void exists() throws InterruptedException {
148158
}
149159

150160
@Test // GH-4706
161+
@EnableIfVectorSearchAvailable(collection = Movie.class)
151162
void updatesVectorIndex() throws InterruptedException {
152163

153164
String indexName = "vector_index";
@@ -177,6 +188,7 @@ void updatesVectorIndex() throws InterruptedException {
177188
}
178189

179190
@Test // GH-4706
191+
@EnableIfVectorSearchAvailable(collection = Movie.class)
180192
void createsVectorIndexWithFilters() throws InterruptedException {
181193

182194
VectorIndex idx = new VectorIndex("vector_index")

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/AtlasContainer.java

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@
1616
package org.springframework.data.mongodb.test.util;
1717

1818
import org.springframework.core.env.StandardEnvironment;
19-
2019
import org.testcontainers.mongodb.MongoDBAtlasLocalContainer;
2120
import org.testcontainers.utility.DockerImageName;
2221

22+
import com.github.dockerjava.api.command.InspectContainerResponse;
23+
2324
/**
24-
* Extension to MongoDBAtlasLocalContainer.
25+
* Extension to {@link MongoDBAtlasLocalContainer}. Registers mapped host an port as system properties
26+
* ({@link #ATLAS_HOST}, {@link #ATLAS_PORT}).
2527
*
2628
* @author Christoph Strobl
2729
*/
@@ -31,6 +33,9 @@ public class AtlasContainer extends MongoDBAtlasLocalContainer {
3133
private static final String DEFAULT_TAG = "8.0.0";
3234
private static final String LATEST = "latest";
3335

36+
public static final String ATLAS_HOST = "docker.mongodb.atlas.host";
37+
public static final String ATLAS_PORT = "docker.mongodb.atlas.port";
38+
3439
private AtlasContainer(String dockerImageName) {
3540
super(DockerImageName.parse(dockerImageName));
3641
}
@@ -55,4 +60,20 @@ public static AtlasContainer tagged(String tag) {
5560
return new AtlasContainer(DEFAULT_IMAGE_NAME.withTag(tag));
5661
}
5762

63+
@Override
64+
protected void containerIsStarted(InspectContainerResponse containerInfo) {
65+
66+
super.containerIsStarted(containerInfo);
67+
68+
System.setProperty(ATLAS_HOST, getHost());
69+
System.setProperty(ATLAS_PORT, getMappedPort(27017).toString());
70+
}
71+
72+
@Override
73+
protected void containerIsStopping(InspectContainerResponse containerInfo) {
74+
75+
System.clearProperty(ATLAS_HOST);
76+
System.clearProperty(ATLAS_PORT);
77+
super.containerIsStopping(containerInfo);
78+
}
5879
}

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/EnableIfVectorSearchAvailable.java

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,30 @@
2525
import org.junit.jupiter.api.extension.ExtendWith;
2626

2727
/**
28+
* {@link EnableIfVectorSearchAvailable} indicates a specific method can only be run in an environment that has a search
29+
* server available. This means that not only the mongodb instance needs to have a
30+
* {@literal searchIndexManagementHostAndPort} configured, but also that the search index sever is actually up and
31+
* running, responding to a {@literal $listSearchIndexes} aggregation.
32+
*
2833
* @author Christoph Strobl
34+
* @since 5.0
35+
* @see Tag
2936
*/
30-
@Target({ ElementType.TYPE, ElementType.METHOD })
37+
@Target({ ElementType.METHOD })
3138
@Retention(RetentionPolicy.RUNTIME)
3239
@Documented
3340
@Tag("vector-search")
3441
@ExtendWith(MongoServerCondition.class)
3542
public @interface EnableIfVectorSearchAvailable {
3643

44+
/**
45+
* @return the name of the collection used to run the {@literal $listSearchIndexes} aggregation.
46+
*/
47+
String collectionName() default "";
48+
49+
/**
50+
* @return the type for resolving the name of the collection used to run the {@literal $listSearchIndexes}
51+
* aggregation. The {@link #collectionName()} has precedence over the type.
52+
*/
53+
Class<?> collection() default Object.class;
3754
}

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoExtensions.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ static class Client {
3131
static final String REACTIVE_REPLSET_KEY = "mongo.client.replset.reactive";
3232
}
3333

34-
static class Termplate {
34+
static class Template {
3535

3636
static final Namespace NAMESPACE = Namespace.create(MongoTemplateExtension.class);
3737
static final String SYNC = "mongo.template.sync";

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoServerCondition.java

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,21 @@
1515
*/
1616
package org.springframework.data.mongodb.test.util;
1717

18+
import java.time.Duration;
19+
1820
import org.junit.jupiter.api.extension.ConditionEvaluationResult;
1921
import org.junit.jupiter.api.extension.ExecutionCondition;
2022
import org.junit.jupiter.api.extension.ExtensionContext;
2123
import org.junit.jupiter.api.extension.ExtensionContext.Namespace;
2224
import org.springframework.core.annotation.AnnotatedElementUtils;
25+
import org.springframework.data.mongodb.MongoCollectionUtils;
2326
import org.springframework.data.util.Version;
27+
import org.springframework.util.NumberUtils;
28+
import org.springframework.util.StringUtils;
29+
import org.testcontainers.shaded.org.awaitility.Awaitility;
30+
31+
import com.mongodb.Function;
32+
import com.mongodb.client.MongoClient;
2433

2534
/**
2635
* @author Christoph Strobl
@@ -42,10 +51,13 @@ public ConditionEvaluationResult evaluateExecutionCondition(ExtensionContext con
4251
}
4352
}
4453

45-
if(context.getTags().contains("vector-search")) {
46-
if(!atlasEnvironment(context)) {
54+
if (context.getTags().contains("vector-search")) {
55+
if (!atlasEnvironment(context)) {
4756
return ConditionEvaluationResult.disabled("Disabled for servers not supporting Vector Search.");
4857
}
58+
if (!isSearchIndexAvailable(context)) {
59+
return ConditionEvaluationResult.disabled("Search index unavailable.");
60+
}
4961
}
5062

5163
if (context.getTags().contains("version-specific") && context.getElement().isPresent()) {
@@ -90,8 +102,55 @@ private Version serverVersion(ExtensionContext context) {
90102
Version.class);
91103
}
92104

105+
private boolean isSearchIndexAvailable(ExtensionContext context) {
106+
107+
EnableIfVectorSearchAvailable vectorSearchAvailable = AnnotatedElementUtils
108+
.findMergedAnnotation(context.getElement().get(), EnableIfVectorSearchAvailable.class);
109+
110+
if (vectorSearchAvailable == null) {
111+
return true;
112+
}
113+
114+
String collectionName = StringUtils.hasText(vectorSearchAvailable.collectionName())
115+
? vectorSearchAvailable.collectionName()
116+
: MongoCollectionUtils.getPreferredCollectionName(vectorSearchAvailable.collection());
117+
118+
return context.getStore(NAMESPACE).getOrComputeIfAbsent("search-index-%s-available".formatted(collectionName),
119+
(key) -> {
120+
try {
121+
doWithClient(client -> {
122+
Awaitility.await().atMost(Duration.ofSeconds(60)).pollInterval(Duration.ofMillis(200)).until(() -> {
123+
return MongoTestUtils.isSearchIndexReady(client, null, collectionName);
124+
});
125+
return "done waiting for search index";
126+
});
127+
} catch (Exception e) {
128+
return false;
129+
}
130+
return true;
131+
}, Boolean.class);
132+
133+
}
134+
93135
private boolean atlasEnvironment(ExtensionContext context) {
94-
return context.getStore(NAMESPACE).getOrComputeIfAbsent(Version.class, (key) -> MongoTestUtils.isVectorSearchEnabled(),
95-
Boolean.class);
136+
137+
return context.getStore(NAMESPACE).getOrComputeIfAbsent("mongodb-atlas",
138+
(key) -> doWithClient(MongoTestUtils::isVectorSearchEnabled), Boolean.class);
139+
}
140+
141+
private <T> T doWithClient(Function<MongoClient, T> function) {
142+
143+
String host = System.getProperty(AtlasContainer.ATLAS_HOST);
144+
String port = System.getProperty(AtlasContainer.ATLAS_PORT);
145+
146+
if (StringUtils.hasText(host) && StringUtils.hasText(port)) {
147+
try (MongoClient client = MongoTestUtils.client(host, NumberUtils.parseNumber(port, Integer.class))) {
148+
return function.apply(client);
149+
}
150+
}
151+
152+
try (MongoClient client = MongoTestUtils.client()) {
153+
return function.apply(client);
154+
}
96155
}
97156
}

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTemplateExtension.java

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,15 @@
3333

3434
import org.springframework.data.mongodb.core.MongoOperations;
3535
import org.springframework.data.mongodb.core.ReactiveMongoOperations;
36-
import org.springframework.data.mongodb.test.util.MongoExtensions.Termplate;
36+
import org.springframework.data.mongodb.test.util.MongoExtensions.Template;
3737
import org.springframework.data.util.ParsingUtils;
3838
import org.springframework.util.ClassUtils;
3939

4040
/**
4141
* JUnit {@link Extension} providing parameter resolution for synchronous and reactive MongoDB Template API objects.
4242
*
4343
* @author Christoph Strobl
44-
* @see Template
44+
* @see org.springframework.data.mongodb.test.util.Template
4545
* @see MongoTestTemplate
4646
* @see ReactiveMongoTestTemplate
4747
*/
@@ -65,32 +65,32 @@ public void postProcessTestInstance(Object testInstance, ExtensionContext contex
6565
@Override
6666
public boolean supportsParameter(ParameterContext parameterContext, ExtensionContext extensionContext)
6767
throws ParameterResolutionException {
68-
return super.supportsParameter(parameterContext, extensionContext) || parameterContext.isAnnotated(Template.class);
68+
return super.supportsParameter(parameterContext, extensionContext) || parameterContext.isAnnotated(org.springframework.data.mongodb.test.util.Template.class);
6969
}
7070

7171
@Override
7272
public Object resolveParameter(ParameterContext parameterContext, ExtensionContext extensionContext)
7373
throws ParameterResolutionException {
7474

75-
if (parameterContext.getParameter().getAnnotation(Template.class) == null) {
75+
if (parameterContext.getParameter().getAnnotation(org.springframework.data.mongodb.test.util.Template.class) == null) {
7676
return super.resolveParameter(parameterContext, extensionContext);
7777
}
7878

7979
Class<?> parameterType = parameterContext.getParameter().getType();
80-
return getMongoTemplate(parameterType, parameterContext.getParameter().getAnnotation(Template.class),
80+
return getMongoTemplate(parameterType, parameterContext.getParameter().getAnnotation(org.springframework.data.mongodb.test.util.Template.class),
8181
extensionContext);
8282
}
8383

8484
private void injectFields(ExtensionContext context, Object testInstance, Predicate<Field> predicate) {
8585

86-
AnnotationUtils.findAnnotatedFields(context.getRequiredTestClass(), Template.class, predicate).forEach(field -> {
86+
AnnotationUtils.findAnnotatedFields(context.getRequiredTestClass(), org.springframework.data.mongodb.test.util.Template.class, predicate).forEach(field -> {
8787

8888
assertValidFieldCandidate(field);
8989

9090
try {
9191

9292
ReflectionUtils.makeAccessible(field).set(testInstance,
93-
getMongoTemplate(field.getType(), field.getAnnotation(Template.class), context));
93+
getMongoTemplate(field.getType(), field.getAnnotation(org.springframework.data.mongodb.test.util.Template.class), context));
9494
} catch (Throwable t) {
9595
ExceptionUtils.throwAsUncheckedException(t);
9696
}
@@ -107,14 +107,14 @@ private void assertSupportedType(String target, Class<?> type) {
107107
if (!ClassUtils.isAssignable(MongoOperations.class, type)
108108
&& !ClassUtils.isAssignable(ReactiveMongoOperations.class, type)) {
109109
throw new ExtensionConfigurationException(
110-
String.format("Can only resolve @%s %s of type %s or %s but was: %s", Template.class.getSimpleName(), target,
110+
String.format("Can only resolve @%s %s of type %s or %s but was: %s", org.springframework.data.mongodb.test.util.Template.class.getSimpleName(), target,
111111
MongoOperations.class.getName(), ReactiveMongoOperations.class.getName(), type.getName()));
112112
}
113113
}
114114

115-
private Object getMongoTemplate(Class<?> type, Template options, ExtensionContext extensionContext) {
115+
private Object getMongoTemplate(Class<?> type, org.springframework.data.mongodb.test.util.Template options, ExtensionContext extensionContext) {
116116

117-
Store templateStore = extensionContext.getStore(MongoExtensions.Termplate.NAMESPACE);
117+
Store templateStore = extensionContext.getStore(Template.NAMESPACE);
118118

119119
boolean replSetClient = holdsReplSetClient(extensionContext) || options.replicaSet();
120120

@@ -126,7 +126,7 @@ private Object getMongoTemplate(Class<?> type, Template options, ExtensionContex
126126

127127
if (ClassUtils.isAssignable(MongoOperations.class, type)) {
128128

129-
String key = Termplate.SYNC + "-" + dbName;
129+
String key = Template.SYNC + "-" + dbName;
130130
return templateStore.getOrComputeIfAbsent(key, it -> {
131131

132132
com.mongodb.client.MongoClient client = (com.mongodb.client.MongoClient) getMongoClient(
@@ -137,7 +137,7 @@ private Object getMongoTemplate(Class<?> type, Template options, ExtensionContex
137137

138138
if (ClassUtils.isAssignable(ReactiveMongoOperations.class, type)) {
139139

140-
String key = Termplate.REACTIVE + "-" + dbName;
140+
String key = Template.REACTIVE + "-" + dbName;
141141
return templateStore.getOrComputeIfAbsent(key, it -> {
142142

143143
com.mongodb.reactivestreams.client.MongoClient client = (com.mongodb.reactivestreams.client.MongoClient) getMongoClient(

0 commit comments

Comments
 (0)