From 698f9b33e49dc959e4bf9313a94e08017c751ae5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Mon, 14 Apr 2025 12:08:30 +0200 Subject: [PATCH] Fix flaky test running blocking code in event loop MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Dariusz Jędrzejczyk --- .../WebFluxSseIntegrationTests.java | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 76f908b8a..57af9e9b6 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -35,10 +35,8 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; -import reactor.core.publisher.Mono; import reactor.netty.DisposableServer; import reactor.netty.http.server.HttpServer; -import reactor.test.StepVerifier; import org.springframework.http.server.reactive.HttpHandler; import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; @@ -47,6 +45,7 @@ import org.springframework.web.reactive.function.server.RouterFunctions; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertWith; import static org.awaitility.Awaitility.await; import static org.mockito.Mockito.mock; @@ -106,12 +105,9 @@ void testCreateMessageWithoutSamplingCapabilities(String clientType) { var clientBuilder = clientBuilders.get(clientType); McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - - exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)).block(); - - return Mono.just(mock(CallToolResult.class)); - }); + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), + (exchange, request) -> exchange.createMessage(mock(CreateMessageRequest.class)) + .thenReturn(mock(CallToolResult.class))); var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); @@ -148,6 +144,8 @@ void testCreateMessageSuccess(String clientType) throws InterruptedException { CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + AtomicReference samplingResult = new AtomicReference<>(); + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { @@ -162,16 +160,9 @@ void testCreateMessageSuccess(String clientType) throws InterruptedException { .build()) .build(); - StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }).verifyComplete(); - - return Mono.just(callResponse); + return exchange.createMessage(craeteMessageRequest) + .doOnNext(samplingResult::set) + .thenReturn(callResponse); }); var mcpServer = McpServer.async(mcpServerTransportProvider) @@ -191,8 +182,17 @@ void testCreateMessageSuccess(String clientType) throws InterruptedException { assertThat(response).isNotNull(); assertThat(response).isEqualTo(callResponse); + + assertWith(samplingResult.get(), result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }); } - mcpServer.close(); + mcpServer.closeGracefully().block(); } // ---------------------------------------