Skip to content

Commit 02687b8

Browse files
christophstroblmp911de
authored andcommitted
Introduce @EnableIfVectorSearchAvailable to wait and conditionally skip tests.
We now wait until a search index becomes available. If the search index doesn't come alive within 60 seconds, we skip that test (or test class). Closes: #5013 Original pull request: #5014
1 parent 01cd906 commit 02687b8

File tree

5 files changed

+157
-27
lines changed

5 files changed

+157
-27
lines changed

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/VectorIndexIntegrationTests.java

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
*/
1616
package org.springframework.data.mongodb.core.index;
1717

18-
import static org.assertj.core.api.Assertions.*;
19-
import static org.awaitility.Awaitility.*;
18+
import static org.assertj.core.api.Assertions.assertThatRuntimeException;
19+
import static org.awaitility.Awaitility.await;
2020
import static org.springframework.data.mongodb.test.util.Assertions.assertThat;
2121

2222
import java.util.List;
@@ -26,16 +26,17 @@
2626
import org.junit.jupiter.api.AfterEach;
2727
import org.junit.jupiter.api.BeforeEach;
2828
import org.junit.jupiter.api.Test;
29+
import org.junit.jupiter.api.extension.ExtendWith;
2930
import org.junit.jupiter.params.ParameterizedTest;
3031
import org.junit.jupiter.params.provider.ValueSource;
31-
3232
import org.springframework.data.annotation.Id;
3333
import org.springframework.data.mongodb.core.index.VectorIndex.SimilarityFunction;
3434
import org.springframework.data.mongodb.core.mapping.Field;
3535
import org.springframework.data.mongodb.test.util.AtlasContainer;
36+
import org.springframework.data.mongodb.test.util.EnableIfVectorSearchAvailable;
37+
import org.springframework.data.mongodb.test.util.MongoServerCondition;
3638
import org.springframework.data.mongodb.test.util.MongoTestTemplate;
3739
import org.springframework.data.mongodb.test.util.MongoTestUtils;
38-
3940
import org.testcontainers.junit.jupiter.Container;
4041
import org.testcontainers.junit.jupiter.Testcontainers;
4142

