diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java index e6045170..53b59cb3 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java @@ -354,7 +354,7 @@ private Flux extractError(ClientResponse response, Str if (responseException.getStatusCode().isSameCodeAs(HttpStatus.BAD_REQUEST)) { return Mono.error(new McpTransportSessionNotFoundException(sessionRepresentation, toPropagate)); } - return Mono.empty(); + return Mono.error(toPropagate); }).flux(); } diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java index 43d6f40f..b7a9e4a0 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -32,6 +32,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; import reactor.test.StepVerifier; import org.springframework.context.annotation.Bean; @@ -96,9 +97,11 @@ public void before() { @AfterEach public void after() { + reactor.netty.http.HttpResources.disposeLoopsAndConnections(); if (mcpServerTransportProvider != null) { mcpServerTransportProvider.closeGracefully().block(); } + Schedulers.shutdownNow(); if (tomcatServer.appContext() != null) { tomcatServer.appContext().close(); } @@ -779,6 +782,33 @@ void testToolCallSuccess() { mcpServer.close(); } + @Test + void testThrowingToolCallIsCaughtBeforeTimeout() { + McpSyncServer mcpServer = McpServer.sync(mcpServerTransportProvider) + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.SyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + // We trigger a timeout on blocking read, raising an exception + Mono.never().block(Duration.ofSeconds(1)); + return null; + })) + .build(); + + try (var mcpClient = clientBuilder.requestTimeout(Duration.ofMillis(6666)).build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // We expect the tool call to fail immediately with the exception raised by + // the offending tool + // instead of getting back a timeout. + assertThatExceptionOfType(McpError.class) + .isThrownBy(() -> mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()))) + .withMessageContaining("Timeout on blocking read"); + } + + mcpServer.close(); + } + @Test void testToolListChangeHandlingSuccess() { diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index ab48fc0f..aaea4bb0 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -321,82 +321,81 @@ public HttpClientSseClientTransport build() { } + protected Flux eventStream() { + + HttpRequest request = requestBuilder.copy() + .uri(Utils.resolveUri(this.baseUri, this.sseEndpoint)) + .header("Accept", "text/event-stream") + .header("Cache-Control", "no-cache") + .GET() + .build(); + + return Flux.create(sseSink -> this.httpClient + .sendAsync(request, responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, sseSink)) + .exceptionallyCompose(e -> { + sseSink.error(e); + return CompletableFuture.failedFuture(e); + })).map(responseEvent -> (ResponseSubscribers.SseResponseEvent) responseEvent); + } + @Override public Mono connect(Function, Mono> handler) { return Mono.create(sink -> { - HttpRequest request = requestBuilder.copy() - .uri(Utils.resolveUri(this.baseUri, this.sseEndpoint)) - .header("Accept", "text/event-stream") - .header("Cache-Control", "no-cache") - .GET() - .build(); - - Disposable connection = Flux.create(sseSink -> this.httpClient - .sendAsync(request, responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, sseSink)) - .exceptionallyCompose(e -> { - sseSink.error(e); - return CompletableFuture.failedFuture(e); - })) - .map(responseEvent -> (ResponseSubscribers.SseResponseEvent) responseEvent) - .flatMap(responseEvent -> { - if (isClosing) { - return Mono.empty(); - } + Flux events = eventStream(); - int statusCode = responseEvent.responseInfo().statusCode(); - - if (statusCode >= 200 && statusCode < 300) { - try { - if (ENDPOINT_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { - String messageEndpointUri = responseEvent.sseEvent().data(); - if (this.messageEndpointSink.tryEmitValue(messageEndpointUri).isSuccess()) { - sink.success(); - return Flux.empty(); // No further processing needed - } - else { - sink.error(new McpError("Failed to handle SSE endpoint event")); - } - } - else if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { - JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, - responseEvent.sseEvent().data()); + Disposable connection = events.flatMap(responseEvent -> { + if (isClosing) { + return Mono.empty(); + } + + int statusCode = responseEvent.responseInfo().statusCode(); + + if (statusCode >= 200 && statusCode < 300) { + try { + if (ENDPOINT_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { + String messageEndpointUri = responseEvent.sseEvent().data(); + if (this.messageEndpointSink.tryEmitValue(messageEndpointUri).isSuccess()) { sink.success(); - return Flux.just(message); + return Flux.empty(); // No further processing needed } else { - logger.error("Received unrecognized SSE event type: {}", - responseEvent.sseEvent().event()); - sink.error(new McpError( - "Received unrecognized SSE event type: " + responseEvent.sseEvent().event())); + sink.error(new McpError("Failed to handle SSE endpoint event")); } } - catch (IOException e) { - logger.error("Error processing SSE event", e); - sink.error(new McpError("Error processing SSE event")); + else if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { + JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, + responseEvent.sseEvent().data()); + sink.success(); + return Flux.just(message); + } + else { + logger.error("Received unrecognized SSE event type: {}", responseEvent.sseEvent().event()); + sink.error(new McpError( + "Received unrecognized SSE event type: " + responseEvent.sseEvent().event())); } } - return Flux.error( - new RuntimeException("Failed to send message: " + responseEvent)); - - }) - .flatMap(jsonRpcMessage -> handler.apply(Mono.just(jsonRpcMessage))) - .onErrorComplete(t -> { - if (!isClosing) { - logger.warn("SSE stream observed an error", t); - sink.error(t); - } - return true; - }) - .doFinally(s -> { - Disposable ref = this.sseSubscription.getAndSet(null); - if (ref != null && !ref.isDisposed()) { - ref.dispose(); + catch (IOException e) { + logger.error("Error processing SSE event", e); + sink.error(new McpError("Error processing SSE event")); } - }) - .contextWrite(sink.contextView()) - .subscribe(); + } + return Flux.error( + new RuntimeException("Failed to send message: " + responseEvent)); + + }).flatMap(jsonRpcMessage -> handler.apply(Mono.just(jsonRpcMessage))).onErrorComplete(t -> { + if (!isClosing) { + logger.warn("SSE stream observed an error", t); + sink.error(t); + } + return true; + }).doFinally(s -> { + Disposable ref = this.sseSubscription.getAndSet(null); + if (ref != null && !ref.isDisposed()) { + ref.dispose(); + } + }).contextWrite(sink.contextView()).subscribe(); this.sseSubscription.set(connection); }); @@ -421,13 +420,17 @@ public Mono sendMessage(JSONRPCMessage message) { } return this.serializeMessage(message) - .flatMap(body -> sendHttpPost(messageEndpointUri, body)) - .doOnNext(response -> { + .flatMap(body -> sendHttpPost(messageEndpointUri, body).handle((response, sink) -> { if (response.statusCode() != 200 && response.statusCode() != 201 && response.statusCode() != 202 && response.statusCode() != 206) { - logger.error("Error sending message: {}", response.statusCode()); + sink.error(new RuntimeException( + "Sending message failed with a non-OK HTTP code: " + response.statusCode())); } - }) + else { + sink.next(response); + sink.complete(); + } + })) .doOnError(error -> { if (!isClosing) { logger.error("Error sending message: {}", error.getMessage()); diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java index e4348be2..ca0a2282 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java @@ -4,34 +4,43 @@ package io.modelcontextprotocol.client.transport; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; + import java.net.URI; import java.net.http.HttpClient; +import java.net.http.HttpClient.Version; +import java.net.http.HttpHeaders; import java.net.http.HttpRequest; +import java.net.http.HttpResponse.ResponseInfo; import java.time.Duration; import java.util.Map; +import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.client.transport.ResponseSubscribers.SseResponseEvent; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCNotification; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCResponse; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.publisher.Sinks; import reactor.test.StepVerifier; -import org.springframework.http.codec.ServerSentEvent; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; - /** * Tests for the {@link HttpClientSseClientTransport} class. * @@ -51,15 +60,55 @@ class HttpClientSseClientTransportTests { private TestHttpClientSseClientTransport transport; + public record MyResponseInfo(int statusCode, HttpHeaders headers, Version version) implements ResponseInfo { + MyResponseInfo(int statusCode, HttpHeaders headers) { + this(statusCode, headers, Version.HTTP_1_1); + } + + MyResponseInfo(int statusCode) { + this(statusCode, HttpHeaders.of(Map.of(), (k, v) -> true), Version.HTTP_1_1); + } + } + // Test class to access protected methods static class TestHttpClientSseClientTransport extends HttpClientSseClientTransport { private final AtomicInteger inboundMessageCount = new AtomicInteger(0); - private Sinks.Many> events = Sinks.many().unicast().onBackpressureBuffer(); + private Sinks.Many events = Sinks.many().unicast().onBackpressureBuffer(); public TestHttpClientSseClientTransport(final String baseUri) { - super(HttpClient.newHttpClient(), HttpRequest.newBuilder(), baseUri, "/sse", new ObjectMapper()); + super(HttpClient.newBuilder().version(HttpClient.Version.HTTP_1_1).build(), + HttpRequest.newBuilder().header("Content-Type", "application/json"), baseUri, "/sse", + new ObjectMapper()); + } + + CopyOnWriteArrayList requestMessages = new CopyOnWriteArrayList<>(); + + CopyOnWriteArrayList notificationMessages = new CopyOnWriteArrayList<>(); + + CopyOnWriteArrayList responseMessages = new CopyOnWriteArrayList<>(); + + Function, Mono> handler = (messageMono) -> messageMono + .doOnNext(message -> { + // System.out.println("Received message $$$$$$$$$$$$$$: " + message); + if (message instanceof JSONRPCRequest request) { + requestMessages.add(request); + } + else if (message instanceof JSONRPCNotification notificaiton) { + notificationMessages.add(notificaiton); + } + else if (message instanceof JSONRPCResponse response) { + responseMessages.add(response); + } + else { + throw new IllegalArgumentException("Unsupported message type: " + message.getClass()); + } + }); + + @Override + protected Flux eventStream() { + return super.eventStream().mergeWith(events.asFlux()); } public int getInboundMessageCount() { @@ -67,12 +116,14 @@ public int getInboundMessageCount() { } public void simulateEndpointEvent(String jsonMessage) { - events.tryEmitNext(ServerSentEvent.builder().event("endpoint").data(jsonMessage).build()); + events.tryEmitNext(new SseResponseEvent(new MyResponseInfo(200), + new ResponseSubscribers.SseEvent(null, "endpoint", jsonMessage))); inboundMessageCount.incrementAndGet(); } public void simulateMessageEvent(String jsonMessage) { - events.tryEmitNext(ServerSentEvent.builder().event("message").data(jsonMessage).build()); + events.tryEmitNext(new SseResponseEvent(new MyResponseInfo(200), + new ResponseSubscribers.SseEvent(null, "message", jsonMessage))); inboundMessageCount.incrementAndGet(); } @@ -88,7 +139,7 @@ void startContainer() { void setUp() { startContainer(); transport = new TestHttpClientSseClientTransport(host); - transport.connect(Function.identity()).block(); + transport.connect(transport.handler).block(); } @AfterEach @@ -123,6 +174,7 @@ void testMessageProcessing() { StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); assertThat(transport.getInboundMessageCount()).isEqualTo(1); + assertThat(transport.requestMessages).hasSize(1); } @Test