diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java index dd7c65396..e60451706 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java @@ -129,11 +129,10 @@ private DefaultMcpTransportSession createTransportSession() { Function> onClose = sessionId -> sessionId == null ? Mono.empty() : webClient.delete().uri(this.endpoint).headers(httpHeaders -> { httpHeaders.add("mcp-session-id", sessionId); - }) - .retrieve() - .toBodilessEntity() - .doOnError(e -> logger.warn("Got error when closing transport", e)) - .then(); + }).retrieve().toBodilessEntity().onErrorComplete(e -> { + logger.warn("Got error when closing transport", e); + return true; + }).then(); return new DefaultMcpTransportSession(onClose); } @@ -305,12 +304,12 @@ else if (mediaType.isCompatibleWith(MediaType.APPLICATION_JSON)) { } }) .flatMap(jsonRpcMessage -> this.handler.get().apply(Mono.just(jsonRpcMessage))) - .onErrorResume(t -> { + .onErrorComplete(t -> { // handle the error first this.handleException(t); // inform the caller of sendMessage sink.error(t); - return Flux.empty(); + return true; }) .doFinally(s -> { Disposable ref = disposableRef.getAndSet(null); diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java index f824193fd..5ff707b3c 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java @@ -1,12 +1,12 @@ package io.modelcontextprotocol.client; -import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; -import io.modelcontextprotocol.spec.McpClientTransport; import org.junit.jupiter.api.Timeout; +import org.springframework.web.reactive.function.client.WebClient; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; -import org.springframework.web.reactive.function.client.WebClient; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; @Timeout(15) public class WebClientStreamableHttpAsyncClientTests extends AbstractMcpAsyncClientTests { diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java index 9ecd8a7d1..70260c8bf 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java @@ -1,12 +1,12 @@ package io.modelcontextprotocol.client; -import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; -import io.modelcontextprotocol.spec.McpClientTransport; import org.junit.jupiter.api.Timeout; +import org.springframework.web.reactive.function.client.WebClient; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; -import org.springframework.web.reactive.function.client.WebClient; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; @Timeout(15) public class WebClientStreamableHttpSyncClientTests extends AbstractMcpSyncClientTests { diff --git a/mcp-spring/mcp-spring-webmvc/src/test/resources/logback.xml b/mcp-spring/mcp-spring-webmvc/src/test/resources/logback.xml index bc1140bb5..d4ccbc173 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/resources/logback.xml +++ b/mcp-spring/mcp-spring-webmvc/src/test/resources/logback.xml @@ -9,16 +9,16 @@ - + - + - + - + diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java index 6748eb75c..22e8f195b 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java @@ -79,7 +79,7 @@ public abstract class AbstractMcpAsyncClientResiliencyTests { host = "http://" + ipAddressViaToxiproxy + ":" + portViaToxiproxy; } - private static void disconnect() { + static void disconnect() { long start = System.nanoTime(); try { // disconnect @@ -96,7 +96,7 @@ private static void disconnect() { } } - private static void reconnect() { + static void reconnect() { long start = System.nanoTime(); try { proxy.toxics().get("RESET_UPSTREAM").remove(); @@ -110,7 +110,7 @@ private static void reconnect() { } } - private static void restartMcpServer() { + static void restartMcpServer() { container.stop(); container.start(); } diff --git a/mcp/pom.xml b/mcp/pom.xml index 773432827..829b99bc1 100644 --- a/mcp/pom.xml +++ b/mcp/pom.xml @@ -141,11 +141,11 @@ - net.bytebuddy - byte-buddy - ${byte-buddy.version} - test - + net.bytebuddy + byte-buddy + ${byte-buddy.version} + test + io.projectreactor reactor-test @@ -202,6 +202,14 @@ test + + org.testcontainers + toxiproxy + ${toxiproxy.version} + test + + + diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java deleted file mode 100644 index abfafa551..000000000 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/FlowSseClient.java +++ /dev/null @@ -1,211 +0,0 @@ -/* -* Copyright 2024 - 2024 the original author or authors. -*/ -package io.modelcontextprotocol.client.transport; - -import java.net.URI; -import java.net.http.HttpClient; -import java.net.http.HttpRequest; -import java.net.http.HttpResponse; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.Flow; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Function; -import java.util.regex.Pattern; - -/** - * A Server-Sent Events (SSE) client implementation using Java's Flow API for reactive - * stream processing. This client establishes a connection to an SSE endpoint and - * processes the incoming event stream, parsing SSE-formatted messages into structured - * events. - * - *

- * The client supports standard SSE event fields including: - *

    - *
  • event - The event type (defaults to "message" if not specified)
  • - *
  • id - The event ID
  • - *
  • data - The event payload data
  • - *
- * - *

- * Events are delivered to a provided {@link SseEventHandler} which can process events and - * handle any errors that occur during the connection. - * - * @author Christian Tzolov - * @see SseEventHandler - * @see SseEvent - */ -public class FlowSseClient { - - private final HttpClient httpClient; - - private final HttpRequest.Builder requestBuilder; - - /** - * Pattern to extract the data content from SSE data field lines. Matches lines - * starting with "data:" and captures the remaining content. - */ - private static final Pattern EVENT_DATA_PATTERN = Pattern.compile("^data:(.+)$", Pattern.MULTILINE); - - /** - * Pattern to extract the event ID from SSE id field lines. Matches lines starting - * with "id:" and captures the ID value. - */ - private static final Pattern EVENT_ID_PATTERN = Pattern.compile("^id:(.+)$", Pattern.MULTILINE); - - /** - * Pattern to extract the event type from SSE event field lines. Matches lines - * starting with "event:" and captures the event type. - */ - private static final Pattern EVENT_TYPE_PATTERN = Pattern.compile("^event:(.+)$", Pattern.MULTILINE); - - /** - * Record class representing a Server-Sent Event with its standard fields. - * - * @param id the event ID (may be null) - * @param type the event type (defaults to "message" if not specified in the stream) - * @param data the event payload data - */ - public static record SseEvent(String id, String type, String data) { - } - - /** - * Interface for handling SSE events and errors. Implementations can process received - * events and handle any errors that occur during the SSE connection. - */ - public interface SseEventHandler { - - /** - * Called when an SSE event is received. - * @param event the received SSE event containing id, type, and data - */ - void onEvent(SseEvent event); - - /** - * Called when an error occurs during the SSE connection. - * @param error the error that occurred - */ - void onError(Throwable error); - - } - - /** - * Creates a new FlowSseClient with the specified HTTP client. - * @param httpClient the {@link HttpClient} instance to use for SSE connections - */ - public FlowSseClient(HttpClient httpClient) { - this(httpClient, HttpRequest.newBuilder()); - } - - /** - * Creates a new FlowSseClient with the specified HTTP client and request builder. - * @param httpClient the {@link HttpClient} instance to use for SSE connections - * @param requestBuilder the {@link HttpRequest.Builder} to use for SSE requests - */ - public FlowSseClient(HttpClient httpClient, HttpRequest.Builder requestBuilder) { - this.httpClient = httpClient; - this.requestBuilder = requestBuilder; - } - - /** - * Subscribes to an SSE endpoint and processes the event stream. - * - *

- * This method establishes a connection to the specified URL and begins processing the - * SSE stream. Events are parsed and delivered to the provided event handler. The - * connection remains active until either an error occurs or the server closes the - * connection. - * @param url the SSE endpoint URL to connect to - * @param eventHandler the handler that will receive SSE events and error - * notifications - * @throws RuntimeException if the connection fails with a non-200 status code - */ - public void subscribe(String url, SseEventHandler eventHandler) { - HttpRequest request = this.requestBuilder.copy() - .uri(URI.create(url)) - .header("Accept", "text/event-stream") - .header("Cache-Control", "no-cache") - .GET() - .build(); - - StringBuilder eventBuilder = new StringBuilder(); - AtomicReference currentEventId = new AtomicReference<>(); - AtomicReference currentEventType = new AtomicReference<>("message"); - - Flow.Subscriber lineSubscriber = new Flow.Subscriber<>() { - private Flow.Subscription subscription; - - @Override - public void onSubscribe(Flow.Subscription subscription) { - this.subscription = subscription; - subscription.request(Long.MAX_VALUE); - } - - @Override - public void onNext(String line) { - if (line.isEmpty()) { - // Empty line means end of event - if (eventBuilder.length() > 0) { - String eventData = eventBuilder.toString(); - SseEvent event = new SseEvent(currentEventId.get(), currentEventType.get(), eventData.trim()); - eventHandler.onEvent(event); - eventBuilder.setLength(0); - } - } - else { - if (line.startsWith("data:")) { - var matcher = EVENT_DATA_PATTERN.matcher(line); - if (matcher.find()) { - eventBuilder.append(matcher.group(1).trim()).append("\n"); - } - } - else if (line.startsWith("id:")) { - var matcher = EVENT_ID_PATTERN.matcher(line); - if (matcher.find()) { - currentEventId.set(matcher.group(1).trim()); - } - } - else if (line.startsWith("event:")) { - var matcher = EVENT_TYPE_PATTERN.matcher(line); - if (matcher.find()) { - currentEventType.set(matcher.group(1).trim()); - } - } - } - subscription.request(1); - } - - @Override - public void onError(Throwable throwable) { - eventHandler.onError(throwable); - } - - @Override - public void onComplete() { - // Handle any remaining event data - if (eventBuilder.length() > 0) { - String eventData = eventBuilder.toString(); - SseEvent event = new SseEvent(currentEventId.get(), currentEventType.get(), eventData.trim()); - eventHandler.onEvent(event); - } - } - }; - - Function, HttpResponse.BodySubscriber> subscriberFactory = subscriber -> HttpResponse.BodySubscribers - .fromLineSubscriber(subscriber); - - CompletableFuture> future = this.httpClient.sendAsync(request, - info -> subscriberFactory.apply(lineSubscriber)); - - future.thenAccept(response -> { - int status = response.statusCode(); - if (status != 200 && status != 201 && status != 202 && status != 206) { - throw new RuntimeException("Failed to connect to SSE stream. Unexpected status code: " + status); - } - }).exceptionally(throwable -> { - eventHandler.onError(throwable); - return null; - }); - } - -} diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index d951349d1..ab48fc0f7 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -10,24 +10,27 @@ import java.net.http.HttpResponse; import java.time.Duration; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Function; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.client.transport.FlowSseClient.SseEvent; + +import io.modelcontextprotocol.client.transport.ResponseSubscribers.ResponseEvent; import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.Utils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; /** * Server-Sent Events (SSE) implementation of the @@ -75,9 +78,6 @@ public class HttpClientSseClientTransport implements McpClientTransport { /** SSE endpoint path */ private final String sseEndpoint; - /** SSE client for handling server-sent events. Uses the /sse endpoint */ - private final FlowSseClient sseClient; - /** * HTTP client for sending messages to the server. Uses HTTP POST over the message * endpoint @@ -93,14 +93,14 @@ public class HttpClientSseClientTransport implements McpClientTransport { /** Flag indicating if the transport is in closing state */ private volatile boolean isClosing = false; - /** Latch for coordinating endpoint discovery */ - private final CountDownLatch closeLatch = new CountDownLatch(1); - - /** Holds the discovered message endpoint URL */ - private final AtomicReference messageEndpoint = new AtomicReference<>(); + /** Holds the SSE subscription disposable */ + private final AtomicReference sseSubscription = new AtomicReference<>(); - /** Holds the SSE connection future */ - private final AtomicReference> connectionFuture = new AtomicReference<>(); + /** + * Sink for managing the message endpoint URI provided by the server. Stores the most + * recent endpoint URI and makes it available for outbound message processing. + */ + protected final Sinks.One messageEndpointSink = Sinks.one(); /** * Creates a new transport instance with default HTTP client and object mapper. @@ -184,8 +184,6 @@ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, HttpReques this.objectMapper = objectMapper; this.httpClient = httpClient; this.requestBuilder = requestBuilder; - - this.sseClient = new FlowSseClient(this.httpClient, requestBuilder); } /** @@ -323,63 +321,85 @@ public HttpClientSseClientTransport build() { } - /** - * Establishes the SSE connection with the server and sets up message handling. - * - *

- * This method: - *

    - *
  • Initiates the SSE connection
  • - *
  • Handles endpoint discovery events
  • - *
  • Processes incoming JSON-RPC messages
  • - *
- * @param handler the function to process received JSON-RPC messages - * @return a Mono that completes when the connection is established - */ @Override public Mono connect(Function, Mono> handler) { - CompletableFuture future = new CompletableFuture<>(); - connectionFuture.set(future); - - URI clientUri = Utils.resolveUri(this.baseUri, this.sseEndpoint); - sseClient.subscribe(clientUri.toString(), new FlowSseClient.SseEventHandler() { - @Override - public void onEvent(SseEvent event) { - if (isClosing) { - return; - } - - try { - if (ENDPOINT_EVENT_TYPE.equals(event.type())) { - String endpoint = event.data(); - messageEndpoint.set(endpoint); - closeLatch.countDown(); - future.complete(null); + + return Mono.create(sink -> { + + HttpRequest request = requestBuilder.copy() + .uri(Utils.resolveUri(this.baseUri, this.sseEndpoint)) + .header("Accept", "text/event-stream") + .header("Cache-Control", "no-cache") + .GET() + .build(); + + Disposable connection = Flux.create(sseSink -> this.httpClient + .sendAsync(request, responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, sseSink)) + .exceptionallyCompose(e -> { + sseSink.error(e); + return CompletableFuture.failedFuture(e); + })) + .map(responseEvent -> (ResponseSubscribers.SseResponseEvent) responseEvent) + .flatMap(responseEvent -> { + if (isClosing) { + return Mono.empty(); + } + + int statusCode = responseEvent.responseInfo().statusCode(); + + if (statusCode >= 200 && statusCode < 300) { + try { + if (ENDPOINT_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { + String messageEndpointUri = responseEvent.sseEvent().data(); + if (this.messageEndpointSink.tryEmitValue(messageEndpointUri).isSuccess()) { + sink.success(); + return Flux.empty(); // No further processing needed + } + else { + sink.error(new McpError("Failed to handle SSE endpoint event")); + } + } + else if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { + JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, + responseEvent.sseEvent().data()); + sink.success(); + return Flux.just(message); + } + else { + logger.error("Received unrecognized SSE event type: {}", + responseEvent.sseEvent().event()); + sink.error(new McpError( + "Received unrecognized SSE event type: " + responseEvent.sseEvent().event())); + } + } + catch (IOException e) { + logger.error("Error processing SSE event", e); + sink.error(new McpError("Error processing SSE event")); + } } - else if (MESSAGE_EVENT_TYPE.equals(event.type())) { - JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, event.data()); - handler.apply(Mono.just(message)).subscribe(); + return Flux.error( + new RuntimeException("Failed to send message: " + responseEvent)); + + }) + .flatMap(jsonRpcMessage -> handler.apply(Mono.just(jsonRpcMessage))) + .onErrorComplete(t -> { + if (!isClosing) { + logger.warn("SSE stream observed an error", t); + sink.error(t); } - else { - logger.error("Received unrecognized SSE event type: {}", event.type()); + return true; + }) + .doFinally(s -> { + Disposable ref = this.sseSubscription.getAndSet(null); + if (ref != null && !ref.isDisposed()) { + ref.dispose(); } - } - catch (IOException e) { - logger.error("Error processing SSE event", e); - future.completeExceptionally(e); - } - } + }) + .contextWrite(sink.contextView()) + .subscribe(); - @Override - public void onError(Throwable error) { - if (!isClosing) { - logger.error("SSE connection error", error); - future.completeExceptionally(error); - } - } + this.sseSubscription.set(connection); }); - - return Mono.fromFuture(future); } /** @@ -394,53 +414,57 @@ public void onError(Throwable error) { */ @Override public Mono sendMessage(JSONRPCMessage message) { - if (isClosing) { - return Mono.empty(); - } - try { - if (!closeLatch.await(10, TimeUnit.SECONDS)) { - return Mono.error(new McpError("Failed to wait for the message endpoint")); + return this.messageEndpointSink.asMono().flatMap(messageEndpointUri -> { + if (isClosing) { + return Mono.empty(); } - } - catch (InterruptedException e) { - return Mono.error(new McpError("Failed to wait for the message endpoint")); - } - String endpoint = messageEndpoint.get(); - if (endpoint == null) { - return Mono.error(new McpError("No message endpoint available")); - } + return this.serializeMessage(message) + .flatMap(body -> sendHttpPost(messageEndpointUri, body)) + .doOnNext(response -> { + if (response.statusCode() != 200 && response.statusCode() != 201 && response.statusCode() != 202 + && response.statusCode() != 206) { + logger.error("Error sending message: {}", response.statusCode()); + } + }) + .doOnError(error -> { + if (!isClosing) { + logger.error("Error sending message: {}", error.getMessage()); + } + }); + }).then(); - try { - String jsonText = this.objectMapper.writeValueAsString(message); - URI requestUri = Utils.resolveUri(baseUri, endpoint); - HttpRequest request = this.requestBuilder.copy() - .uri(requestUri) - .POST(HttpRequest.BodyPublishers.ofString(jsonText)) - .build(); + } - return Mono.fromFuture( - httpClient.sendAsync(request, HttpResponse.BodyHandlers.discarding()).thenAccept(response -> { - if (response.statusCode() != 200 && response.statusCode() != 201 && response.statusCode() != 202 - && response.statusCode() != 206) { - logger.error("Error sending message: {}", response.statusCode()); - } - })); - } - catch (IOException e) { - if (!isClosing) { - return Mono.error(new RuntimeException("Failed to serialize message", e)); + private Mono serializeMessage(final JSONRPCMessage message) { + return Mono.defer(() -> { + try { + return Mono.just(objectMapper.writeValueAsString(message)); } - return Mono.empty(); - } + catch (IOException e) { + // TODO: why McpError and not RuntimeException? + return Mono.error(new McpError("Failed to serialize message")); + } + }); + } + + private Mono> sendHttpPost(final String endpoint, final String body) { + final URI requestUri = Utils.resolveUri(baseUri, endpoint); + final HttpRequest request = this.requestBuilder.copy() + .uri(requestUri) + .POST(HttpRequest.BodyPublishers.ofString(body)) + .build(); + + // TODO: why discard the body? + return Mono.fromFuture(httpClient.sendAsync(request, HttpResponse.BodyHandlers.discarding())); } /** * Gracefully closes the transport connection. * *

- * Sets the closing flag and cancels any pending connection future. This prevents new + * Sets the closing flag and disposes of the SSE subscription. This prevents new * messages from being sent and allows ongoing operations to complete. * @return a Mono that completes when the closing process is initiated */ @@ -448,9 +472,9 @@ public Mono sendMessage(JSONRPCMessage message) { public Mono closeGracefully() { return Mono.fromRunnable(() -> { isClosing = true; - CompletableFuture future = connectionFuture.get(); - if (future != null && !future.isDone()) { - future.cancel(true); + Disposable subscription = sseSubscription.get(); + if (subscription != null && !subscription.isDisposed()) { + subscription.dispose(); } }); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java new file mode 100644 index 000000000..4cf1690ff --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java @@ -0,0 +1,640 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +import java.io.IOException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodyHandler; +import java.time.Duration; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.CompletionException; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; + +import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.client.transport.ResponseSubscribers.ResponseEvent; +import io.modelcontextprotocol.spec.DefaultMcpTransportSession; +import io.modelcontextprotocol.spec.DefaultMcpTransportStream; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpTransportSession; +import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; +import io.modelcontextprotocol.spec.McpTransportStream; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.Utils; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.FluxSink; +import reactor.core.publisher.Mono; +import reactor.util.function.Tuple2; +import reactor.util.function.Tuples; + +/** + * An implementation of the Streamable HTTP protocol as defined by the + * 2025-03-26 version of the MCP specification. + * + *

+ * The transport is capable of resumability and reconnects. It reacts to transport-level + * session invalidation and will propagate {@link McpTransportSessionNotFoundException + * appropriate exceptions} to the higher level abstraction layer when needed in order to + * allow proper state management. The implementation handles servers that are stateful and + * provide session meta information, but can also communicate with stateless servers that + * do not provide a session identifier and do not support SSE streams. + *

+ *

+ * This implementation does not handle backwards compatibility with the "HTTP + * with SSE" transport. In order to communicate over the phased-out + * 2024-11-05 protocol, use {@link HttpClientSseClientTransport} or + * {@code WebFluxSseClientTransport}. + *

+ * + * @author Christian Tzolov + * @see Streamable + * HTTP transport specification + */ +public class HttpClientStreamableHttpTransport implements McpClientTransport { + + private static final Logger logger = LoggerFactory.getLogger(HttpClientStreamableHttpTransport.class); + + private static final String DEFAULT_ENDPOINT = "/mcp"; + + /** + * HTTP client for sending messages to the server. Uses HTTP POST over the message + * endpoint + */ + private final HttpClient httpClient; + + /** HTTP request builder for building requests to send messages to the server */ + private final HttpRequest.Builder requestBuilder; + + /** + * Event type for JSON-RPC messages received through the SSE connection. The server + * sends messages with this event type to transmit JSON-RPC protocol data. + */ + private static final String MESSAGE_EVENT_TYPE = "message"; + + private static final String APPLICATION_JSON = "application/json"; + + private static final String TEXT_EVENT_STREAM = "text/event-stream"; + + public static int NOT_FOUND = 404; + + public static int METHOD_NOT_ALLOWED = 405; + + public static int BAD_REQUEST = 400; + + private final ObjectMapper objectMapper; + + private final URI baseUri; + + private final String endpoint; + + private final boolean openConnectionOnStartup; + + private final boolean resumableStreams; + + private final AtomicReference activeSession = new AtomicReference<>(); + + private final AtomicReference, Mono>> handler = new AtomicReference<>(); + + private final AtomicReference> exceptionHandler = new AtomicReference<>(); + + private HttpClientStreamableHttpTransport(ObjectMapper objectMapper, HttpClient httpClient, + HttpRequest.Builder requestBuilder, String baseUri, String endpoint, boolean resumableStreams, + boolean openConnectionOnStartup) { + this.objectMapper = objectMapper; + this.httpClient = httpClient; + this.requestBuilder = requestBuilder; + this.baseUri = URI.create(baseUri); + this.endpoint = endpoint; + this.resumableStreams = resumableStreams; + this.openConnectionOnStartup = openConnectionOnStartup; + this.activeSession.set(createTransportSession()); + } + + public static Builder builder(String baseUri) { + return new Builder(baseUri); + } + + @Override + public Mono connect(Function, Mono> handler) { + return Mono.deferContextual(ctx -> { + this.handler.set(handler); + if (this.openConnectionOnStartup) { + logger.debug("Eagerly opening connection on startup"); + return this.reconnect(null).onErrorComplete(t -> { + logger.warn("Eager connect failed ", t); + return true; + }).then(); + } + return Mono.empty(); + }); + } + + private DefaultMcpTransportSession createTransportSession() { + Function> onClose = sessionId -> sessionId == null ? Mono.empty() + : createDelete(sessionId); + return new DefaultMcpTransportSession(onClose); + } + + private Publisher createDelete(String sessionId) { + HttpRequest request = this.requestBuilder.copy() + .uri(Utils.resolveUri(this.baseUri, this.endpoint)) + .header("Cache-Control", "no-cache") + .header("mcp-session-id", sessionId) + .DELETE() + .build(); + + return Mono.fromFuture(() -> this.httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())).then(); + } + + @Override + public void setExceptionHandler(Consumer handler) { + logger.debug("Exception handler registered"); + this.exceptionHandler.set(handler); + } + + private void handleException(Throwable t) { + logger.debug("Handling exception for session {}", sessionIdOrPlaceholder(this.activeSession.get()), t); + if (t instanceof McpTransportSessionNotFoundException) { + McpTransportSession invalidSession = this.activeSession.getAndSet(createTransportSession()); + logger.warn("Server does not recognize session {}. Invalidating.", invalidSession.sessionId()); + invalidSession.close(); + } + Consumer handler = this.exceptionHandler.get(); + if (handler != null) { + handler.accept(t); + } + } + + @Override + public Mono closeGracefully() { + return Mono.defer(() -> { + logger.debug("Graceful close triggered"); + DefaultMcpTransportSession currentSession = this.activeSession.getAndSet(createTransportSession()); + if (currentSession != null) { + return currentSession.closeGracefully(); + } + return Mono.empty(); + }); + } + + private Mono reconnect(McpTransportStream stream) { + + return Mono.deferContextual(ctx -> { + + if (stream != null) { + logger.debug("Reconnecting stream {} with lastId {}", stream.streamId(), stream.lastId()); + } + else { + logger.debug("Reconnecting with no prior stream"); + } + + final AtomicReference disposableRef = new AtomicReference<>(); + final McpTransportSession transportSession = this.activeSession.get(); + + HttpRequest.Builder requestBuilder = this.requestBuilder.copy(); + + if (transportSession != null && transportSession.sessionId().isPresent()) { + requestBuilder = requestBuilder.header("mcp-session-id", transportSession.sessionId().get()); + } + + if (stream != null && stream.lastId().isPresent()) { + requestBuilder = requestBuilder.header("last-event-id", stream.lastId().get()); + } + + HttpRequest request = requestBuilder.uri(Utils.resolveUri(this.baseUri, this.endpoint)) + .header("Accept", TEXT_EVENT_STREAM) + .header("Cache-Control", "no-cache") + .GET() + .build(); + + Disposable connection = Flux.create(sseSink -> this.httpClient + .sendAsync(request, responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, sseSink)) + .whenComplete((response, throwable) -> { + if (throwable != null) { + sseSink.error(throwable); + } + else { + logger.debug("SSE connection established successfully"); + } + })) + .map(responseEvent -> (ResponseSubscribers.SseResponseEvent) responseEvent) + .flatMap(responseEvent -> { + int statusCode = responseEvent.responseInfo().statusCode(); + + if (statusCode >= 200 && statusCode < 300) { + + if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { + try { + // We don't support batching ATM and probably won't since + // the + // next version considers removing it. + McpSchema.JSONRPCMessage message = McpSchema + .deserializeJsonRpcMessage(this.objectMapper, responseEvent.sseEvent().data()); + + Tuple2, Iterable> idWithMessages = Tuples + .of(Optional.ofNullable(responseEvent.sseEvent().id()), List.of(message)); + + McpTransportStream sessionStream = stream != null ? stream + : new DefaultMcpTransportStream<>(this.resumableStreams, this::reconnect); + logger.debug("Connected stream {}", sessionStream.streamId()); + + return Flux.from(sessionStream.consumeSseStream(Flux.just(idWithMessages))); + + } + catch (IOException ioException) { + return Flux.error(new McpError( + "Error parsing JSON-RPC message: " + responseEvent.sseEvent().data())); + } + } + } + else if (statusCode == METHOD_NOT_ALLOWED) { // NotAllowed + logger.debug("The server does not support SSE streams, using request-response mode."); + return Flux.empty(); + } + else if (statusCode == NOT_FOUND) { + String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); + McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( + "Session not found for session ID: " + sessionIdRepresentation); + return Flux.error(exception); + } + else if (statusCode == BAD_REQUEST) { + String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); + McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( + "Session not found for session ID: " + sessionIdRepresentation); + return Flux.error(exception); + } + + return Flux.error( + new McpError("Received unrecognized SSE event type: " + responseEvent.sseEvent().event())); + + }).flatMap(jsonrpcMessage -> this.handler.get().apply(Mono.just(jsonrpcMessage))) + .onErrorMap(CompletionException.class, t -> t.getCause()) + .onErrorComplete(t -> { + this.handleException(t); + return true; + }) + .doFinally(s -> { + Disposable ref = disposableRef.getAndSet(null); + if (ref != null) { + transportSession.removeConnection(ref); + } + }) + .contextWrite(ctx) + .subscribe(); + + disposableRef.set(connection); + transportSession.addConnection(connection); + return Mono.just(connection); + }); + + } + + private BodyHandler toSendMessageBodySubscriber(FluxSink sink) { + + BodyHandler responseBodyHandler = responseInfo -> { + + String contentType = responseInfo.headers().firstValue("Content-Type").orElse("").toLowerCase(); + + if (contentType.contains(TEXT_EVENT_STREAM)) { + // For SSE streams, use line subscriber that returns Void + logger.debug("Received SSE stream response, using line subscriber"); + return ResponseSubscribers.sseToBodySubscriber(responseInfo, sink); + } + else if (contentType.contains(APPLICATION_JSON)) { + // For JSON responses and others, use string subscriber + logger.debug("Received response, using string subscriber"); + return ResponseSubscribers.aggregateBodySubscriber(responseInfo, sink); + } + + logger.debug("Received Bodyless response, using discarding subscriber"); + return ResponseSubscribers.bodilessBodySubscriber(responseInfo, sink); + }; + + return responseBodyHandler; + + } + + public String toString(McpSchema.JSONRPCMessage message) { + try { + return this.objectMapper.writeValueAsString(message); + } + catch (IOException e) { + throw new RuntimeException("Failed to serialize JSON-RPC message", e); + } + } + + public Mono sendMessage(McpSchema.JSONRPCMessage sendMessage) { + return Mono.create(messageSink -> { + logger.debug("Sending message {}", sendMessage); + + final AtomicReference disposableRef = new AtomicReference<>(); + final McpTransportSession transportSession = this.activeSession.get(); + + HttpRequest.Builder requestBuilder = this.requestBuilder.copy(); + + if (transportSession != null && transportSession.sessionId().isPresent()) { + requestBuilder = requestBuilder.header("mcp-session-id", transportSession.sessionId().get()); + } + + String jsonBody = this.toString(sendMessage); + + HttpRequest request = requestBuilder.uri(Utils.resolveUri(this.baseUri, this.endpoint)) + .header("Accept", TEXT_EVENT_STREAM + ", " + APPLICATION_JSON) + .header("Content-Type", APPLICATION_JSON) + .header("Cache-Control", "no-cache") + .POST(HttpRequest.BodyPublishers.ofString(jsonBody)) + .build(); + + Disposable connection = Flux.create(responseEventSink -> { + + // Create the async request with proper body subscriber selection + Mono.fromFuture(this.httpClient.sendAsync(request, this.toSendMessageBodySubscriber(responseEventSink)) + .whenComplete((response, throwable) -> { + if (throwable != null) { + responseEventSink.error(throwable); + } + else { + logger.debug("SSE connection established successfully"); + } + })).onErrorMap(CompletionException.class, t -> t.getCause()).onErrorComplete().subscribe(); + + }).flatMap(responseEvent -> { + if (transportSession.markInitialized( + responseEvent.responseInfo().headers().firstValue("mcp-session-id").orElseGet(() -> null))) { + // Once we have a session, we try to open an async stream for + // the server to send notifications and requests out-of-band. + + reconnect(null).contextWrite(messageSink.contextView()).subscribe(); + } + + String sessionRepresentation = sessionIdOrPlaceholder(transportSession); + + int statusCode = responseEvent.responseInfo().statusCode(); + + if (statusCode >= 200 && statusCode < 300) { + + String contentType = responseEvent.responseInfo() + .headers() + .firstValue("Content-Type") + .orElse("") + .toLowerCase(); + + if (contentType.isBlank()) { + logger.debug("No content type returned for POST in session {}", sessionRepresentation); + // No content type means no response body, so we can just return + // an empty stream + messageSink.success(); + return Flux.empty(); + } + else if (contentType.contains(TEXT_EVENT_STREAM)) { + return Flux.just(((ResponseSubscribers.SseResponseEvent) responseEvent).sseEvent()) + .flatMap(sseEvent -> { + try { + // We don't support batching ATM and probably won't + // since the + // next version considers removing it. + McpSchema.JSONRPCMessage message = McpSchema + .deserializeJsonRpcMessage(this.objectMapper, sseEvent.data()); + + Tuple2, Iterable> idWithMessages = Tuples + .of(Optional.ofNullable(sseEvent.id()), List.of(message)); + + McpTransportStream sessionStream = new DefaultMcpTransportStream<>( + this.resumableStreams, this::reconnect); + + logger.debug("Connected stream {}", sessionStream.streamId()); + + messageSink.success(); + + return Flux.from(sessionStream.consumeSseStream(Flux.just(idWithMessages))); + } + catch (IOException ioException) { + return Flux.error( + new McpError("Error parsing JSON-RPC message: " + sseEvent.data())); + } + }); + } + else if (contentType.contains(APPLICATION_JSON)) { + messageSink.success(); + String data = ((ResponseSubscribers.AggregateResponseEvent) responseEvent).data(); + try { + return Mono.just(McpSchema.deserializeJsonRpcMessage(objectMapper, data)); + } + catch (IOException e) { + return Mono.error(e); + } + } + logger.warn("Unknown media type {} returned for POST in session {}", contentType, + sessionRepresentation); + + return Flux.error( + new RuntimeException("Unknown media type returned: " + contentType)); + } + else if (statusCode == NOT_FOUND) { + McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( + "Session not found for session ID: " + sessionRepresentation); + return Flux.error(exception); + } + // Some implementations can return 400 when presented with a + // session id that it doesn't know about, so we will + // invalidate the session + // https://github.com/modelcontextprotocol/typescript-sdk/issues/389 + else if (statusCode == BAD_REQUEST) { + McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( + "Session not found for session ID: " + sessionRepresentation); + return Flux.error(exception); + } + + return Flux.error( + new RuntimeException("Failed to send message: " + responseEvent)); + }) + .flatMap(jsonRpcMessage -> this.handler.get().apply(Mono.just(jsonRpcMessage))) + .onErrorMap(CompletionException.class, t -> t.getCause()) + .onErrorComplete(t -> { + // handle the error first + this.handleException(t); + // inform the caller of sendMessage + messageSink.error(t); + return true; + }) + .doFinally(s -> { + logger.debug("SendMessage finally: {}", s); + Disposable ref = disposableRef.getAndSet(null); + if (ref != null) { + transportSession.removeConnection(ref); + } + }) + .contextWrite(messageSink.contextView()) + .subscribe(); + + disposableRef.set(connection); + transportSession.addConnection(connection); + }); + } + + private static String sessionIdOrPlaceholder(McpTransportSession transportSession) { + return transportSession.sessionId().orElse("[missing_session_id]"); + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return this.objectMapper.convertValue(data, typeRef); + } + + /** + * Builder for {@link HttpClientStreamableHttpTransport}. + */ + public static class Builder { + + private final String baseUri; + + private ObjectMapper objectMapper; + + private HttpClient.Builder clientBuilder = HttpClient.newBuilder() + .version(HttpClient.Version.HTTP_1_1) + .connectTimeout(Duration.ofSeconds(10)); + + private String endpoint = DEFAULT_ENDPOINT; + + private boolean resumableStreams = true; + + private boolean openConnectionOnStartup = false; + + private HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(); + + /** + * Creates a new builder with the specified base URI. + * @param baseUri the base URI of the MCP server + */ + private Builder(String baseUri) { + Assert.hasText(baseUri, "baseUri must not be empty"); + this.baseUri = baseUri; + } + + /** + * Sets the HTTP client builder. + * @param clientBuilder the HTTP client builder + * @return this builder + */ + public Builder clientBuilder(HttpClient.Builder clientBuilder) { + Assert.notNull(clientBuilder, "clientBuilder must not be null"); + this.clientBuilder = clientBuilder; + return this; + } + + /** + * Customizes the HTTP client builder. + * @param clientCustomizer the consumer to customize the HTTP client builder + * @return this builder + */ + public Builder customizeClient(final Consumer clientCustomizer) { + Assert.notNull(clientCustomizer, "clientCustomizer must not be null"); + clientCustomizer.accept(clientBuilder); + return this; + } + + /** + * Sets the HTTP request builder. + * @param requestBuilder the HTTP request builder + * @return this builder + */ + public Builder requestBuilder(HttpRequest.Builder requestBuilder) { + Assert.notNull(requestBuilder, "requestBuilder must not be null"); + this.requestBuilder = requestBuilder; + return this; + } + + /** + * Customizes the HTTP client builder. + * @param requestCustomizer the consumer to customize the HTTP request builder + * @return this builder + */ + public Builder customizeRequest(final Consumer requestCustomizer) { + Assert.notNull(requestCustomizer, "requestCustomizer must not be null"); + requestCustomizer.accept(requestBuilder); + return this; + } + + /** + * Configure the {@link ObjectMapper} to use. + * @param objectMapper instance to use + * @return the builder instance + */ + public Builder objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + /** + * Configure the endpoint to make HTTP requests against. + * @param endpoint endpoint to use + * @return the builder instance + */ + public Builder endpoint(String endpoint) { + Assert.hasText(endpoint, "endpoint must be a non-empty String"); + this.endpoint = endpoint; + return this; + } + + /** + * Configure whether to use the stream resumability feature by keeping track of + * SSE event ids. + * @param resumableStreams if {@code true} event ids will be tracked and upon + * disconnection, the last seen id will be used upon reconnection as a header to + * resume consuming messages. + * @return the builder instance + */ + public Builder resumableStreams(boolean resumableStreams) { + this.resumableStreams = resumableStreams; + return this; + } + + /** + * Configure whether the client should open an SSE connection upon startup. Not + * all servers support this (although it is in theory possible with the current + * specification), so use with caution. By default, this value is {@code false}. + * @param openConnectionOnStartup if {@code true} the {@link #connect(Function)} + * method call will try to open an SSE connection before sending any JSON-RPC + * request + * @return the builder instance + */ + public Builder openConnectionOnStartup(boolean openConnectionOnStartup) { + this.openConnectionOnStartup = openConnectionOnStartup; + return this; + } + + /** + * Construct a fresh instance of {@link HttpClientStreamableHttpTransport} using + * the current builder configuration. + * @return a new instance of {@link HttpClientStreamableHttpTransport} + */ + public HttpClientStreamableHttpTransport build() { + ObjectMapper objectMapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); + + return new HttpClientStreamableHttpTransport(objectMapper, clientBuilder.build(), requestBuilder, baseUri, + endpoint, resumableStreams, openConnectionOnStartup); + } + + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java new file mode 100644 index 000000000..26b0d13bd --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java @@ -0,0 +1,289 @@ +/* +* Copyright 2024 - 2024 the original author or authors. +*/ +package io.modelcontextprotocol.client.transport; + +import java.net.http.HttpResponse; +import java.net.http.HttpResponse.BodySubscriber; +import java.net.http.HttpResponse.ResponseInfo; +import java.util.concurrent.atomic.AtomicReference; +import java.util.regex.Pattern; + +import org.reactivestreams.FlowAdapters; +import org.reactivestreams.Subscription; + +import reactor.core.publisher.BaseSubscriber; +import reactor.core.publisher.FluxSink; + +/** + * Utility class providing various {@link BodySubscriber} implementations for handling + * different types of HTTP response bodies in the context of Model Context Protocol (MCP) + * clients. + * + *

+ * Defines subscribers for processing Server-Sent Events (SSE), aggregate responses, and + * bodiless responses. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + */ +class ResponseSubscribers { + + record SseEvent(String id, String event, String data) { + } + + sealed interface ResponseEvent permits SseResponseEvent, AggregateResponseEvent, DummyEvent { + + ResponseInfo responseInfo(); + + } + + record DummyEvent(ResponseInfo responseInfo) implements ResponseEvent { + + } + + record SseResponseEvent(ResponseInfo responseInfo, SseEvent sseEvent) implements ResponseEvent { + } + + record AggregateResponseEvent(ResponseInfo responseInfo, String data) implements ResponseEvent { + } + + static BodySubscriber sseToBodySubscriber(ResponseInfo responseInfo, FluxSink sink) { + return HttpResponse.BodySubscribers + .fromLineSubscriber(FlowAdapters.toFlowSubscriber(new SseLineSubscriber(responseInfo, sink))); + } + + static BodySubscriber aggregateBodySubscriber(ResponseInfo responseInfo, FluxSink sink) { + return HttpResponse.BodySubscribers + .fromLineSubscriber(FlowAdapters.toFlowSubscriber(new AggregateSubscriber(responseInfo, sink))); + } + + static BodySubscriber bodilessBodySubscriber(ResponseInfo responseInfo, FluxSink sink) { + return HttpResponse.BodySubscribers + .fromLineSubscriber(FlowAdapters.toFlowSubscriber(new BodilessResponseLineSubscriber(responseInfo, sink))); + } + + static class SseLineSubscriber extends BaseSubscriber { + + /** + * Pattern to extract data content from SSE "data:" lines. + */ + private static final Pattern EVENT_DATA_PATTERN = Pattern.compile("^data:(.+)$", Pattern.MULTILINE); + + /** + * Pattern to extract event ID from SSE "id:" lines. + */ + private static final Pattern EVENT_ID_PATTERN = Pattern.compile("^id:(.+)$", Pattern.MULTILINE); + + /** + * Pattern to extract event type from SSE "event:" lines. + */ + private static final Pattern EVENT_TYPE_PATTERN = Pattern.compile("^event:(.+)$", Pattern.MULTILINE); + + /** + * The sink for emitting parsed response events. + */ + private final FluxSink sink; + + /** + * StringBuilder for accumulating multi-line event data. + */ + private final StringBuilder eventBuilder; + + /** + * Current event's ID, if specified. + */ + private final AtomicReference currentEventId; + + /** + * Current event's type, if specified. + */ + private final AtomicReference currentEventType; + + /** + * The response information from the HTTP response. Send with each event to + * provide context. + */ + private ResponseInfo responseInfo; + + /** + * Creates a new LineSubscriber that will emit parsed SSE events to the provided + * sink. + * @param sink the {@link FluxSink} to emit parsed {@link ResponseEvent} objects + * to + */ + public SseLineSubscriber(ResponseInfo responseInfo, FluxSink sink) { + this.sink = sink; + this.eventBuilder = new StringBuilder(); + this.currentEventId = new AtomicReference<>(); + this.currentEventType = new AtomicReference<>(); + this.responseInfo = responseInfo; + } + + @Override + protected void hookOnSubscribe(Subscription subscription) { + + sink.onRequest(n -> { + subscription.request(n); + }); + + // Register disposal callback to cancel subscription when Flux is disposed + sink.onDispose(() -> { + subscription.cancel(); + }); + } + + @Override + protected void hookOnNext(String line) { + if (line.isEmpty()) { + // Empty line means end of event + if (this.eventBuilder.length() > 0) { + String eventData = this.eventBuilder.toString(); + SseEvent sseEvent = new SseEvent(currentEventId.get(), currentEventType.get(), eventData.trim()); + + this.sink.next(new SseResponseEvent(responseInfo, sseEvent)); + this.eventBuilder.setLength(0); + } + } + else { + if (line.startsWith("data:")) { + var matcher = EVENT_DATA_PATTERN.matcher(line); + if (matcher.find()) { + this.eventBuilder.append(matcher.group(1).trim()).append("\n"); + } + } + else if (line.startsWith("id:")) { + var matcher = EVENT_ID_PATTERN.matcher(line); + if (matcher.find()) { + this.currentEventId.set(matcher.group(1).trim()); + } + } + else if (line.startsWith("event:")) { + var matcher = EVENT_TYPE_PATTERN.matcher(line); + if (matcher.find()) { + this.currentEventType.set(matcher.group(1).trim()); + } + } + } + } + + @Override + protected void hookOnComplete() { + if (this.eventBuilder.length() > 0) { + String eventData = this.eventBuilder.toString(); + SseEvent sseEvent = new SseEvent(currentEventId.get(), currentEventType.get(), eventData.trim()); + this.sink.next(new SseResponseEvent(responseInfo, sseEvent)); + } + this.sink.complete(); + } + + @Override + protected void hookOnError(Throwable throwable) { + this.sink.error(throwable); + } + + } + + static class AggregateSubscriber extends BaseSubscriber { + + /** + * The sink for emitting parsed response events. + */ + private final FluxSink sink; + + /** + * StringBuilder for accumulating multi-line event data. + */ + private final StringBuilder eventBuilder; + + /** + * The response information from the HTTP response. Send with each event to + * provide context. + */ + private ResponseInfo responseInfo; + + /** + * Creates a new JsonLineSubscriber that will emit parsed JSON-RPC messages. + * @param sink the {@link FluxSink} to emit parsed {@link ResponseEvent} objects + * to + */ + public AggregateSubscriber(ResponseInfo responseInfo, FluxSink sink) { + this.sink = sink; + this.eventBuilder = new StringBuilder(); + this.responseInfo = responseInfo; + } + + @Override + protected void hookOnSubscribe(Subscription subscription) { + sink.onRequest(subscription::request); + + // Register disposal callback to cancel subscription when Flux is disposed + sink.onDispose(subscription::cancel); + } + + @Override + protected void hookOnNext(String line) { + this.eventBuilder.append(line).append("\n"); + } + + @Override + protected void hookOnComplete() { + if (this.eventBuilder.length() > 0) { + String data = this.eventBuilder.toString(); + this.sink.next(new AggregateResponseEvent(responseInfo, data)); + } + this.sink.complete(); + } + + @Override + protected void hookOnError(Throwable throwable) { + this.sink.error(throwable); + } + + } + + static class BodilessResponseLineSubscriber extends BaseSubscriber { + + /** + * The sink for emitting parsed response events. + */ + private final FluxSink sink; + + private final ResponseInfo responseInfo; + + public BodilessResponseLineSubscriber(ResponseInfo responseInfo, FluxSink sink) { + this.sink = sink; + this.responseInfo = responseInfo; + } + + @Override + protected void hookOnSubscribe(Subscription subscription) { + + sink.onRequest(n -> { + subscription.request(n); + }); + + // Register disposal callback to cancel subscription when Flux is disposed + sink.onDispose(() -> { + subscription.cancel(); + }); + } + + @Override + protected void hookOnComplete() { + // emit dummy event to be able to inspect the response info + // this is a shortcut allowing for a more streamlined processing using + // operator composition instead of having to deal with the CompletableFuture + // along the Subscriber for inspecting the result + this.sink.next(new DummyEvent(responseInfo)); + this.sink.complete(); + } + + @Override + protected void hookOnError(Throwable throwable) { + this.sink.error(throwable); + } + + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java index 8545348ed..5553445b6 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java @@ -362,7 +362,7 @@ public Mono closeGracefully() { } else { logger.warn("Process not started"); - return Mono.empty(); + return Mono.empty(); } })).doOnNext(process -> { if (process.exitValue() != 0) { diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java index ecc6f8666..eb2b7edeb 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java @@ -61,14 +61,19 @@ public long streamId() { @Override public Publisher consumeSseStream( Publisher, Iterable>> eventStream) { - return Flux.deferContextual(ctx -> Flux.from(eventStream).doOnError(e -> { - if (resumable && !(e instanceof McpTransportSessionNotFoundException)) { - Mono.from(reconnect.apply(this)).contextWrite(ctx).subscribe(); - } - }).doOnNext(idAndMessage -> idAndMessage.getT1().ifPresent(id -> { - String previousId = this.lastId.getAndSet(id); - logger.debug("Updating last id {} -> {} for stream {}", previousId, id, this.streamId); - })).flatMapIterable(Tuple2::getT2)); + + // @formatter:off + return Flux.deferContextual(ctx -> Flux.from(eventStream) + .doOnNext(idAndMessage -> idAndMessage.getT1().ifPresent(id -> { + String previousId = this.lastId.getAndSet(id); + logger.debug("Updating last id {} -> {} for stream {}", previousId, id, this.streamId); + })) + .doOnError(e -> { + if (resumable && !(e instanceof McpTransportSessionNotFoundException)) { + Mono.from(reconnect.apply(this)).contextWrite(ctx).subscribe(); + } + }) + .flatMapIterable(Tuple2::getT2)); // @formatter:on } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index fa0853d81..14b36d451 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -141,14 +141,18 @@ else if (message instanceof McpSchema.JSONRPCRequest request) { var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, error.getMessage(), null)); - return this.transport.sendMessage(errorResponse).then(Mono.empty()); - }).flatMap(this.transport::sendMessage).subscribe(); + return Mono.just(errorResponse); + }).flatMap(this.transport::sendMessage).onErrorComplete(t -> { + logger.warn("Issue sending response to the client, ", t); + return true; + }).subscribe(); } else if (message instanceof McpSchema.JSONRPCNotification notification) { logger.debug("Received notification: {}", notification); - handleIncomingNotification(notification) - .doOnError(error -> logger.error("Error handling notification: {}", error.getMessage())) - .subscribe(); + handleIncomingNotification(notification).onErrorComplete(t -> { + logger.error("Error handling notification: {}", t.getMessage()); + return true; + }).subscribe(); } else { logger.warn("Received unknown message type: {}", message); @@ -171,11 +175,7 @@ private Mono handleIncomingRequest(McpSchema.JSONRPCR } return handler.handle(request.params()) - .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)) - .onErrorResume(error -> Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), - null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, - error.getMessage(), null)))); // TODO: add error message - // through the data field + .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)); }); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java b/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java index 20a8c0cf5..7ba35bbf0 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java +++ b/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java @@ -15,8 +15,6 @@ */ package io.modelcontextprotocol; -import java.util.Map; - import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpServerSession.Factory; diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java new file mode 100644 index 000000000..b673ed612 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java @@ -0,0 +1,226 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ +package io.modelcontextprotocol.client; + +import eu.rekawek.toxiproxy.Proxy; +import eu.rekawek.toxiproxy.ToxiproxyClient; +import eu.rekawek.toxiproxy.model.ToxicDirection; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpTransport; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.Network; +import org.testcontainers.containers.ToxiproxyContainer; +import org.testcontainers.containers.wait.strategy.Wait; +import reactor.test.StepVerifier; + +import java.io.IOException; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; + +import static org.assertj.core.api.Assertions.assertThatCode; + +/** + * Resiliency test suite for the {@link McpAsyncClient} that can be used with different + * {@link McpTransport} implementations that support Streamable HTTP. + * + * The purpose of these tests is to allow validating the transport layer resiliency + * instead of the functionality offered by the logical layer of MCP concepts such as + * tools, resources, prompts, etc. + * + * @author Dariusz Jędrzejczyk + */ +// KEEP IN SYNC with the class in mcp-test module +public abstract class AbstractMcpAsyncClientResiliencyTests { + + private static final Logger logger = LoggerFactory.getLogger(AbstractMcpAsyncClientResiliencyTests.class); + + static Network network = Network.newNetwork(); + static String host = "http://localhost:3001"; + + // Uses the https://github.com/tzolov/mcp-everything-server-docker-image + @SuppressWarnings("resource") + static GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js streamableHttp") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withNetwork(network) + .withNetworkAliases("everything-server") + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + static ToxiproxyContainer toxiproxy = new ToxiproxyContainer("ghcr.io/shopify/toxiproxy:2.5.0").withNetwork(network) + .withExposedPorts(8474, 3000); + + static Proxy proxy; + + static { + container.start(); + + toxiproxy.start(); + + final ToxiproxyClient toxiproxyClient = new ToxiproxyClient(toxiproxy.getHost(), toxiproxy.getControlPort()); + try { + proxy = toxiproxyClient.createProxy("everything-server", "0.0.0.0:3000", "everything-server:3001"); + } + catch (IOException e) { + throw new RuntimeException("Can't create proxy!", e); + } + + final String ipAddressViaToxiproxy = toxiproxy.getHost(); + final int portViaToxiproxy = toxiproxy.getMappedPort(3000); + + host = "http://" + ipAddressViaToxiproxy + ":" + portViaToxiproxy; + } + + static void disconnect() { + long start = System.nanoTime(); + try { + // disconnect + // proxy.toxics().bandwidth("CUT_CONNECTION_DOWNSTREAM", + // ToxicDirection.DOWNSTREAM, 0); + // proxy.toxics().bandwidth("CUT_CONNECTION_UPSTREAM", + // ToxicDirection.UPSTREAM, 0); + proxy.toxics().resetPeer("RESET_DOWNSTREAM", ToxicDirection.DOWNSTREAM, 0); + proxy.toxics().resetPeer("RESET_UPSTREAM", ToxicDirection.UPSTREAM, 0); + logger.info("Disconnect took {} ms", Duration.ofNanos(System.nanoTime() - start).toMillis()); + } + catch (IOException e) { + throw new RuntimeException("Failed to disconnect", e); + } + } + + static void reconnect() { + long start = System.nanoTime(); + try { + proxy.toxics().get("RESET_UPSTREAM").remove(); + proxy.toxics().get("RESET_DOWNSTREAM").remove(); + // proxy.toxics().get("CUT_CONNECTION_DOWNSTREAM").remove(); + // proxy.toxics().get("CUT_CONNECTION_UPSTREAM").remove(); + logger.info("Reconnect took {} ms", Duration.ofNanos(System.nanoTime() - start).toMillis()); + } + catch (IOException e) { + throw new RuntimeException("Failed to reconnect", e); + } + } + + static void restartMcpServer() { + container.stop(); + container.start(); + } + + abstract McpClientTransport createMcpTransport(); + + protected Duration getRequestTimeout() { + return Duration.ofSeconds(14); + } + + protected Duration getInitializationTimeout() { + return Duration.ofSeconds(2); + } + + McpAsyncClient client(McpClientTransport transport) { + return client(transport, Function.identity()); + } + + McpAsyncClient client(McpClientTransport transport, Function customizer) { + AtomicReference client = new AtomicReference<>(); + + assertThatCode(() -> { + McpClient.AsyncSpec builder = McpClient.async(transport) + .requestTimeout(getRequestTimeout()) + .initializationTimeout(getInitializationTimeout()) + .capabilities(McpSchema.ClientCapabilities.builder().roots(true).build()); + builder = customizer.apply(builder); + client.set(builder.build()); + }).doesNotThrowAnyException(); + + return client.get(); + } + + void withClient(McpClientTransport transport, Consumer c) { + withClient(transport, Function.identity(), c); + } + + void withClient(McpClientTransport transport, Function customizer, + Consumer c) { + var client = client(transport, customizer); + try { + c.accept(client); + } + finally { + StepVerifier.create(client.closeGracefully()).expectComplete().verify(Duration.ofSeconds(10)); + } + } + + @Test + void testPing() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + + disconnect(); + + StepVerifier.create(mcpAsyncClient.ping()).expectError().verify(); + + reconnect(); + + StepVerifier.create(mcpAsyncClient.ping()).expectNextCount(1).verifyComplete(); + }); + } + + @Test + void testSessionInvalidation() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + + restartMcpServer(); + + // The first try will face the session mismatch exception and the second one + // will go through the re-initialization process. + StepVerifier.create(mcpAsyncClient.ping().retry(1)).expectNextCount(1).verifyComplete(); + }); + } + + @Test + void testCallTool() { + withClient(createMcpTransport(), mcpAsyncClient -> { + AtomicReference> tools = new AtomicReference<>(); + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + StepVerifier.create(mcpAsyncClient.listTools()) + .consumeNextWith(list -> tools.set(list.tools())) + .verifyComplete(); + + disconnect(); + + String name = tools.get().get(0).name(); + // Assuming this is the echo tool + McpSchema.CallToolRequest request = new McpSchema.CallToolRequest(name, Map.of("message", "hello")); + StepVerifier.create(mcpAsyncClient.callTool(request)).expectError().verify(); + + reconnect(); + + StepVerifier.create(mcpAsyncClient.callTool(request)).expectNextCount(1).verifyComplete(); + }); + } + + @Test + void testSessionClose() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + // In case of Streamable HTTP this call should issue a HTTP DELETE request + // invalidating the session + StepVerifier.create(mcpAsyncClient.closeGracefully()).expectComplete().verify(); + // The next use should immediately re-initialize with no issue and send the + // request without any broken connections. + StepVerifier.create(mcpAsyncClient.ping()).expectNextCount(1).verifyComplete(); + }); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientResiliencyTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientResiliencyTests.java new file mode 100644 index 000000000..945278154 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientResiliencyTests.java @@ -0,0 +1,39 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.io.IOException; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; +import reactor.test.StepVerifier; + +@Timeout(15) +public class HttpClientStreamableHttpAsyncClientResiliencyTests extends AbstractMcpAsyncClientResiliencyTests { + + @Override + protected McpClientTransport createMcpTransport() { + return HttpClientStreamableHttpTransport.builder(host).build(); + } + + @Test + void testPingWithExactExceptionType() { + withClient(createMcpTransport(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + + disconnect(); + + StepVerifier.create(mcpAsyncClient.ping()).expectError(IOException.class).verify(); + + reconnect(); + + StepVerifier.create(mcpAsyncClient.ping()).expectNextCount(1).verifyComplete(); + }); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientTests.java new file mode 100644 index 000000000..aa081b51b --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientTests.java @@ -0,0 +1,41 @@ +package io.modelcontextprotocol.client; + +import org.junit.jupiter.api.Timeout; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; + +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; + +@Timeout(15) +public class HttpClientStreamableHttpAsyncClientTests extends AbstractMcpAsyncClientTests { + + private String host = "http://localhost:3001"; + + // Uses the https://github.com/tzolov/mcp-everything-server-docker-image + @SuppressWarnings("resource") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js streamableHttp") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + @Override + protected McpClientTransport createMcpTransport() { + + return HttpClientStreamableHttpTransport.builder(host).build(); + } + + @Override + protected void onStart() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + } + + @Override + public void onClose() { + container.stop(); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpSyncClientTests.java new file mode 100644 index 000000000..8285f417f --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpSyncClientTests.java @@ -0,0 +1,40 @@ +package io.modelcontextprotocol.client; + +import org.junit.jupiter.api.Timeout; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; + +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.spec.McpClientTransport; + +@Timeout(15) +public class HttpClientStreamableHttpSyncClientTests extends AbstractMcpSyncClientTests { + + static String host = "http://localhost:3001"; + + // Uses the https://github.com/tzolov/mcp-everything-server-docker-image + @SuppressWarnings("resource") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js streamableHttp") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + @Override + protected McpClientTransport createMcpTransport() { + return HttpClientStreamableHttpTransport.builder(host).build(); + } + + @Override + protected void onStart() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + } + + @Override + public void onClose() { + container.stop(); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientLostConnectionTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientLostConnectionTests.java new file mode 100644 index 000000000..0a72b785d --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientLostConnectionTests.java @@ -0,0 +1,137 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import static org.assertj.core.api.Assertions.assertThatCode; + +import java.io.IOException; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.Network; +import org.testcontainers.containers.ToxiproxyContainer; +import org.testcontainers.containers.wait.strategy.Wait; + +import eu.rekawek.toxiproxy.Proxy; +import eu.rekawek.toxiproxy.ToxiproxyClient; +import eu.rekawek.toxiproxy.model.ToxicDirection; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import reactor.test.StepVerifier; + +@Timeout(15) +public class HttpSseMcpAsyncClientLostConnectionTests { + + private static final Logger logger = LoggerFactory.getLogger(HttpSseMcpAsyncClientLostConnectionTests.class); + + static Network network = Network.newNetwork(); + static String host = "http://localhost:3001"; + + // Uses the https://github.com/tzolov/mcp-everything-server-docker-image + @SuppressWarnings("resource") + static GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js sse") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withNetwork(network) + .withNetworkAliases("everything-server") + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + static ToxiproxyContainer toxiproxy = new ToxiproxyContainer("ghcr.io/shopify/toxiproxy:2.5.0").withNetwork(network) + .withExposedPorts(8474, 3000); + + static Proxy proxy; + + static { + container.start(); + + toxiproxy.start(); + + final ToxiproxyClient toxiproxyClient = new ToxiproxyClient(toxiproxy.getHost(), toxiproxy.getControlPort()); + try { + proxy = toxiproxyClient.createProxy("everything-server", "0.0.0.0:3000", "everything-server:3001"); + } + catch (IOException e) { + throw new RuntimeException("Can't create proxy!", e); + } + + final String ipAddressViaToxiproxy = toxiproxy.getHost(); + final int portViaToxiproxy = toxiproxy.getMappedPort(3000); + + host = "http://" + ipAddressViaToxiproxy + ":" + portViaToxiproxy; + } + + static void disconnect() { + long start = System.nanoTime(); + try { + proxy.toxics().resetPeer("RESET_DOWNSTREAM", ToxicDirection.DOWNSTREAM, 0); + proxy.toxics().resetPeer("RESET_UPSTREAM", ToxicDirection.UPSTREAM, 0); + logger.info("Disconnect took {} ms", Duration.ofNanos(System.nanoTime() - start).toMillis()); + } + catch (IOException e) { + throw new RuntimeException("Failed to disconnect", e); + } + } + + static void reconnect() { + long start = System.nanoTime(); + try { + proxy.toxics().get("RESET_UPSTREAM").remove(); + proxy.toxics().get("RESET_DOWNSTREAM").remove(); + logger.info("Reconnect took {} ms", Duration.ofNanos(System.nanoTime() - start).toMillis()); + } + catch (IOException e) { + throw new RuntimeException("Failed to reconnect", e); + } + } + + McpAsyncClient client(McpClientTransport transport) { + AtomicReference client = new AtomicReference<>(); + + assertThatCode(() -> { + McpClient.AsyncSpec builder = McpClient.async(transport) + .requestTimeout(Duration.ofSeconds(14)) + .initializationTimeout(Duration.ofSeconds(2)) + .capabilities(McpSchema.ClientCapabilities.builder().roots(true).build()); + client.set(builder.build()); + }).doesNotThrowAnyException(); + + return client.get(); + } + + void withClient(McpClientTransport transport, Consumer c) { + var client = client(transport); + try { + c.accept(client); + } + finally { + StepVerifier.create(client.closeGracefully()).expectComplete().verify(Duration.ofSeconds(10)); + } + } + + @Test + void testPingWithEaxctExceptionType() { + withClient(HttpClientSseClientTransport.builder(host).build(), mcpAsyncClient -> { + StepVerifier.create(mcpAsyncClient.initialize()).expectNextCount(1).verifyComplete(); + + disconnect(); + + // Veryfiy that the exception type is IOException and not TimeoutException + StepVerifier.create(mcpAsyncClient.ping()).expectError(IOException.class).verify(); + + reconnect(); + + StepVerifier.create(mcpAsyncClient.ping()).expectNextCount(1).verifyComplete(); + }); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java index 1b66a98cd..6cb3f7b65 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/HttpSseMcpAsyncClientTests.java @@ -4,12 +4,13 @@ package io.modelcontextprotocol.client; -import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; -import io.modelcontextprotocol.spec.McpClientTransport; import org.junit.jupiter.api.Timeout; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.spec.McpClientTransport; + /** * Tests for the {@link McpSyncClient} with {@link HttpClientSseClientTransport}. * diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java index 1b1c72012..e4348be25 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java @@ -7,23 +7,20 @@ import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpRequest; -import java.net.http.HttpResponse; import java.time.Duration; import java.util.Map; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; +import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; -import org.mockito.ArgumentCaptor; -import org.mockito.Mockito; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; import reactor.core.publisher.Mono; @@ -34,11 +31,6 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import com.fasterxml.jackson.databind.ObjectMapper; /** * Tests for the {@link HttpClientSseClientTransport} class. @@ -371,25 +363,4 @@ void testChainedCustomizations() { customizedTransport.closeGracefully().block(); } - @Test - @SuppressWarnings("unchecked") - void testResolvingClientEndpoint() { - HttpClient httpClient = Mockito.mock(HttpClient.class); - HttpResponse httpResponse = Mockito.mock(HttpResponse.class); - CompletableFuture> future = new CompletableFuture<>(); - future.complete(httpResponse); - when(httpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))).thenReturn(future); - - HttpClientSseClientTransport transport = new HttpClientSseClientTransport(httpClient, HttpRequest.newBuilder(), - "http://example.com", "http://example.com/sse", new ObjectMapper()); - - transport.connect(Function.identity()); - - ArgumentCaptor httpRequestCaptor = ArgumentCaptor.forClass(HttpRequest.class); - verify(httpClient).sendAsync(httpRequestCaptor.capture(), any(HttpResponse.BodyHandler.class)); - assertThat(httpRequestCaptor.getValue().uri()).isEqualTo(URI.create("http://example.com/sse")); - - transport.closeGracefully().block(); - } - }