Skip to content

Commit ef140fc

Browse files
committed
refactor: extract MCP client initialization logic into LifecyleInitializer
- Create new LifecyleInitializer class to handle protocol initialization phase - Move initialization logic from McpAsyncClient to dedicated initializer - Add javadocs for MCP initialization process - Implement protocol version negotiation and capability exchange - Add exception handling for transport session recovery - Include test suite for LifecyleInitializer - Simplify McpAsyncClient by delegating initialization responsibilities This refactoring improves separation of concerns and makes the initialization process more maintainable and testable. Signed-off-by: Christian Tzolov <christian.tzolov@broadcom.com>
1 parent c711f83 commit ef140fc

File tree

3 files changed

+786
-187
lines changed

3 files changed

+786
-187
lines changed
Lines changed: 348 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,348 @@
1+
package io.modelcontextprotocol.client;
2+
3+
import java.time.Duration;
4+
import java.util.ArrayList;
5+
import java.util.Collections;
6+
import java.util.List;
7+
import java.util.concurrent.atomic.AtomicReference;
8+
import java.util.function.Function;
9+
10+
import org.slf4j.Logger;
11+
import org.slf4j.LoggerFactory;
12+
13+
import io.modelcontextprotocol.spec.McpClientSession;
14+
import io.modelcontextprotocol.spec.McpError;
15+
import io.modelcontextprotocol.spec.McpSchema;
16+
import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException;
17+
import io.modelcontextprotocol.util.Assert;
18+
import reactor.core.publisher.Mono;
19+
import reactor.core.publisher.Sinks;
20+
import reactor.util.context.ContextView;
21+
22+
/**
23+
* <b>Handles the protocol initialization phase between client and server</b>
24+
*
25+
* <p>
26+
* The initialization phase MUST be the first interaction between client and server.
27+
* During this phase, the client and server perform the following operations:
28+
* <ul>
29+
* <li>Establish protocol version compatibility</li>
30+
* <li>Exchange and negotiate capabilities</li>
31+
* <li>Share implementation details</li>
32+
* </ul>
33+
*
34+
* <b>Client Initialization Process</b>
35+
* <p>
36+
* The client MUST initiate this phase by sending an initialize request containing:
37+
* <ul>
38+
* <li>Protocol version supported</li>
39+
* <li>Client capabilities</li>
40+
* <li>Client implementation information</li>
41+
* </ul>
42+
*
43+
* <p>
44+
* After successful initialization, the client MUST send an initialized notification to
45+
* indicate it is ready to begin normal operations.
46+
*
47+
* <b>Server Response</b>
48+
* <p>
49+
* The server MUST respond with its own capabilities and information.
50+
*
51+
* <b>Protocol Version Negotiation</b>
52+
* <p>
53+
* In the initialize request, the client MUST send a protocol version it supports. This
54+
* SHOULD be the latest version supported by the client.
55+
*
56+
* <p>
57+
* If the server supports the requested protocol version, it MUST respond with the same
58+
* version. Otherwise, the server MUST respond with another protocol version it supports.
59+
* This SHOULD be the latest version supported by the server.
60+
*
61+
* <p>
62+
* If the client does not support the version in the server's response, it SHOULD
63+
* disconnect.
64+
*
65+
* <b>Request Restrictions</b>
66+
* <p>
67+
* <strong>Important:</strong> The following restrictions apply during initialization:
68+
* <ul>
69+
* <li>The client SHOULD NOT send requests other than pings before the server has
70+
* responded to the initialize request</li>
71+
* <li>The server SHOULD NOT send requests other than pings and logging before receiving
72+
* the initialized notification</li>
73+
* </ul>
74+
*/
75+
public class LifecyleInitializer {
76+
77+
private static final Logger logger = LoggerFactory.getLogger(LifecyleInitializer.class);
78+
79+
/**
80+
* The MCP session supplier that manages bidirectional JSON-RPC communication between
81+
* clients and servers.
82+
*/
83+
private final Function<ContextView, McpClientSession> sessionSupplier;
84+
85+
private final McpSchema.ClientCapabilities clientCapabilities;
86+
87+
private final McpSchema.Implementation clientInfo;
88+
89+
private List<String> protocolVersions;
90+
91+
private final AtomicReference<DefaultInitialization> initializationRef = new AtomicReference<>();
92+
93+
/**
94+
* The max timeout to await for the client-server connection to be initialized.
95+
*/
96+
private final Duration initializationTimeout;
97+
98+
public LifecyleInitializer(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo,
99+
List<String> protocolVersions, Duration initializationTimeout,
100+
Function<ContextView, McpClientSession> sessionSupplier) {
101+
102+
Assert.notNull(sessionSupplier, "Session supplier must not be null");
103+
Assert.notNull(clientCapabilities, "Client capabilities must not be null");
104+
Assert.notNull(clientInfo, "Client info must not be null");
105+
Assert.notEmpty(protocolVersions, "Protocol versions must not be empty");
106+
Assert.notNull(initializationTimeout, "Initialization timeout must not be null");
107+
108+
this.sessionSupplier = sessionSupplier;
109+
this.clientCapabilities = clientCapabilities;
110+
this.clientInfo = clientInfo;
111+
this.protocolVersions = Collections.unmodifiableList(new ArrayList<>(protocolVersions));
112+
this.initializationTimeout = initializationTimeout;
113+
}
114+
115+
/**
116+
* This method is package-private and used for test only. Should not be called by user
117+
* code.
118+
* @param protocolVersions the Client supported protocol versions.
119+
*/
120+
void setProtocolVersions(List<String> protocolVersions) {
121+
this.protocolVersions = protocolVersions;
122+
}
123+
124+
/**
125+
* Represents the initialization state of the MCP client.
126+
*/
127+
interface Initialization {
128+
129+
/**
130+
* Returns the MCP client session that is used to communicate with the server.
131+
* This session is established during the initialization process and is used for
132+
* sending requests and notifications.
133+
* @return The MCP client session
134+
*/
135+
McpClientSession mcpSession();
136+
137+
/**
138+
* Returns the result of the MCP initialization process. This result contains
139+
* information about the protocol version, capabilities, server info, and
140+
* instructions provided by the server during the initialization phase.
141+
* @return The result of the MCP initialization process
142+
*/
143+
McpSchema.InitializeResult initializeResult();
144+
145+
}
146+
147+
/**
148+
* Default implementation of the {@link Initialization} interface that manages the MCP
149+
* client initialization process.
150+
*/
151+
private static class DefaultInitialization implements Initialization {
152+
153+
/**
154+
* A sink that emits the result of the MCP initialization process. It allows
155+
* subscribers to wait for the initialization to complete.
156+
*/
157+
private final Sinks.One<McpSchema.InitializeResult> initSink;
158+
159+
/**
160+
* Holds the result of the MCP initialization process. It is used to cache the
161+
* result for future requests.
162+
*/
163+
private final AtomicReference<McpSchema.InitializeResult> result;
164+
165+
/**
166+
* Holds the MCP client session that is used to communicate with the server. It is
167+
* set during the initialization process and used for sending requests and
168+
* notifications.
169+
*/
170+
private final AtomicReference<McpClientSession> mcpClientSession;
171+
172+
private DefaultInitialization() {
173+
this.initSink = Sinks.one();
174+
this.result = new AtomicReference<>();
175+
this.mcpClientSession = new AtomicReference<>();
176+
}
177+
178+
// ---------------------------------------------------
179+
// Public access for mcpSession and initializeResult because they are
180+
// used in by the McpAsyncClient.
181+
// ----------------------------------------------------
182+
public McpClientSession mcpSession() {
183+
return this.mcpClientSession.get();
184+
}
185+
186+
public McpSchema.InitializeResult initializeResult() {
187+
return this.result.get();
188+
}
189+
190+
// ---------------------------------------------------
191+
// Private accessors used internally by the LifecycleInitializer to set the MCP
192+
// client session and complete the initialization process.
193+
// ---------------------------------------------------
194+
private void setMcpClientSession(McpClientSession mcpClientSession) {
195+
this.mcpClientSession.set(mcpClientSession);
196+
}
197+
198+
/**
199+
* Returns a Mono that completes when the MCP client initialization is complete.
200+
* This allows subscribers to wait for the initialization to finish before
201+
* proceeding with further operations.
202+
* @return A Mono that emits the result of the MCP initialization process
203+
*/
204+
private Mono<McpSchema.InitializeResult> await() {
205+
return this.initSink.asMono();
206+
}
207+
208+
/**
209+
* Completes the initialization process with the given result. It caches the
210+
* result and emits it to all subscribers waiting for the initialization to
211+
* complete.
212+
* @param initializeResult The result of the MCP initialization process
213+
*/
214+
private void complete(McpSchema.InitializeResult initializeResult) {
215+
// first ensure the result is cached
216+
this.result.set(initializeResult);
217+
// inform all the subscribers waiting for the initialization
218+
this.initSink.emitValue(initializeResult, Sinks.EmitFailureHandler.FAIL_FAST);
219+
}
220+
221+
private void error(Throwable t) {
222+
this.initSink.emitError(t, Sinks.EmitFailureHandler.FAIL_FAST);
223+
}
224+
225+
private void close() {
226+
this.mcpSession().close();
227+
}
228+
229+
private Mono<Void> closeGracefully() {
230+
return this.mcpSession().closeGracefully();
231+
}
232+
233+
}
234+
235+
public boolean isInitialized() {
236+
return currentInitializationResult() != null;
237+
}
238+
239+
public McpSchema.InitializeResult currentInitializationResult() {
240+
DefaultInitialization current = this.initializationRef.get();
241+
McpSchema.InitializeResult initializeResult = current != null ? current.result.get() : null;
242+
return initializeResult;
243+
}
244+
245+
/**
246+
* Hook to handle exceptions that occur during the MCP transport session.
247+
* <p>
248+
* If the exception is a {@link McpTransportSessionNotFoundException}, it indicates
249+
* that the session was not found, and we should re-initialize the client.
250+
* </p>
251+
* @param t The exception to handle
252+
*/
253+
public void handleException(Throwable t) {
254+
logger.warn("Handling exception", t);
255+
if (t instanceof McpTransportSessionNotFoundException) {
256+
DefaultInitialization previous = this.initializationRef.getAndSet(null);
257+
if (previous != null) {
258+
previous.close();
259+
}
260+
// Providing an empty operation since we are only interested in triggering
261+
// the implicit initialization step.
262+
withIntitialization("re-initializing", result -> Mono.empty()).subscribe();
263+
}
264+
}
265+
266+
/**
267+
* Utility method to ensure the initialization is established before executing an
268+
* operation.
269+
* @param <T> The type of the result Mono
270+
* @param actionName The action to perform when the client is initialized
271+
* @param operation The operation to execute when the client is initialized
272+
* @return A Mono that completes with the result of the operation
273+
*/
274+
public <T> Mono<T> withIntitialization(String actionName, Function<Initialization, Mono<T>> operation) {
275+
return Mono.deferContextual(ctx -> {
276+
DefaultInitialization newInit = new DefaultInitialization();
277+
DefaultInitialization previous = this.initializationRef.compareAndExchange(null, newInit);
278+
279+
boolean needsToInitialize = previous == null;
280+
logger.debug(needsToInitialize ? "Initialization process started" : "Joining previous initialization");
281+
282+
Mono<McpSchema.InitializeResult> initializationJob = needsToInitialize ? doInitialize(newInit, ctx)
283+
: previous.await();
284+
285+
return initializationJob.map(initializeResult -> this.initializationRef.get())
286+
.timeout(this.initializationTimeout)
287+
.onErrorResume(ex -> {
288+
logger.warn("Failed to initialize", ex);
289+
return Mono.error(new McpError("Client failed to initialize " + actionName));
290+
})
291+
.flatMap(operation);
292+
});
293+
}
294+
295+
private Mono<McpSchema.InitializeResult> doInitialize(DefaultInitialization initialization, ContextView ctx) {
296+
initialization.setMcpClientSession(this.sessionSupplier.apply(ctx));
297+
298+
McpClientSession mcpClientSession = initialization.mcpSession();
299+
300+
String latestVersion = this.protocolVersions.get(this.protocolVersions.size() - 1);
301+
302+
McpSchema.InitializeRequest initializeRequest = new McpSchema.InitializeRequest(latestVersion,
303+
this.clientCapabilities, this.clientInfo);
304+
305+
Mono<McpSchema.InitializeResult> result = mcpClientSession.sendRequest(McpSchema.METHOD_INITIALIZE,
306+
initializeRequest, McpAsyncClient.INITIALIZE_RESULT_TYPE_REF);
307+
308+
return result.flatMap(initializeResult -> {
309+
logger.info("Server response with Protocol: {}, Capabilities: {}, Info: {} and Instructions {}",
310+
initializeResult.protocolVersion(), initializeResult.capabilities(), initializeResult.serverInfo(),
311+
initializeResult.instructions());
312+
313+
if (!this.protocolVersions.contains(initializeResult.protocolVersion())) {
314+
return Mono.error(new McpError(
315+
"Unsupported protocol version from the server: " + initializeResult.protocolVersion()));
316+
}
317+
318+
return mcpClientSession.sendNotification(McpSchema.METHOD_NOTIFICATION_INITIALIZED, null)
319+
.thenReturn(initializeResult);
320+
}).doOnNext(initialization::complete).onErrorResume(ex -> {
321+
initialization.error(ex);
322+
return Mono.error(ex);
323+
});
324+
}
325+
326+
/**
327+
* Closes the current initialization if it exists.
328+
*/
329+
public void close() {
330+
DefaultInitialization current = this.initializationRef.getAndSet(null);
331+
if (current != null) {
332+
current.close();
333+
}
334+
}
335+
336+
/**
337+
* Gracefully closes the current initialization if it exists.
338+
* @return A Mono that completes when the connection is closed
339+
*/
340+
public Mono<?> closeGracefully() {
341+
return Mono.defer(() -> {
342+
DefaultInitialization current = this.initializationRef.getAndSet(null);
343+
Mono<?> sessionClose = current != null ? current.closeGracefully() : Mono.empty();
344+
return sessionClose;
345+
});
346+
}
347+
348+
}

0 commit comments

Comments
 (0)