diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/LifecycleInitializer.java b/mcp/src/main/java/io/modelcontextprotocol/client/LifecycleInitializer.java
new file mode 100644
index 00000000..e33fafa6
--- /dev/null
+++ b/mcp/src/main/java/io/modelcontextprotocol/client/LifecycleInitializer.java
@@ -0,0 +1,348 @@
+package io.modelcontextprotocol.client;
+
+import java.time.Duration;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Function;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import io.modelcontextprotocol.spec.McpClientSession;
+import io.modelcontextprotocol.spec.McpError;
+import io.modelcontextprotocol.spec.McpSchema;
+import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException;
+import io.modelcontextprotocol.util.Assert;
+import reactor.core.publisher.Mono;
+import reactor.core.publisher.Sinks;
+import reactor.util.context.ContextView;
+
+/**
+ * Handles the protocol initialization phase between client and server
+ *
+ *
+ * The initialization phase MUST be the first interaction between client and server.
+ * During this phase, the client and server perform the following operations:
+ *
+ * - Establish protocol version compatibility
+ * - Exchange and negotiate capabilities
+ * - Share implementation details
+ *
+ *
+ * Client Initialization Process
+ *
+ * The client MUST initiate this phase by sending an initialize request containing:
+ *
+ * - Protocol version supported
+ * - Client capabilities
+ * - Client implementation information
+ *
+ *
+ *
+ * After successful initialization, the client MUST send an initialized notification to
+ * indicate it is ready to begin normal operations.
+ *
+ * Server Response
+ *
+ * The server MUST respond with its own capabilities and information.
+ *
+ * Protocol Version Negotiation
+ *
+ * In the initialize request, the client MUST send a protocol version it supports. This
+ * SHOULD be the latest version supported by the client.
+ *
+ *
+ * If the server supports the requested protocol version, it MUST respond with the same
+ * version. Otherwise, the server MUST respond with another protocol version it supports.
+ * This SHOULD be the latest version supported by the server.
+ *
+ *
+ * If the client does not support the version in the server's response, it SHOULD
+ * disconnect.
+ *
+ * Request Restrictions
+ *
+ * Important: The following restrictions apply during initialization:
+ *
+ * - The client SHOULD NOT send requests other than pings before the server has
+ * responded to the initialize request
+ * - The server SHOULD NOT send requests other than pings and logging before receiving
+ * the initialized notification
+ *
+ */
+class LifecycleInitializer {
+
+ private static final Logger logger = LoggerFactory.getLogger(LifecycleInitializer.class);
+
+ /**
+ * The MCP session supplier that manages bidirectional JSON-RPC communication between
+ * clients and servers.
+ */
+ private final Function sessionSupplier;
+
+ private final McpSchema.ClientCapabilities clientCapabilities;
+
+ private final McpSchema.Implementation clientInfo;
+
+ private List protocolVersions;
+
+ private final AtomicReference initializationRef = new AtomicReference<>();
+
+ /**
+ * The max timeout to await for the client-server connection to be initialized.
+ */
+ private final Duration initializationTimeout;
+
+ public LifecycleInitializer(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo,
+ List protocolVersions, Duration initializationTimeout,
+ Function sessionSupplier) {
+
+ Assert.notNull(sessionSupplier, "Session supplier must not be null");
+ Assert.notNull(clientCapabilities, "Client capabilities must not be null");
+ Assert.notNull(clientInfo, "Client info must not be null");
+ Assert.notEmpty(protocolVersions, "Protocol versions must not be empty");
+ Assert.notNull(initializationTimeout, "Initialization timeout must not be null");
+
+ this.sessionSupplier = sessionSupplier;
+ this.clientCapabilities = clientCapabilities;
+ this.clientInfo = clientInfo;
+ this.protocolVersions = Collections.unmodifiableList(new ArrayList<>(protocolVersions));
+ this.initializationTimeout = initializationTimeout;
+ }
+
+ /**
+ * This method is package-private and used for test only. Should not be called by user
+ * code.
+ * @param protocolVersions the Client supported protocol versions.
+ */
+ void setProtocolVersions(List protocolVersions) {
+ this.protocolVersions = protocolVersions;
+ }
+
+ /**
+ * Represents the initialization state of the MCP client.
+ */
+ interface Initialization {
+
+ /**
+ * Returns the MCP client session that is used to communicate with the server.
+ * This session is established during the initialization process and is used for
+ * sending requests and notifications.
+ * @return The MCP client session
+ */
+ McpClientSession mcpSession();
+
+ /**
+ * Returns the result of the MCP initialization process. This result contains
+ * information about the protocol version, capabilities, server info, and
+ * instructions provided by the server during the initialization phase.
+ * @return The result of the MCP initialization process
+ */
+ McpSchema.InitializeResult initializeResult();
+
+ }
+
+ /**
+ * Default implementation of the {@link Initialization} interface that manages the MCP
+ * client initialization process.
+ */
+ private static class DefaultInitialization implements Initialization {
+
+ /**
+ * A sink that emits the result of the MCP initialization process. It allows
+ * subscribers to wait for the initialization to complete.
+ */
+ private final Sinks.One initSink;
+
+ /**
+ * Holds the result of the MCP initialization process. It is used to cache the
+ * result for future requests.
+ */
+ private final AtomicReference result;
+
+ /**
+ * Holds the MCP client session that is used to communicate with the server. It is
+ * set during the initialization process and used for sending requests and
+ * notifications.
+ */
+ private final AtomicReference mcpClientSession;
+
+ private DefaultInitialization() {
+ this.initSink = Sinks.one();
+ this.result = new AtomicReference<>();
+ this.mcpClientSession = new AtomicReference<>();
+ }
+
+ // ---------------------------------------------------
+ // Public access for mcpSession and initializeResult because they are
+ // used in by the McpAsyncClient.
+ // ----------------------------------------------------
+ public McpClientSession mcpSession() {
+ return this.mcpClientSession.get();
+ }
+
+ public McpSchema.InitializeResult initializeResult() {
+ return this.result.get();
+ }
+
+ // ---------------------------------------------------
+ // Private accessors used internally by the LifecycleInitializer to set the MCP
+ // client session and complete the initialization process.
+ // ---------------------------------------------------
+ private void setMcpClientSession(McpClientSession mcpClientSession) {
+ this.mcpClientSession.set(mcpClientSession);
+ }
+
+ /**
+ * Returns a Mono that completes when the MCP client initialization is complete.
+ * This allows subscribers to wait for the initialization to finish before
+ * proceeding with further operations.
+ * @return A Mono that emits the result of the MCP initialization process
+ */
+ private Mono await() {
+ return this.initSink.asMono();
+ }
+
+ /**
+ * Completes the initialization process with the given result. It caches the
+ * result and emits it to all subscribers waiting for the initialization to
+ * complete.
+ * @param initializeResult The result of the MCP initialization process
+ */
+ private void complete(McpSchema.InitializeResult initializeResult) {
+ // first ensure the result is cached
+ this.result.set(initializeResult);
+ // inform all the subscribers waiting for the initialization
+ this.initSink.emitValue(initializeResult, Sinks.EmitFailureHandler.FAIL_FAST);
+ }
+
+ private void error(Throwable t) {
+ this.initSink.emitError(t, Sinks.EmitFailureHandler.FAIL_FAST);
+ }
+
+ private void close() {
+ this.mcpSession().close();
+ }
+
+ private Mono closeGracefully() {
+ return this.mcpSession().closeGracefully();
+ }
+
+ }
+
+ public boolean isInitialized() {
+ return this.currentInitializationResult() != null;
+ }
+
+ public McpSchema.InitializeResult currentInitializationResult() {
+ DefaultInitialization current = this.initializationRef.get();
+ McpSchema.InitializeResult initializeResult = current != null ? current.result.get() : null;
+ return initializeResult;
+ }
+
+ /**
+ * Hook to handle exceptions that occur during the MCP transport session.
+ *
+ * If the exception is a {@link McpTransportSessionNotFoundException}, it indicates
+ * that the session was not found, and we should re-initialize the client.
+ *
+ * @param t The exception to handle
+ */
+ public void handleException(Throwable t) {
+ logger.warn("Handling exception", t);
+ if (t instanceof McpTransportSessionNotFoundException) {
+ DefaultInitialization previous = this.initializationRef.getAndSet(null);
+ if (previous != null) {
+ previous.close();
+ }
+ // Providing an empty operation since we are only interested in triggering
+ // the implicit initialization step.
+ withIntitialization("re-initializing", result -> Mono.empty()).subscribe();
+ }
+ }
+
+ /**
+ * Utility method to ensure the initialization is established before executing an
+ * operation.
+ * @param The type of the result Mono
+ * @param actionName The action to perform when the client is initialized
+ * @param operation The operation to execute when the client is initialized
+ * @return A Mono that completes with the result of the operation
+ */
+ public Mono withIntitialization(String actionName, Function> operation) {
+ return Mono.deferContextual(ctx -> {
+ DefaultInitialization newInit = new DefaultInitialization();
+ DefaultInitialization previous = this.initializationRef.compareAndExchange(null, newInit);
+
+ boolean needsToInitialize = previous == null;
+ logger.debug(needsToInitialize ? "Initialization process started" : "Joining previous initialization");
+
+ Mono initializationJob = needsToInitialize ? doInitialize(newInit, ctx)
+ : previous.await();
+
+ return initializationJob.map(initializeResult -> this.initializationRef.get())
+ .timeout(this.initializationTimeout)
+ .onErrorResume(ex -> {
+ logger.warn("Failed to initialize", ex);
+ return Mono.error(new McpError("Client failed to initialize " + actionName));
+ })
+ .flatMap(operation);
+ });
+ }
+
+ private Mono doInitialize(DefaultInitialization initialization, ContextView ctx) {
+ initialization.setMcpClientSession(this.sessionSupplier.apply(ctx));
+
+ McpClientSession mcpClientSession = initialization.mcpSession();
+
+ String latestVersion = this.protocolVersions.get(this.protocolVersions.size() - 1);
+
+ McpSchema.InitializeRequest initializeRequest = new McpSchema.InitializeRequest(latestVersion,
+ this.clientCapabilities, this.clientInfo);
+
+ Mono result = mcpClientSession.sendRequest(McpSchema.METHOD_INITIALIZE,
+ initializeRequest, McpAsyncClient.INITIALIZE_RESULT_TYPE_REF);
+
+ return result.flatMap(initializeResult -> {
+ logger.info("Server response with Protocol: {}, Capabilities: {}, Info: {} and Instructions {}",
+ initializeResult.protocolVersion(), initializeResult.capabilities(), initializeResult.serverInfo(),
+ initializeResult.instructions());
+
+ if (!this.protocolVersions.contains(initializeResult.protocolVersion())) {
+ return Mono.error(new McpError(
+ "Unsupported protocol version from the server: " + initializeResult.protocolVersion()));
+ }
+
+ return mcpClientSession.sendNotification(McpSchema.METHOD_NOTIFICATION_INITIALIZED, null)
+ .thenReturn(initializeResult);
+ }).doOnNext(initialization::complete).onErrorResume(ex -> {
+ initialization.error(ex);
+ return Mono.error(ex);
+ });
+ }
+
+ /**
+ * Closes the current initialization if it exists.
+ */
+ public void close() {
+ DefaultInitialization current = this.initializationRef.getAndSet(null);
+ if (current != null) {
+ current.close();
+ }
+ }
+
+ /**
+ * Gracefully closes the current initialization if it exists.
+ * @return A Mono that completes when the connection is closed
+ */
+ public Mono> closeGracefully() {
+ return Mono.defer(() -> {
+ DefaultInitialization current = this.initializationRef.getAndSet(null);
+ Mono> sessionClose = current != null ? current.closeGracefully() : Mono.empty();
+ return sessionClose;
+ });
+ }
+
+}
\ No newline at end of file
diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java
index fa76f397..cf8142c6 100644
--- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java
+++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java
@@ -10,7 +10,6 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import org.slf4j.Logger;
@@ -19,8 +18,6 @@
import com.fasterxml.jackson.core.type.TypeReference;
import io.modelcontextprotocol.spec.McpClientSession;
-import io.modelcontextprotocol.spec.McpClientSession.NotificationHandler;
-import io.modelcontextprotocol.spec.McpClientSession.RequestHandler;
import io.modelcontextprotocol.spec.McpClientTransport;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
@@ -36,13 +33,12 @@
import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification;
import io.modelcontextprotocol.spec.McpSchema.PaginatedRequest;
import io.modelcontextprotocol.spec.McpSchema.Root;
-import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException;
+import io.modelcontextprotocol.spec.McpClientSession.NotificationHandler;
+import io.modelcontextprotocol.spec.McpClientSession.RequestHandler;
import io.modelcontextprotocol.util.Assert;
import io.modelcontextprotocol.util.Utils;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
-import reactor.core.publisher.Sinks;
-import reactor.util.context.ContextView;
/**
* The Model Context Protocol (MCP) client implementation that provides asynchronous
@@ -104,13 +100,6 @@ public class McpAsyncClient {
public static final TypeReference LOGGING_MESSAGE_NOTIFICATION_TYPE_REF = new TypeReference<>() {
};
- private final AtomicReference initializationRef = new AtomicReference<>();
-
- /**
- * The max timeout to await for the client-server connection to be initialized.
- */
- private final Duration initializationTimeout;
-
/**
* Client capabilities.
*/
@@ -154,15 +143,9 @@ public class McpAsyncClient {
private final McpClientTransport transport;
/**
- * Supported protocol versions.
+ * The lifecycle initializer that manages the client-server connection initialization.
*/
- private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION);
-
- /**
- * The MCP session supplier that manages bidirectional JSON-RPC communication between
- * clients and servers.
- */
- private final Function sessionSupplier;
+ private final LifecycleInitializer initializer;
/**
* Create a new McpAsyncClient with the given transport and session request-response
@@ -183,7 +166,6 @@ public class McpAsyncClient {
this.clientCapabilities = features.clientCapabilities();
this.transport = transport;
this.roots = new ConcurrentHashMap<>(features.roots());
- this.initializationTimeout = initializationTimeout;
// Request Handlers
Map> requestHandlers = new HashMap<>();
@@ -271,28 +253,11 @@ public class McpAsyncClient {
notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_MESSAGE,
asyncLoggingNotificationHandler(loggingConsumersFinal));
- this.transport.setExceptionHandler(this::handleException);
- this.sessionSupplier = ctx -> new McpClientSession(requestTimeout, transport, requestHandlers,
- notificationHandlers, con -> con.contextWrite(ctx));
- }
-
- private void handleException(Throwable t) {
- logger.warn("Handling exception", t);
- if (t instanceof McpTransportSessionNotFoundException) {
- Initialization previous = this.initializationRef.getAndSet(null);
- if (previous != null) {
- previous.close();
- }
- // Providing an empty operation since we are only interested in triggering the
- // implicit initialization step.
- withSession("re-initializing", result -> Mono.empty()).subscribe();
- }
- }
-
- private McpSchema.InitializeResult currentInitializationResult() {
- Initialization current = this.initializationRef.get();
- McpSchema.InitializeResult initializeResult = current != null ? current.result.get() : null;
- return initializeResult;
+ this.initializer = new LifecycleInitializer(clientCapabilities, clientInfo,
+ List.of(McpSchema.LATEST_PROTOCOL_VERSION), initializationTimeout,
+ ctx -> new McpClientSession(requestTimeout, transport, requestHandlers, notificationHandlers,
+ con -> con.contextWrite(ctx)));
+ this.transport.setExceptionHandler(this.initializer::handleException);
}
/**
@@ -300,7 +265,7 @@ private McpSchema.InitializeResult currentInitializationResult() {
* @return The server capabilities
*/
public McpSchema.ServerCapabilities getServerCapabilities() {
- McpSchema.InitializeResult initializeResult = currentInitializationResult();
+ McpSchema.InitializeResult initializeResult = this.initializer.currentInitializationResult();
return initializeResult != null ? initializeResult.capabilities() : null;
}
@@ -310,7 +275,7 @@ public McpSchema.ServerCapabilities getServerCapabilities() {
* @return The server instructions
*/
public String getServerInstructions() {
- McpSchema.InitializeResult initializeResult = currentInitializationResult();
+ McpSchema.InitializeResult initializeResult = this.initializer.currentInitializationResult();
return initializeResult != null ? initializeResult.instructions() : null;
}
@@ -319,7 +284,7 @@ public String getServerInstructions() {
* @return The server implementation details
*/
public McpSchema.Implementation getServerInfo() {
- McpSchema.InitializeResult initializeResult = currentInitializationResult();
+ McpSchema.InitializeResult initializeResult = this.initializer.currentInitializationResult();
return initializeResult != null ? initializeResult.serverInfo() : null;
}
@@ -328,8 +293,7 @@ public McpSchema.Implementation getServerInfo() {
* @return true if the client-server connection is initialized
*/
public boolean isInitialized() {
- Initialization current = this.initializationRef.get();
- return current != null && (current.result.get() != null);
+ return this.initializer.isInitialized();
}
/**
@@ -352,10 +316,7 @@ public McpSchema.Implementation getClientInfo() {
* Closes the client connection immediately.
*/
public void close() {
- Initialization current = this.initializationRef.getAndSet(null);
- if (current != null) {
- current.close();
- }
+ this.initializer.close();
this.transport.close();
}
@@ -365,9 +326,7 @@ public void close() {
*/
public Mono closeGracefully() {
return Mono.defer(() -> {
- Initialization current = this.initializationRef.getAndSet(null);
- Mono> sessionClose = current != null ? current.closeGracefully() : Mono.empty();
- return sessionClose.then(transport.closeGracefully());
+ return this.initializer.closeGracefully().then(transport.closeGracefully());
});
}
@@ -401,118 +360,7 @@ public Mono closeGracefully() {
*
*/
public Mono initialize() {
- return withSession("by explicit API call", init -> Mono.just(init.get()));
- }
-
- private Mono doInitialize(Initialization initialization, ContextView ctx) {
- initialization.setMcpClientSession(this.sessionSupplier.apply(ctx));
-
- McpClientSession mcpClientSession = initialization.mcpSession();
-
- String latestVersion = this.protocolVersions.get(this.protocolVersions.size() - 1);
-
- McpSchema.InitializeRequest initializeRequest = new McpSchema.InitializeRequest(// @formatter:off
- latestVersion,
- this.clientCapabilities,
- this.clientInfo); // @formatter:on
-
- Mono result = mcpClientSession.sendRequest(McpSchema.METHOD_INITIALIZE,
- initializeRequest, INITIALIZE_RESULT_TYPE_REF);
-
- return result.flatMap(initializeResult -> {
- logger.info("Server response with Protocol: {}, Capabilities: {}, Info: {} and Instructions {}",
- initializeResult.protocolVersion(), initializeResult.capabilities(), initializeResult.serverInfo(),
- initializeResult.instructions());
-
- if (!this.protocolVersions.contains(initializeResult.protocolVersion())) {
- return Mono.error(new McpError(
- "Unsupported protocol version from the server: " + initializeResult.protocolVersion()));
- }
-
- return mcpClientSession.sendNotification(McpSchema.METHOD_NOTIFICATION_INITIALIZED, null)
- .thenReturn(initializeResult);
- }).doOnNext(initialization::complete).onErrorResume(ex -> {
- initialization.error(ex);
- return Mono.error(ex);
- });
- }
-
- private static class Initialization {
-
- private final Sinks.One initSink = Sinks.one();
-
- private final AtomicReference result = new AtomicReference<>();
-
- private final AtomicReference mcpClientSession = new AtomicReference<>();
-
- static Initialization create() {
- return new Initialization();
- }
-
- void setMcpClientSession(McpClientSession mcpClientSession) {
- this.mcpClientSession.set(mcpClientSession);
- }
-
- McpClientSession mcpSession() {
- return this.mcpClientSession.get();
- }
-
- McpSchema.InitializeResult get() {
- return this.result.get();
- }
-
- Mono await() {
- return this.initSink.asMono();
- }
-
- void complete(McpSchema.InitializeResult initializeResult) {
- // first ensure the result is cached
- this.result.set(initializeResult);
- // inform all the subscribers waiting for the initialization
- this.initSink.emitValue(initializeResult, Sinks.EmitFailureHandler.FAIL_FAST);
- }
-
- void error(Throwable t) {
- this.initSink.emitError(t, Sinks.EmitFailureHandler.FAIL_FAST);
- }
-
- void close() {
- this.mcpSession().close();
- }
-
- Mono closeGracefully() {
- return this.mcpSession().closeGracefully();
- }
-
- }
-
- /**
- * Utility method to handle the common pattern of ensuring initialization before
- * executing an operation.
- * @param The type of the result Mono
- * @param actionName The action to perform when the client is initialized
- * @param operation The operation to execute when the client is initialized
- * @return A Mono that completes with the result of the operation
- */
- private Mono withSession(String actionName, Function> operation) {
- return Mono.deferContextual(ctx -> {
- Initialization newInit = Initialization.create();
- Initialization previous = this.initializationRef.compareAndExchange(null, newInit);
-
- boolean needsToInitialize = previous == null;
- logger.debug(needsToInitialize ? "Initialization process started" : "Joining previous initialization");
-
- Mono initializationJob = needsToInitialize ? doInitialize(newInit, ctx)
- : previous.await();
-
- return initializationJob.map(initializeResult -> this.initializationRef.get())
- .timeout(this.initializationTimeout)
- .onErrorResume(ex -> {
- logger.warn("Failed to initialize", ex);
- return Mono.error(new McpError("Client failed to initialize " + actionName));
- })
- .flatMap(operation);
- });
+ return this.initializer.withIntitialization("by explicit API call", init -> Mono.just(init.initializeResult()));
}
// --------------------------
@@ -524,7 +372,7 @@ private Mono withSession(String actionName, Function ping() {
- return this.withSession("pinging the server",
+ return this.initializer.withIntitialization("pinging the server",
init -> init.mcpSession().sendRequest(McpSchema.METHOD_PING, null, OBJECT_TYPE_REF));
}
@@ -605,7 +453,7 @@ public Mono removeRoot(String rootUri) {
* @return A Mono that completes when the notification is sent.
*/
public Mono rootsListChangedNotification() {
- return this.withSession("sending roots list changed notification",
+ return this.initializer.withIntitialization("sending roots list changed notification",
init -> init.mcpSession().sendNotification(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED));
}
@@ -664,8 +512,8 @@ private RequestHandler elicitationCreateHandler() {
* @see #listTools()
*/
public Mono callTool(McpSchema.CallToolRequest callToolRequest) {
- return this.withSession("calling tools", init -> {
- if (init.get().capabilities().tools() == null) {
+ return this.initializer.withIntitialization("calling tools", init -> {
+ if (init.initializeResult().capabilities().tools() == null) {
return Mono.error(new McpError("Server does not provide tools capability"));
}
return init.mcpSession()
@@ -693,8 +541,8 @@ public Mono listTools() {
* @return A Mono that emits the list of tools result
*/
public Mono listTools(String cursor) {
- return this.withSession("listing tools", init -> {
- if (init.get().capabilities().tools() == null) {
+ return this.initializer.withIntitialization("listing tools", init -> {
+ if (init.initializeResult().capabilities().tools() == null) {
return Mono.error(new McpError("Server does not provide tools capability"));
}
return init.mcpSession()
@@ -757,8 +605,8 @@ public Mono listResources() {
* @see #readResource(McpSchema.Resource)
*/
public Mono listResources(String cursor) {
- return this.withSession("listing resources", init -> {
- if (init.get().capabilities().resources() == null) {
+ return this.initializer.withIntitialization("listing resources", init -> {
+ if (init.initializeResult().capabilities().resources() == null) {
return Mono.error(new McpError("Server does not provide the resources capability"));
}
return init.mcpSession()
@@ -789,8 +637,8 @@ public Mono readResource(McpSchema.Resource resour
* @see McpSchema.ReadResourceResult
*/
public Mono readResource(McpSchema.ReadResourceRequest readResourceRequest) {
- return this.withSession("reading resources", init -> {
- if (init.get().capabilities().resources() == null) {
+ return this.initializer.withIntitialization("reading resources", init -> {
+ if (init.initializeResult().capabilities().resources() == null) {
return Mono.error(new McpError("Server does not provide the resources capability"));
}
return init.mcpSession()
@@ -827,8 +675,8 @@ public Mono listResourceTemplates() {
* @see McpSchema.ListResourceTemplatesResult
*/
public Mono listResourceTemplates(String cursor) {
- return this.withSession("listing resource templates", init -> {
- if (init.get().capabilities().resources() == null) {
+ return this.initializer.withIntitialization("listing resource templates", init -> {
+ if (init.initializeResult().capabilities().resources() == null) {
return Mono.error(new McpError("Server does not provide the resources capability"));
}
return init.mcpSession()
@@ -847,7 +695,7 @@ public Mono listResourceTemplates(String
* @see #unsubscribeResource(McpSchema.UnsubscribeRequest)
*/
public Mono subscribeResource(McpSchema.SubscribeRequest subscribeRequest) {
- return this.withSession("subscribing to resources", init -> init.mcpSession()
+ return this.initializer.withIntitialization("subscribing to resources", init -> init.mcpSession()
.sendRequest(McpSchema.METHOD_RESOURCES_SUBSCRIBE, subscribeRequest, VOID_TYPE_REFERENCE));
}
@@ -861,7 +709,7 @@ public Mono subscribeResource(McpSchema.SubscribeRequest subscribeRequest)
* @see #subscribeResource(McpSchema.SubscribeRequest)
*/
public Mono unsubscribeResource(McpSchema.UnsubscribeRequest unsubscribeRequest) {
- return this.withSession("unsubscribing from resources", init -> init.mcpSession()
+ return this.initializer.withIntitialization("unsubscribing from resources", init -> init.mcpSession()
.sendRequest(McpSchema.METHOD_RESOURCES_UNSUBSCRIBE, unsubscribeRequest, VOID_TYPE_REFERENCE));
}
@@ -927,7 +775,7 @@ public Mono listPrompts() {
* @see #getPrompt(GetPromptRequest)
*/
public Mono listPrompts(String cursor) {
- return this.withSession("listing prompts", init -> init.mcpSession()
+ return this.initializer.withIntitialization("listing prompts", init -> init.mcpSession()
.sendRequest(McpSchema.METHOD_PROMPT_LIST, new PaginatedRequest(cursor), LIST_PROMPTS_RESULT_TYPE_REF));
}
@@ -941,7 +789,7 @@ public Mono listPrompts(String cursor) {
* @see #listPrompts()
*/
public Mono getPrompt(GetPromptRequest getPromptRequest) {
- return this.withSession("getting prompts", init -> init.mcpSession()
+ return this.initializer.withIntitialization("getting prompts", init -> init.mcpSession()
.sendRequest(McpSchema.METHOD_PROMPT_GET, getPromptRequest, GET_PROMPT_RESULT_TYPE_REF));
}
@@ -992,7 +840,7 @@ public Mono setLoggingLevel(LoggingLevel loggingLevel) {
return Mono.error(new McpError("Logging level must not be null"));
}
- return this.withSession("setting logging level", init -> {
+ return this.initializer.withIntitialization("setting logging level", init -> {
var params = new McpSchema.SetLevelRequest(loggingLevel);
return init.mcpSession().sendRequest(McpSchema.METHOD_LOGGING_SET_LEVEL, params, OBJECT_TYPE_REF).then();
});
@@ -1004,7 +852,7 @@ public Mono setLoggingLevel(LoggingLevel loggingLevel) {
* @param protocolVersions the Client supported protocol versions.
*/
void setProtocolVersions(List protocolVersions) {
- this.protocolVersions = protocolVersions;
+ this.initializer.setProtocolVersions(protocolVersions);
}
// --------------------------
@@ -1024,7 +872,7 @@ void setProtocolVersions(List protocolVersions) {
* @see McpSchema.CompleteResult
*/
public Mono completeCompletion(McpSchema.CompleteRequest completeRequest) {
- return this.withSession("complete completions", init -> init.mcpSession()
+ return this.initializer.withIntitialization("complete completions", init -> init.mcpSession()
.sendRequest(McpSchema.METHOD_COMPLETION_COMPLETE, completeRequest, COMPLETION_COMPLETE_RESULT_TYPE_REF));
}
diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerTests.java
new file mode 100644
index 00000000..c8d69192
--- /dev/null
+++ b/mcp/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerTests.java
@@ -0,0 +1,412 @@
+/*
+ * Copyright 2024-2024 the original author or authors.
+ */
+
+package io.modelcontextprotocol.client;
+
+import java.time.Duration;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Function;
+
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+
+import io.modelcontextprotocol.spec.McpClientSession;
+import io.modelcontextprotocol.spec.McpError;
+import io.modelcontextprotocol.spec.McpSchema;
+import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException;
+import reactor.core.publisher.Mono;
+import reactor.core.scheduler.Schedulers;
+import reactor.test.StepVerifier;
+import reactor.test.scheduler.VirtualTimeScheduler;
+import reactor.util.context.Context;
+import reactor.util.context.ContextView;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+/**
+ * Tests for {@link LifecycleInitializer}.
+ */
+class LifecycleInitializerTests {
+
+ private static final Duration INITIALIZATION_TIMEOUT = Duration.ofSeconds(5);
+
+ private static final McpSchema.ClientCapabilities CLIENT_CAPABILITIES = McpSchema.ClientCapabilities.builder()
+ .build();
+
+ private static final McpSchema.Implementation CLIENT_INFO = new McpSchema.Implementation("test-client", "1.0.0");
+
+ private static final List PROTOCOL_VERSIONS = List.of("1.0.0", "2.0.0");
+
+ private static final McpSchema.InitializeResult MOCK_INIT_RESULT = new McpSchema.InitializeResult("2.0.0",
+ McpSchema.ServerCapabilities.builder().build(), new McpSchema.Implementation("test-server", "1.0.0"),
+ "Test instructions");
+
+ @Mock
+ private McpClientSession mockClientSession;
+
+ @Mock
+ private Function mockSessionSupplier;
+
+ private LifecycleInitializer initializer;
+
+ @BeforeEach
+ void setUp() {
+ MockitoAnnotations.openMocks(this);
+
+ when(mockSessionSupplier.apply(any(ContextView.class))).thenReturn(mockClientSession);
+ when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any()))
+ .thenReturn(Mono.just(MOCK_INIT_RESULT));
+ when(mockClientSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), any()))
+ .thenReturn(Mono.empty());
+ when(mockClientSession.closeGracefully()).thenReturn(Mono.empty());
+
+ initializer = new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, PROTOCOL_VERSIONS,
+ INITIALIZATION_TIMEOUT, mockSessionSupplier);
+ }
+
+ @Test
+ void constructorShouldValidateParameters() {
+ assertThatThrownBy(() -> new LifecycleInitializer(null, CLIENT_INFO, PROTOCOL_VERSIONS, INITIALIZATION_TIMEOUT,
+ mockSessionSupplier))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("Client capabilities must not be null");
+
+ assertThatThrownBy(() -> new LifecycleInitializer(CLIENT_CAPABILITIES, null, PROTOCOL_VERSIONS,
+ INITIALIZATION_TIMEOUT, mockSessionSupplier))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("Client info must not be null");
+
+ assertThatThrownBy(() -> new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, null,
+ INITIALIZATION_TIMEOUT, mockSessionSupplier))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("Protocol versions must not be empty");
+
+ assertThatThrownBy(() -> new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, List.of(),
+ INITIALIZATION_TIMEOUT, mockSessionSupplier))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("Protocol versions must not be empty");
+
+ assertThatThrownBy(() -> new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, PROTOCOL_VERSIONS, null,
+ mockSessionSupplier))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("Initialization timeout must not be null");
+
+ assertThatThrownBy(() -> new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, PROTOCOL_VERSIONS,
+ INITIALIZATION_TIMEOUT, null))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("Session supplier must not be null");
+ }
+
+ @Test
+ void shouldInitializeSuccessfully() {
+ StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult())))
+ .assertNext(result -> {
+ assertThat(result).isEqualTo(MOCK_INIT_RESULT);
+ assertThat(initializer.isInitialized()).isTrue();
+ assertThat(initializer.currentInitializationResult()).isEqualTo(MOCK_INIT_RESULT);
+ })
+ .verifyComplete();
+
+ verify(mockClientSession).sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(McpSchema.InitializeRequest.class),
+ any());
+ verify(mockClientSession).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), eq(null));
+ }
+
+ @Test
+ void shouldUseLatestProtocolVersionInInitializeRequest() {
+ AtomicReference capturedRequest = new AtomicReference<>();
+
+ when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())).thenAnswer(invocation -> {
+ capturedRequest.set((McpSchema.InitializeRequest) invocation.getArgument(1));
+ return Mono.just(MOCK_INIT_RESULT);
+ });
+
+ StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult())))
+ .assertNext(result -> {
+ assertThat(capturedRequest.get().protocolVersion()).isEqualTo("2.0.0"); // Latest
+ // version
+ assertThat(capturedRequest.get().capabilities()).isEqualTo(CLIENT_CAPABILITIES);
+ assertThat(capturedRequest.get().clientInfo()).isEqualTo(CLIENT_INFO);
+ })
+ .verifyComplete();
+ }
+
+ @Test
+ void shouldFailForUnsupportedProtocolVersion() {
+ McpSchema.InitializeResult unsupportedResult = new McpSchema.InitializeResult("999.0.0", // Unsupported
+ // version
+ McpSchema.ServerCapabilities.builder().build(), new McpSchema.Implementation("test-server", "1.0.0"),
+ "Test instructions");
+
+ when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any()))
+ .thenReturn(Mono.just(unsupportedResult));
+
+ StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult())))
+ .expectError(McpError.class)
+ .verify();
+
+ verify(mockClientSession, never()).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), any());
+ }
+
+ @Test
+ void shouldTimeoutOnSlowInitialization() {
+ VirtualTimeScheduler virtualTimeScheduler = VirtualTimeScheduler.getOrSet();
+
+ Duration INITIALIZE_TIMEOUT = Duration.ofSeconds(1);
+ Duration SLOW_RESPONSE_DELAY = Duration.ofSeconds(5);
+
+ LifecycleInitializer shortTimeoutInitializer = new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO,
+ PROTOCOL_VERSIONS, INITIALIZE_TIMEOUT, mockSessionSupplier);
+
+ when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any()))
+ .thenReturn(Mono.just(MOCK_INIT_RESULT).delayElement(SLOW_RESPONSE_DELAY, virtualTimeScheduler));
+
+ StepVerifier
+ .withVirtualTime(() -> shortTimeoutInitializer.withIntitialization("test",
+ init -> Mono.just(init.initializeResult())), () -> virtualTimeScheduler, Long.MAX_VALUE)
+ .expectSubscription()
+ .expectNoEvent(INITIALIZE_TIMEOUT)
+ .expectError(McpError.class)
+ .verify();
+ }
+
+ @Test
+ void shouldReuseExistingInitialization() {
+ // First initialization
+ StepVerifier.create(initializer.withIntitialization("test1", init -> Mono.just("result1")))
+ .expectNext("result1")
+ .verifyComplete();
+
+ // Second call should reuse the same initialization
+ StepVerifier.create(initializer.withIntitialization("test2", init -> Mono.just("result2")))
+ .expectNext("result2")
+ .verifyComplete();
+
+ // Verify session was created only once
+ verify(mockSessionSupplier, times(1)).apply(any(ContextView.class));
+ verify(mockClientSession, times(1)).sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any());
+ }
+
+ @Test
+ void shouldHandleConcurrentInitializationRequests() {
+ AtomicInteger sessionCreationCount = new AtomicInteger(0);
+
+ when(mockSessionSupplier.apply(any(ContextView.class))).thenAnswer(invocation -> {
+ sessionCreationCount.incrementAndGet();
+ return mockClientSession;
+ });
+
+ // Start multiple concurrent initializations using subscribeOn with parallel
+ // scheduler
+ Mono init1 = initializer.withIntitialization("test1", init -> Mono.just("result1"))
+ .subscribeOn(Schedulers.parallel());
+ Mono init2 = initializer.withIntitialization("test2", init -> Mono.just("result2"))
+ .subscribeOn(Schedulers.parallel());
+ Mono init3 = initializer.withIntitialization("test3", init -> Mono.just("result3"))
+ .subscribeOn(Schedulers.parallel());
+
+ StepVerifier.create(Mono.zip(init1, init2, init3)).assertNext(tuple -> {
+ assertThat(tuple.getT1()).isEqualTo("result1");
+ assertThat(tuple.getT2()).isEqualTo("result2");
+ assertThat(tuple.getT3()).isEqualTo("result3");
+ }).verifyComplete();
+
+ // Should only create one session despite concurrent requests
+ assertThat(sessionCreationCount.get()).isEqualTo(1);
+ verify(mockClientSession, times(1)).sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any());
+ }
+
+ @Test
+ void shouldHandleInitializationFailure() {
+ when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any()))
+ .thenReturn(Mono.error(new RuntimeException("Connection failed")));
+
+ StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult())))
+ .expectError(McpError.class)
+ .verify();
+
+ assertThat(initializer.isInitialized()).isFalse();
+ assertThat(initializer.currentInitializationResult()).isNull();
+ }
+
+ @Test
+ void shouldHandleTransportSessionNotFoundException() {
+ // successful initialization first
+ StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult())))
+ .expectNext(MOCK_INIT_RESULT)
+ .verifyComplete();
+
+ assertThat(initializer.isInitialized()).isTrue();
+
+ // Simulate transport session not found
+ initializer.handleException(new McpTransportSessionNotFoundException("Session not found"));
+
+ assertThat(initializer.isInitialized()).isTrue();
+
+ // Verify that the session was closed and re-initialized
+ verify(mockClientSession).close();
+
+ // Verify session was created 2 times (once for initial and once for
+ // re-initialization)
+ verify(mockSessionSupplier, times(2)).apply(any(ContextView.class));
+ }
+
+ @Test
+ void shouldHandleOtherExceptions() {
+ // Simulate a successful initialization first
+ StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult())))
+ .expectNext(MOCK_INIT_RESULT)
+ .verifyComplete();
+
+ assertThat(initializer.isInitialized()).isTrue();
+
+ // Simulate other exception (should not trigger re-initialization)
+ initializer.handleException(new RuntimeException("Some other error"));
+
+ // Should still be initialized
+ assertThat(initializer.isInitialized()).isTrue();
+ verify(mockClientSession, never()).close();
+ // Verify that the session was not re-created
+ verify(mockSessionSupplier, times(1)).apply(any(ContextView.class));
+ }
+
+ @Test
+ void shouldCloseGracefully() {
+ StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult())))
+ .expectNext(MOCK_INIT_RESULT)
+ .verifyComplete();
+
+ StepVerifier.create(initializer.closeGracefully()).verifyComplete();
+
+ verify(mockClientSession).closeGracefully();
+ assertThat(initializer.isInitialized()).isFalse();
+ }
+
+ @Test
+ void shouldCloseImmediately() {
+ StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult())))
+ .expectNext(MOCK_INIT_RESULT)
+ .verifyComplete();
+
+ // Close immediately
+ initializer.close();
+
+ verify(mockClientSession).close();
+ assertThat(initializer.isInitialized()).isFalse();
+ }
+
+ @Test
+ void shouldHandleCloseWithoutInitialization() {
+ // Close without initialization should not throw
+ initializer.close();
+
+ StepVerifier.create(initializer.closeGracefully()).verifyComplete();
+
+ verify(mockClientSession, never()).close();
+ verify(mockClientSession, never()).closeGracefully();
+ }
+
+ @Test
+ void shouldSetProtocolVersionsForTesting() {
+ List newVersions = List.of("3.0.0", "4.0.0");
+ initializer.setProtocolVersions(newVersions);
+
+ AtomicReference capturedRequest = new AtomicReference<>();
+
+ when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())).thenAnswer(invocation -> {
+ capturedRequest.set((McpSchema.InitializeRequest) invocation.getArgument(1));
+ return Mono.just(new McpSchema.InitializeResult("4.0.0", McpSchema.ServerCapabilities.builder().build(),
+ new McpSchema.Implementation("test-server", "1.0.0"), "Test instructions"));
+ });
+
+ StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult())))
+ .assertNext(result -> {
+ // Latest from new versions
+ assertThat(capturedRequest.get().protocolVersion()).isEqualTo("4.0.0");
+ })
+ .verifyComplete();
+ }
+
+ @Test
+ void shouldPassContextToSessionSupplier() {
+ String contextKey = "test.key";
+ String contextValue = "test.value";
+
+ AtomicReference capturedContext = new AtomicReference<>();
+
+ when(mockSessionSupplier.apply(any(ContextView.class))).thenAnswer(invocation -> {
+ capturedContext.set(invocation.getArgument(0));
+ return mockClientSession;
+ });
+
+ StepVerifier
+ .create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult()))
+ .contextWrite(Context.of(contextKey, contextValue)))
+ .expectNext(MOCK_INIT_RESULT)
+ .verifyComplete();
+
+ assertThat(capturedContext.get().hasKey(contextKey)).isTrue();
+ assertThat((String) capturedContext.get().get(contextKey)).isEqualTo(contextValue);
+ }
+
+ @Test
+ void shouldProvideAccessToMcpSessionAndInitializeResult() {
+ StepVerifier.create(initializer.withIntitialization("test", init -> {
+ assertThat(init.mcpSession()).isEqualTo(mockClientSession);
+ assertThat(init.initializeResult()).isEqualTo(MOCK_INIT_RESULT);
+ return Mono.just("success");
+ })).expectNext("success").verifyComplete();
+ }
+
+ @Test
+ void shouldHandleNotificationFailure() {
+ when(mockClientSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), any()))
+ .thenReturn(Mono.error(new RuntimeException("Notification failed")));
+
+ StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult())))
+ .expectError(RuntimeException.class)
+ .verify();
+
+ verify(mockClientSession).sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any());
+ verify(mockClientSession).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), eq(null));
+ }
+
+ @Test
+ void shouldReturnNullWhenNotInitialized() {
+ assertThat(initializer.isInitialized()).isFalse();
+ assertThat(initializer.currentInitializationResult()).isNull();
+ }
+
+ @Test
+ void shouldReinitializeAfterTransportSessionException() {
+ // First initialization
+ StepVerifier.create(initializer.withIntitialization("test1", init -> Mono.just("result1")))
+ .expectNext("result1")
+ .verifyComplete();
+
+ // Simulate transport session exception
+ initializer.handleException(new McpTransportSessionNotFoundException("Session lost"));
+
+ // Should be able to initialize again
+ StepVerifier.create(initializer.withIntitialization("test2", init -> Mono.just("result2")))
+ .expectNext("result2")
+ .verifyComplete();
+
+ // Verify two separate initializations occurred
+ verify(mockSessionSupplier, times(2)).apply(any(ContextView.class));
+ verify(mockClientSession, times(2)).sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any());
+ }
+
+}