@@ -48,6 +49,7 @@
4849
* @author Christoph Strobl
4950
* @author Mark Paluch
5051
*/
52+
@ExtendWith(MongoServerCondition.class)
5153
@Testcontainers(disabledWithoutDocker = true)
5254
class VectorIndexIntegrationTests {
5355

@@ -66,6 +68,7 @@ class VectorIndexIntegrationTests {
6668

6769
@BeforeEach
6870
void init() {
71+
6972
template.createCollection(Movie.class);
7073
indexOps = template.searchIndexOps(Movie.class);
7174
}
@@ -79,6 +82,7 @@ void cleanup() {
7982

8083
@ParameterizedTest // GH-4706
8184
@ValueSource(strings = { "euclidean", "cosine", "dotProduct" })
85+
@EnableIfVectorSearchAvailable(collection = Movie.class)
8286
void createsSimpleVectorIndex(String similarityFunction) {
8387

8488
VectorIndex idx = new VectorIndex("vector_index").addVector("plotEmbedding",
@@ -98,6 +102,7 @@ void createsSimpleVectorIndex(String similarityFunction) {
98102
}
99103

100104
@Test // GH-4706
105+
@EnableIfVectorSearchAvailable(collection = Movie.class)
101106
void dropIndex() {
102107

103108
VectorIndex idx = new VectorIndex("vector_index").addVector("plotEmbedding",
@@ -113,6 +118,7 @@ void dropIndex() {
113118
}
114119

115120
@Test // GH-4706
121+
@EnableIfVectorSearchAvailable(collection = Movie.class)
116122
void statusChanges() throws InterruptedException {
117123

118124
String indexName = "vector_index";
@@ -131,6 +137,7 @@ void statusChanges() throws InterruptedException {
131137
}
132138

133139
@Test // GH-4706
140+
@EnableIfVectorSearchAvailable(collection = Movie.class)
134141
void exists() throws InterruptedException {
135142

136143
String indexName = "vector_index";
@@ -148,6 +155,7 @@ void exists() throws InterruptedException {
148155
}
149156

150157
@Test // GH-4706
158+
@EnableIfVectorSearchAvailable(collection = Movie.class)
151159
void updatesVectorIndex() throws InterruptedException {
152160

153161
String indexName = "vector_index";
@@ -177,6 +185,7 @@ void updatesVectorIndex() throws InterruptedException {
177185
}
178186

179187
@Test // GH-4706
188+
@EnableIfVectorSearchAvailable(collection = Movie.class)
180189
void createsVectorIndexWithFilters() throws InterruptedException {
181190

182191
VectorIndex idx = new VectorIndex("vector_index")

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/AtlasContainer.java

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@
1616
package org.springframework.data.mongodb.test.util;
1717

1818
import org.springframework.core.env.StandardEnvironment;
19-
2019
import org.testcontainers.mongodb.MongoDBAtlasLocalContainer;
2120
import org.testcontainers.utility.DockerImageName;
2221

22+
import com.github.dockerjava.api.command.InspectContainerResponse;
23+
2324
/**
24-
* Extension to MongoDBAtlasLocalContainer.
25+
* Extension to {@link MongoDBAtlasLocalContainer}. Registers mapped host an port as system properties
26+
* ({@link #ATLAS_HOST}, {@link #ATLAS_PORT}).
2527
*
2628
* @author Christoph Strobl
2729
*/
@@ -31,6 +33,9 @@ public class AtlasContainer extends MongoDBAtlasLocalContainer {
3133
private static final String DEFAULT_TAG = "8.0.0";
3234
private static final String LATEST = "latest";
3335

36+
public static final String ATLAS_HOST = "docker.mongodb.atlas.host";
37+
public static final String ATLAS_PORT = "docker.mongodb.atlas.port";
38+
3439
private AtlasContainer(String dockerImageName) {
3540
super(DockerImageName.parse(dockerImageName));
3641
}
@@ -55,4 +60,20 @@ public static AtlasContainer tagged(String tag) {
5560
return new AtlasContainer(DEFAULT_IMAGE_NAME.withTag(tag));
5661
}
5762

63+
@Override
64+
protected void containerIsStarted(InspectContainerResponse containerInfo) {
65+
66+
super.containerIsStarted(containerInfo);
67+
68+
System.setProperty(ATLAS_HOST, getHost());
69+
System.setProperty(ATLAS_PORT, getMappedPort(27017).toString());
70+
}
71+
72+
@Override
73+
protected void containerIsStopping(InspectContainerResponse containerInfo) {
74+
75+
System.clearProperty(ATLAS_HOST);
76+
System.clearProperty(ATLAS_PORT);
77+
super.containerIsStopping(containerInfo);
78+
}
5879
}

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/EnableIfVectorSearchAvailable.java

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,30 @@
2525
import org.junit.jupiter.api.extension.ExtendWith;
2626

2727
/**
28+
* {@link EnableIfVectorSearchAvailable} indicates a specific method can only be run in an environment that has a search
29+
* server available. This means that not only the mongodb instance needs to have a
30+
* {@literal searchIndexManagementHostAndPort} configured, but also that the search index sever is actually up and
31+
* running, responding to a {@literal $listSearchIndexes} aggregation.
32+
*
2833
* @author Christoph Strobl
34+
* @since 5.0
35+
* @see Tag
2936
*/
30-
@Target({ ElementType.TYPE, ElementType.METHOD })
37+
@Target({ ElementType.METHOD })
3138
@Retention(RetentionPolicy.RUNTIME)
3239
@Documented
3340
@Tag("vector-search")
3441
@ExtendWith(MongoServerCondition.class)
3542
public @interface EnableIfVectorSearchAvailable {
3643

44+
/**
45+
* @return the name of the collection used to run the {@literal $listSearchIndexes} aggregation.
46+
*/
47+
String collectionName() default "";
48+
49+
/**
50+
* @return the type for resolving the name of the collection used to run the {@literal $listSearchIndexes}
51+
* aggregation. The {@link #collectionName()} has precedence over the type.
52+
*/
53+
Class<?> collection() default Object.class;
3754
}

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoServerCondition.java

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,21 @@
1515
*/
1616
package org.springframework.data.mongodb.test.util;
1717

18+
import java.time.Duration;
19+
1820
import org.junit.jupiter.api.extension.ConditionEvaluationResult;
1921
import org.junit.jupiter.api.extension.ExecutionCondition;
2022
import org.junit.jupiter.api.extension.ExtensionContext;
2123
import org.junit.jupiter.api.extension.ExtensionContext.Namespace;
2224
import org.springframework.core.annotation.AnnotatedElementUtils;
25+
import org.springframework.data.mongodb.MongoCollectionUtils;
2326
import org.springframework.data.util.Version;
27+
import org.springframework.util.NumberUtils;
28+
import org.springframework.util.StringUtils;
29+
import org.testcontainers.shaded.org.awaitility.Awaitility;
30+
31+
import com.mongodb.Function;
32+
import com.mongodb.client.MongoClient;
2433

2534
/**
2635
* @author Christoph Strobl
@@ -42,10 +51,13 @@ public ConditionEvaluationResult evaluateExecutionCondition(ExtensionContext con
4251
}
4352
}
4453

45-
if(context.getTags().contains("vector-search")) {
46-
if(!atlasEnvironment(context)) {
54+
if (context.getTags().contains("vector-search")) {
55+
if (!atlasEnvironment(context)) {
4756
return ConditionEvaluationResult.disabled("Disabled for servers not supporting Vector Search.");
4857
}
58+
if (!isSearchIndexAvailable(context)) {
59+
return ConditionEvaluationResult.disabled("Search index unavailable.");
60+
}
4961
}
5062

5163
if (context.getTags().contains("version-specific") && context.getElement().isPresent()) {
@@ -90,8 +102,55 @@ private Version serverVersion(ExtensionContext context) {
90102
Version.class);
91103
}
92104

105+
private boolean isSearchIndexAvailable(ExtensionContext context) {
106+
107+
EnableIfVectorSearchAvailable vectorSearchAvailable = AnnotatedElementUtils
108+
.findMergedAnnotation(context.getElement().get(), EnableIfVectorSearchAvailable.class);
109+
110+
if (vectorSearchAvailable == null) {
111+
return true;
112+
}
113+
114+
String collectionName = StringUtils.hasText(vectorSearchAvailable.collectionName())
115+
? vectorSearchAvailable.collectionName()
116+
: MongoCollectionUtils.getPreferredCollectionName(vectorSearchAvailable.collection());
117+
118+
return context.getStore(NAMESPACE).getOrComputeIfAbsent("search-index-%s-available".formatted(collectionName),
119+
(key) -> {
120+
try {
121+
doWithClient(client -> {
122+
Awaitility.await().atMost(Duration.ofSeconds(60)).pollInterval(Duration.ofMillis(200)).until(() -> {
123+
return MongoTestUtils.isSearchIndexReady(client, null, collectionName);
124+
});
125+
return "done waiting for search index";
126+
});
127+
} catch (Exception e) {
128+
return false;
129+
}
130+
return true;
131+
}, Boolean.class);
132+
133+
}
134+
93135
private boolean atlasEnvironment(ExtensionContext context) {
94-
return context.getStore(NAMESPACE).getOrComputeIfAbsent(Version.class, (key) -> MongoTestUtils.isVectorSearchEnabled(),
95-
Boolean.class);
136+
137+
return context.getStore(NAMESPACE).getOrComputeIfAbsent("mongodb-atlas",
138+
(key) -> doWithClient(MongoTestUtils::isVectorSearchEnabled), Boolean.class);
139+
}
140+
141+
private <T> T doWithClient(Function<MongoClient, T> function) {
142+
143+
String host = System.getProperty(AtlasContainer.ATLAS_HOST);
144+
String port = System.getProperty(AtlasContainer.ATLAS_PORT);
145+
146+
if (StringUtils.hasText(host) && StringUtils.hasText(port)) {
147+
try (MongoClient client = MongoTestUtils.client(host, NumberUtils.parseNumber(port, Integer.class))) {
148+
return function.apply(client);
149+
}
150+
}
151+
152+
try (MongoClient client = MongoTestUtils.client()) {
153+
return function.apply(client);
154+
}
96155
}
97156
}

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/test/util/MongoTestUtils.java

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,15 @@
1515
*/
1616
package org.springframework.data.mongodb.test.util;
1717

18+
import org.jspecify.annotations.Nullable;
19+
import org.springframework.util.StringUtils;
1820
import reactor.core.publisher.Mono;
1921
import reactor.test.StepVerifier;
2022
import reactor.util.retry.Retry;
2123

2224
import java.time.Duration;
2325
import java.util.List;
26+
import java.util.concurrent.TimeUnit;
2427

2528
import org.bson.Document;
2629
import org.springframework.core.env.Environment;
@@ -30,6 +33,7 @@
3033
import org.springframework.util.ObjectUtils;
3134

3235
import com.mongodb.ConnectionString;
36+
import com.mongodb.MongoClientSettings;
3337
import com.mongodb.ReadPreference;
3438
import com.mongodb.WriteConcern;
3539
import com.mongodb.client.MongoClient;
@@ -68,6 +72,10 @@ public static MongoClient client(String host, int port) {
6872
}
6973

7074
public static MongoClient client(ConnectionString connectionString) {
75+
MongoClientSettings settings = MongoClientSettings.builder().applyConnectionString(connectionString)
76+
.applyToSocketSettings(builder -> {
77+
builder.connectTimeout(120, TimeUnit.SECONDS);
78+
}).build();
7179
return com.mongodb.client.MongoClients.create(connectionString, SpringDataMongoDB.driverInformation());
7280
}
7381

@@ -176,11 +184,10 @@ public static void dropCollectionNow(String dbName, String collectionName,
176184
* @param collectionName must not be {@literal null}.
177185
* @param client must not be {@literal null}.
178186
*/
179-
public static void dropCollectionNow(String dbName, String collectionName,
180-
com.mongodb.client.MongoClient client) {
187+
public static void dropCollectionNow(String dbName, String collectionName, com.mongodb.client.MongoClient client) {
181188

182-
com.mongodb.client.MongoDatabase database = client.getDatabase(dbName)
183-
.withWriteConcern(WriteConcern.MAJORITY).withReadPreference(ReadPreference.primary());
189+
com.mongodb.client.MongoDatabase database = client.getDatabase(dbName).withWriteConcern(WriteConcern.MAJORITY)
190+
.withReadPreference(ReadPreference.primary());
184191

185192
database.getCollection(collectionName).drop();
186193
}
@@ -205,11 +212,10 @@ public static void flushCollection(String dbName, String collectionName,
205212
.verifyComplete();
206213
}
207214

208-
public static void flushCollection(String dbName, String collectionName,
209-
com.mongodb.client.MongoClient client) {
215+
public static void flushCollection(String dbName, String collectionName, com.mongodb.client.MongoClient client) {
210216

211-
com.mongodb.client.MongoDatabase database = client.getDatabase(dbName)
212-
.withWriteConcern(WriteConcern.MAJORITY).withReadPreference(ReadPreference.primary());
217+
com.mongodb.client.MongoDatabase database = client.getDatabase(dbName).withWriteConcern(WriteConcern.MAJORITY)
218+
.withReadPreference(ReadPreference.primary());
213219

214220
database.getCollection(collectionName).deleteMany(new Document());
215221
}
@@ -267,19 +273,36 @@ public static boolean serverIsReplSet() {
267273
@SuppressWarnings("unchecked")
268274
public static boolean isVectorSearchEnabled() {
269275
try (MongoClient client = MongoTestUtils.client()) {
276+
return isVectorSearchEnabled(client);
277+
}
278+
}
270279

280+
public static boolean isVectorSearchEnabled(MongoClient client) {
281+
try {
271282
return client.getDatabase("admin").runCommand(new Document("getCmdLineOpts", "1")).get("argv", List.class)
272-
.stream().anyMatch(it -> {
273-
if(it instanceof String cfgString) {
274-
return cfgString.startsWith("searchIndexManagementHostAndPort");
275-
}
276-
return false;
277-
});
283+
.stream().anyMatch(it -> {
284+
if (it instanceof String cfgString) {
285+
return cfgString.startsWith("searchIndexManagementHostAndPort");
286+
}
287+
return false;
288+
});
278289
} catch (Exception e) {
279290
return false;
280291
}
281292
}
282293

294+
public static boolean isSearchIndexReady(MongoClient client, @Nullable String database, String collectionName) {
295+
296+
try {
297+
MongoCollection<Document> collection = client.getDatabase(StringUtils.hasText(database) ? database : "test").getCollection(collectionName);
298+
collection.aggregate(List.of(new Document("$listSearchIndexes", new Document())));
299+
} catch (Exception e) {
300+
return false;
301+
}
302+
return true;
303+
304+
}
305+
283306
public static Duration getTimeout() {
284307

285308
return ObjectUtils.nullSafeEquals("jenkins", ENV.getProperty("user.name")) ? Duration.ofMillis(100)
@@ -297,10 +320,11 @@ private static void giveTheServerALittleTimeToThink() {
297320

298321
public static CollectionInfo readCollectionInfo(MongoDatabase db, String collectionName) {
299322

300-
List<Document> list = db.runCommand(new Document().append("listCollections", 1).append("filter", new Document("name", collectionName)))
323+
List<Document> list = db
324+
.runCommand(new Document().append("listCollections", 1).append("filter", new Document("name", collectionName)))
301325
.get("cursor", Document.class).get("firstBatch", List.class);
302326

303-
if(list.isEmpty()) {
327+
if (list.isEmpty()) {
304328
throw new IllegalStateException(String.format("Collection %s not found.", collectionName));
305329
}
306330
return CollectionInfo.from(list.get(0));

0 commit comments

Comments
 (0)