From 8c39fb8f3dfd2dc82ad2b9adf31103bb7c538c59 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Fri, 13 Jun 2025 07:05:47 +0200 Subject: [PATCH 01/13] feat: implement Streamable HTTP transport and refactor SSE transport to reactive streams This change imporves the transport layer with reactive patterns and adds support for the latest MCP specification while maintaining backward compatibility with existing SSE transport. - Add HttpClientStreamableHttpTransport implementing 2025-03-26 MCP Streamable HTTP spec - Add ResponseSubscribers utility for handling SSE and JSON HTTP responses - Refactor HttpClientSseClientTransport to use reactive streams instead of CompletableFuture - Replace FlowSseClient with direct reactive stream handling - Use Disposable-based connection management instead of CountDownLatch - Replace message endpoint discovery with Sinks.One approach - Add resiliency tests using Toxiproxy for network failure scenarios - Minor type safety improvements in StdioClientTransport and DefaultMcpTransportStream Signed-off-by: Christian Tzolov --- ...bClientStreamableHttpAsyncClientTests.java | 6 +- ...ebClientStreamableHttpSyncClientTests.java | 6 +- mcp/pom.xml | 18 +- .../HttpClientSseClientTransport.java | 260 ++++--- .../HttpClientStreamableHttpTransport.java | 668 ++++++++++++++++++ .../client/transport/ResponseSubscribers.java | 354 ++++++++++ .../transport/StdioClientTransport.java | 2 +- .../spec/DefaultMcpTransportStream.java | 21 +- .../MockMcpServerTransportProvider.java | 2 - ...AbstractMcpAsyncClientResiliencyTests.java | 222 ++++++ ...eamableHttpAsyncClientResiliencyTests.java | 20 + ...pClientStreamableHttpAsyncClientTests.java | 41 ++ ...tpClientStreamableHttpSyncClientTests.java | 40 ++ .../client/HttpSseMcpAsyncClientTests.java | 5 +- .../HttpClientSseClientTransportTests.java | 31 +- mcp/src/test/resources/logback.xml | 3 + 16 files changed, 1533 insertions(+), 166 deletions(-) create mode 100644 mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientResiliencyTests.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientTests.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpSyncClientTests.java diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java index f824193fd..5ff707b3c 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java @@ -1,12 +1,12 @@ package io.modelcontextprotocol.client; -import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; -import io.modelcontextprotocol.spec.McpClientTransport; import org.junit.jupiter.api.Timeout; +import org.springframework.web.reactive.function.client.WebClient; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; -import org.springframework.web.reactive.function.client.WebClient; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; @Timeout(15) public class WebClientStreamableHttpAsyncClientTests extends AbstractMcpAsyncClientTests { diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java index 9ecd8a7d1..70260c8bf 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java @@ -1,12 +1,12 @@ package io.modelcontextprotocol.client; -import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; -import io.modelcontextprotocol.spec.McpClientTransport; import org.junit.jupiter.api.Timeout; +import org.springframework.web.reactive.function.client.WebClient; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; -import org.springframework.web.reactive.function.client.WebClient; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; @Timeout(15) public class WebClientStreamableHttpSyncClientTests extends AbstractMcpSyncClientTests { diff --git a/mcp/pom.xml b/mcp/pom.xml index 773432827..829b99bc1 100644 --- a/mcp/pom.xml +++ b/mcp/pom.xml @@ -141,11 +141,11 @@ - net.bytebuddy - byte-buddy - ${byte-buddy.version} - test - + net.bytebuddy + byte-buddy + ${byte-buddy.version} + test + io.projectreactor reactor-test @@ -202,6 +202,14 @@ test + + org.testcontainers + toxiproxy + ${toxiproxy.version} + test + + + diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index d951349d1..9beb2b373 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -9,30 +9,32 @@ import java.net.http.HttpRequest; import java.net.http.HttpResponse; import java.time.Duration; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Function; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.client.transport.FlowSseClient.SseEvent; + +import io.modelcontextprotocol.client.transport.ResponseSubscribers.ResponseEvent; import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.Utils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; /** * Server-Sent Events (SSE) implementation of the * {@link io.modelcontextprotocol.spec.McpTransport} that follows the MCP HTTP with SSE - * transport specification, using Java's HttpClient. + * transport specification, using Java's HttpClient and FlowSseClient. * *

* This transport implementation establishes a bidirectional communication channel between @@ -75,9 +77,6 @@ public class HttpClientSseClientTransport implements McpClientTransport { /** SSE endpoint path */ private final String sseEndpoint; - /** SSE client for handling server-sent events. Uses the /sse endpoint */ - private final FlowSseClient sseClient; - /** * HTTP client for sending messages to the server. Uses HTTP POST over the message * endpoint @@ -93,19 +92,26 @@ public class HttpClientSseClientTransport implements McpClientTransport { /** Flag indicating if the transport is in closing state */ private volatile boolean isClosing = false; - /** Latch for coordinating endpoint discovery */ - private final CountDownLatch closeLatch = new CountDownLatch(1); + // /** Latch for coordinating endpoint discovery */ + // private final CountDownLatch closeLatch = new CountDownLatch(1); - /** Holds the discovered message endpoint URL */ - private final AtomicReference messageEndpoint = new AtomicReference<>(); + // /** Holds the discovered message endpoint URL */ + // private final AtomicReference messageEndpoint = new + // AtomicReference<>(); - /** Holds the SSE connection future */ - private final AtomicReference> connectionFuture = new AtomicReference<>(); + /** Holds the SSE subscription disposable */ + private final AtomicReference sseSubscription = new AtomicReference<>(); + + /** + * Sink for managing the message endpoint URI provided by the server. Stores the most + * recent endpoint URI and makes it available for outbound message processing. + */ + protected final Sinks.One messageEndpointSink = Sinks.one(); /** * Creates a new transport instance with default HTTP client and object mapper. * @param baseUri the base URI of the MCP server - * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This + * @deprecated Use {@link HttpClientSseClientTransport2#builder(String)} instead. This * constructor will be removed in future versions. */ @Deprecated(forRemoval = true) @@ -119,7 +125,7 @@ public HttpClientSseClientTransport(String baseUri) { * @param baseUri the base URI of the MCP server * @param objectMapper the object mapper for JSON serialization/deserialization * @throws IllegalArgumentException if objectMapper or clientBuilder is null - * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This + * @deprecated Use {@link HttpClientSseClientTransport2#builder(String)} instead. This * constructor will be removed in future versions. */ @Deprecated(forRemoval = true) @@ -134,7 +140,7 @@ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, String bas * @param sseEndpoint the SSE endpoint path * @param objectMapper the object mapper for JSON serialization/deserialization * @throws IllegalArgumentException if objectMapper or clientBuilder is null - * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This + * @deprecated Use {@link HttpClientSseClientTransport2#builder(String)} instead. This * constructor will be removed in future versions. */ @Deprecated(forRemoval = true) @@ -152,7 +158,7 @@ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, String bas * @param sseEndpoint the SSE endpoint path * @param objectMapper the object mapper for JSON serialization/deserialization * @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null - * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This + * @deprecated Use {@link HttpClientSseClientTransport2#builder(String)} instead. This * constructor will be removed in future versions. */ @Deprecated(forRemoval = true) @@ -184,12 +190,10 @@ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, HttpReques this.objectMapper = objectMapper; this.httpClient = httpClient; this.requestBuilder = requestBuilder; - - this.sseClient = new FlowSseClient(this.httpClient, requestBuilder); } /** - * Creates a new builder for {@link HttpClientSseClientTransport}. + * Creates a new builder for {@link HttpClientSseClientTransport2}. * @param baseUri the base URI of the MCP server * @return a new builder instance */ @@ -198,7 +202,7 @@ public static Builder builder(String baseUri) { } /** - * Builder for {@link HttpClientSseClientTransport}. + * Builder for {@link HttpClientSseClientTransport2}. */ public static class Builder { @@ -225,7 +229,7 @@ public static class Builder { /** * Creates a new builder with the specified base URI. * @param baseUri the base URI of the MCP server - * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. + * @deprecated Use {@link HttpClientSseClientTransport2#builder(String)} instead. * This constructor is deprecated and will be removed or made {@code protected} or * {@code private} in a future release. */ @@ -313,7 +317,7 @@ public Builder objectMapper(ObjectMapper objectMapper) { } /** - * Builds a new {@link HttpClientSseClientTransport} instance. + * Builds a new {@link HttpClientSseClientTransport2} instance. * @return a new transport instance */ public HttpClientSseClientTransport build() { @@ -323,63 +327,82 @@ public HttpClientSseClientTransport build() { } - /** - * Establishes the SSE connection with the server and sets up message handling. - * - *

- * This method: - *

- * @param handler the function to process received JSON-RPC messages - * @return a Mono that completes when the connection is established - */ @Override public Mono connect(Function, Mono> handler) { - CompletableFuture future = new CompletableFuture<>(); - connectionFuture.set(future); - - URI clientUri = Utils.resolveUri(this.baseUri, this.sseEndpoint); - sseClient.subscribe(clientUri.toString(), new FlowSseClient.SseEventHandler() { - @Override - public void onEvent(SseEvent event) { - if (isClosing) { - return; - } - try { - if (ENDPOINT_EVENT_TYPE.equals(event.type())) { - String endpoint = event.data(); - messageEndpoint.set(endpoint); - closeLatch.countDown(); - future.complete(null); + return Mono.create(sink -> { + + HttpRequest request = requestBuilder.copy() + .uri(Utils.resolveUri(this.baseUri, this.sseEndpoint)) + .header("Accept", "text/event-stream") + .header("Cache-Control", "no-cache") + .GET() + .build(); + + Disposable connection = Flux + .create(sseSink -> this.httpClient.sendAsync(request, + responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, sseSink))) + .flatMap(responseEvent -> { + if (isClosing) { + return Mono.empty(); } - else if (MESSAGE_EVENT_TYPE.equals(event.type())) { - JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, event.data()); - handler.apply(Mono.just(message)).subscribe(); + + int statusCode = responseEvent.responseInfo().statusCode(); + + if (statusCode >= 200 && statusCode < 300) { + try { + if (ENDPOINT_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { + String messageEndpointUri = responseEvent.sseEvent().data(); + if (this.messageEndpointSink.tryEmitValue(messageEndpointUri).isSuccess()) { + sink.success(); + return Flux.empty(); // No further processing needed + } + else { + sink.error(new McpError("Failed to handle SSE endpoint event")); + } + } + else if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { + JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, + responseEvent.sseEvent().data()); + sink.success(); + return Flux.just(message); + } + else { + logger.error("Received unrecognized SSE event type: {}", + responseEvent.sseEvent().event()); + sink.error(new McpError( + "Received unrecognized SSE event type: " + responseEvent.sseEvent().event())); + } + } + catch (IOException e) { + logger.error("Error processing SSE event", e); + sink.error(new McpError("Error processing SSE event")); + } } - else { - logger.error("Received unrecognized SSE event type: {}", event.type()); + return Flux.error( + new RuntimeException("Failed to send message: " + responseEvent)); + + }) + .flatMap(jsonRpcMessage -> handler.apply(Mono.just(jsonRpcMessage))) + .onErrorResume(t -> { + if (!isClosing) { + logger.error("SSE connection error", t); + sink.error(t); } - } - catch (IOException e) { - logger.error("Error processing SSE event", e); - future.completeExceptionally(e); - } - } + return Mono.empty(); - @Override - public void onError(Throwable error) { - if (!isClosing) { - logger.error("SSE connection error", error); - future.completeExceptionally(error); - } - } - }); + }) + .doFinally(s -> { + Disposable ref = this.sseSubscription.getAndSet(null); + if (ref != null && !ref.isDisposed()) { + ref.dispose(); + } + }) + .contextWrite(sink.contextView()) + .subscribe(); - return Mono.fromFuture(future); + this.sseSubscription.set(connection); + }); } /** @@ -394,45 +417,58 @@ public void onError(Throwable error) { */ @Override public Mono sendMessage(JSONRPCMessage message) { - if (isClosing) { - return Mono.empty(); - } - try { - if (!closeLatch.await(10, TimeUnit.SECONDS)) { - return Mono.error(new McpError("Failed to wait for the message endpoint")); + return this.messageEndpointSink.asMono().flatMap(messageEndpointUri -> { + if (isClosing) { + return Mono.empty(); } - } - catch (InterruptedException e) { - return Mono.error(new McpError("Failed to wait for the message endpoint")); - } - String endpoint = messageEndpoint.get(); - if (endpoint == null) { - return Mono.error(new McpError("No message endpoint available")); - } + try { + return this.serializeMessage(message) + .flatMap(body -> sendHttpPost(messageEndpointUri, body)) + .doOnNext(this::logIfNotOk) + .doOnError(error -> { + if (!isClosing) { + logger.error("Error sending message: {}", error.getMessage()); + } + }) + .then(); + } + catch (Exception e) { + if (!isClosing) { + return Mono.error(new RuntimeException("Failed to serialize message", e)); + } + return Mono.empty(); + } + }).then(); - try { - String jsonText = this.objectMapper.writeValueAsString(message); - URI requestUri = Utils.resolveUri(baseUri, endpoint); - HttpRequest request = this.requestBuilder.copy() - .uri(requestUri) - .POST(HttpRequest.BodyPublishers.ofString(jsonText)) - .build(); + } - return Mono.fromFuture( - httpClient.sendAsync(request, HttpResponse.BodyHandlers.discarding()).thenAccept(response -> { - if (response.statusCode() != 200 && response.statusCode() != 201 && response.statusCode() != 202 - && response.statusCode() != 206) { - logger.error("Error sending message: {}", response.statusCode()); - } - })); - } - catch (IOException e) { - if (!isClosing) { - return Mono.error(new RuntimeException("Failed to serialize message", e)); + private Mono serializeMessage(final JSONRPCMessage message) { + return Mono.defer(() -> { + try { + return Mono.just(objectMapper.writeValueAsString(message)); + } + catch (IOException e) { + return Mono.error(new McpError("Failed to serialize message")); } - return Mono.empty(); + }); + } + + private Mono> sendHttpPost(final String endpoint, final String body) { + final URI requestUri = Utils.resolveUri(baseUri, endpoint); + final HttpRequest request = this.requestBuilder.copy() + .uri(requestUri) + .POST(HttpRequest.BodyPublishers.ofString(body)) + .build(); + + return Mono.fromFuture(httpClient.sendAsync(request, HttpResponse.BodyHandlers.discarding())); + } + + private void logIfNotOk(final HttpResponse response) { + if (response.statusCode() != 200 && response.statusCode() != 201 && response.statusCode() != 202 + && response.statusCode() != 206) { + logger.error("Error sending message: {}", response.statusCode()); } } @@ -440,7 +476,7 @@ public Mono sendMessage(JSONRPCMessage message) { * Gracefully closes the transport connection. * *

- * Sets the closing flag and cancels any pending connection future. This prevents new + * Sets the closing flag and disposes of the SSE subscription. This prevents new * messages from being sent and allows ongoing operations to complete. * @return a Mono that completes when the closing process is initiated */ @@ -448,9 +484,9 @@ public Mono sendMessage(JSONRPCMessage message) { public Mono closeGracefully() { return Mono.fromRunnable(() -> { isClosing = true; - CompletableFuture future = connectionFuture.get(); - if (future != null && !future.isDone()) { - future.cancel(true); + Disposable subscription = sseSubscription.get(); + if (subscription != null && !subscription.isDisposed()) { + subscription.dispose(); } }); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java new file mode 100644 index 000000000..2c17b7148 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java @@ -0,0 +1,668 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +import java.io.IOException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandler; +import java.time.Duration; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; + +import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.client.transport.ResponseSubscribers.ResponseEvent; +import io.modelcontextprotocol.spec.DefaultMcpTransportSession; +import io.modelcontextprotocol.spec.DefaultMcpTransportStream; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpTransportSession; +import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; +import io.modelcontextprotocol.spec.McpTransportStream; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.Utils; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.FluxSink; +import reactor.core.publisher.Mono; +import reactor.util.function.Tuple2; +import reactor.util.function.Tuples; + +/** + * An implementation of the Streamable HTTP protocol as defined by the + * 2025-03-26 version of the MCP specification. + * + *

+ * The transport is capable of resumability and reconnects. It reacts to transport-level + * session invalidation and will propagate {@link McpTransportSessionNotFoundException + * appropriate exceptions} to the higher level abstraction layer when needed in order to + * allow proper state management. The implementation handles servers that are stateful and + * provide session meta information, but can also communicate with stateless servers that + * do not provide a session identifier and do not support SSE streams. + *

+ *

+ * This implementation does not handle backwards compatibility with the "HTTP + * with SSE" transport. In order to communicate over the phased-out + * 2024-11-05 protocol, use {@link HttpClientSseClientTransport} or + * {@link WebFluxSseClientTransport}. + *

+ * + * @author Christian Tzolov + * @see Streamable + * HTTP transport specification + */ +public class HttpClientStreamableHttpTransport implements McpClientTransport { + + private static final Logger logger = LoggerFactory.getLogger(HttpClientStreamableHttpTransport.class); + + private static final String DEFAULT_ENDPOINT = "/mcp"; + + /** + * HTTP client for sending messages to the server. Uses HTTP POST over the message + * endpoint + */ + private final HttpClient httpClient; + + /** HTTP request builder for building requests to send messages to the server */ + private final HttpRequest.Builder requestBuilder; + + /** + * Event type for JSON-RPC messages received through the SSE connection. The server + * sends messages with this event type to transmit JSON-RPC protocol data. + */ + private static final String MESSAGE_EVENT_TYPE = "message"; + + private static final String APPLICATION_JSON = "application/json"; + + private static final String TEXT_EVENT_STREAM = "text/event-stream"; + + public static int NOT_FOUND = 404; + + public static int METHOD_NOT_ALLOWED = 405; + + public static int BAD_REQUEST = 400; + + private final ObjectMapper objectMapper; + + private final URI baseUri; + + private final String endpoint; + + private final boolean openConnectionOnStartup; + + private final boolean resumableStreams; + + private final AtomicReference activeSession = new AtomicReference<>(); + + private final AtomicReference, Mono>> handler = new AtomicReference<>(); + + private final AtomicReference> exceptionHandler = new AtomicReference<>(); + + private HttpClientStreamableHttpTransport(ObjectMapper objectMapper, HttpClient httpClient, + HttpRequest.Builder requestBuilder, String baseUri, String endpoint, boolean resumableStreams, + boolean openConnectionOnStartup) { + this.objectMapper = objectMapper; + this.httpClient = httpClient; + this.requestBuilder = requestBuilder; + this.baseUri = URI.create(baseUri); + this.endpoint = endpoint; + this.resumableStreams = resumableStreams; + this.openConnectionOnStartup = openConnectionOnStartup; + this.activeSession.set(createTransportSession()); + } + + public static Builder builder(String baseUri) { + return new Builder(baseUri); + } + + @Override + public Mono connect(Function, Mono> handler) { + return Mono.deferContextual(ctx -> { + this.handler.set(handler); + if (openConnectionOnStartup) { + logger.debug("Eagerly opening connection on startup"); + return this.reconnect(null).then(); + } + return Mono.empty(); + }); + } + + private DefaultMcpTransportSession createTransportSession() { + Function> onClose = sessionId -> sessionId == null ? Mono.empty() + : createDelete(sessionId); + return new DefaultMcpTransportSession(onClose); + } + + private Publisher createDelete(String sessionId) { + + return Mono.defer(() -> { // Do we need to defer this? + + HttpRequest request = this.requestBuilder.copy() + .uri(Utils.resolveUri(this.baseUri, this.endpoint)) + .header("Cache-Control", "no-cache") + .header("mcp-session-id", sessionId) + .DELETE() + .build(); + + return Mono.fromFuture(() -> this.httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())) + .doOnError(e -> logger.warn("Got error when closing transport", e)) + .then(); + }); + } + + @Override + public void setExceptionHandler(Consumer handler) { + logger.debug("Exception handler registered"); + this.exceptionHandler.set(handler); + } + + private void handleException(Throwable t) { + logger.debug("Handling exception for session {}", sessionIdOrPlaceholder(this.activeSession.get()), t); + if (t instanceof McpTransportSessionNotFoundException) { + McpTransportSession invalidSession = this.activeSession.getAndSet(createTransportSession()); + logger.warn("Server does not recognize session {}. Invalidating.", invalidSession.sessionId()); + invalidSession.close(); + } + Consumer handler = this.exceptionHandler.get(); + if (handler != null) { + handler.accept(t); + } + } + + @Override + public Mono closeGracefully() { + return Mono.defer(() -> { + logger.debug("Graceful close triggered"); + DefaultMcpTransportSession currentSession = this.activeSession.getAndSet(createTransportSession()); + if (currentSession != null) { + return currentSession.closeGracefully(); + } + return Mono.empty(); + }); + } + + private Mono reconnect(McpTransportStream stream) { + + return Mono.deferContextual(ctx -> { + + if (stream != null) { + logger.debug("Reconnecting stream {} with lastId {}", stream.streamId(), stream.lastId()); + } + else { + logger.debug("Reconnecting with no prior stream"); + } + + final AtomicReference disposableRef = new AtomicReference<>(); + final McpTransportSession transportSession = this.activeSession.get(); + + HttpRequest.Builder requestBuilder = this.requestBuilder.copy(); + + if (transportSession != null && transportSession.sessionId().isPresent()) { + requestBuilder = requestBuilder.header("mcp-session-id", transportSession.sessionId().get()); + } + + if (stream != null && stream.lastId().isPresent()) { + requestBuilder = requestBuilder.header("last-event-id", stream.lastId().get()); + } + + HttpRequest request = requestBuilder.uri(Utils.resolveUri(this.baseUri, this.endpoint)) + .header("Accept", TEXT_EVENT_STREAM) + .header("Cache-Control", "no-cache") + .GET() + .build(); + + Disposable connection = Flux.create(sseSink -> this.httpClient.sendAsync(request, + responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, sseSink)) + // .whenComplete((response, throwable) -> { + // if (throwable != null) { + // sseSink.error(throwable); + // } else { + // int status = response.statusCode(); + // if (status == METHOD_NOT_ALLOWED) { // NotAllowed + // logger.debug("The server does not support SSE streams, using + // request-response mode."); + // sseSink.complete(); + // } else if (status == NOT_FOUND || status == BAD_REQUEST) { // NotFound + // String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); + // sseSink.error(new McpTransportSessionNotFoundException( + // "Session not found for session ID: " + sessionIdRepresentation)); + // } else if (!isEventStream(response)) { + // String message = "Failed to connect to SSE stream. HTTP " + + // response.statusCode(); + // if (response.body() != null) { + // message += ": " + response.body(); + // } + // logger.info("Opening an SSE stream failed. This can be safely ignored." + + // message); + // sseSink.error(new RuntimeException(message)); + // } + // // If status is OK, the lineSubscriber will handle the + // // stream + // logger.debug("Established SSE stream via GET"); + // } + // }) + ).flatMap(responseEvent -> { + int statusCode = responseEvent.responseInfo().statusCode(); + + if (statusCode >= 200 && statusCode < 300) { + + if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { + try { + // We don't support batching ATM and probably won't since the + // next version considers removing it. + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, + responseEvent.sseEvent().data()); + + Tuple2, Iterable> idWithMessages = Tuples + .of(Optional.ofNullable(responseEvent.sseEvent().id()), List.of(message)); + + McpTransportStream sessionStream = stream != null ? stream + : new DefaultMcpTransportStream<>(this.resumableStreams, this::reconnect); + logger.debug("Connected stream {}", sessionStream.streamId()); + + return Flux.from(sessionStream.consumeSseStream(Flux.just(idWithMessages))); + + } + catch (IOException ioException) { + return Flux.error( + new McpError("Error parsing JSON-RPC message: " + responseEvent.sseEvent().data())); + } + } + } + else if (statusCode == METHOD_NOT_ALLOWED) { // NotAllowed + logger.debug("The server does not support SSE streams, using request-response mode."); + return Flux.empty(); + } + else if (statusCode == NOT_FOUND) { + String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); + McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( + "Session not found for session ID: " + sessionIdRepresentation); + return Flux.error(exception); + } + else if (statusCode == BAD_REQUEST) { + String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); + McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( + "Session not found for session ID: " + sessionIdRepresentation); + return Flux.error(exception); + } + + return Flux.error( + new McpError("Received unrecognized SSE event type: " + responseEvent.sseEvent().event())); + + }).flatMap(jsonrpcMessage -> this.handler.get().apply(Mono.just(jsonrpcMessage))) + .onErrorComplete(t -> { + this.handleException(t); + return true; + }) + .doFinally(s -> { + Disposable ref = disposableRef.getAndSet(null); + if (ref != null) { + transportSession.removeConnection(ref); + } + }) + .contextWrite(ctx) + .subscribe(); + + disposableRef.set(connection); + transportSession.addConnection(connection); + return Mono.just(connection); + }); + + } + + // private static boolean isEventStream(HttpResponse response) { + // String contentType = + // response.headers().firstValue("Content-Type").orElse("").toLowerCase(); + // return response.statusCode() >= 200 && response.statusCode() < 300 && + // contentType.contains(TEXT_EVENT_STREAM); + // } + + private BodyHandler toSendMessageBodySubscriber(FluxSink sink) { + + BodyHandler responseBodyHandler = responseInfo -> { + + String contentType = responseInfo.headers().firstValue("Content-Type").orElse("").toLowerCase(); + + if (contentType.contains(TEXT_EVENT_STREAM)) { + // For SSE streams, use line subscriber that returns Void + logger.debug("Received SSE stream response, using line subscriber"); + return ResponseSubscribers.sseToBodySubscriber(responseInfo, sink); + } + else if (contentType.contains(APPLICATION_JSON)) { + // For JSON responses and others, use string subscriber + logger.debug("Received response, using string subscriber"); + return ResponseSubscribers.jsonoBodySubscriber(responseInfo, sink); + } + + logger.debug("Received Bodyless response, using discarding subscriber"); + // return HttpResponse.BodySubscribers.discarding(); + return ResponseSubscribers.bodylessBodySubscriber(responseInfo, sink); + }; + + return responseBodyHandler; + + } + + public String toString(McpSchema.JSONRPCMessage message) { + try { + return this.objectMapper.writeValueAsString(message); + } + catch (IOException e) { + throw new RuntimeException("Failed to serialize JSON-RPC message", e); + } + } + + public Mono sendMessage(McpSchema.JSONRPCMessage sendMessage) { + return Mono.create(messageSink -> { + logger.debug("Sending message {}", sendMessage); + + final AtomicReference disposableRef = new AtomicReference<>(); + final McpTransportSession transportSession = this.activeSession.get(); + + HttpRequest.Builder requestBuilder = this.requestBuilder.copy(); + + if (transportSession != null && transportSession.sessionId().isPresent()) { + requestBuilder = requestBuilder.header("mcp-session-id", transportSession.sessionId().get()); + } + + String jsonBody = this.toString(sendMessage); + + HttpRequest request = requestBuilder.uri(Utils.resolveUri(this.baseUri, this.endpoint)) + .header("Accept", TEXT_EVENT_STREAM + ", " + APPLICATION_JSON) + .header("Content-Type", APPLICATION_JSON) + .header("Cache-Control", "no-cache") + .POST(HttpRequest.BodyPublishers.ofString(jsonBody)) + .build(); + + Disposable connection = Flux.create(responseEventSink -> { + + // Create the async request with proper body subscriber selection + Mono.fromFuture(this.httpClient.sendAsync(request, this.toSendMessageBodySubscriber(responseEventSink)) + // .whenComplete((res, e) -> { + // if (e != null) { + // logger.warn("Error sending message", e); + // responseEventSink.error(e); + // } else if (res.statusCode() == NOT_FOUND) { + // String sessionIdRepresentation = + // sessionIdOrPlaceholder(transportSession); + // McpTransportSessionNotFoundException exception = new + // McpTransportSessionNotFoundException( + // "Session not found for session ID: " + sessionIdRepresentation); + // this.handleException(exception); + // responseEventSink.error(exception); + // } else if (res.statusCode() == BAD_REQUEST) { + // System.out.println("BAD_REQUEST"); + // } else { + // logger.debug("whenComplete complete: resp: {}, reqBode: {}", request, + // jsonBody); + // } + // })).doOnSubscribe(sub -> { + // logger.debug("OnSubscribe: {}, Sending message to server: {}", sub, + // jsonBody); + // } + ).subscribe(); + + }).flatMap(responseEvent -> { + if (transportSession.markInitialized( + responseEvent.responseInfo().headers().firstValue("mcp-session-id").orElseGet(() -> null))) { + // Once we have a session, we try to open an async stream for + // the server to send notifications and requests out-of-band. + + reconnect(null).contextWrite(messageSink.contextView()).subscribe(); + } + + String sessionRepresentation = sessionIdOrPlaceholder(transportSession); + + int statusCode = responseEvent.responseInfo().statusCode(); + + if (statusCode >= 200 && statusCode < 300) { + + String contentType = responseEvent.responseInfo() + .headers() + .firstValue("Content-Type") + .orElse("") + .toLowerCase(); + + if (contentType.isBlank()) { + logger.debug("No content type returned for POST in session {}", sessionRepresentation); + // No content type means no response body, so we can just return + // an empty stream + messageSink.success(); + return Flux.empty(); + } + else if (contentType.contains(TEXT_EVENT_STREAM)) { + try { + // We don't support batching ATM and probably won't since the + // next version considers removing it. + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, + responseEvent.sseEvent().data()); + + Tuple2, Iterable> idWithMessages = Tuples + .of(Optional.ofNullable(responseEvent.sseEvent().id()), List.of(message)); + + McpTransportStream sessionStream = new DefaultMcpTransportStream<>( + this.resumableStreams, this::reconnect); + + logger.debug("Connected stream {}", sessionStream.streamId()); + + messageSink.success(); + + return Flux.from(sessionStream.consumeSseStream(Flux.just(idWithMessages))); + + } + catch (IOException ioException) { + return Flux.error( + new McpError("Error parsing JSON-RPC message: " + responseEvent.sseEvent().data())); + } + } + else if (contentType.contains(APPLICATION_JSON)) { + McpSchema.JSONRPCMessage jsonRpcResponse = responseEvent.jsonRpcMessage(); + messageSink.success(); + return Flux.just(jsonRpcResponse); // ??? + } + logger.warn("Unknown media type {} returned for POST in session {}", contentType, + sessionRepresentation); + + return Flux.error( + new RuntimeException("Unknown media type returned: " + contentType)); + } + else if (statusCode == NOT_FOUND) { + McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( + "Session not found for session ID: " + sessionRepresentation); + return Flux.error(exception); + } + // Some implementations can return 400 when presented with a + // session id that it doesn't know about, so we will + // invalidate the session + // https://github.com/modelcontextprotocol/typescript-sdk/issues/389 + else if (statusCode == BAD_REQUEST) { + McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( + "Session not found for session ID: " + sessionRepresentation); + return Flux.error(exception); + } + + return Flux.error( + new RuntimeException("Failed to send message: " + responseEvent)); + }).flatMap(jsonRpcMessage -> this.handler.get().apply(Mono.just(jsonRpcMessage))).onErrorResume(t -> { + // handle the error first + this.handleException(t); + // inform the caller of sendMessage + messageSink.error(t); + return Flux.empty(); + }).doFinally(s -> { + logger.debug("SendMessage finally: {}", s); + Disposable ref = disposableRef.getAndSet(null); + if (ref != null) { + transportSession.removeConnection(ref); + } + }).contextWrite(messageSink.contextView()).subscribe(); + + disposableRef.set(connection); + transportSession.addConnection(connection); + }); + } + + private static String sessionIdOrPlaceholder(McpTransportSession transportSession) { + return transportSession.sessionId().orElse("[missing_session_id]"); + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return this.objectMapper.convertValue(data, typeRef); + } + + /** + * Builder for {@link HttpClientStreamableHttpTransport}. + */ + public static class Builder { + + private ObjectMapper objectMapper; + + private String baseUri; + + private HttpClient.Builder clientBuilder = HttpClient.newBuilder() + .version(HttpClient.Version.HTTP_1_1) + .connectTimeout(Duration.ofSeconds(10)); + + private String endpoint = DEFAULT_ENDPOINT; + + private boolean resumableStreams = true; + + private boolean openConnectionOnStartup = false; + + private HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(); + + /** + * Creates a new builder with the specified base URI. + * @param baseUri the base URI of the MCP server + */ + private Builder(String baseUri) { + Assert.hasText(baseUri, "baseUri must not be empty"); + this.baseUri = baseUri; + } + + /** + * Sets the HTTP client builder. + * @param clientBuilder the HTTP client builder + * @return this builder + */ + public Builder clientBuilder(HttpClient.Builder clientBuilder) { + Assert.notNull(clientBuilder, "clientBuilder must not be null"); + this.clientBuilder = clientBuilder; + return this; + } + + /** + * Customizes the HTTP client builder. + * @param clientCustomizer the consumer to customize the HTTP client builder + * @return this builder + */ + public Builder customizeClient(final Consumer clientCustomizer) { + Assert.notNull(clientCustomizer, "clientCustomizer must not be null"); + clientCustomizer.accept(clientBuilder); + return this; + } + + /** + * Sets the HTTP request builder. + * @param requestBuilder the HTTP request builder + * @return this builder + */ + public Builder requestBuilder(HttpRequest.Builder requestBuilder) { + Assert.notNull(requestBuilder, "requestBuilder must not be null"); + this.requestBuilder = requestBuilder; + return this; + } + + /** + * Customizes the HTTP client builder. + * @param requestCustomizer the consumer to customize the HTTP request builder + * @return this builder + */ + public Builder customizeRequest(final Consumer requestCustomizer) { + Assert.notNull(requestCustomizer, "requestCustomizer must not be null"); + requestCustomizer.accept(requestBuilder); + return this; + } + + /** + * Configure the {@link ObjectMapper} to use. + * @param objectMapper instance to use + * @return the builder instance + */ + public Builder objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + /** + * Configure the endpoint to make HTTP requests against. + * @param endpoint endpoint to use + * @return the builder instance + */ + public Builder endpoint(String endpoint) { + Assert.hasText(endpoint, "endpoint must be a non-empty String"); + this.endpoint = endpoint; + return this; + } + + /** + * Configure whether to use the stream resumability feature by keeping track of + * SSE event ids. + * @param resumableStreams if {@code true} event ids will be tracked and upon + * disconnection, the last seen id will be used upon reconnection as a header to + * resume consuming messages. + * @return the builder instance + */ + public Builder resumableStreams(boolean resumableStreams) { + this.resumableStreams = resumableStreams; + return this; + } + + /** + * Configure whether the client should open an SSE connection upon startup. Not + * all servers support this (although it is in theory possible with the current + * specification), so use with caution. By default, this value is {@code false}. + * @param openConnectionOnStartup if {@code true} the {@link #connect(Function)} + * method call will try to open an SSE connection before sending any JSON-RPC + * request + * @return the builder instance + */ + public Builder openConnectionOnStartup(boolean openConnectionOnStartup) { + this.openConnectionOnStartup = openConnectionOnStartup; + return this; + } + + /** + * Construct a fresh instance of {@link HttpClientStreamableHttpTransport} using + * the current builder configuration. + * @return a new instance of {@link HttpClientStreamableHttpTransport} + */ + public HttpClientStreamableHttpTransport build() { + ObjectMapper objectMapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); + + return new HttpClientStreamableHttpTransport(objectMapper, clientBuilder.build(), requestBuilder, baseUri, + endpoint, resumableStreams, openConnectionOnStartup); + } + + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java new file mode 100644 index 000000000..a49ef7255 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java @@ -0,0 +1,354 @@ +/* +* Copyright 2024 - 2024 the original author or authors. +*/ +package io.modelcontextprotocol.client.transport; + +import java.io.IOException; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodySubscriber; +import java.net.http.HttpResponse.ResponseInfo; +import java.util.concurrent.atomic.AtomicReference; +import java.util.regex.Pattern; + +import org.reactivestreams.FlowAdapters; +import org.reactivestreams.Subscription; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; +import reactor.core.publisher.BaseSubscriber; +import reactor.core.publisher.FluxSink; + +public class ResponseSubscribers { + + /** + * Represents a Server-Sent Event with its standard fields. + * + * @param id the event ID, may be {@code null} + * @param event the event type, may be {@code null} (defaults to "message") + * @param data the event payload data, never {@code null} + */ + public static record SseEvent(String id, String event, String data) { + } + + public record ResponseEvent(ResponseInfo responseInfo, SseEvent sseEvent, JSONRPCMessage jsonRpcMessage) { + + public ResponseEvent(ResponseInfo responseInfo, SseEvent sseEvent) { + this(responseInfo, sseEvent, null); + } + + public ResponseEvent(ResponseInfo responseInfo, JSONRPCMessage jsonRpcMessage) { + this(responseInfo, null, jsonRpcMessage); + } + } + + public static BodySubscriber sseToBodySubscriber(ResponseInfo responseInfo, FluxSink sink) { + return HttpResponse.BodySubscribers + .fromLineSubscriber(FlowAdapters.toFlowSubscriber(new SseLineSubscriber(responseInfo, sink))); + } + + public static BodySubscriber jsonoBodySubscriber(ResponseInfo responseInfo, FluxSink sink) { + return HttpResponse.BodySubscribers + .fromLineSubscriber(FlowAdapters.toFlowSubscriber(new JsonLineSubscriber(responseInfo, sink))); + } + + public static BodySubscriber bodylessBodySubscriber(ResponseInfo responseInfo, FluxSink sink) { + return HttpResponse.BodySubscribers + .fromLineSubscriber(FlowAdapters.toFlowSubscriber(new BodylessResponseLineSubscriber(responseInfo, sink))); + } + + public static class SseLineSubscriber extends BaseSubscriber { + + /** + * Pattern to extract data content from SSE "data:" lines. + */ + private static final Pattern EVENT_DATA_PATTERN = Pattern.compile("^data:(.+)$", Pattern.MULTILINE); + + /** + * Pattern to extract event ID from SSE "id:" lines. + */ + private static final Pattern EVENT_ID_PATTERN = Pattern.compile("^id:(.+)$", Pattern.MULTILINE); + + /** + * Pattern to extract event type from SSE "event:" lines. + */ + private static final Pattern EVENT_TYPE_PATTERN = Pattern.compile("^event:(.+)$", Pattern.MULTILINE); + + /** + * The sink for emitting parsed response events. + */ + private final FluxSink sink; + + /** + * StringBuilder for accumulating multi-line event data. + */ + private final StringBuilder eventBuilder; + + /** + * Current event's ID, if specified. + */ + private final AtomicReference currentEventId; + + /** + * Current event's type, if specified. + */ + private final AtomicReference currentEventType; + + /** + * The response information from the HTTP response. Send with each event to + * provide context. + */ + private ResponseInfo responseInfo; + + /** + * Creates a new LineSubscriber that will emit parsed SSE events to the provided + * sink. + * @param sink the {@link FluxSink} to emit parsed {@link ResponseEvent} objects + * to + */ + public SseLineSubscriber(ResponseInfo responseInfo, FluxSink sink) { + this.sink = sink; + this.eventBuilder = new StringBuilder(); + this.currentEventId = new AtomicReference<>(); + this.currentEventType = new AtomicReference<>(); + this.responseInfo = responseInfo; + } + + /** + * Initializes the subscription and sets up disposal callback. + * @param subscription the {@link Subscription} to the upstream line source + */ + @Override + protected void hookOnSubscribe(Subscription subscription) { + + sink.onRequest(n -> { + if (subscription != null) { + subscription.request(n); + } + }); + + // Register disposal callback to cancel subscription when Flux is disposed + sink.onDispose(() -> { + if (subscription != null) { + subscription.cancel(); + } + }); + } + + /** + * Processes each line from the SSE stream according to the SSE protocol. Empty + * lines trigger event emission, other lines are parsed for data, id, or event + * type. + * @param line the line to process from the SSE stream + */ + @Override + protected void hookOnNext(String line) { + if (line.isEmpty()) { + // Empty line means end of event + if (this.eventBuilder.length() > 0) { + String eventData = this.eventBuilder.toString(); + SseEvent sseEvent = new SseEvent(currentEventId.get(), currentEventType.get(), eventData.trim()); + + this.sink.next(new ResponseEvent(responseInfo, sseEvent)); + this.eventBuilder.setLength(0); + } + } + else { + if (line.startsWith("data:")) { + var matcher = EVENT_DATA_PATTERN.matcher(line); + if (matcher.find()) { + this.eventBuilder.append(matcher.group(1).trim()).append("\n"); + } + } + else if (line.startsWith("id:")) { + var matcher = EVENT_ID_PATTERN.matcher(line); + if (matcher.find()) { + this.currentEventId.set(matcher.group(1).trim()); + } + } + else if (line.startsWith("event:")) { + var matcher = EVENT_TYPE_PATTERN.matcher(line); + if (matcher.find()) { + this.currentEventType.set(matcher.group(1).trim()); + } + } + } + } + + /** + * Called when the upstream line source completes normally. + */ + @Override + protected void hookOnComplete() { + if (this.eventBuilder.length() > 0) { + String eventData = this.eventBuilder.toString(); + SseEvent sseEvent = new SseEvent(currentEventId.get(), currentEventType.get(), eventData.trim()); + this.sink.next(new ResponseEvent(responseInfo, sseEvent)); + } + this.sink.complete(); + } + + /** + * Called when an error occurs in the upstream line source. + * @param throwable the error that occurred + */ + @Override + protected void hookOnError(Throwable throwable) { + this.sink.error(throwable); + } + + } + + public static class JsonLineSubscriber extends BaseSubscriber { + + private ObjectMapper objectMapper = new ObjectMapper(); + + /** + * The sink for emitting parsed response events. + */ + private final FluxSink sink; + + /** + * StringBuilder for accumulating multi-line event data. + */ + private final StringBuilder eventBuilder; + + /** + * The response information from the HTTP response. Send with each event to + * provide context. + */ + private ResponseInfo responseInfo; + + /** + * Creates a new JsonLineSubscriber that will emit parsed JSON-RPC messages. + * @param sink the {@link FluxSink} to emit parsed {@link ResponseEvent} objects + * to + */ + public JsonLineSubscriber(ResponseInfo responseInfo, FluxSink sink) { + this.sink = sink; + this.eventBuilder = new StringBuilder(); + this.responseInfo = responseInfo; + } + + /** + * Initializes the subscription and sets up disposal callback. + * @param subscription the {@link Subscription} to the upstream line source + */ + @Override + protected void hookOnSubscribe(Subscription subscription) { + + sink.onRequest(n -> { + if (subscription != null) { + subscription.request(n); + } + }); + + // Register disposal callback to cancel subscription when Flux is disposed + sink.onDispose(() -> { + if (subscription != null) { + subscription.cancel(); + } + }); + } + + /** + * Aggregate each line from the Http response. + * @param line next line to process from the Http response + */ + @Override + protected void hookOnNext(String line) { + this.eventBuilder.append(line).append("\n"); + } + + /** + * Called when the upstream line source completes normally. + */ + @Override + protected void hookOnComplete() { + if (this.eventBuilder.length() > 0) { + String jsonData = this.eventBuilder.toString(); + try { + McpSchema.JSONRPCMessage jsonRpcResponse = McpSchema.deserializeJsonRpcMessage(objectMapper, + jsonData); + this.sink.next(new ResponseEvent(responseInfo, jsonRpcResponse)); + } + catch (IOException e) { + sink.error(e); + } + } + this.sink.complete(); + } + + /** + * Called when an error occurs in the upstream line source. + * @param throwable the error that occurred + */ + @Override + protected void hookOnError(Throwable throwable) { + this.sink.error(throwable); + } + + } + + public static class BodylessResponseLineSubscriber extends BaseSubscriber { + + /** + * The sink for emitting parsed response events. + */ + private final FluxSink sink; + + private final ResponseInfo responseInfo; + + public BodylessResponseLineSubscriber(ResponseInfo responseInfo, FluxSink sink) { + this.sink = sink; + this.responseInfo = responseInfo; + } + + /** + * Initializes the subscription and sets up disposal callback. + * @param subscription the {@link Subscription} to the upstream line source + */ + @Override + protected void hookOnSubscribe(Subscription subscription) { + + sink.onRequest(n -> { + if (subscription != null) { + subscription.request(n); + } + }); + + // Register disposal callback to cancel subscription when Flux is disposed + sink.onDispose(() -> { + if (subscription != null) { + subscription.cancel(); + } + }); + } + + @Override + protected void hookOnNext(String line) { + System.out.println(">>>>>>>>>>>>>> Received line: " + line); + } + + /** + * Called when the upstream line source completes normally. + */ + @Override + protected void hookOnComplete() { + this.sink.next(new ResponseEvent(responseInfo, new SseEvent(null, null, null))); + this.sink.complete(); + } + + /** + * Called when an error occurs in the upstream line source. + * @param throwable the error that occurred + */ + @Override + protected void hookOnError(Throwable throwable) { + this.sink.error(throwable); + } + + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java index 8545348ed..5553445b6 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java @@ -362,7 +362,7 @@ public Mono closeGracefully() { } else { logger.warn("Process not started"); - return Mono.empty(); + return Mono.empty(); } })).doOnNext(process -> { if (process.exitValue() != 0) { diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java index ecc6f8666..eb2b7edeb 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java @@ -61,14 +61,19 @@ public long streamId() { @Override public Publisher consumeSseStream( Publisher, Iterable>> eventStream) { - return Flux.deferContextual(ctx -> Flux.from(eventStream).doOnError(e -> { - if (resumable && !(e instanceof McpTransportSessionNotFoundException)) { - Mono.from(reconnect.apply(this)).contextWrite(ctx).subscribe(); - } - }).doOnNext(idAndMessage -> idAndMessage.getT1().ifPresent(id -> { - String previousId = this.lastId.getAndSet(id); - logger.debug("Updating last id {} -> {} for stream {}", previousId, id, this.streamId); - })).flatMapIterable(Tuple2::getT2)); + + // @formatter:off + return Flux.deferContextual(ctx -> Flux.from(eventStream) + .doOnNext(idAndMessage -> idAndMessage.getT1().ifPresent(id -> { + String previousId = this.lastId.getAndSet(id); + logger.debug("Updating last id {} -> {} for stream {}", previousId, id, this.streamId); + })) + .doOnError(e -> { + if (resumable && !(e instanceof McpTransportSessionNotFoundException)) { + Mono.from(reconnect.apply(this)).contextWrite(ctx).subscribe(); + } + }) + .flatMapIterable(Tuple2::getT2)); // @formatter:on } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java b/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java index 20a8c0cf5..7ba35bbf0 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java +++ b/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java @@ -15,8 +15,6 @@ */ package io.modelcontextprotocol; -import java.util.Map; - import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpServerSession.Factory; diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java new file mode 100644 index 000000000..4aa93ec85 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java @@ -0,0 +1,222 @@ +package io.modelcontextprotocol.client; + +import eu.rekawek.toxiproxy.Proxy; +import eu.rekawek.toxiproxy.ToxiproxyClient; +import eu.rekawek.toxiproxy.model.ToxicDirection; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpTransport; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.Network; +import org.testcontainers.containers.ToxiproxyContainer; +import org.testcontainers.containers.wait.strategy.Wait; +import reactor.test.StepVerifier; + +import java.io.IOException; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; + +import static org.assertj.core.api.Assertions.assertThatCode; + +/** + * Resiliency test suite for the {@link McpAsyncClient} that can be used with different + * {@link McpTransport} implementations that support Streamable HTTP. + * + * The purpose of these tests is to allow validating the transport layer resiliency + * instead of the functionality offered by the logical layer of MCP concepts such as + * tools, resources, prompts, etc. + * + * @author Dariusz Jędrzejczyk + */ +public abstract class AbstractMcpAsyncClientResiliencyTests { + + private static final Logger logger = LoggerFactory.getLogger(AbstractMcpAsyncClientResiliencyTests.class); + + static Network network = Network.newNetwork(); + static String host = "http://localhost:3001"; + + // Uses the https://github.com/tzolov/mcp-everything-server-docker-image + @SuppressWarnings("resource") + static GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js streamableHttp") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withNetwork(network) + .withNetworkAliases("everything-server") + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + static ToxiproxyContainer toxiproxy = new ToxiproxyContainer("ghcr.io/shopify/toxiproxy:2.5.0").withNetwork(network) + .withExposedPorts(8474, 3000); + + static Proxy proxy; + + static { + container.start(); + + toxiproxy.start(); + + final ToxiproxyClient toxiproxyClient = new ToxiproxyClient(toxiproxy.getHost(), toxiproxy.getControlPort()); + try { + proxy = toxiproxyClient.createProxy("everything-server", "0.0.0.0:3000", "everything-server:3001"); + } + catch (IOException e) { + throw new RuntimeException("Can't create proxy!", e); + } + + final String ipAddressViaToxiproxy = toxiproxy.getHost(); + final int portViaToxiproxy = toxiproxy.getMappedPort(3000); + + host = "http://" + ipAddressViaToxiproxy + ":" + portViaToxiproxy; + } + + static void disconnect() { + long start = System.nanoTime(); + try { + // disconnect + // proxy.toxics().bandwidth("CUT_CONNECTION_DOWNSTREAM", + // ToxicDirection.DOWNSTREAM, 0); + // proxy.toxics().bandwidth("CUT_CONNECTION_UPSTREAM", + // ToxicDirection.UPSTREAM, 0); + proxy.toxics().resetPeer("RESET_DOWNSTREAM", ToxicDirection.DOWNSTREAM, 0); + proxy.toxics().resetPeer("RESET_UPSTREAM", ToxicDirection.UPSTREAM, 0); + logger.info("Disconnect took {} ms", Duration.ofNanos(System.nanoTime() - start).toMillis()); + } + catch (IOException e) { + throw new RuntimeException("Failed to disconnect", e); + } + } + + static void reconnect() { + long start = System.nanoTime(); + try { + proxy.toxics().get("RESET_UPSTREAM").remove(); + proxy.toxics().get("RESET_DOWNSTREAM").remove(); + // proxy.toxics().get("CUT_CONNECTION_DOWNSTREAM").remove(); + // proxy.toxics().get("CUT_CONNECTION_UPSTREAM").remove(); + logger.info("Reconnect took {} ms", Duration.ofNanos(System.nanoTime() - start).toMillis()); + } + catch (IOException e) { + throw new RuntimeException("Failed to reconnect", e); + } + } + + static void restartMcpServer() { + container.stop(); + container.start(); + } + + abstract McpClientTransport createMcpTransport(); + + protected Duration getRequestTimeout() { + return Duration.ofSeconds(14); + } + + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(2); + } + + McpAsyncClient client(McpClientTransport transport) { + return client(transport, Function.identity()); + } + + McpAsyncClient client(McpClientTransport transport, Function customizer) { + AtomicReference client = new AtomicReference<>(); + + assertThatCode(() -> { + McpClient.AsyncSpec builder = McpClient.async(transport) + .requestTimeout(getRequestTimeout()) + .initializationTimeout(getInitializationTimeout()) + .capabilities(McpSchema.ClientCapabilities.builder().roots(true).build()); + builder = customizer.apply(builder); + client.set(builder.build()); + }).doesNotThrowAnyException(); + + return client.get(); + } + + void withClient(McpClientTransport transport, Consumer c) { + withClient(transport, Function.identity(), c); + } + + void withClient(McpClientTransport transport, Function customizer, + Consumer c) { + var client = client(transport, customizer); + try { + c.accept(client); + } + finally { + StepVerifier.create(client.closeGracefully()).expectComplete().verify(Duration.ofSeconds(10)); + } + } + + @Test + void testPing() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + + disconnect(); + + StepVerifier.create(mcpAsyncClient.ping()).expectError().verify(); + + reconnect(); + + StepVerifier.create(mcpAsyncClient.ping()).expectNextCount(1).verifyComplete(); + }); + } + + @Test + void testSessionInvalidation() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + + restartMcpServer(); + + // The first try will face the session mismatch exception and the second one + // will go through the re-initialization process. + StepVerifier.create(mcpAsyncClient.ping().retry(1)).expectNextCount(1).verifyComplete(); + }); + } + + @Test + void testCallTool() { + withClient(createMcpTransport(), mcpAsyncClient -> { + AtomicReference> tools = new AtomicReference<>(); + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + StepVerifier.create(mcpAsyncClient.listTools()) + .consumeNextWith(list -> tools.set(list.tools())) + .verifyComplete(); + + disconnect(); + + String name = tools.get().get(0).name(); + // Assuming this is the echo tool + McpSchema.CallToolRequest request = new McpSchema.CallToolRequest(name, Map.of("message", "hello")); + StepVerifier.create(mcpAsyncClient.callTool(request)).expectError().verify(); + + reconnect(); + + StepVerifier.create(mcpAsyncClient.callTool(request)).expectNextCount(1).verifyComplete(); + }); + } + + @Test + void testSessionClose() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + // In case of Streamable HTTP this call should issue a HTTP DELETE request + // invalidating the session + StepVerifier.create(mcpAsyncClient.closeGracefully()).expectComplete().verify(); + // The next use should immediately re-initialize with no issue and send the + // request without any broken connections. + StepVerifier.create(mcpAsyncClient.ping()).expectNextCount(1).verifyComplete(); + }); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientResiliencyTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientResiliencyTests.java new file mode 100644 index 000000000..f088fa7ba --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientResiliencyTests.java @@ -0,0 +1,20 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import org.junit.jupiter.api.Timeout; + +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; + +@Timeout(15) +public class HttpClientStreamableHttpAsyncClientResiliencyTests extends AbstractMcpAsyncClientResiliencyTests { + + @Override + protected McpClientTransport createMcpTransport() { + return HttpClientStreamableHttpTransport.builder(host).build(); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientTests.java new file mode 100644 index 000000000..aa081b51b --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientTests.java @@ -0,0 +1,41 @@ +package io.modelcontextprotocol.client; + +import org.junit.jupiter.api.Timeout; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; + +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; + +@Timeout(15) +public class HttpClientStreamableHttpAsyncClientTests extends AbstractMcpAsyncClientTests { + + private String host = "http://localhost:3001"; + + // Uses the https://github.com/tzolov/mcp-everything-server-docker-image + @SuppressWarnings("resource") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js streamableHttp") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + @Override + protected McpClientTransport createMcpTransport() { + + return HttpClientStreamableHttpTransport.builder(host).build(); + } + + @Override + protected void onStart() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + } + + @Override + public void onClose() { + container.stop(); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpSyncClientTests.java new file mode 100644 index 000000000..8285f417f --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpSyncClientTests.java @@ -0,0 +1,40 @@ +package io.modelcontextprotocol.client; + +import org.junit.jupiter.api.Timeout; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; + +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; + +@Timeout(15) +public class HttpClientStreamableHttpSyncClientTests extends AbstractMcpSyncClientTests { + + static String host = "http://localhost:3001"; + + // Uses the https://github.com/tzolov/mcp-everything-server-docker-image + @SuppressWarnings("resource") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js streamableHttp") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + @Override + protected McpClientTransport createMcpTransport() { + return HttpClientStreamableHttpTransport.builder(host).build(); + } + + @Override + protected void onStart() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + } + + @Override + public void onClose() { + container.stop(); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java index 1b66a98cd..6cb3f7b65 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java @@ -4,12 +4,13 @@ package io.modelcontextprotocol.client; -import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; -import io.modelcontextprotocol.spec.McpClientTransport; import org.junit.jupiter.api.Timeout; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.spec.McpClientTransport; + /** * Tests for the {@link McpSyncClient} with {@link HttpClientSseClientTransport}. * diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java index 1b1c72012..e4348be25 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java @@ -7,23 +7,20 @@ import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpRequest; -import java.net.http.HttpResponse; import java.time.Duration; import java.util.Map; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; +import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; -import org.mockito.ArgumentCaptor; -import org.mockito.Mockito; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; import reactor.core.publisher.Mono; @@ -34,11 +31,6 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import com.fasterxml.jackson.databind.ObjectMapper; /** * Tests for the {@link HttpClientSseClientTransport} class. @@ -371,25 +363,4 @@ void testChainedCustomizations() { customizedTransport.closeGracefully().block(); } - @Test - @SuppressWarnings("unchecked") - void testResolvingClientEndpoint() { - HttpClient httpClient = Mockito.mock(HttpClient.class); - HttpResponse httpResponse = Mockito.mock(HttpResponse.class); - CompletableFuture> future = new CompletableFuture<>(); - future.complete(httpResponse); - when(httpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))).thenReturn(future); - - HttpClientSseClientTransport transport = new HttpClientSseClientTransport(httpClient, HttpRequest.newBuilder(), - "http://example.com", "http://example.com/sse", new ObjectMapper()); - - transport.connect(Function.identity()); - - ArgumentCaptor httpRequestCaptor = ArgumentCaptor.forClass(HttpRequest.class); - verify(httpClient).sendAsync(httpRequestCaptor.capture(), any(HttpResponse.BodyHandler.class)); - assertThat(httpRequestCaptor.getValue().uri()).isEqualTo(URI.create("http://example.com/sse")); - - transport.closeGracefully().block(); - } - } diff --git a/mcp/src/test/resources/logback.xml b/mcp/src/test/resources/logback.xml index 0246d6c75..d860fb985 100644 --- a/mcp/src/test/resources/logback.xml +++ b/mcp/src/test/resources/logback.xml @@ -16,6 +16,9 @@ + + + From 53f57400935523a77ddc6243f50bfca403f5b472 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Tue, 24 Jun 2025 13:19:17 +0200 Subject: [PATCH 02/13] fix java 17 related issue Signed-off-by: Christian Tzolov --- .../transport/HttpClientSseClientTransport.java | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index 9beb2b373..077140531 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -426,7 +426,12 @@ public Mono sendMessage(JSONRPCMessage message) { try { return this.serializeMessage(message) .flatMap(body -> sendHttpPost(messageEndpointUri, body)) - .doOnNext(this::logIfNotOk) + .doOnNext(response -> { + if (response.statusCode() != 200 && response.statusCode() != 201 && response.statusCode() != 202 + && response.statusCode() != 206) { + logger.error("Error sending message: {}", response.statusCode()); + } + }) .doOnError(error -> { if (!isClosing) { logger.error("Error sending message: {}", error.getMessage()); @@ -465,13 +470,6 @@ private Mono> sendHttpPost(final String endpoint, final Strin return Mono.fromFuture(httpClient.sendAsync(request, HttpResponse.BodyHandlers.discarding())); } - private void logIfNotOk(final HttpResponse response) { - if (response.statusCode() != 200 && response.statusCode() != 201 && response.statusCode() != 202 - && response.statusCode() != 206) { - logger.error("Error sending message: {}", response.statusCode()); - } - } - /** * Gracefully closes the transport connection. * From 57f0e0aafdde44caa77dfad6069946273153cbb9 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Tue, 24 Jun 2025 13:26:15 +0200 Subject: [PATCH 03/13] minor Signed-off-by: Christian Tzolov --- .../transport/HttpClientSseClientTransport.java | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index 077140531..bb8f19015 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -111,7 +111,7 @@ public class HttpClientSseClientTransport implements McpClientTransport { /** * Creates a new transport instance with default HTTP client and object mapper. * @param baseUri the base URI of the MCP server - * @deprecated Use {@link HttpClientSseClientTransport2#builder(String)} instead. This + * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This * constructor will be removed in future versions. */ @Deprecated(forRemoval = true) @@ -125,7 +125,7 @@ public HttpClientSseClientTransport(String baseUri) { * @param baseUri the base URI of the MCP server * @param objectMapper the object mapper for JSON serialization/deserialization * @throws IllegalArgumentException if objectMapper or clientBuilder is null - * @deprecated Use {@link HttpClientSseClientTransport2#builder(String)} instead. This + * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This * constructor will be removed in future versions. */ @Deprecated(forRemoval = true) @@ -140,7 +140,7 @@ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, String bas * @param sseEndpoint the SSE endpoint path * @param objectMapper the object mapper for JSON serialization/deserialization * @throws IllegalArgumentException if objectMapper or clientBuilder is null - * @deprecated Use {@link HttpClientSseClientTransport2#builder(String)} instead. This + * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This * constructor will be removed in future versions. */ @Deprecated(forRemoval = true) @@ -158,7 +158,7 @@ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, String bas * @param sseEndpoint the SSE endpoint path * @param objectMapper the object mapper for JSON serialization/deserialization * @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null - * @deprecated Use {@link HttpClientSseClientTransport2#builder(String)} instead. This + * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. This * constructor will be removed in future versions. */ @Deprecated(forRemoval = true) @@ -193,7 +193,7 @@ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, HttpReques } /** - * Creates a new builder for {@link HttpClientSseClientTransport2}. + * Creates a new builder for {@link HttpClientSseClientTransport}. * @param baseUri the base URI of the MCP server * @return a new builder instance */ @@ -202,7 +202,7 @@ public static Builder builder(String baseUri) { } /** - * Builder for {@link HttpClientSseClientTransport2}. + * Builder for {@link HttpClientSseClientTransport}. */ public static class Builder { @@ -229,7 +229,7 @@ public static class Builder { /** * Creates a new builder with the specified base URI. * @param baseUri the base URI of the MCP server - * @deprecated Use {@link HttpClientSseClientTransport2#builder(String)} instead. + * @deprecated Use {@link HttpClientSseClientTransport#builder(String)} instead. * This constructor is deprecated and will be removed or made {@code protected} or * {@code private} in a future release. */ @@ -317,7 +317,7 @@ public Builder objectMapper(ObjectMapper objectMapper) { } /** - * Builds a new {@link HttpClientSseClientTransport2} instance. + * Builds a new {@link HttpClientSseClientTransport} instance. * @return a new transport instance */ public HttpClientSseClientTransport build() { From ae387a6f44ac742838216d6cfb3312c8844a9c7a Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Tue, 24 Jun 2025 13:28:53 +0200 Subject: [PATCH 04/13] remove commented code Signed-off-by: Christian Tzolov --- .../client/transport/HttpClientSseClientTransport.java | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index bb8f19015..685c0a7b1 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -34,7 +34,7 @@ /** * Server-Sent Events (SSE) implementation of the * {@link io.modelcontextprotocol.spec.McpTransport} that follows the MCP HTTP with SSE - * transport specification, using Java's HttpClient and FlowSseClient. + * transport specification, using Java's HttpClient. * *

* This transport implementation establishes a bidirectional communication channel between @@ -92,13 +92,6 @@ public class HttpClientSseClientTransport implements McpClientTransport { /** Flag indicating if the transport is in closing state */ private volatile boolean isClosing = false; - // /** Latch for coordinating endpoint discovery */ - // private final CountDownLatch closeLatch = new CountDownLatch(1); - - // /** Holds the discovered message endpoint URL */ - // private final AtomicReference messageEndpoint = new - // AtomicReference<>(); - /** Holds the SSE subscription disposable */ private final AtomicReference sseSubscription = new AtomicReference<>(); From fef3bdbc89e47bf662ddf18bac66cdb26671ab7f Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Wed, 25 Jun 2025 18:34:22 +0200 Subject: [PATCH 05/13] refactor: improve error handling and cleanup HTTP client transports - Add proper exception handling with CompletableFuture.exceptionallyCompose for async HTTP operations - Add test for specific exception type handling in resiliency tests This change makes the HTTP client transports more robust by ensuring exceptions are properly propagated. Signed-off-by: Christian Tzolov --- .../HttpClientSseClientTransport.java | 29 +-- .../HttpClientStreamableHttpTransport.java | 174 +++++++----------- ...eamableHttpAsyncClientResiliencyTests.java | 19 ++ 3 files changed, 100 insertions(+), 122 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index 685c0a7b1..3eb76b869 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -9,6 +9,7 @@ import java.net.http.HttpRequest; import java.net.http.HttpResponse; import java.time.Duration; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Function; @@ -332,10 +333,13 @@ public Mono connect(Function, Mono> h .GET() .build(); - Disposable connection = Flux - .create(sseSink -> this.httpClient.sendAsync(request, - responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, sseSink))) - .flatMap(responseEvent -> { + Disposable connection = Flux.create(sseSink -> this.httpClient + .sendAsync(request, responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, sseSink)) + .exceptionallyCompose(e -> { + logger.warn("Error sending message", e); + sseSink.error(e); + return CompletableFuture.failedFuture(e); + })).flatMap(responseEvent -> { if (isClosing) { return Mono.empty(); } @@ -375,24 +379,19 @@ else if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { return Flux.error( new RuntimeException("Failed to send message: " + responseEvent)); - }) - .flatMap(jsonRpcMessage -> handler.apply(Mono.just(jsonRpcMessage))) - .onErrorResume(t -> { + }).flatMap(jsonRpcMessage -> handler.apply(Mono.just(jsonRpcMessage))).onErrorResume(t -> { if (!isClosing) { logger.error("SSE connection error", t); sink.error(t); } return Mono.empty(); - }) - .doFinally(s -> { + }).doFinally(s -> { Disposable ref = this.sseSubscription.getAndSet(null); if (ref != null && !ref.isDisposed()) { ref.dispose(); } - }) - .contextWrite(sink.contextView()) - .subscribe(); + }).contextWrite(sink.contextView()).subscribe(); this.sseSubscription.set(connection); }); @@ -460,7 +459,11 @@ private Mono> sendHttpPost(final String endpoint, final Strin .POST(HttpRequest.BodyPublishers.ofString(body)) .build(); - return Mono.fromFuture(httpClient.sendAsync(request, HttpResponse.BodyHandlers.discarding())); + return Mono.fromFuture( + httpClient.sendAsync(request, HttpResponse.BodyHandlers.discarding()).exceptionallyCompose(e -> { + logger.warn("Error sending message", e); + return CompletableFuture.failedFuture(e); + })); } /** diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java index 2c17b7148..68bec4fed 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java @@ -13,6 +13,7 @@ import java.time.Duration; import java.util.List; import java.util.Optional; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Function; @@ -160,9 +161,12 @@ private Publisher createDelete(String sessionId) { .DELETE() .build(); - return Mono.fromFuture(() -> this.httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())) - .doOnError(e -> logger.warn("Got error when closing transport", e)) - .then(); + return Mono.fromFuture(() -> this.httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString()) + .exceptionallyCompose(e -> { + logger.warn("Error sending message", e); + + return CompletableFuture.failedFuture(e); + })).doOnError(e -> logger.warn("Got error when closing transport", e)).then(); }); } @@ -227,86 +231,63 @@ private Mono reconnect(McpTransportStream stream) { .GET() .build(); - Disposable connection = Flux.create(sseSink -> this.httpClient.sendAsync(request, - responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, sseSink)) - // .whenComplete((response, throwable) -> { - // if (throwable != null) { - // sseSink.error(throwable); - // } else { - // int status = response.statusCode(); - // if (status == METHOD_NOT_ALLOWED) { // NotAllowed - // logger.debug("The server does not support SSE streams, using - // request-response mode."); - // sseSink.complete(); - // } else if (status == NOT_FOUND || status == BAD_REQUEST) { // NotFound - // String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); - // sseSink.error(new McpTransportSessionNotFoundException( - // "Session not found for session ID: " + sessionIdRepresentation)); - // } else if (!isEventStream(response)) { - // String message = "Failed to connect to SSE stream. HTTP " + - // response.statusCode(); - // if (response.body() != null) { - // message += ": " + response.body(); - // } - // logger.info("Opening an SSE stream failed. This can be safely ignored." + - // message); - // sseSink.error(new RuntimeException(message)); - // } - // // If status is OK, the lineSubscriber will handle the - // // stream - // logger.debug("Established SSE stream via GET"); - // } - // }) - ).flatMap(responseEvent -> { - int statusCode = responseEvent.responseInfo().statusCode(); - - if (statusCode >= 200 && statusCode < 300) { - - if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { - try { - // We don't support batching ATM and probably won't since the - // next version considers removing it. - McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, - responseEvent.sseEvent().data()); - - Tuple2, Iterable> idWithMessages = Tuples - .of(Optional.ofNullable(responseEvent.sseEvent().id()), List.of(message)); - - McpTransportStream sessionStream = stream != null ? stream - : new DefaultMcpTransportStream<>(this.resumableStreams, this::reconnect); - logger.debug("Connected stream {}", sessionStream.streamId()); - - return Flux.from(sessionStream.consumeSseStream(Flux.just(idWithMessages))); - - } - catch (IOException ioException) { - return Flux.error( - new McpError("Error parsing JSON-RPC message: " + responseEvent.sseEvent().data())); + Disposable connection = Flux.create(sseSink -> this.httpClient + .sendAsync(request, responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, sseSink)) + .exceptionallyCompose(e -> { + logger.warn("Error sending message", e); + sseSink.error(e); + return CompletableFuture.failedFuture(e); + })).flatMap(responseEvent -> { + int statusCode = responseEvent.responseInfo().statusCode(); + + if (statusCode >= 200 && statusCode < 300) { + + if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { + try { + // We don't support batching ATM and probably won't since + // the + // next version considers removing it. + McpSchema.JSONRPCMessage message = McpSchema + .deserializeJsonRpcMessage(this.objectMapper, responseEvent.sseEvent().data()); + + Tuple2, Iterable> idWithMessages = Tuples + .of(Optional.ofNullable(responseEvent.sseEvent().id()), List.of(message)); + + McpTransportStream sessionStream = stream != null ? stream + : new DefaultMcpTransportStream<>(this.resumableStreams, this::reconnect); + logger.debug("Connected stream {}", sessionStream.streamId()); + + return Flux.from(sessionStream.consumeSseStream(Flux.just(idWithMessages))); + + } + catch (IOException ioException) { + return Flux.error(new McpError( + "Error parsing JSON-RPC message: " + responseEvent.sseEvent().data())); + } } } - } - else if (statusCode == METHOD_NOT_ALLOWED) { // NotAllowed - logger.debug("The server does not support SSE streams, using request-response mode."); - return Flux.empty(); - } - else if (statusCode == NOT_FOUND) { - String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); - McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( - "Session not found for session ID: " + sessionIdRepresentation); - return Flux.error(exception); - } - else if (statusCode == BAD_REQUEST) { - String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); - McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( - "Session not found for session ID: " + sessionIdRepresentation); - return Flux.error(exception); - } + else if (statusCode == METHOD_NOT_ALLOWED) { // NotAllowed + logger.debug("The server does not support SSE streams, using request-response mode."); + return Flux.empty(); + } + else if (statusCode == NOT_FOUND) { + String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); + McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( + "Session not found for session ID: " + sessionIdRepresentation); + return Flux.error(exception); + } + else if (statusCode == BAD_REQUEST) { + String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); + McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( + "Session not found for session ID: " + sessionIdRepresentation); + return Flux.error(exception); + } - return Flux.error( - new McpError("Received unrecognized SSE event type: " + responseEvent.sseEvent().event())); + return Flux.error( + new McpError("Received unrecognized SSE event type: " + responseEvent.sseEvent().event())); - }).flatMap(jsonrpcMessage -> this.handler.get().apply(Mono.just(jsonrpcMessage))) + }).flatMap(jsonrpcMessage -> this.handler.get().apply(Mono.just(jsonrpcMessage))) .onErrorComplete(t -> { this.handleException(t); return true; @@ -327,13 +308,6 @@ else if (statusCode == BAD_REQUEST) { } - // private static boolean isEventStream(HttpResponse response) { - // String contentType = - // response.headers().firstValue("Content-Type").orElse("").toLowerCase(); - // return response.statusCode() >= 200 && response.statusCode() < 300 && - // contentType.contains(TEXT_EVENT_STREAM); - // } - private BodyHandler toSendMessageBodySubscriber(FluxSink sink) { BodyHandler responseBodyHandler = responseInfo -> { @@ -395,29 +369,11 @@ public Mono sendMessage(McpSchema.JSONRPCMessage sendMessage) { // Create the async request with proper body subscriber selection Mono.fromFuture(this.httpClient.sendAsync(request, this.toSendMessageBodySubscriber(responseEventSink)) - // .whenComplete((res, e) -> { - // if (e != null) { - // logger.warn("Error sending message", e); - // responseEventSink.error(e); - // } else if (res.statusCode() == NOT_FOUND) { - // String sessionIdRepresentation = - // sessionIdOrPlaceholder(transportSession); - // McpTransportSessionNotFoundException exception = new - // McpTransportSessionNotFoundException( - // "Session not found for session ID: " + sessionIdRepresentation); - // this.handleException(exception); - // responseEventSink.error(exception); - // } else if (res.statusCode() == BAD_REQUEST) { - // System.out.println("BAD_REQUEST"); - // } else { - // logger.debug("whenComplete complete: resp: {}, reqBode: {}", request, - // jsonBody); - // } - // })).doOnSubscribe(sub -> { - // logger.debug("OnSubscribe: {}, Sending message to server: {}", sub, - // jsonBody); - // } - ).subscribe(); + .exceptionallyCompose(e -> { + logger.warn("Error sending message", e); + responseEventSink.error(e); + return CompletableFuture.failedFuture(e); + })).subscribe(); }).flatMap(responseEvent -> { if (transportSession.markInitialized( diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientResiliencyTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientResiliencyTests.java index f088fa7ba..ddc896625 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientResiliencyTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientResiliencyTests.java @@ -4,10 +4,14 @@ package io.modelcontextprotocol.client; +import java.util.concurrent.CompletionException; + +import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; import io.modelcontextprotocol.spec.McpClientTransport; +import reactor.test.StepVerifier; @Timeout(15) public class HttpClientStreamableHttpAsyncClientResiliencyTests extends AbstractMcpAsyncClientResiliencyTests { @@ -17,4 +21,19 @@ protected McpClientTransport createMcpTransport() { return HttpClientStreamableHttpTransport.builder(host).build(); } + @Test + void testPingWithEaxctExceptionType() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + + disconnect(); + + StepVerifier.create(mcpAsyncClient.ping()).expectError(CompletionException.class).verify(); + + reconnect(); + + StepVerifier.create(mcpAsyncClient.ping()).expectNextCount(1).verifyComplete(); + }); + } + } From 9b0ac03f2143ff65ee9dd408f3d4cc09cd1fa78b Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Wed, 25 Jun 2025 19:09:28 +0200 Subject: [PATCH 06/13] remove redundant class Signed-off-by: Christian Tzolov --- .../client/transport/FlowSseClient.java | 211 ------------------ 1 file changed, 211 deletions(-) delete mode 100644 mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java deleted file mode 100644 index abfafa551..000000000 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java +++ /dev/null @@ -1,211 +0,0 @@ -/* -* Copyright 2024 - 2024 the original author or authors. -*/ -package io.modelcontextprotocol.client.transport; - -import java.net.URI; -import java.net.http.HttpClient; -import java.net.http.HttpRequest; -import java.net.http.HttpResponse; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.Flow; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Function; -import java.util.regex.Pattern; - -/** - * A Server-Sent Events (SSE) client implementation using Java's Flow API for reactive - * stream processing. This client establishes a connection to an SSE endpoint and - * processes the incoming event stream, parsing SSE-formatted messages into structured - * events. - * - *

- * The client supports standard SSE event fields including: - *

    - *
  • event - The event type (defaults to "message" if not specified)
  • - *
  • id - The event ID
  • - *
  • data - The event payload data
  • - *
- * - *

- * Events are delivered to a provided {@link SseEventHandler} which can process events and - * handle any errors that occur during the connection. - * - * @author Christian Tzolov - * @see SseEventHandler - * @see SseEvent - */ -public class FlowSseClient { - - private final HttpClient httpClient; - - private final HttpRequest.Builder requestBuilder; - - /** - * Pattern to extract the data content from SSE data field lines. Matches lines - * starting with "data:" and captures the remaining content. - */ - private static final Pattern EVENT_DATA_PATTERN = Pattern.compile("^data:(.+)$", Pattern.MULTILINE); - - /** - * Pattern to extract the event ID from SSE id field lines. Matches lines starting - * with "id:" and captures the ID value. - */ - private static final Pattern EVENT_ID_PATTERN = Pattern.compile("^id:(.+)$", Pattern.MULTILINE); - - /** - * Pattern to extract the event type from SSE event field lines. Matches lines - * starting with "event:" and captures the event type. - */ - private static final Pattern EVENT_TYPE_PATTERN = Pattern.compile("^event:(.+)$", Pattern.MULTILINE); - - /** - * Record class representing a Server-Sent Event with its standard fields. - * - * @param id the event ID (may be null) - * @param type the event type (defaults to "message" if not specified in the stream) - * @param data the event payload data - */ - public static record SseEvent(String id, String type, String data) { - } - - /** - * Interface for handling SSE events and errors. Implementations can process received - * events and handle any errors that occur during the SSE connection. - */ - public interface SseEventHandler { - - /** - * Called when an SSE event is received. - * @param event the received SSE event containing id, type, and data - */ - void onEvent(SseEvent event); - - /** - * Called when an error occurs during the SSE connection. - * @param error the error that occurred - */ - void onError(Throwable error); - - } - - /** - * Creates a new FlowSseClient with the specified HTTP client. - * @param httpClient the {@link HttpClient} instance to use for SSE connections - */ - public FlowSseClient(HttpClient httpClient) { - this(httpClient, HttpRequest.newBuilder()); - } - - /** - * Creates a new FlowSseClient with the specified HTTP client and request builder. - * @param httpClient the {@link HttpClient} instance to use for SSE connections - * @param requestBuilder the {@link HttpRequest.Builder} to use for SSE requests - */ - public FlowSseClient(HttpClient httpClient, HttpRequest.Builder requestBuilder) { - this.httpClient = httpClient; - this.requestBuilder = requestBuilder; - } - - /** - * Subscribes to an SSE endpoint and processes the event stream. - * - *

- * This method establishes a connection to the specified URL and begins processing the - * SSE stream. Events are parsed and delivered to the provided event handler. The - * connection remains active until either an error occurs or the server closes the - * connection. - * @param url the SSE endpoint URL to connect to - * @param eventHandler the handler that will receive SSE events and error - * notifications - * @throws RuntimeException if the connection fails with a non-200 status code - */ - public void subscribe(String url, SseEventHandler eventHandler) { - HttpRequest request = this.requestBuilder.copy() - .uri(URI.create(url)) - .header("Accept", "text/event-stream") - .header("Cache-Control", "no-cache") - .GET() - .build(); - - StringBuilder eventBuilder = new StringBuilder(); - AtomicReference currentEventId = new AtomicReference<>(); - AtomicReference currentEventType = new AtomicReference<>("message"); - - Flow.Subscriber lineSubscriber = new Flow.Subscriber<>() { - private Flow.Subscription subscription; - - @Override - public void onSubscribe(Flow.Subscription subscription) { - this.subscription = subscription; - subscription.request(Long.MAX_VALUE); - } - - @Override - public void onNext(String line) { - if (line.isEmpty()) { - // Empty line means end of event - if (eventBuilder.length() > 0) { - String eventData = eventBuilder.toString(); - SseEvent event = new SseEvent(currentEventId.get(), currentEventType.get(), eventData.trim()); - eventHandler.onEvent(event); - eventBuilder.setLength(0); - } - } - else { - if (line.startsWith("data:")) { - var matcher = EVENT_DATA_PATTERN.matcher(line); - if (matcher.find()) { - eventBuilder.append(matcher.group(1).trim()).append("\n"); - } - } - else if (line.startsWith("id:")) { - var matcher = EVENT_ID_PATTERN.matcher(line); - if (matcher.find()) { - currentEventId.set(matcher.group(1).trim()); - } - } - else if (line.startsWith("event:")) { - var matcher = EVENT_TYPE_PATTERN.matcher(line); - if (matcher.find()) { - currentEventType.set(matcher.group(1).trim()); - } - } - } - subscription.request(1); - } - - @Override - public void onError(Throwable throwable) { - eventHandler.onError(throwable); - } - - @Override - public void onComplete() { - // Handle any remaining event data - if (eventBuilder.length() > 0) { - String eventData = eventBuilder.toString(); - SseEvent event = new SseEvent(currentEventId.get(), currentEventType.get(), eventData.trim()); - eventHandler.onEvent(event); - } - } - }; - - Function, HttpResponse.BodySubscriber> subscriberFactory = subscriber -> HttpResponse.BodySubscribers - .fromLineSubscriber(subscriber); - - CompletableFuture> future = this.httpClient.sendAsync(request, - info -> subscriberFactory.apply(lineSubscriber)); - - future.thenAccept(response -> { - int status = response.statusCode(); - if (status != 200 && status != 201 && status != 202 && status != 206) { - throw new RuntimeException("Failed to connect to SSE stream. Unexpected status code: " + status); - } - }).exceptionally(throwable -> { - eventHandler.onError(throwable); - return null; - }); - } - -} From b4273a0b2cb865ea5eb85842c38e92837772d2f6 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Thu, 26 Jun 2025 13:59:54 +0200 Subject: [PATCH 07/13] Add sse lost-connection test to verify that no timeouts are thrown Signed-off-by: Christian Tzolov --- ...pSseMcpAsyncClientLostConnectionTests.java | 137 ++++++++++++++++++ 1 file changed, 137 insertions(+) create mode 100644 mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientLostConnectionTests.java diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientLostConnectionTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientLostConnectionTests.java new file mode 100644 index 000000000..0a72b785d --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientLostConnectionTests.java @@ -0,0 +1,137 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import static org.assertj.core.api.Assertions.assertThatCode; + +import java.io.IOException; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.Network; +import org.testcontainers.containers.ToxiproxyContainer; +import org.testcontainers.containers.wait.strategy.Wait; + +import eu.rekawek.toxiproxy.Proxy; +import eu.rekawek.toxiproxy.ToxiproxyClient; +import eu.rekawek.toxiproxy.model.ToxicDirection; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import reactor.test.StepVerifier; + +@Timeout(15) +public class HttpSseMcpAsyncClientLostConnectionTests { + + private static final Logger logger = LoggerFactory.getLogger(HttpSseMcpAsyncClientLostConnectionTests.class); + + static Network network = Network.newNetwork(); + static String host = "http://localhost:3001"; + + // Uses the https://github.com/tzolov/mcp-everything-server-docker-image + @SuppressWarnings("resource") + static GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js sse") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withNetwork(network) + .withNetworkAliases("everything-server") + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + static ToxiproxyContainer toxiproxy = new ToxiproxyContainer("ghcr.io/shopify/toxiproxy:2.5.0").withNetwork(network) + .withExposedPorts(8474, 3000); + + static Proxy proxy; + + static { + container.start(); + + toxiproxy.start(); + + final ToxiproxyClient toxiproxyClient = new ToxiproxyClient(toxiproxy.getHost(), toxiproxy.getControlPort()); + try { + proxy = toxiproxyClient.createProxy("everything-server", "0.0.0.0:3000", "everything-server:3001"); + } + catch (IOException e) { + throw new RuntimeException("Can't create proxy!", e); + } + + final String ipAddressViaToxiproxy = toxiproxy.getHost(); + final int portViaToxiproxy = toxiproxy.getMappedPort(3000); + + host = "http://" + ipAddressViaToxiproxy + ":" + portViaToxiproxy; + } + + static void disconnect() { + long start = System.nanoTime(); + try { + proxy.toxics().resetPeer("RESET_DOWNSTREAM", ToxicDirection.DOWNSTREAM, 0); + proxy.toxics().resetPeer("RESET_UPSTREAM", ToxicDirection.UPSTREAM, 0); + logger.info("Disconnect took {} ms", Duration.ofNanos(System.nanoTime() - start).toMillis()); + } + catch (IOException e) { + throw new RuntimeException("Failed to disconnect", e); + } + } + + static void reconnect() { + long start = System.nanoTime(); + try { + proxy.toxics().get("RESET_UPSTREAM").remove(); + proxy.toxics().get("RESET_DOWNSTREAM").remove(); + logger.info("Reconnect took {} ms", Duration.ofNanos(System.nanoTime() - start).toMillis()); + } + catch (IOException e) { + throw new RuntimeException("Failed to reconnect", e); + } + } + + McpAsyncClient client(McpClientTransport transport) { + AtomicReference client = new AtomicReference<>(); + + assertThatCode(() -> { + McpClient.AsyncSpec builder = McpClient.async(transport) + .requestTimeout(Duration.ofSeconds(14)) + .initializationTimeout(Duration.ofSeconds(2)) + .capabilities(McpSchema.ClientCapabilities.builder().roots(true).build()); + client.set(builder.build()); + }).doesNotThrowAnyException(); + + return client.get(); + } + + void withClient(McpClientTransport transport, Consumer c) { + var client = client(transport); + try { + c.accept(client); + } + finally { + StepVerifier.create(client.closeGracefully()).expectComplete().verify(Duration.ofSeconds(10)); + } + } + + @Test + void testPingWithEaxctExceptionType() { + withClient(HttpClientSseClientTransport.builder(host).build(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + + disconnect(); + + // Veryfiy that the exception type is IOException and not TimeoutException + StepVerifier.create(mcpAsyncClient.ping()).expectError(IOException.class).verify(); + + reconnect(); + + StepVerifier.create(mcpAsyncClient.ping()).expectNextCount(1).verifyComplete(); + }); + } + +} From a644fcf88228e3841e9392857ab25b5dee307064 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Fri, 27 Jun 2025 15:49:01 +0200 Subject: [PATCH 08/13] refactor: improve HTTP client error handling and cleanup debug code - Replace exceptionallyCompose with whenComplete for better async error handling - Make ResponseSubscribers class package-private Signed-off-by: Christian Tzolov --- .../HttpClientStreamableHttpTransport.java | 35 ++++++++++++------- .../client/transport/ResponseSubscribers.java | 3 +- 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java index 68bec4fed..6223abb84 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java @@ -162,10 +162,13 @@ private Publisher createDelete(String sessionId) { .build(); return Mono.fromFuture(() -> this.httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString()) - .exceptionallyCompose(e -> { - logger.warn("Error sending message", e); - - return CompletableFuture.failedFuture(e); + .whenComplete((response, throwable) -> { + if (throwable != null) { + logger.warn("Error sending message", throwable); + } + else { + logger.debug("SSE connection established successfully"); + } })).doOnError(e -> logger.warn("Got error when closing transport", e)).then(); }); } @@ -233,10 +236,14 @@ private Mono reconnect(McpTransportStream stream) { Disposable connection = Flux.create(sseSink -> this.httpClient .sendAsync(request, responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, sseSink)) - .exceptionallyCompose(e -> { - logger.warn("Error sending message", e); - sseSink.error(e); - return CompletableFuture.failedFuture(e); + .whenComplete((response, throwable) -> { + if (throwable != null) { + logger.warn("Error sending message", throwable); + sseSink.error(throwable); + } + else { + logger.debug("SSE connection established successfully"); + } })).flatMap(responseEvent -> { int statusCode = responseEvent.responseInfo().statusCode(); @@ -369,10 +376,14 @@ public Mono sendMessage(McpSchema.JSONRPCMessage sendMessage) { // Create the async request with proper body subscriber selection Mono.fromFuture(this.httpClient.sendAsync(request, this.toSendMessageBodySubscriber(responseEventSink)) - .exceptionallyCompose(e -> { - logger.warn("Error sending message", e); - responseEventSink.error(e); - return CompletableFuture.failedFuture(e); + .whenComplete((response, throwable) -> { + if (throwable != null) { + logger.warn("Error sending message", throwable); + responseEventSink.error(throwable); + } + else { + logger.debug("SSE connection established successfully"); + } })).subscribe(); }).flatMap(responseEvent -> { diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java index a49ef7255..22c632472 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java @@ -20,7 +20,7 @@ import reactor.core.publisher.BaseSubscriber; import reactor.core.publisher.FluxSink; -public class ResponseSubscribers { +class ResponseSubscribers { /** * Represents a Server-Sent Event with its standard fields. @@ -328,7 +328,6 @@ protected void hookOnSubscribe(Subscription subscription) { @Override protected void hookOnNext(String line) { - System.out.println(">>>>>>>>>>>>>> Received line: " + line); } /** From 9afbb8ed719f24317d2f0de019675ecc69e2e0aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Thu, 3 Jul 2025 19:10:03 +0200 Subject: [PATCH 09/13] Improve error handling and logging MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Dariusz Jędrzejczyk --- .../WebClientStreamableHttpTransport.java | 13 +++--- .../HttpClientSseClientTransport.java | 10 ++--- .../HttpClientStreamableHttpTransport.java | 41 +++++++------------ .../spec/McpClientSession.java | 20 ++++----- ...eamableHttpAsyncClientResiliencyTests.java | 2 +- mcp/src/test/resources/logback.xml | 3 -- 6 files changed, 36 insertions(+), 53 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java index dd7c65396..e60451706 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java @@ -129,11 +129,10 @@ private DefaultMcpTransportSession createTransportSession() { Function> onClose = sessionId -> sessionId == null ? Mono.empty() : webClient.delete().uri(this.endpoint).headers(httpHeaders -> { httpHeaders.add("mcp-session-id", sessionId); - }) - .retrieve() - .toBodilessEntity() - .doOnError(e -> logger.warn("Got error when closing transport", e)) - .then(); + }).retrieve().toBodilessEntity().onErrorComplete(e -> { + logger.warn("Got error when closing transport", e); + return true; + }).then(); return new DefaultMcpTransportSession(onClose); } @@ -305,12 +304,12 @@ else if (mediaType.isCompatibleWith(MediaType.APPLICATION_JSON)) { } }) .flatMap(jsonRpcMessage -> this.handler.get().apply(Mono.just(jsonRpcMessage))) - .onErrorResume(t -> { + .onErrorComplete(t -> { // handle the error first this.handleException(t); // inform the caller of sendMessage sink.error(t); - return Flux.empty(); + return true; }) .doFinally(s -> { Disposable ref = disposableRef.getAndSet(null); diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index 3eb76b869..20077c4a7 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -336,7 +336,6 @@ public Mono connect(Function, Mono> h Disposable connection = Flux.create(sseSink -> this.httpClient .sendAsync(request, responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, sseSink)) .exceptionallyCompose(e -> { - logger.warn("Error sending message", e); sseSink.error(e); return CompletableFuture.failedFuture(e); })).flatMap(responseEvent -> { @@ -386,6 +385,9 @@ else if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { } return Mono.empty(); + }).onErrorComplete(t -> { + logger.warn("SSE stream observed an error", t); + return true; }).doFinally(s -> { Disposable ref = this.sseSubscription.getAndSet(null); if (ref != null && !ref.isDisposed()) { @@ -459,11 +461,7 @@ private Mono> sendHttpPost(final String endpoint, final Strin .POST(HttpRequest.BodyPublishers.ofString(body)) .build(); - return Mono.fromFuture( - httpClient.sendAsync(request, HttpResponse.BodyHandlers.discarding()).exceptionallyCompose(e -> { - logger.warn("Error sending message", e); - return CompletableFuture.failedFuture(e); - })); + return Mono.fromFuture(httpClient.sendAsync(request, HttpResponse.BodyHandlers.discarding())); } /** diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java index 6223abb84..fbff9a240 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java @@ -138,7 +138,10 @@ public Mono connect(Function, Mono { + logger.warn("Eager connect failed ", t); + return true; + }).then(); } return Mono.empty(); }); @@ -151,26 +154,14 @@ private DefaultMcpTransportSession createTransportSession() { } private Publisher createDelete(String sessionId) { - - return Mono.defer(() -> { // Do we need to defer this? - - HttpRequest request = this.requestBuilder.copy() - .uri(Utils.resolveUri(this.baseUri, this.endpoint)) - .header("Cache-Control", "no-cache") - .header("mcp-session-id", sessionId) - .DELETE() - .build(); - - return Mono.fromFuture(() -> this.httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString()) - .whenComplete((response, throwable) -> { - if (throwable != null) { - logger.warn("Error sending message", throwable); - } - else { - logger.debug("SSE connection established successfully"); - } - })).doOnError(e -> logger.warn("Got error when closing transport", e)).then(); - }); + HttpRequest request = this.requestBuilder.copy() + .uri(Utils.resolveUri(this.baseUri, this.endpoint)) + .header("Cache-Control", "no-cache") + .header("mcp-session-id", sessionId) + .DELETE() + .build(); + + return Mono.fromFuture(() -> this.httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())).then(); } @Override @@ -238,7 +229,6 @@ private Mono reconnect(McpTransportStream stream) { .sendAsync(request, responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, sseSink)) .whenComplete((response, throwable) -> { if (throwable != null) { - logger.warn("Error sending message", throwable); sseSink.error(throwable); } else { @@ -378,13 +368,12 @@ public Mono sendMessage(McpSchema.JSONRPCMessage sendMessage) { Mono.fromFuture(this.httpClient.sendAsync(request, this.toSendMessageBodySubscriber(responseEventSink)) .whenComplete((response, throwable) -> { if (throwable != null) { - logger.warn("Error sending message", throwable); responseEventSink.error(throwable); } else { logger.debug("SSE connection established successfully"); } - })).subscribe(); + })).onErrorComplete().subscribe(); }).flatMap(responseEvent -> { if (transportSession.markInitialized( @@ -467,12 +456,12 @@ else if (statusCode == BAD_REQUEST) { return Flux.error( new RuntimeException("Failed to send message: " + responseEvent)); - }).flatMap(jsonRpcMessage -> this.handler.get().apply(Mono.just(jsonRpcMessage))).onErrorResume(t -> { + }).flatMap(jsonRpcMessage -> this.handler.get().apply(Mono.just(jsonRpcMessage))).onErrorComplete(t -> { // handle the error first this.handleException(t); // inform the caller of sendMessage messageSink.error(t); - return Flux.empty(); + return true; }).doFinally(s -> { logger.debug("SendMessage finally: {}", s); Disposable ref = disposableRef.getAndSet(null); diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index fa0853d81..14b36d451 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -141,14 +141,18 @@ else if (message instanceof McpSchema.JSONRPCRequest request) { var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, error.getMessage(), null)); - return this.transport.sendMessage(errorResponse).then(Mono.empty()); - }).flatMap(this.transport::sendMessage).subscribe(); + return Mono.just(errorResponse); + }).flatMap(this.transport::sendMessage).onErrorComplete(t -> { + logger.warn("Issue sending response to the client, ", t); + return true; + }).subscribe(); } else if (message instanceof McpSchema.JSONRPCNotification notification) { logger.debug("Received notification: {}", notification); - handleIncomingNotification(notification) - .doOnError(error -> logger.error("Error handling notification: {}", error.getMessage())) - .subscribe(); + handleIncomingNotification(notification).onErrorComplete(t -> { + logger.error("Error handling notification: {}", t.getMessage()); + return true; + }).subscribe(); } else { logger.warn("Received unknown message type: {}", message); @@ -171,11 +175,7 @@ private Mono handleIncomingRequest(McpSchema.JSONRPCR } return handler.handle(request.params()) - .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)) - .onErrorResume(error -> Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), - null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, - error.getMessage(), null)))); // TODO: add error message - // through the data field + .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)); }); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientResiliencyTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientResiliencyTests.java index ddc896625..66e4c4c5c 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientResiliencyTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientResiliencyTests.java @@ -22,7 +22,7 @@ protected McpClientTransport createMcpTransport() { } @Test - void testPingWithEaxctExceptionType() { + void testPingWithExactExceptionType() { withClient(createMcpTransport(), mcpAsyncClient -> { StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); diff --git a/mcp/src/test/resources/logback.xml b/mcp/src/test/resources/logback.xml index d860fb985..0246d6c75 100644 --- a/mcp/src/test/resources/logback.xml +++ b/mcp/src/test/resources/logback.xml @@ -16,9 +16,6 @@ - - - From 61702830aa05aeb45356efd9abf43f3c2cef5e19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Thu, 3 Jul 2025 20:01:06 +0200 Subject: [PATCH 10/13] Remove union-like type, remove unnecessary code, centralize deserialization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Dariusz Jędrzejczyk --- .../HttpClientSseClientTransport.java | 58 ++++++------ .../HttpClientStreamableHttpTransport.java | 74 ++++++++------- .../client/transport/ResponseSubscribers.java | 94 +++++++------------ 3 files changed, 104 insertions(+), 122 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index 20077c4a7..ab48fc0f7 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -338,7 +338,9 @@ public Mono connect(Function, Mono> h .exceptionallyCompose(e -> { sseSink.error(e); return CompletableFuture.failedFuture(e); - })).flatMap(responseEvent -> { + })) + .map(responseEvent -> (ResponseSubscribers.SseResponseEvent) responseEvent) + .flatMap(responseEvent -> { if (isClosing) { return Mono.empty(); } @@ -378,22 +380,23 @@ else if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { return Flux.error( new RuntimeException("Failed to send message: " + responseEvent)); - }).flatMap(jsonRpcMessage -> handler.apply(Mono.just(jsonRpcMessage))).onErrorResume(t -> { + }) + .flatMap(jsonRpcMessage -> handler.apply(Mono.just(jsonRpcMessage))) + .onErrorComplete(t -> { if (!isClosing) { - logger.error("SSE connection error", t); + logger.warn("SSE stream observed an error", t); sink.error(t); } - return Mono.empty(); - - }).onErrorComplete(t -> { - logger.warn("SSE stream observed an error", t); return true; - }).doFinally(s -> { + }) + .doFinally(s -> { Disposable ref = this.sseSubscription.getAndSet(null); if (ref != null && !ref.isDisposed()) { ref.dispose(); } - }).contextWrite(sink.contextView()).subscribe(); + }) + .contextWrite(sink.contextView()) + .subscribe(); this.sseSubscription.set(connection); }); @@ -417,28 +420,19 @@ public Mono sendMessage(JSONRPCMessage message) { return Mono.empty(); } - try { - return this.serializeMessage(message) - .flatMap(body -> sendHttpPost(messageEndpointUri, body)) - .doOnNext(response -> { - if (response.statusCode() != 200 && response.statusCode() != 201 && response.statusCode() != 202 - && response.statusCode() != 206) { - logger.error("Error sending message: {}", response.statusCode()); - } - }) - .doOnError(error -> { - if (!isClosing) { - logger.error("Error sending message: {}", error.getMessage()); - } - }) - .then(); - } - catch (Exception e) { - if (!isClosing) { - return Mono.error(new RuntimeException("Failed to serialize message", e)); - } - return Mono.empty(); - } + return this.serializeMessage(message) + .flatMap(body -> sendHttpPost(messageEndpointUri, body)) + .doOnNext(response -> { + if (response.statusCode() != 200 && response.statusCode() != 201 && response.statusCode() != 202 + && response.statusCode() != 206) { + logger.error("Error sending message: {}", response.statusCode()); + } + }) + .doOnError(error -> { + if (!isClosing) { + logger.error("Error sending message: {}", error.getMessage()); + } + }); }).then(); } @@ -449,6 +443,7 @@ private Mono serializeMessage(final JSONRPCMessage message) { return Mono.just(objectMapper.writeValueAsString(message)); } catch (IOException e) { + // TODO: why McpError and not RuntimeException? return Mono.error(new McpError("Failed to serialize message")); } }); @@ -461,6 +456,7 @@ private Mono> sendHttpPost(final String endpoint, final Strin .POST(HttpRequest.BodyPublishers.ofString(body)) .build(); + // TODO: why discard the body? return Mono.fromFuture(httpClient.sendAsync(request, HttpResponse.BodyHandlers.discarding())); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java index fbff9a240..0ece5a1c2 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java @@ -13,7 +13,6 @@ import java.time.Duration; import java.util.List; import java.util.Optional; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Function; @@ -60,7 +59,7 @@ * "https://modelcontextprotocol.io/specification/2024-11-05/basic/transports#http-with-sse">"HTTP * with SSE" transport. In order to communicate over the phased-out * 2024-11-05 protocol, use {@link HttpClientSseClientTransport} or - * {@link WebFluxSseClientTransport}. + * {@code WebFluxSseClientTransport}. *

* * @author Christian Tzolov @@ -234,7 +233,9 @@ private Mono reconnect(McpTransportStream stream) { else { logger.debug("SSE connection established successfully"); } - })).flatMap(responseEvent -> { + })) + .map(responseEvent -> (ResponseSubscribers.SseResponseEvent) responseEvent) + .flatMap(responseEvent -> { int statusCode = responseEvent.responseInfo().statusCode(); if (statusCode >= 200 && statusCode < 300) { @@ -319,12 +320,11 @@ private BodyHandler toSendMessageBodySubscriber(FluxSink si else if (contentType.contains(APPLICATION_JSON)) { // For JSON responses and others, use string subscriber logger.debug("Received response, using string subscriber"); - return ResponseSubscribers.jsonoBodySubscriber(responseInfo, sink); + return ResponseSubscribers.aggregateBodySubscriber(responseInfo, sink); } logger.debug("Received Bodyless response, using discarding subscriber"); - // return HttpResponse.BodySubscribers.discarding(); - return ResponseSubscribers.bodylessBodySubscriber(responseInfo, sink); + return ResponseSubscribers.bodilessBodySubscriber(responseInfo, sink); }; return responseBodyHandler; @@ -404,34 +404,42 @@ public Mono sendMessage(McpSchema.JSONRPCMessage sendMessage) { return Flux.empty(); } else if (contentType.contains(TEXT_EVENT_STREAM)) { - try { - // We don't support batching ATM and probably won't since the - // next version considers removing it. - McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, - responseEvent.sseEvent().data()); - - Tuple2, Iterable> idWithMessages = Tuples - .of(Optional.ofNullable(responseEvent.sseEvent().id()), List.of(message)); - - McpTransportStream sessionStream = new DefaultMcpTransportStream<>( - this.resumableStreams, this::reconnect); - - logger.debug("Connected stream {}", sessionStream.streamId()); - - messageSink.success(); - - return Flux.from(sessionStream.consumeSseStream(Flux.just(idWithMessages))); - - } - catch (IOException ioException) { - return Flux.error( - new McpError("Error parsing JSON-RPC message: " + responseEvent.sseEvent().data())); - } + return Flux.just(((ResponseSubscribers.SseResponseEvent) responseEvent).sseEvent()) + .flatMap(sseEvent -> { + try { + // We don't support batching ATM and probably won't + // since the + // next version considers removing it. + McpSchema.JSONRPCMessage message = McpSchema + .deserializeJsonRpcMessage(this.objectMapper, sseEvent.data()); + + Tuple2, Iterable> idWithMessages = Tuples + .of(Optional.ofNullable(sseEvent.id()), List.of(message)); + + McpTransportStream sessionStream = new DefaultMcpTransportStream<>( + this.resumableStreams, this::reconnect); + + logger.debug("Connected stream {}", sessionStream.streamId()); + + messageSink.success(); + + return Flux.from(sessionStream.consumeSseStream(Flux.just(idWithMessages))); + } + catch (IOException ioException) { + return Flux.error( + new McpError("Error parsing JSON-RPC message: " + sseEvent.data())); + } + }); } else if (contentType.contains(APPLICATION_JSON)) { - McpSchema.JSONRPCMessage jsonRpcResponse = responseEvent.jsonRpcMessage(); messageSink.success(); - return Flux.just(jsonRpcResponse); // ??? + String data = ((ResponseSubscribers.AggregateResponseEvent) responseEvent).data(); + try { + return Mono.just(McpSchema.deserializeJsonRpcMessage(objectMapper, data)); + } + catch (IOException e) { + return Mono.error(e); + } } logger.warn("Unknown media type {} returned for POST in session {}", contentType, sessionRepresentation); @@ -489,9 +497,9 @@ public T unmarshalFrom(Object data, TypeReference typeRef) { */ public static class Builder { - private ObjectMapper objectMapper; + private final String baseUri; - private String baseUri; + private ObjectMapper objectMapper; private HttpClient.Builder clientBuilder = HttpClient.newBuilder() .version(HttpClient.Version.HTTP_1_1) diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java index 22c632472..7df8f2d84 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java @@ -3,7 +3,6 @@ */ package io.modelcontextprotocol.client.transport; -import java.io.IOException; import java.net.http.HttpResponse; import java.net.http.HttpResponse.BodySubscriber; import java.net.http.HttpResponse.ResponseInfo; @@ -13,10 +12,6 @@ import org.reactivestreams.FlowAdapters; import org.reactivestreams.Subscription; -import com.fasterxml.jackson.databind.ObjectMapper; - -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; import reactor.core.publisher.BaseSubscriber; import reactor.core.publisher.FluxSink; @@ -29,36 +24,41 @@ class ResponseSubscribers { * @param event the event type, may be {@code null} (defaults to "message") * @param data the event payload data, never {@code null} */ - public static record SseEvent(String id, String event, String data) { + record SseEvent(String id, String event, String data) { } - public record ResponseEvent(ResponseInfo responseInfo, SseEvent sseEvent, JSONRPCMessage jsonRpcMessage) { + sealed interface ResponseEvent permits SseResponseEvent, AggregateResponseEvent, DummyEvent { - public ResponseEvent(ResponseInfo responseInfo, SseEvent sseEvent) { - this(responseInfo, sseEvent, null); - } + ResponseInfo responseInfo(); - public ResponseEvent(ResponseInfo responseInfo, JSONRPCMessage jsonRpcMessage) { - this(responseInfo, null, jsonRpcMessage); - } } - public static BodySubscriber sseToBodySubscriber(ResponseInfo responseInfo, FluxSink sink) { + record DummyEvent(ResponseInfo responseInfo) implements ResponseEvent { + + } + + record SseResponseEvent(ResponseInfo responseInfo, SseEvent sseEvent) implements ResponseEvent { + } + + record AggregateResponseEvent(ResponseInfo responseInfo, String data) implements ResponseEvent { + } + + static BodySubscriber sseToBodySubscriber(ResponseInfo responseInfo, FluxSink sink) { return HttpResponse.BodySubscribers .fromLineSubscriber(FlowAdapters.toFlowSubscriber(new SseLineSubscriber(responseInfo, sink))); } - public static BodySubscriber jsonoBodySubscriber(ResponseInfo responseInfo, FluxSink sink) { + static BodySubscriber aggregateBodySubscriber(ResponseInfo responseInfo, FluxSink sink) { return HttpResponse.BodySubscribers - .fromLineSubscriber(FlowAdapters.toFlowSubscriber(new JsonLineSubscriber(responseInfo, sink))); + .fromLineSubscriber(FlowAdapters.toFlowSubscriber(new AggregateSubscriber(responseInfo, sink))); } - public static BodySubscriber bodylessBodySubscriber(ResponseInfo responseInfo, FluxSink sink) { + static BodySubscriber bodilessBodySubscriber(ResponseInfo responseInfo, FluxSink sink) { return HttpResponse.BodySubscribers - .fromLineSubscriber(FlowAdapters.toFlowSubscriber(new BodylessResponseLineSubscriber(responseInfo, sink))); + .fromLineSubscriber(FlowAdapters.toFlowSubscriber(new BodilessResponseLineSubscriber(responseInfo, sink))); } - public static class SseLineSubscriber extends BaseSubscriber { + static class SseLineSubscriber extends BaseSubscriber { /** * Pattern to extract data content from SSE "data:" lines. @@ -150,7 +150,7 @@ protected void hookOnNext(String line) { String eventData = this.eventBuilder.toString(); SseEvent sseEvent = new SseEvent(currentEventId.get(), currentEventType.get(), eventData.trim()); - this.sink.next(new ResponseEvent(responseInfo, sseEvent)); + this.sink.next(new SseResponseEvent(responseInfo, sseEvent)); this.eventBuilder.setLength(0); } } @@ -184,7 +184,7 @@ protected void hookOnComplete() { if (this.eventBuilder.length() > 0) { String eventData = this.eventBuilder.toString(); SseEvent sseEvent = new SseEvent(currentEventId.get(), currentEventType.get(), eventData.trim()); - this.sink.next(new ResponseEvent(responseInfo, sseEvent)); + this.sink.next(new SseResponseEvent(responseInfo, sseEvent)); } this.sink.complete(); } @@ -200,9 +200,7 @@ protected void hookOnError(Throwable throwable) { } - public static class JsonLineSubscriber extends BaseSubscriber { - - private ObjectMapper objectMapper = new ObjectMapper(); + static class AggregateSubscriber extends BaseSubscriber { /** * The sink for emitting parsed response events. @@ -225,7 +223,7 @@ public static class JsonLineSubscriber extends BaseSubscriber { * @param sink the {@link FluxSink} to emit parsed {@link ResponseEvent} objects * to */ - public JsonLineSubscriber(ResponseInfo responseInfo, FluxSink sink) { + public AggregateSubscriber(ResponseInfo responseInfo, FluxSink sink) { this.sink = sink; this.eventBuilder = new StringBuilder(); this.responseInfo = responseInfo; @@ -237,19 +235,10 @@ public JsonLineSubscriber(ResponseInfo responseInfo, FluxSink sin */ @Override protected void hookOnSubscribe(Subscription subscription) { - - sink.onRequest(n -> { - if (subscription != null) { - subscription.request(n); - } - }); + sink.onRequest(subscription::request); // Register disposal callback to cancel subscription when Flux is disposed - sink.onDispose(() -> { - if (subscription != null) { - subscription.cancel(); - } - }); + sink.onDispose(subscription::cancel); } /** @@ -267,15 +256,8 @@ protected void hookOnNext(String line) { @Override protected void hookOnComplete() { if (this.eventBuilder.length() > 0) { - String jsonData = this.eventBuilder.toString(); - try { - McpSchema.JSONRPCMessage jsonRpcResponse = McpSchema.deserializeJsonRpcMessage(objectMapper, - jsonData); - this.sink.next(new ResponseEvent(responseInfo, jsonRpcResponse)); - } - catch (IOException e) { - sink.error(e); - } + String data = this.eventBuilder.toString(); + this.sink.next(new AggregateResponseEvent(responseInfo, data)); } this.sink.complete(); } @@ -291,7 +273,7 @@ protected void hookOnError(Throwable throwable) { } - public static class BodylessResponseLineSubscriber extends BaseSubscriber { + static class BodilessResponseLineSubscriber extends BaseSubscriber { /** * The sink for emitting parsed response events. @@ -300,7 +282,7 @@ public static class BodylessResponseLineSubscriber extends BaseSubscriber sink) { + public BodilessResponseLineSubscriber(ResponseInfo responseInfo, FluxSink sink) { this.sink = sink; this.responseInfo = responseInfo; } @@ -313,29 +295,25 @@ public BodylessResponseLineSubscriber(ResponseInfo responseInfo, FluxSink { - if (subscription != null) { - subscription.request(n); - } + subscription.request(n); }); // Register disposal callback to cancel subscription when Flux is disposed sink.onDispose(() -> { - if (subscription != null) { - subscription.cancel(); - } + subscription.cancel(); }); } - @Override - protected void hookOnNext(String line) { - } - /** * Called when the upstream line source completes normally. */ @Override protected void hookOnComplete() { - this.sink.next(new ResponseEvent(responseInfo, new SseEvent(null, null, null))); + // emit dummy event to be able to inspect the response info + // this is a shortcut allowing for a more streamlined processing using + // operator composition instead of having to deal with the CompletableFuture + // along the Subscriber for inspecting the result + this.sink.next(new DummyEvent(responseInfo)); this.sink.complete(); } From 7315a53329bc32f0e15193580dd6613163dc9431 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Thu, 3 Jul 2025 20:18:40 +0200 Subject: [PATCH 11/13] Reduce the volume of logs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Dariusz Jędrzejczyk --- .../mcp-spring-webmvc/src/test/resources/logback.xml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mcp-spring/mcp-spring-webmvc/src/test/resources/logback.xml b/mcp-spring/mcp-spring-webmvc/src/test/resources/logback.xml index bc1140bb5..d4ccbc173 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/resources/logback.xml +++ b/mcp-spring/mcp-spring-webmvc/src/test/resources/logback.xml @@ -9,16 +9,16 @@ - + - + - + - + From cf02a57f5167476c2a7d6142e73daa5dcfc3a949 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Thu, 3 Jul 2025 20:24:47 +0200 Subject: [PATCH 12/13] Remove subscription null check as the subscription can not be null MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Dariusz Jędrzejczyk --- .../client/transport/ResponseSubscribers.java | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java index 7df8f2d84..f4f7c4f88 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java @@ -123,16 +123,12 @@ public SseLineSubscriber(ResponseInfo responseInfo, FluxSink sink protected void hookOnSubscribe(Subscription subscription) { sink.onRequest(n -> { - if (subscription != null) { - subscription.request(n); - } + subscription.request(n); }); // Register disposal callback to cancel subscription when Flux is disposed sink.onDispose(() -> { - if (subscription != null) { - subscription.cancel(); - } + subscription.cancel(); }); } From 4b14fdfda000ed7a44a792f235e5dba3fff4bf3c Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Thu, 3 Jul 2025 21:25:39 +0200 Subject: [PATCH 13/13] Address rewview comments Signed-off-by: Christian Tzolov --- ...AbstractMcpAsyncClientResiliencyTests.java | 6 +- .../HttpClientStreamableHttpTransport.java | 38 +++++++----- .../client/transport/ResponseSubscribers.java | 62 ++++--------------- ...AbstractMcpAsyncClientResiliencyTests.java | 4 ++ ...eamableHttpAsyncClientResiliencyTests.java | 4 +- 5 files changed, 44 insertions(+), 70 deletions(-) diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java index 6748eb75c..22e8f195b 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java @@ -79,7 +79,7 @@ public abstract class AbstractMcpAsyncClientResiliencyTests { host = "http://" + ipAddressViaToxiproxy + ":" + portViaToxiproxy; } - private static void disconnect() { + static void disconnect() { long start = System.nanoTime(); try { // disconnect @@ -96,7 +96,7 @@ private static void disconnect() { } } - private static void reconnect() { + static void reconnect() { long start = System.nanoTime(); try { proxy.toxics().get("RESET_UPSTREAM").remove(); @@ -110,7 +110,7 @@ private static void reconnect() { } } - private static void restartMcpServer() { + static void restartMcpServer() { container.stop(); container.start(); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java index 0ece5a1c2..4cf1690ff 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java @@ -13,6 +13,7 @@ import java.time.Duration; import java.util.List; import java.util.Optional; +import java.util.concurrent.CompletionException; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Function; @@ -135,7 +136,7 @@ public static Builder builder(String baseUri) { public Mono connect(Function, Mono> handler) { return Mono.deferContextual(ctx -> { this.handler.set(handler); - if (openConnectionOnStartup) { + if (this.openConnectionOnStartup) { logger.debug("Eagerly opening connection on startup"); return this.reconnect(null).onErrorComplete(t -> { logger.warn("Eager connect failed ", t); @@ -286,6 +287,7 @@ else if (statusCode == BAD_REQUEST) { }).flatMap(jsonrpcMessage -> this.handler.get().apply(Mono.just(jsonrpcMessage))) + .onErrorMap(CompletionException.class, t -> t.getCause()) .onErrorComplete(t -> { this.handleException(t); return true; @@ -373,7 +375,7 @@ public Mono sendMessage(McpSchema.JSONRPCMessage sendMessage) { else { logger.debug("SSE connection established successfully"); } - })).onErrorComplete().subscribe(); + })).onErrorMap(CompletionException.class, t -> t.getCause()).onErrorComplete().subscribe(); }).flatMap(responseEvent -> { if (transportSession.markInitialized( @@ -464,19 +466,25 @@ else if (statusCode == BAD_REQUEST) { return Flux.error( new RuntimeException("Failed to send message: " + responseEvent)); - }).flatMap(jsonRpcMessage -> this.handler.get().apply(Mono.just(jsonRpcMessage))).onErrorComplete(t -> { - // handle the error first - this.handleException(t); - // inform the caller of sendMessage - messageSink.error(t); - return true; - }).doFinally(s -> { - logger.debug("SendMessage finally: {}", s); - Disposable ref = disposableRef.getAndSet(null); - if (ref != null) { - transportSession.removeConnection(ref); - } - }).contextWrite(messageSink.contextView()).subscribe(); + }) + .flatMap(jsonRpcMessage -> this.handler.get().apply(Mono.just(jsonRpcMessage))) + .onErrorMap(CompletionException.class, t -> t.getCause()) + .onErrorComplete(t -> { + // handle the error first + this.handleException(t); + // inform the caller of sendMessage + messageSink.error(t); + return true; + }) + .doFinally(s -> { + logger.debug("SendMessage finally: {}", s); + Disposable ref = disposableRef.getAndSet(null); + if (ref != null) { + transportSession.removeConnection(ref); + } + }) + .contextWrite(messageSink.contextView()) + .subscribe(); disposableRef.set(connection); transportSession.addConnection(connection); diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java index f4f7c4f88..26b0d13bd 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java @@ -15,15 +15,20 @@ import reactor.core.publisher.BaseSubscriber; import reactor.core.publisher.FluxSink; +/** + * Utility class providing various {@link BodySubscriber} implementations for handling + * different types of HTTP response bodies in the context of Model Context Protocol (MCP) + * clients. + * + *

+ * Defines subscribers for processing Server-Sent Events (SSE), aggregate responses, and + * bodiless responses. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + */ class ResponseSubscribers { - /** - * Represents a Server-Sent Event with its standard fields. - * - * @param id the event ID, may be {@code null} - * @param event the event type, may be {@code null} (defaults to "message") - * @param data the event payload data, never {@code null} - */ record SseEvent(String id, String event, String data) { } @@ -115,10 +120,6 @@ public SseLineSubscriber(ResponseInfo responseInfo, FluxSink sink this.responseInfo = responseInfo; } - /** - * Initializes the subscription and sets up disposal callback. - * @param subscription the {@link Subscription} to the upstream line source - */ @Override protected void hookOnSubscribe(Subscription subscription) { @@ -132,12 +133,6 @@ protected void hookOnSubscribe(Subscription subscription) { }); } - /** - * Processes each line from the SSE stream according to the SSE protocol. Empty - * lines trigger event emission, other lines are parsed for data, id, or event - * type. - * @param line the line to process from the SSE stream - */ @Override protected void hookOnNext(String line) { if (line.isEmpty()) { @@ -172,9 +167,6 @@ else if (line.startsWith("event:")) { } } - /** - * Called when the upstream line source completes normally. - */ @Override protected void hookOnComplete() { if (this.eventBuilder.length() > 0) { @@ -185,10 +177,6 @@ protected void hookOnComplete() { this.sink.complete(); } - /** - * Called when an error occurs in the upstream line source. - * @param throwable the error that occurred - */ @Override protected void hookOnError(Throwable throwable) { this.sink.error(throwable); @@ -225,10 +213,6 @@ public AggregateSubscriber(ResponseInfo responseInfo, FluxSink si this.responseInfo = responseInfo; } - /** - * Initializes the subscription and sets up disposal callback. - * @param subscription the {@link Subscription} to the upstream line source - */ @Override protected void hookOnSubscribe(Subscription subscription) { sink.onRequest(subscription::request); @@ -237,18 +221,11 @@ protected void hookOnSubscribe(Subscription subscription) { sink.onDispose(subscription::cancel); } - /** - * Aggregate each line from the Http response. - * @param line next line to process from the Http response - */ @Override protected void hookOnNext(String line) { this.eventBuilder.append(line).append("\n"); } - /** - * Called when the upstream line source completes normally. - */ @Override protected void hookOnComplete() { if (this.eventBuilder.length() > 0) { @@ -258,10 +235,6 @@ protected void hookOnComplete() { this.sink.complete(); } - /** - * Called when an error occurs in the upstream line source. - * @param throwable the error that occurred - */ @Override protected void hookOnError(Throwable throwable) { this.sink.error(throwable); @@ -283,10 +256,6 @@ public BodilessResponseLineSubscriber(ResponseInfo responseInfo, FluxSink