diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/LifecycleInitializer.java b/mcp/src/main/java/io/modelcontextprotocol/client/LifecycleInitializer.java new file mode 100644 index 00000000..e33fafa6 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/LifecycleInitializer.java @@ -0,0 +1,348 @@ +package io.modelcontextprotocol.client; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import io.modelcontextprotocol.spec.McpClientSession; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; +import io.modelcontextprotocol.util.Assert; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.util.context.ContextView; + +/** + * Handles the protocol initialization phase between client and server + * + *

+ * The initialization phase MUST be the first interaction between client and server. + * During this phase, the client and server perform the following operations: + *

+ * + * Client Initialization Process + *

+ * The client MUST initiate this phase by sending an initialize request containing: + *

+ * + *

+ * After successful initialization, the client MUST send an initialized notification to + * indicate it is ready to begin normal operations. + * + * Server Response + *

+ * The server MUST respond with its own capabilities and information. + * + * Protocol Version Negotiation + *

+ * In the initialize request, the client MUST send a protocol version it supports. This + * SHOULD be the latest version supported by the client. + * + *

+ * If the server supports the requested protocol version, it MUST respond with the same + * version. Otherwise, the server MUST respond with another protocol version it supports. + * This SHOULD be the latest version supported by the server. + * + *

+ * If the client does not support the version in the server's response, it SHOULD + * disconnect. + * + * Request Restrictions + *

+ * Important: The following restrictions apply during initialization: + *

+ */ +class LifecycleInitializer { + + private static final Logger logger = LoggerFactory.getLogger(LifecycleInitializer.class); + + /** + * The MCP session supplier that manages bidirectional JSON-RPC communication between + * clients and servers. + */ + private final Function sessionSupplier; + + private final McpSchema.ClientCapabilities clientCapabilities; + + private final McpSchema.Implementation clientInfo; + + private List protocolVersions; + + private final AtomicReference initializationRef = new AtomicReference<>(); + + /** + * The max timeout to await for the client-server connection to be initialized. + */ + private final Duration initializationTimeout; + + public LifecycleInitializer(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo, + List protocolVersions, Duration initializationTimeout, + Function sessionSupplier) { + + Assert.notNull(sessionSupplier, "Session supplier must not be null"); + Assert.notNull(clientCapabilities, "Client capabilities must not be null"); + Assert.notNull(clientInfo, "Client info must not be null"); + Assert.notEmpty(protocolVersions, "Protocol versions must not be empty"); + Assert.notNull(initializationTimeout, "Initialization timeout must not be null"); + + this.sessionSupplier = sessionSupplier; + this.clientCapabilities = clientCapabilities; + this.clientInfo = clientInfo; + this.protocolVersions = Collections.unmodifiableList(new ArrayList<>(protocolVersions)); + this.initializationTimeout = initializationTimeout; + } + + /** + * This method is package-private and used for test only. Should not be called by user + * code. + * @param protocolVersions the Client supported protocol versions. + */ + void setProtocolVersions(List protocolVersions) { + this.protocolVersions = protocolVersions; + } + + /** + * Represents the initialization state of the MCP client. + */ + interface Initialization { + + /** + * Returns the MCP client session that is used to communicate with the server. + * This session is established during the initialization process and is used for + * sending requests and notifications. + * @return The MCP client session + */ + McpClientSession mcpSession(); + + /** + * Returns the result of the MCP initialization process. This result contains + * information about the protocol version, capabilities, server info, and + * instructions provided by the server during the initialization phase. + * @return The result of the MCP initialization process + */ + McpSchema.InitializeResult initializeResult(); + + } + + /** + * Default implementation of the {@link Initialization} interface that manages the MCP + * client initialization process. + */ + private static class DefaultInitialization implements Initialization { + + /** + * A sink that emits the result of the MCP initialization process. It allows + * subscribers to wait for the initialization to complete. + */ + private final Sinks.One initSink; + + /** + * Holds the result of the MCP initialization process. It is used to cache the + * result for future requests. + */ + private final AtomicReference result; + + /** + * Holds the MCP client session that is used to communicate with the server. It is + * set during the initialization process and used for sending requests and + * notifications. + */ + private final AtomicReference mcpClientSession; + + private DefaultInitialization() { + this.initSink = Sinks.one(); + this.result = new AtomicReference<>(); + this.mcpClientSession = new AtomicReference<>(); + } + + // --------------------------------------------------- + // Public access for mcpSession and initializeResult because they are + // used in by the McpAsyncClient. + // ---------------------------------------------------- + public McpClientSession mcpSession() { + return this.mcpClientSession.get(); + } + + public McpSchema.InitializeResult initializeResult() { + return this.result.get(); + } + + // --------------------------------------------------- + // Private accessors used internally by the LifecycleInitializer to set the MCP + // client session and complete the initialization process. + // --------------------------------------------------- + private void setMcpClientSession(McpClientSession mcpClientSession) { + this.mcpClientSession.set(mcpClientSession); + } + + /** + * Returns a Mono that completes when the MCP client initialization is complete. + * This allows subscribers to wait for the initialization to finish before + * proceeding with further operations. + * @return A Mono that emits the result of the MCP initialization process + */ + private Mono await() { + return this.initSink.asMono(); + } + + /** + * Completes the initialization process with the given result. It caches the + * result and emits it to all subscribers waiting for the initialization to + * complete. + * @param initializeResult The result of the MCP initialization process + */ + private void complete(McpSchema.InitializeResult initializeResult) { + // first ensure the result is cached + this.result.set(initializeResult); + // inform all the subscribers waiting for the initialization + this.initSink.emitValue(initializeResult, Sinks.EmitFailureHandler.FAIL_FAST); + } + + private void error(Throwable t) { + this.initSink.emitError(t, Sinks.EmitFailureHandler.FAIL_FAST); + } + + private void close() { + this.mcpSession().close(); + } + + private Mono closeGracefully() { + return this.mcpSession().closeGracefully(); + } + + } + + public boolean isInitialized() { + return this.currentInitializationResult() != null; + } + + public McpSchema.InitializeResult currentInitializationResult() { + DefaultInitialization current = this.initializationRef.get(); + McpSchema.InitializeResult initializeResult = current != null ? current.result.get() : null; + return initializeResult; + } + + /** + * Hook to handle exceptions that occur during the MCP transport session. + *

+ * If the exception is a {@link McpTransportSessionNotFoundException}, it indicates + * that the session was not found, and we should re-initialize the client. + *

+ * @param t The exception to handle + */ + public void handleException(Throwable t) { + logger.warn("Handling exception", t); + if (t instanceof McpTransportSessionNotFoundException) { + DefaultInitialization previous = this.initializationRef.getAndSet(null); + if (previous != null) { + previous.close(); + } + // Providing an empty operation since we are only interested in triggering + // the implicit initialization step. + withIntitialization("re-initializing", result -> Mono.empty()).subscribe(); + } + } + + /** + * Utility method to ensure the initialization is established before executing an + * operation. + * @param The type of the result Mono + * @param actionName The action to perform when the client is initialized + * @param operation The operation to execute when the client is initialized + * @return A Mono that completes with the result of the operation + */ + public Mono withIntitialization(String actionName, Function> operation) { + return Mono.deferContextual(ctx -> { + DefaultInitialization newInit = new DefaultInitialization(); + DefaultInitialization previous = this.initializationRef.compareAndExchange(null, newInit); + + boolean needsToInitialize = previous == null; + logger.debug(needsToInitialize ? "Initialization process started" : "Joining previous initialization"); + + Mono initializationJob = needsToInitialize ? doInitialize(newInit, ctx) + : previous.await(); + + return initializationJob.map(initializeResult -> this.initializationRef.get()) + .timeout(this.initializationTimeout) + .onErrorResume(ex -> { + logger.warn("Failed to initialize", ex); + return Mono.error(new McpError("Client failed to initialize " + actionName)); + }) + .flatMap(operation); + }); + } + + private Mono doInitialize(DefaultInitialization initialization, ContextView ctx) { + initialization.setMcpClientSession(this.sessionSupplier.apply(ctx)); + + McpClientSession mcpClientSession = initialization.mcpSession(); + + String latestVersion = this.protocolVersions.get(this.protocolVersions.size() - 1); + + McpSchema.InitializeRequest initializeRequest = new McpSchema.InitializeRequest(latestVersion, + this.clientCapabilities, this.clientInfo); + + Mono result = mcpClientSession.sendRequest(McpSchema.METHOD_INITIALIZE, + initializeRequest, McpAsyncClient.INITIALIZE_RESULT_TYPE_REF); + + return result.flatMap(initializeResult -> { + logger.info("Server response with Protocol: {}, Capabilities: {}, Info: {} and Instructions {}", + initializeResult.protocolVersion(), initializeResult.capabilities(), initializeResult.serverInfo(), + initializeResult.instructions()); + + if (!this.protocolVersions.contains(initializeResult.protocolVersion())) { + return Mono.error(new McpError( + "Unsupported protocol version from the server: " + initializeResult.protocolVersion())); + } + + return mcpClientSession.sendNotification(McpSchema.METHOD_NOTIFICATION_INITIALIZED, null) + .thenReturn(initializeResult); + }).doOnNext(initialization::complete).onErrorResume(ex -> { + initialization.error(ex); + return Mono.error(ex); + }); + } + + /** + * Closes the current initialization if it exists. + */ + public void close() { + DefaultInitialization current = this.initializationRef.getAndSet(null); + if (current != null) { + current.close(); + } + } + + /** + * Gracefully closes the current initialization if it exists. + * @return A Mono that completes when the connection is closed + */ + public Mono closeGracefully() { + return Mono.defer(() -> { + DefaultInitialization current = this.initializationRef.getAndSet(null); + Mono sessionClose = current != null ? current.closeGracefully() : Mono.empty(); + return sessionClose; + }); + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index fa76f397..cf8142c6 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -10,7 +10,6 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import org.slf4j.Logger; @@ -19,8 +18,6 @@ import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.spec.McpClientSession; -import io.modelcontextprotocol.spec.McpClientSession.NotificationHandler; -import io.modelcontextprotocol.spec.McpClientSession.RequestHandler; import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; @@ -36,13 +33,12 @@ import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; import io.modelcontextprotocol.spec.McpSchema.PaginatedRequest; import io.modelcontextprotocol.spec.McpSchema.Root; -import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; +import io.modelcontextprotocol.spec.McpClientSession.NotificationHandler; +import io.modelcontextprotocol.spec.McpClientSession.RequestHandler; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.Utils; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import reactor.core.publisher.Sinks; -import reactor.util.context.ContextView; /** * The Model Context Protocol (MCP) client implementation that provides asynchronous @@ -104,13 +100,6 @@ public class McpAsyncClient { public static final TypeReference LOGGING_MESSAGE_NOTIFICATION_TYPE_REF = new TypeReference<>() { }; - private final AtomicReference initializationRef = new AtomicReference<>(); - - /** - * The max timeout to await for the client-server connection to be initialized. - */ - private final Duration initializationTimeout; - /** * Client capabilities. */ @@ -154,15 +143,9 @@ public class McpAsyncClient { private final McpClientTransport transport; /** - * Supported protocol versions. + * The lifecycle initializer that manages the client-server connection initialization. */ - private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); - - /** - * The MCP session supplier that manages bidirectional JSON-RPC communication between - * clients and servers. - */ - private final Function sessionSupplier; + private final LifecycleInitializer initializer; /** * Create a new McpAsyncClient with the given transport and session request-response @@ -183,7 +166,6 @@ public class McpAsyncClient { this.clientCapabilities = features.clientCapabilities(); this.transport = transport; this.roots = new ConcurrentHashMap<>(features.roots()); - this.initializationTimeout = initializationTimeout; // Request Handlers Map> requestHandlers = new HashMap<>(); @@ -271,28 +253,11 @@ public class McpAsyncClient { notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_MESSAGE, asyncLoggingNotificationHandler(loggingConsumersFinal)); - this.transport.setExceptionHandler(this::handleException); - this.sessionSupplier = ctx -> new McpClientSession(requestTimeout, transport, requestHandlers, - notificationHandlers, con -> con.contextWrite(ctx)); - } - - private void handleException(Throwable t) { - logger.warn("Handling exception", t); - if (t instanceof McpTransportSessionNotFoundException) { - Initialization previous = this.initializationRef.getAndSet(null); - if (previous != null) { - previous.close(); - } - // Providing an empty operation since we are only interested in triggering the - // implicit initialization step. - withSession("re-initializing", result -> Mono.empty()).subscribe(); - } - } - - private McpSchema.InitializeResult currentInitializationResult() { - Initialization current = this.initializationRef.get(); - McpSchema.InitializeResult initializeResult = current != null ? current.result.get() : null; - return initializeResult; + this.initializer = new LifecycleInitializer(clientCapabilities, clientInfo, + List.of(McpSchema.LATEST_PROTOCOL_VERSION), initializationTimeout, + ctx -> new McpClientSession(requestTimeout, transport, requestHandlers, notificationHandlers, + con -> con.contextWrite(ctx))); + this.transport.setExceptionHandler(this.initializer::handleException); } /** @@ -300,7 +265,7 @@ private McpSchema.InitializeResult currentInitializationResult() { * @return The server capabilities */ public McpSchema.ServerCapabilities getServerCapabilities() { - McpSchema.InitializeResult initializeResult = currentInitializationResult(); + McpSchema.InitializeResult initializeResult = this.initializer.currentInitializationResult(); return initializeResult != null ? initializeResult.capabilities() : null; } @@ -310,7 +275,7 @@ public McpSchema.ServerCapabilities getServerCapabilities() { * @return The server instructions */ public String getServerInstructions() { - McpSchema.InitializeResult initializeResult = currentInitializationResult(); + McpSchema.InitializeResult initializeResult = this.initializer.currentInitializationResult(); return initializeResult != null ? initializeResult.instructions() : null; } @@ -319,7 +284,7 @@ public String getServerInstructions() { * @return The server implementation details */ public McpSchema.Implementation getServerInfo() { - McpSchema.InitializeResult initializeResult = currentInitializationResult(); + McpSchema.InitializeResult initializeResult = this.initializer.currentInitializationResult(); return initializeResult != null ? initializeResult.serverInfo() : null; } @@ -328,8 +293,7 @@ public McpSchema.Implementation getServerInfo() { * @return true if the client-server connection is initialized */ public boolean isInitialized() { - Initialization current = this.initializationRef.get(); - return current != null && (current.result.get() != null); + return this.initializer.isInitialized(); } /** @@ -352,10 +316,7 @@ public McpSchema.Implementation getClientInfo() { * Closes the client connection immediately. */ public void close() { - Initialization current = this.initializationRef.getAndSet(null); - if (current != null) { - current.close(); - } + this.initializer.close(); this.transport.close(); } @@ -365,9 +326,7 @@ public void close() { */ public Mono closeGracefully() { return Mono.defer(() -> { - Initialization current = this.initializationRef.getAndSet(null); - Mono sessionClose = current != null ? current.closeGracefully() : Mono.empty(); - return sessionClose.then(transport.closeGracefully()); + return this.initializer.closeGracefully().then(transport.closeGracefully()); }); } @@ -401,118 +360,7 @@ public Mono closeGracefully() { *

*/ public Mono initialize() { - return withSession("by explicit API call", init -> Mono.just(init.get())); - } - - private Mono doInitialize(Initialization initialization, ContextView ctx) { - initialization.setMcpClientSession(this.sessionSupplier.apply(ctx)); - - McpClientSession mcpClientSession = initialization.mcpSession(); - - String latestVersion = this.protocolVersions.get(this.protocolVersions.size() - 1); - - McpSchema.InitializeRequest initializeRequest = new McpSchema.InitializeRequest(// @formatter:off - latestVersion, - this.clientCapabilities, - this.clientInfo); // @formatter:on - - Mono result = mcpClientSession.sendRequest(McpSchema.METHOD_INITIALIZE, - initializeRequest, INITIALIZE_RESULT_TYPE_REF); - - return result.flatMap(initializeResult -> { - logger.info("Server response with Protocol: {}, Capabilities: {}, Info: {} and Instructions {}", - initializeResult.protocolVersion(), initializeResult.capabilities(), initializeResult.serverInfo(), - initializeResult.instructions()); - - if (!this.protocolVersions.contains(initializeResult.protocolVersion())) { - return Mono.error(new McpError( - "Unsupported protocol version from the server: " + initializeResult.protocolVersion())); - } - - return mcpClientSession.sendNotification(McpSchema.METHOD_NOTIFICATION_INITIALIZED, null) - .thenReturn(initializeResult); - }).doOnNext(initialization::complete).onErrorResume(ex -> { - initialization.error(ex); - return Mono.error(ex); - }); - } - - private static class Initialization { - - private final Sinks.One initSink = Sinks.one(); - - private final AtomicReference result = new AtomicReference<>(); - - private final AtomicReference mcpClientSession = new AtomicReference<>(); - - static Initialization create() { - return new Initialization(); - } - - void setMcpClientSession(McpClientSession mcpClientSession) { - this.mcpClientSession.set(mcpClientSession); - } - - McpClientSession mcpSession() { - return this.mcpClientSession.get(); - } - - McpSchema.InitializeResult get() { - return this.result.get(); - } - - Mono await() { - return this.initSink.asMono(); - } - - void complete(McpSchema.InitializeResult initializeResult) { - // first ensure the result is cached - this.result.set(initializeResult); - // inform all the subscribers waiting for the initialization - this.initSink.emitValue(initializeResult, Sinks.EmitFailureHandler.FAIL_FAST); - } - - void error(Throwable t) { - this.initSink.emitError(t, Sinks.EmitFailureHandler.FAIL_FAST); - } - - void close() { - this.mcpSession().close(); - } - - Mono closeGracefully() { - return this.mcpSession().closeGracefully(); - } - - } - - /** - * Utility method to handle the common pattern of ensuring initialization before - * executing an operation. - * @param The type of the result Mono - * @param actionName The action to perform when the client is initialized - * @param operation The operation to execute when the client is initialized - * @return A Mono that completes with the result of the operation - */ - private Mono withSession(String actionName, Function> operation) { - return Mono.deferContextual(ctx -> { - Initialization newInit = Initialization.create(); - Initialization previous = this.initializationRef.compareAndExchange(null, newInit); - - boolean needsToInitialize = previous == null; - logger.debug(needsToInitialize ? "Initialization process started" : "Joining previous initialization"); - - Mono initializationJob = needsToInitialize ? doInitialize(newInit, ctx) - : previous.await(); - - return initializationJob.map(initializeResult -> this.initializationRef.get()) - .timeout(this.initializationTimeout) - .onErrorResume(ex -> { - logger.warn("Failed to initialize", ex); - return Mono.error(new McpError("Client failed to initialize " + actionName)); - }) - .flatMap(operation); - }); + return this.initializer.withIntitialization("by explicit API call", init -> Mono.just(init.initializeResult())); } // -------------------------- @@ -524,7 +372,7 @@ private Mono withSession(String actionName, Function ping() { - return this.withSession("pinging the server", + return this.initializer.withIntitialization("pinging the server", init -> init.mcpSession().sendRequest(McpSchema.METHOD_PING, null, OBJECT_TYPE_REF)); } @@ -605,7 +453,7 @@ public Mono removeRoot(String rootUri) { * @return A Mono that completes when the notification is sent. */ public Mono rootsListChangedNotification() { - return this.withSession("sending roots list changed notification", + return this.initializer.withIntitialization("sending roots list changed notification", init -> init.mcpSession().sendNotification(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED)); } @@ -664,8 +512,8 @@ private RequestHandler elicitationCreateHandler() { * @see #listTools() */ public Mono callTool(McpSchema.CallToolRequest callToolRequest) { - return this.withSession("calling tools", init -> { - if (init.get().capabilities().tools() == null) { + return this.initializer.withIntitialization("calling tools", init -> { + if (init.initializeResult().capabilities().tools() == null) { return Mono.error(new McpError("Server does not provide tools capability")); } return init.mcpSession() @@ -693,8 +541,8 @@ public Mono listTools() { * @return A Mono that emits the list of tools result */ public Mono listTools(String cursor) { - return this.withSession("listing tools", init -> { - if (init.get().capabilities().tools() == null) { + return this.initializer.withIntitialization("listing tools", init -> { + if (init.initializeResult().capabilities().tools() == null) { return Mono.error(new McpError("Server does not provide tools capability")); } return init.mcpSession() @@ -757,8 +605,8 @@ public Mono listResources() { * @see #readResource(McpSchema.Resource) */ public Mono listResources(String cursor) { - return this.withSession("listing resources", init -> { - if (init.get().capabilities().resources() == null) { + return this.initializer.withIntitialization("listing resources", init -> { + if (init.initializeResult().capabilities().resources() == null) { return Mono.error(new McpError("Server does not provide the resources capability")); } return init.mcpSession() @@ -789,8 +637,8 @@ public Mono readResource(McpSchema.Resource resour * @see McpSchema.ReadResourceResult */ public Mono readResource(McpSchema.ReadResourceRequest readResourceRequest) { - return this.withSession("reading resources", init -> { - if (init.get().capabilities().resources() == null) { + return this.initializer.withIntitialization("reading resources", init -> { + if (init.initializeResult().capabilities().resources() == null) { return Mono.error(new McpError("Server does not provide the resources capability")); } return init.mcpSession() @@ -827,8 +675,8 @@ public Mono listResourceTemplates() { * @see McpSchema.ListResourceTemplatesResult */ public Mono listResourceTemplates(String cursor) { - return this.withSession("listing resource templates", init -> { - if (init.get().capabilities().resources() == null) { + return this.initializer.withIntitialization("listing resource templates", init -> { + if (init.initializeResult().capabilities().resources() == null) { return Mono.error(new McpError("Server does not provide the resources capability")); } return init.mcpSession() @@ -847,7 +695,7 @@ public Mono listResourceTemplates(String * @see #unsubscribeResource(McpSchema.UnsubscribeRequest) */ public Mono subscribeResource(McpSchema.SubscribeRequest subscribeRequest) { - return this.withSession("subscribing to resources", init -> init.mcpSession() + return this.initializer.withIntitialization("subscribing to resources", init -> init.mcpSession() .sendRequest(McpSchema.METHOD_RESOURCES_SUBSCRIBE, subscribeRequest, VOID_TYPE_REFERENCE)); } @@ -861,7 +709,7 @@ public Mono subscribeResource(McpSchema.SubscribeRequest subscribeRequest) * @see #subscribeResource(McpSchema.SubscribeRequest) */ public Mono unsubscribeResource(McpSchema.UnsubscribeRequest unsubscribeRequest) { - return this.withSession("unsubscribing from resources", init -> init.mcpSession() + return this.initializer.withIntitialization("unsubscribing from resources", init -> init.mcpSession() .sendRequest(McpSchema.METHOD_RESOURCES_UNSUBSCRIBE, unsubscribeRequest, VOID_TYPE_REFERENCE)); } @@ -927,7 +775,7 @@ public Mono listPrompts() { * @see #getPrompt(GetPromptRequest) */ public Mono listPrompts(String cursor) { - return this.withSession("listing prompts", init -> init.mcpSession() + return this.initializer.withIntitialization("listing prompts", init -> init.mcpSession() .sendRequest(McpSchema.METHOD_PROMPT_LIST, new PaginatedRequest(cursor), LIST_PROMPTS_RESULT_TYPE_REF)); } @@ -941,7 +789,7 @@ public Mono listPrompts(String cursor) { * @see #listPrompts() */ public Mono getPrompt(GetPromptRequest getPromptRequest) { - return this.withSession("getting prompts", init -> init.mcpSession() + return this.initializer.withIntitialization("getting prompts", init -> init.mcpSession() .sendRequest(McpSchema.METHOD_PROMPT_GET, getPromptRequest, GET_PROMPT_RESULT_TYPE_REF)); } @@ -992,7 +840,7 @@ public Mono setLoggingLevel(LoggingLevel loggingLevel) { return Mono.error(new McpError("Logging level must not be null")); } - return this.withSession("setting logging level", init -> { + return this.initializer.withIntitialization("setting logging level", init -> { var params = new McpSchema.SetLevelRequest(loggingLevel); return init.mcpSession().sendRequest(McpSchema.METHOD_LOGGING_SET_LEVEL, params, OBJECT_TYPE_REF).then(); }); @@ -1004,7 +852,7 @@ public Mono setLoggingLevel(LoggingLevel loggingLevel) { * @param protocolVersions the Client supported protocol versions. */ void setProtocolVersions(List protocolVersions) { - this.protocolVersions = protocolVersions; + this.initializer.setProtocolVersions(protocolVersions); } // -------------------------- @@ -1024,7 +872,7 @@ void setProtocolVersions(List protocolVersions) { * @see McpSchema.CompleteResult */ public Mono completeCompletion(McpSchema.CompleteRequest completeRequest) { - return this.withSession("complete completions", init -> init.mcpSession() + return this.initializer.withIntitialization("complete completions", init -> init.mcpSession() .sendRequest(McpSchema.METHOD_COMPLETION_COMPLETE, completeRequest, COMPLETION_COMPLETE_RESULT_TYPE_REF)); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerTests.java new file mode 100644 index 00000000..c8d69192 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/client/LifecycleInitializerTests.java @@ -0,0 +1,412 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.client; + +import java.time.Duration; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import io.modelcontextprotocol.spec.McpClientSession; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; +import reactor.test.StepVerifier; +import reactor.test.scheduler.VirtualTimeScheduler; +import reactor.util.context.Context; +import reactor.util.context.ContextView; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link LifecycleInitializer}. + */ +class LifecycleInitializerTests { + + private static final Duration INITIALIZATION_TIMEOUT = Duration.ofSeconds(5); + + private static final McpSchema.ClientCapabilities CLIENT_CAPABILITIES = McpSchema.ClientCapabilities.builder() + .build(); + + private static final McpSchema.Implementation CLIENT_INFO = new McpSchema.Implementation("test-client", "1.0.0"); + + private static final List PROTOCOL_VERSIONS = List.of("1.0.0", "2.0.0"); + + private static final McpSchema.InitializeResult MOCK_INIT_RESULT = new McpSchema.InitializeResult("2.0.0", + McpSchema.ServerCapabilities.builder().build(), new McpSchema.Implementation("test-server", "1.0.0"), + "Test instructions"); + + @Mock + private McpClientSession mockClientSession; + + @Mock + private Function mockSessionSupplier; + + private LifecycleInitializer initializer; + + @BeforeEach + void setUp() { + MockitoAnnotations.openMocks(this); + + when(mockSessionSupplier.apply(any(ContextView.class))).thenReturn(mockClientSession); + when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())) + .thenReturn(Mono.just(MOCK_INIT_RESULT)); + when(mockClientSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), any())) + .thenReturn(Mono.empty()); + when(mockClientSession.closeGracefully()).thenReturn(Mono.empty()); + + initializer = new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, PROTOCOL_VERSIONS, + INITIALIZATION_TIMEOUT, mockSessionSupplier); + } + + @Test + void constructorShouldValidateParameters() { + assertThatThrownBy(() -> new LifecycleInitializer(null, CLIENT_INFO, PROTOCOL_VERSIONS, INITIALIZATION_TIMEOUT, + mockSessionSupplier)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Client capabilities must not be null"); + + assertThatThrownBy(() -> new LifecycleInitializer(CLIENT_CAPABILITIES, null, PROTOCOL_VERSIONS, + INITIALIZATION_TIMEOUT, mockSessionSupplier)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Client info must not be null"); + + assertThatThrownBy(() -> new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, null, + INITIALIZATION_TIMEOUT, mockSessionSupplier)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Protocol versions must not be empty"); + + assertThatThrownBy(() -> new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, List.of(), + INITIALIZATION_TIMEOUT, mockSessionSupplier)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Protocol versions must not be empty"); + + assertThatThrownBy(() -> new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, PROTOCOL_VERSIONS, null, + mockSessionSupplier)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Initialization timeout must not be null"); + + assertThatThrownBy(() -> new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, PROTOCOL_VERSIONS, + INITIALIZATION_TIMEOUT, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Session supplier must not be null"); + } + + @Test + void shouldInitializeSuccessfully() { + StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult()))) + .assertNext(result -> { + assertThat(result).isEqualTo(MOCK_INIT_RESULT); + assertThat(initializer.isInitialized()).isTrue(); + assertThat(initializer.currentInitializationResult()).isEqualTo(MOCK_INIT_RESULT); + }) + .verifyComplete(); + + verify(mockClientSession).sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(McpSchema.InitializeRequest.class), + any()); + verify(mockClientSession).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), eq(null)); + } + + @Test + void shouldUseLatestProtocolVersionInInitializeRequest() { + AtomicReference capturedRequest = new AtomicReference<>(); + + when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())).thenAnswer(invocation -> { + capturedRequest.set((McpSchema.InitializeRequest) invocation.getArgument(1)); + return Mono.just(MOCK_INIT_RESULT); + }); + + StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult()))) + .assertNext(result -> { + assertThat(capturedRequest.get().protocolVersion()).isEqualTo("2.0.0"); // Latest + // version + assertThat(capturedRequest.get().capabilities()).isEqualTo(CLIENT_CAPABILITIES); + assertThat(capturedRequest.get().clientInfo()).isEqualTo(CLIENT_INFO); + }) + .verifyComplete(); + } + + @Test + void shouldFailForUnsupportedProtocolVersion() { + McpSchema.InitializeResult unsupportedResult = new McpSchema.InitializeResult("999.0.0", // Unsupported + // version + McpSchema.ServerCapabilities.builder().build(), new McpSchema.Implementation("test-server", "1.0.0"), + "Test instructions"); + + when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())) + .thenReturn(Mono.just(unsupportedResult)); + + StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult()))) + .expectError(McpError.class) + .verify(); + + verify(mockClientSession, never()).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), any()); + } + + @Test + void shouldTimeoutOnSlowInitialization() { + VirtualTimeScheduler virtualTimeScheduler = VirtualTimeScheduler.getOrSet(); + + Duration INITIALIZE_TIMEOUT = Duration.ofSeconds(1); + Duration SLOW_RESPONSE_DELAY = Duration.ofSeconds(5); + + LifecycleInitializer shortTimeoutInitializer = new LifecycleInitializer(CLIENT_CAPABILITIES, CLIENT_INFO, + PROTOCOL_VERSIONS, INITIALIZE_TIMEOUT, mockSessionSupplier); + + when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())) + .thenReturn(Mono.just(MOCK_INIT_RESULT).delayElement(SLOW_RESPONSE_DELAY, virtualTimeScheduler)); + + StepVerifier + .withVirtualTime(() -> shortTimeoutInitializer.withIntitialization("test", + init -> Mono.just(init.initializeResult())), () -> virtualTimeScheduler, Long.MAX_VALUE) + .expectSubscription() + .expectNoEvent(INITIALIZE_TIMEOUT) + .expectError(McpError.class) + .verify(); + } + + @Test + void shouldReuseExistingInitialization() { + // First initialization + StepVerifier.create(initializer.withIntitialization("test1", init -> Mono.just("result1"))) + .expectNext("result1") + .verifyComplete(); + + // Second call should reuse the same initialization + StepVerifier.create(initializer.withIntitialization("test2", init -> Mono.just("result2"))) + .expectNext("result2") + .verifyComplete(); + + // Verify session was created only once + verify(mockSessionSupplier, times(1)).apply(any(ContextView.class)); + verify(mockClientSession, times(1)).sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any()); + } + + @Test + void shouldHandleConcurrentInitializationRequests() { + AtomicInteger sessionCreationCount = new AtomicInteger(0); + + when(mockSessionSupplier.apply(any(ContextView.class))).thenAnswer(invocation -> { + sessionCreationCount.incrementAndGet(); + return mockClientSession; + }); + + // Start multiple concurrent initializations using subscribeOn with parallel + // scheduler + Mono init1 = initializer.withIntitialization("test1", init -> Mono.just("result1")) + .subscribeOn(Schedulers.parallel()); + Mono init2 = initializer.withIntitialization("test2", init -> Mono.just("result2")) + .subscribeOn(Schedulers.parallel()); + Mono init3 = initializer.withIntitialization("test3", init -> Mono.just("result3")) + .subscribeOn(Schedulers.parallel()); + + StepVerifier.create(Mono.zip(init1, init2, init3)).assertNext(tuple -> { + assertThat(tuple.getT1()).isEqualTo("result1"); + assertThat(tuple.getT2()).isEqualTo("result2"); + assertThat(tuple.getT3()).isEqualTo("result3"); + }).verifyComplete(); + + // Should only create one session despite concurrent requests + assertThat(sessionCreationCount.get()).isEqualTo(1); + verify(mockClientSession, times(1)).sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any()); + } + + @Test + void shouldHandleInitializationFailure() { + when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())) + .thenReturn(Mono.error(new RuntimeException("Connection failed"))); + + StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult()))) + .expectError(McpError.class) + .verify(); + + assertThat(initializer.isInitialized()).isFalse(); + assertThat(initializer.currentInitializationResult()).isNull(); + } + + @Test + void shouldHandleTransportSessionNotFoundException() { + // successful initialization first + StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult()))) + .expectNext(MOCK_INIT_RESULT) + .verifyComplete(); + + assertThat(initializer.isInitialized()).isTrue(); + + // Simulate transport session not found + initializer.handleException(new McpTransportSessionNotFoundException("Session not found")); + + assertThat(initializer.isInitialized()).isTrue(); + + // Verify that the session was closed and re-initialized + verify(mockClientSession).close(); + + // Verify session was created 2 times (once for initial and once for + // re-initialization) + verify(mockSessionSupplier, times(2)).apply(any(ContextView.class)); + } + + @Test + void shouldHandleOtherExceptions() { + // Simulate a successful initialization first + StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult()))) + .expectNext(MOCK_INIT_RESULT) + .verifyComplete(); + + assertThat(initializer.isInitialized()).isTrue(); + + // Simulate other exception (should not trigger re-initialization) + initializer.handleException(new RuntimeException("Some other error")); + + // Should still be initialized + assertThat(initializer.isInitialized()).isTrue(); + verify(mockClientSession, never()).close(); + // Verify that the session was not re-created + verify(mockSessionSupplier, times(1)).apply(any(ContextView.class)); + } + + @Test + void shouldCloseGracefully() { + StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult()))) + .expectNext(MOCK_INIT_RESULT) + .verifyComplete(); + + StepVerifier.create(initializer.closeGracefully()).verifyComplete(); + + verify(mockClientSession).closeGracefully(); + assertThat(initializer.isInitialized()).isFalse(); + } + + @Test + void shouldCloseImmediately() { + StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult()))) + .expectNext(MOCK_INIT_RESULT) + .verifyComplete(); + + // Close immediately + initializer.close(); + + verify(mockClientSession).close(); + assertThat(initializer.isInitialized()).isFalse(); + } + + @Test + void shouldHandleCloseWithoutInitialization() { + // Close without initialization should not throw + initializer.close(); + + StepVerifier.create(initializer.closeGracefully()).verifyComplete(); + + verify(mockClientSession, never()).close(); + verify(mockClientSession, never()).closeGracefully(); + } + + @Test + void shouldSetProtocolVersionsForTesting() { + List newVersions = List.of("3.0.0", "4.0.0"); + initializer.setProtocolVersions(newVersions); + + AtomicReference capturedRequest = new AtomicReference<>(); + + when(mockClientSession.sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any())).thenAnswer(invocation -> { + capturedRequest.set((McpSchema.InitializeRequest) invocation.getArgument(1)); + return Mono.just(new McpSchema.InitializeResult("4.0.0", McpSchema.ServerCapabilities.builder().build(), + new McpSchema.Implementation("test-server", "1.0.0"), "Test instructions")); + }); + + StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult()))) + .assertNext(result -> { + // Latest from new versions + assertThat(capturedRequest.get().protocolVersion()).isEqualTo("4.0.0"); + }) + .verifyComplete(); + } + + @Test + void shouldPassContextToSessionSupplier() { + String contextKey = "test.key"; + String contextValue = "test.value"; + + AtomicReference capturedContext = new AtomicReference<>(); + + when(mockSessionSupplier.apply(any(ContextView.class))).thenAnswer(invocation -> { + capturedContext.set(invocation.getArgument(0)); + return mockClientSession; + }); + + StepVerifier + .create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult())) + .contextWrite(Context.of(contextKey, contextValue))) + .expectNext(MOCK_INIT_RESULT) + .verifyComplete(); + + assertThat(capturedContext.get().hasKey(contextKey)).isTrue(); + assertThat((String) capturedContext.get().get(contextKey)).isEqualTo(contextValue); + } + + @Test + void shouldProvideAccessToMcpSessionAndInitializeResult() { + StepVerifier.create(initializer.withIntitialization("test", init -> { + assertThat(init.mcpSession()).isEqualTo(mockClientSession); + assertThat(init.initializeResult()).isEqualTo(MOCK_INIT_RESULT); + return Mono.just("success"); + })).expectNext("success").verifyComplete(); + } + + @Test + void shouldHandleNotificationFailure() { + when(mockClientSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), any())) + .thenReturn(Mono.error(new RuntimeException("Notification failed"))); + + StepVerifier.create(initializer.withIntitialization("test", init -> Mono.just(init.initializeResult()))) + .expectError(RuntimeException.class) + .verify(); + + verify(mockClientSession).sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any()); + verify(mockClientSession).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_INITIALIZED), eq(null)); + } + + @Test + void shouldReturnNullWhenNotInitialized() { + assertThat(initializer.isInitialized()).isFalse(); + assertThat(initializer.currentInitializationResult()).isNull(); + } + + @Test + void shouldReinitializeAfterTransportSessionException() { + // First initialization + StepVerifier.create(initializer.withIntitialization("test1", init -> Mono.just("result1"))) + .expectNext("result1") + .verifyComplete(); + + // Simulate transport session exception + initializer.handleException(new McpTransportSessionNotFoundException("Session lost")); + + // Should be able to initialize again + StepVerifier.create(initializer.withIntitialization("test2", init -> Mono.just("result2"))) + .expectNext("result2") + .verifyComplete(); + + // Verify two separate initializations occurred + verify(mockSessionSupplier, times(2)).apply(any(ContextView.class)); + verify(mockClientSession, times(2)).sendRequest(eq(McpSchema.METHOD_INITIALIZE), any(), any()); + } + +}