Skip to content

Commit 2be29c7

Browse files
committed
feat: Propagate Context to eager connect via McpClientSession
In order to allow the initial connection to have contextual information in the reactive chain, the McpClientSession should be able to transform the McpClientTransport#connect result, e.g. to attach Context items. This change introduces a new constructor for sessions that makes it possible. Signed-off-by: Dariusz Jędrzejczyk <dariusz.jedrzejczyk@broadcom.com>
1 parent 9ebff0c commit 2be29c7

File tree

2 files changed

+29
-10
lines changed

2 files changed

+29
-10
lines changed

mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import java.util.concurrent.ConcurrentHashMap;
1212
import java.util.concurrent.atomic.AtomicReference;
1313
import java.util.function.Function;
14-
import java.util.function.Supplier;
1514

1615
import org.slf4j.Logger;
1716
import org.slf4j.LoggerFactory;
@@ -42,6 +41,7 @@
4241
import reactor.core.publisher.Flux;
4342
import reactor.core.publisher.Mono;
4443
import reactor.core.publisher.Sinks;
44+
import reactor.util.context.ContextView;
4545

4646
/**
4747
* The Model Context Protocol (MCP) client implementation that provides asynchronous
@@ -161,7 +161,7 @@ public class McpAsyncClient {
161161
* The MCP session supplier that manages bidirectional JSON-RPC communication between
162162
* clients and servers.
163163
*/
164-
private final Supplier<McpClientSession> sessionSupplier;
164+
private final Function<ContextView, McpClientSession> sessionSupplier;
165165

166166
/**
167167
* Create a new McpAsyncClient with the given transport and session request-response
@@ -268,8 +268,8 @@ public class McpAsyncClient {
268268
asyncLoggingNotificationHandler(loggingConsumersFinal));
269269

270270
this.transport.setExceptionHandler(this::handleException);
271-
this.sessionSupplier = () -> new McpClientSession(requestTimeout, transport, requestHandlers,
272-
notificationHandlers);
271+
this.sessionSupplier = ctx -> new McpClientSession(requestTimeout, transport, requestHandlers,
272+
notificationHandlers, con -> con.contextWrite(ctx));
273273

274274
}
275275

@@ -401,9 +401,8 @@ public Mono<McpSchema.InitializeResult> initialize() {
401401
return withSession("by explicit API call", init -> Mono.just(init.get()));
402402
}
403403

404-
private Mono<McpSchema.InitializeResult> doInitialize(Initialization initialization) {
405-
406-
initialization.setMcpClientSession(this.sessionSupplier.get());
404+
private Mono<McpSchema.InitializeResult> doInitialize(Initialization initialization, ContextView ctx) {
405+
initialization.setMcpClientSession(this.sessionSupplier.apply(ctx));
407406

408407
McpClientSession mcpClientSession = initialization.mcpSession();
409408

@@ -493,14 +492,14 @@ Mono<Void> closeGracefully() {
493492
* @return A Mono that completes with the result of the operation
494493
*/
495494
private <T> Mono<T> withSession(String actionName, Function<Initialization, Mono<T>> operation) {
496-
return Mono.defer(() -> {
495+
return Mono.deferContextual(ctx -> {
497496
Initialization newInit = Initialization.create();
498497
Initialization previous = this.initializationRef.compareAndExchange(null, newInit);
499498

500499
boolean needsToInitialize = previous == null;
501500
logger.debug(needsToInitialize ? "Initialization process started" : "Joining previous initialization");
502501

503-
Mono<McpSchema.InitializeResult> initializationJob = needsToInitialize ? doInitialize(newInit)
502+
Mono<McpSchema.InitializeResult> initializationJob = needsToInitialize ? doInitialize(newInit, ctx)
504503
: previous.await();
505504

506505
return initializationJob.map(initializeResult -> this.initializationRef.get())

mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import com.fasterxml.jackson.core.type.TypeReference;
88
import io.modelcontextprotocol.util.Assert;
9+
import org.reactivestreams.Publisher;
910
import org.slf4j.Logger;
1011
import org.slf4j.LoggerFactory;
1112
import reactor.core.publisher.Mono;
@@ -16,6 +17,7 @@
1617
import java.util.UUID;
1718
import java.util.concurrent.ConcurrentHashMap;
1819
import java.util.concurrent.atomic.AtomicLong;
20+
import java.util.function.Function;
1921

2022
/**
2123
* Default implementation of the MCP (Model Context Protocol) session that manages
@@ -99,9 +101,27 @@ public interface NotificationHandler {
99101
* @param transport Transport implementation for message exchange
100102
* @param requestHandlers Map of method names to request handlers
101103
* @param notificationHandlers Map of method names to notification handlers
104+
* @deprecated Use
105+
* {@link #McpClientSession(Duration, McpClientTransport, Map, Map, Function)}
102106
*/
107+
@Deprecated
103108
public McpClientSession(Duration requestTimeout, McpClientTransport transport,
104109
Map<String, RequestHandler<?>> requestHandlers, Map<String, NotificationHandler> notificationHandlers) {
110+
this(requestTimeout, transport, requestHandlers, notificationHandlers, Function.identity());
111+
}
112+
113+
/**
114+
* Creates a new McpClientSession with the specified configuration and handlers.
115+
* @param requestTimeout Duration to wait for responses
116+
* @param transport Transport implementation for message exchange
117+
* @param requestHandlers Map of method names to request handlers
118+
* @param notificationHandlers Map of method names to notification handlers
119+
* @param connectHook Hook that allows transforming the connection Publisher prior to
120+
* subscribing
121+
*/
122+
public McpClientSession(Duration requestTimeout, McpClientTransport transport,
123+
Map<String, RequestHandler<?>> requestHandlers, Map<String, NotificationHandler> notificationHandlers,
124+
Function<? super Mono<Void>, ? extends Publisher<Void>> connectHook) {
105125

106126
Assert.notNull(requestTimeout, "The requestTimeout can not be null");
107127
Assert.notNull(transport, "The transport can not be null");
@@ -113,7 +133,7 @@ public McpClientSession(Duration requestTimeout, McpClientTransport transport,
113133
this.requestHandlers.putAll(requestHandlers);
114134
this.notificationHandlers.putAll(notificationHandlers);
115135

116-
this.transport.connect(mono -> mono.doOnNext(this::handle)).subscribe();
136+
this.transport.connect(mono -> mono.doOnNext(this::handle)).transform(connectHook).subscribe();
117137
}
118138

119139
private void dismissPendingResponses() {

0 commit comments

Comments
 (0)