From 2be29c71c3c02fd20f4b6418c43395e80637421b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Wed, 25 Jun 2025 12:15:35 +0200 Subject: [PATCH 1/2] feat: Propagate Context to eager connect via McpClientSession MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In order to allow the initial connection to have contextual information in the reactive chain, the McpClientSession should be able to transform the McpClientTransport#connect result, e.g. to attach Context items. This change introduces a new constructor for sessions that makes it possible. Signed-off-by: Dariusz Jędrzejczyk --- .../client/McpAsyncClient.java | 17 +++++++------- .../spec/McpClientSession.java | 22 ++++++++++++++++++- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index a7dac4c0..e25bf839 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -11,7 +11,6 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; -import java.util.function.Supplier; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -42,6 +41,7 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.publisher.Sinks; +import reactor.util.context.ContextView; /** * The Model Context Protocol (MCP) client implementation that provides asynchronous @@ -161,7 +161,7 @@ public class McpAsyncClient { * The MCP session supplier that manages bidirectional JSON-RPC communication between * clients and servers. */ - private final Supplier sessionSupplier; + private final Function sessionSupplier; /** * Create a new McpAsyncClient with the given transport and session request-response @@ -268,8 +268,8 @@ public class McpAsyncClient { asyncLoggingNotificationHandler(loggingConsumersFinal)); this.transport.setExceptionHandler(this::handleException); - this.sessionSupplier = () -> new McpClientSession(requestTimeout, transport, requestHandlers, - notificationHandlers); + this.sessionSupplier = ctx -> new McpClientSession(requestTimeout, transport, requestHandlers, + notificationHandlers, con -> con.contextWrite(ctx)); } @@ -401,9 +401,8 @@ public Mono initialize() { return withSession("by explicit API call", init -> Mono.just(init.get())); } - private Mono doInitialize(Initialization initialization) { - - initialization.setMcpClientSession(this.sessionSupplier.get()); + private Mono doInitialize(Initialization initialization, ContextView ctx) { + initialization.setMcpClientSession(this.sessionSupplier.apply(ctx)); McpClientSession mcpClientSession = initialization.mcpSession(); @@ -493,14 +492,14 @@ Mono closeGracefully() { * @return A Mono that completes with the result of the operation */ private Mono withSession(String actionName, Function> operation) { - return Mono.defer(() -> { + return Mono.deferContextual(ctx -> { Initialization newInit = Initialization.create(); Initialization previous = this.initializationRef.compareAndExchange(null, newInit); boolean needsToInitialize = previous == null; logger.debug(needsToInitialize ? "Initialization process started" : "Joining previous initialization"); - Mono initializationJob = needsToInitialize ? doInitialize(newInit) + Mono initializationJob = needsToInitialize ? doInitialize(newInit, ctx) : previous.await(); return initializationJob.map(initializeResult -> this.initializationRef.get()) diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index fa0853d8..36aa1881 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -6,6 +6,7 @@ import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.util.Assert; +import org.reactivestreams.Publisher; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Mono; @@ -16,6 +17,7 @@ import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Function; /** * Default implementation of the MCP (Model Context Protocol) session that manages @@ -99,9 +101,27 @@ public interface NotificationHandler { * @param transport Transport implementation for message exchange * @param requestHandlers Map of method names to request handlers * @param notificationHandlers Map of method names to notification handlers + * @deprecated Use + * {@link #McpClientSession(Duration, McpClientTransport, Map, Map, Function)} */ + @Deprecated public McpClientSession(Duration requestTimeout, McpClientTransport transport, Map> requestHandlers, Map notificationHandlers) { + this(requestTimeout, transport, requestHandlers, notificationHandlers, Function.identity()); + } + + /** + * Creates a new McpClientSession with the specified configuration and handlers. + * @param requestTimeout Duration to wait for responses + * @param transport Transport implementation for message exchange + * @param requestHandlers Map of method names to request handlers + * @param notificationHandlers Map of method names to notification handlers + * @param connectHook Hook that allows transforming the connection Publisher prior to + * subscribing + */ + public McpClientSession(Duration requestTimeout, McpClientTransport transport, + Map> requestHandlers, Map notificationHandlers, + Function, ? extends Publisher> connectHook) { Assert.notNull(requestTimeout, "The requestTimeout can not be null"); Assert.notNull(transport, "The transport can not be null"); @@ -113,7 +133,7 @@ public McpClientSession(Duration requestTimeout, McpClientTransport transport, this.requestHandlers.putAll(requestHandlers); this.notificationHandlers.putAll(notificationHandlers); - this.transport.connect(mono -> mono.doOnNext(this::handle)).subscribe(); + this.transport.connect(mono -> mono.doOnNext(this::handle)).transform(connectHook).subscribe(); } private void dismissPendingResponses() { From 3e1a8893573f857a804560baebd288cb4557bd37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Thu, 26 Jun 2025 12:53:28 +0200 Subject: [PATCH 2/2] Added a test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Dariusz Jędrzejczyk --- .../client/McpAsyncClient.java | 1 - .../client/McpAsyncClientTests.java | 81 +++++++++++++++++++ 2 files changed, 81 insertions(+), 1 deletion(-) create mode 100644 mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientTests.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index e25bf839..617cec17 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -270,7 +270,6 @@ public class McpAsyncClient { this.transport.setExceptionHandler(this::handleException); this.sessionSupplier = ctx -> new McpClientSession(requestTimeout, transport, requestHandlers, notificationHandlers, con -> con.contextWrite(ctx)); - } private void handleException(Throwable t) { diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientTests.java new file mode 100644 index 00000000..14ca8279 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientTests.java @@ -0,0 +1,81 @@ +package io.modelcontextprotocol.client; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; + +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +import static org.assertj.core.api.Assertions.assertThatCode; + +class McpAsyncClientTests { + + public static final McpSchema.Implementation MOCK_SERVER_INFO = new McpSchema.Implementation("test-server", + "1.0.0"); + + public static final McpSchema.ServerCapabilities MOCK_SERVER_CAPABILITIES = McpSchema.ServerCapabilities.builder() + .build(); + + public static final McpSchema.InitializeResult MOCK_INIT_RESULT = new McpSchema.InitializeResult( + McpSchema.LATEST_PROTOCOL_VERSION, MOCK_SERVER_CAPABILITIES, MOCK_SERVER_INFO, "Test instructions"); + + private static final String CONTEXT_KEY = "context.key"; + + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + @Test + void validateContextPassedToTransportConnect() { + McpClientTransport transport = new McpClientTransport() { + Function, Mono> handler; + + final AtomicReference contextValue = new AtomicReference<>(); + + @Override + public Mono connect( + Function, Mono> handler) { + return Mono.deferContextual(ctx -> { + this.handler = handler; + if (ctx.hasKey(CONTEXT_KEY)) { + this.contextValue.set(ctx.get(CONTEXT_KEY)); + } + return Mono.empty(); + }); + } + + @Override + public Mono closeGracefully() { + return Mono.empty(); + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + if (!"hello".equals(this.contextValue.get())) { + return Mono.error(new RuntimeException("Context value not propagated via #connect method")); + } + // We're only interested in handling the init request to provide an init + // response + if (!(message instanceof McpSchema.JSONRPCRequest)) { + return Mono.empty(); + } + McpSchema.JSONRPCResponse initResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, + ((McpSchema.JSONRPCRequest) message).id(), MOCK_INIT_RESULT, null); + return handler.apply(Mono.just(initResponse)).then(); + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return OBJECT_MAPPER.convertValue(data, typeRef); + } + }; + + assertThatCode(() -> { + McpAsyncClient client = McpClient.async(transport).build(); + client.initialize().contextWrite(ctx -> ctx.put(CONTEXT_KEY, "hello")).block(); + }).doesNotThrowAnyException(); + } + +}