Skip to content

Commit f972a51

Browse files
ZachGermanZachary German
authored andcommitted
Adding StreamableHttpServerTransportProvider class and unit tests
1 parent 606d21b commit f972a51

File tree

4 files changed

+121
-59
lines changed

4 files changed

+121
-59
lines changed

mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,6 @@ private Mono<McpSchema.InitializeResult> asyncInitializeRequestHandler(
214214
"Client requested unsupported protocol version: {}, so the server will suggest the {} version instead",
215215
initializeRequest.protocolVersion(), serverProtocolVersion);
216216
}
217-
System.out.println("---------------Server sending initalize response-----------");
218217
return Mono.just(new McpSchema.InitializeResult(serverProtocolVersion, this.serverCapabilities,
219218
this.serverInfo, this.instructions));
220219
});

mcp/src/main/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProvider.java

Lines changed: 91 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
325325
throws ServletException, IOException {
326326

327327
String requestURI = request.getRequestURI();
328-
logger.debug("POST request received for URI: '{}' with headers: {}", requestURI, extractHeaders(request));
328+
logger.info("POST request received for URI: '{}' with headers: {}", requestURI, extractHeaders(request));
329329

330330
if (!requestURI.endsWith(mcpEndpoint)) {
331331
logger.debug("URI does not match MCP endpoint: '{}'", mcpEndpoint);
@@ -373,6 +373,7 @@ public void onDataAvailable() throws IOException {
373373
public void onAllDataRead() throws IOException {
374374
try {
375375
// Parse the JSON-RPC message
376+
logger.debug("Parsing JSON-RPC message: {}", body.toString());
376377
JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body.toString());
377378

378379
// Check if this is an initialize request
@@ -398,30 +399,59 @@ public void onAllDataRead() throws IOException {
398399
return;
399400
}
400401

402+
logger.debug("Creating session for sessionId: {}, isInitialize: {}", sessionId,
403+
isInitializeRequest);
401404
StreamableHttpSession session = getOrCreateSession(sessionId, isInitializeRequest);
402405
if (session == null) {
406+
logger.error("Failed to create session for sessionId: {}", sessionId);
403407
handleSessionNotFound(sessionId, request, response);
404408
asyncContext.complete();
405409
return;
406410
}
407411

412+
// Detect response type early
413+
StreamableHttpSession.ResponseType responseType = session.detectResponseType(message);
414+
logger.debug("Detected response type: {} for message: {}", responseType, message);
415+
416+
// Set session ID header
417+
response.setHeader(SESSION_ID_HEADER, sessionId);
418+
408419
// Make variables effectively final for lambda
409420
final String finalSessionId = sessionId;
410421
final StreamableHttpSession finalSession = session;
411422
final boolean finalIsInitializeRequest = isInitializeRequest;
412423

413-
// Handle message and determine response type
414-
session.handleMessage(message, request, response, asyncContext).flatMap(responseType -> {
415-
response.setHeader(SESSION_ID_HEADER, finalSessionId);
416-
if (responseType == StreamableHttpSession.ResponseType.STREAM) {
417-
return setupSseResponse(finalSession, finalSessionId, response, request.getRequestId(),
418-
asyncContext, finalIsInitializeRequest);
419-
}
420-
return setupImmediateResponse(finalSession, finalSessionId, response, asyncContext);
421-
}).subscribe(null, error -> {
422-
logger.error("Error handling message: {}", error.getMessage());
423-
asyncContext.complete();
424-
});
424+
logger.debug("About to handle message with response type: {}", responseType);
425+
if (responseType == StreamableHttpSession.ResponseType.STREAM) {
426+
logger.debug("Handling STREAM response type");
427+
// Set up SSE response first, then handle message
428+
sendStreamResponse(finalSession, response, request.getRequestId(), asyncContext)
429+
.doOnSuccess(v -> logger.debug("Stream response setup completed"))
430+
.doOnError(e -> logger.error("Error in stream response setup: {}", e.getMessage(), e))
431+
.then(session.handleMessage(message, request, response, asyncContext))
432+
.doOnSuccess(v -> logger.debug("Message handling completed successfully"))
433+
.doOnError(e -> logger.error("Error in message handling: {}", e.getMessage(), e))
434+
.contextWrite(Context.of(MCP_SESSION_ID, sessionId))
435+
.subscribe(null, error -> {
436+
logger.error("Error in STREAM handling chain: {}", error.getMessage(), error);
437+
asyncContext.complete();
438+
});
439+
}
440+
else {
441+
logger.debug("Handling IMMEDIATE response type");
442+
// Handle message then set up immediate response
443+
session.handleMessage(message, request, response, asyncContext)
444+
.doOnSuccess(v -> logger.debug("Message handling completed successfully"))
445+
.doOnError(e -> logger.error("Error in message handling: {}", e.getMessage(), e))
446+
.then(sendImmediateResponse(finalSession, response, asyncContext))
447+
.doOnSuccess(v -> logger.debug("Immediate response setup completed"))
448+
.doOnError(e -> logger.error("Error in immediate response setup: {}", e.getMessage(), e))
449+
.contextWrite(Context.of(MCP_SESSION_ID, sessionId))
450+
.subscribe(null, error -> {
451+
logger.error("Error in IMMEDIATE handling chain: {}", error.getMessage(), error);
452+
asyncContext.complete();
453+
});
454+
}
425455
}
426456
catch (Exception e) {
427457
logger.error("Error processing message: {}", e.getMessage());
@@ -515,10 +545,9 @@ private StreamableHttpSession getOrCreateSession(String sessionId, boolean creat
515545
/**
516546
* Sets up immediate response for session-based handling.
517547
*/
518-
private Mono<Void> setupImmediateResponse(StreamableHttpSession session, String sessionId,
519-
HttpServletResponse response, AsyncContext asyncContext) {
548+
private Mono<Void> sendImmediateResponse(StreamableHttpSession session, HttpServletResponse response,
549+
AsyncContext asyncContext) {
520550
return Mono.fromRunnable(() -> {
521-
response.setHeader(SESSION_ID_HEADER, sessionId);
522551
try {
523552
response.getWriter().flush();
524553
}
@@ -529,25 +558,11 @@ private Mono<Void> setupImmediateResponse(StreamableHttpSession session, String
529558
}
530559

531560
/**
532-
* Sets up SSE response
533-
*/
534-
private Mono<Void> setupSseResponse(StreamableHttpSession session, String sessionId, HttpServletResponse response,
535-
String requestId, AsyncContext asyncContext, boolean isInitializeRequest) {
536-
StreamableHttpSseStream sseStream = session.getOrCreateSseStream(requestId);
537-
return setupSseResponseForStream(sseStream, response, asyncContext, isInitializeRequest).doFirst(() -> {
538-
response.setHeader(SESSION_ID_HEADER, sessionId);
539-
if (isInitializeRequest) {
540-
response.setHeader(SESSION_ID_HEADER, sessionId);
541-
}
542-
});
543-
}
544-
545-
/**
546-
* Sets up SSE response for a specific stream.
561+
* Sends SSE response for the given requestId
547562
*/
548-
private Mono<Void> setupSseResponseForStream(StreamableHttpSseStream sseStream, HttpServletResponse response,
549-
AsyncContext asyncContext, boolean isInitializeRequest) {
550-
return Mono.fromRunnable(() -> {
563+
private Mono<Void> sendStreamResponse(StreamableHttpSession session, HttpServletResponse response, String requestId,
564+
AsyncContext asyncContext) {
565+
return Mono.create(sink -> {
551566
try {
552567
// Set up SSE connection
553568
response.setContentType(TEXT_EVENT_STREAM);
@@ -558,6 +573,7 @@ private Mono<Void> setupSseResponseForStream(StreamableHttpSseStream sseStream,
558573
PrintWriter writer = response.getWriter();
559574

560575
// Subscribe to the SSE stream and write events to the response
576+
StreamableHttpSseStream sseStream = session.getOrCreateSseStream(requestId);
561577
sseStream.getEventFlux().doOnNext(event -> {
562578
try {
563579
if (event.id() != null) {
@@ -588,10 +604,14 @@ private Mono<Void> setupSseResponseForStream(StreamableHttpSseStream sseStream,
588604
logger.error("Error in SSE stream: {}", e.getMessage());
589605
asyncContext.complete();
590606
}).subscribe();
607+
608+
// Signal that SSE subscription is ready
609+
sink.success();
591610
}
592611
catch (IOException e) {
593-
logger.error("Failed to setup SSE response: {}", e.getMessage());
612+
logger.error("Failed to send SSE response: {}", e.getMessage());
594613
asyncContext.complete();
614+
sink.error(e);
595615
}
596616
});
597617
}
@@ -765,38 +785,51 @@ public StreamableHttpSseStream getOrCreateSseStream(String streamName) {
765785
/**
766786
* Handles a message using the appropriate transport based on response type.
767787
*/
768-
public Mono<ResponseType> handleMessage(McpSchema.JSONRPCMessage message, HttpServletRequest request,
788+
public Mono<Void> handleMessage(McpSchema.JSONRPCMessage message, HttpServletRequest request,
769789
HttpServletResponse response, AsyncContext asyncContext) {
770790

771-
// Immediate transport only needs 1 session per StreamableHttpSession
772-
if (immediateSession == null) {
773-
// Create session with both transports on first message
774-
immediateTransport = new HttpServerTransport(objectMapper, response, asyncContext);
775-
immediateSession = sessionFactory.create(immediateTransport);
776-
}
777-
778-
// Streaming transport requires an inner session for each individual request
779-
McpServerTransport streamTransport = getOrCreateSseStream(request.getRequestId()).getTransport();
780-
streamSession = sessionFactory.create(streamTransport);
781-
streamSessions.add(streamSession);
782-
783791
ResponseType responseType = detectResponseType(message);
792+
logger.debug("Handling message with response type: {}, message: {}", responseType, message);
784793

785794
if (responseType == ResponseType.STREAM) {
786-
// Handle the message with the session uses streamed responses
787-
return streamSession.handle(message)
788-
.then(Mono.just(responseType))
789-
.onErrorReturn(ResponseType.IMMEDIATE);
795+
// Streaming transport requires an inner session for each individual
796+
// request
797+
McpServerTransport streamTransport = getOrCreateSseStream(request.getRequestId()).getTransport();
798+
streamSession = sessionFactory.create(streamTransport);
799+
// Force the session to skip initialization check by setting the flag via
800+
// reflection
801+
try {
802+
java.lang.reflect.Field skipField = streamSession.getClass()
803+
.getDeclaredField("skipInitializationCheck");
804+
skipField.setAccessible(true);
805+
skipField.set(streamSession, true);
806+
}
807+
catch (Exception e) {
808+
logger.warn("Could not set skipInitializationCheck flag: {}", e.getMessage());
809+
}
810+
// Copy client info from immediate session if available
811+
if (immediateSession != null) {
812+
streamSession.init(immediateSession.getClientCapabilities(), immediateSession.getClientInfo());
813+
}
814+
streamSessions.add(streamSession);
815+
return streamSession.handle(message);
816+
}
817+
else {
818+
// Immediate transport only needs 1 session per StreamableHttpSession
819+
if (immediateSession == null) {
820+
// Create session with immediate transport
821+
immediateTransport = new HttpServerTransport(objectMapper, response, asyncContext);
822+
immediateSession = sessionFactory.create(immediateTransport);
823+
}
824+
// Handle the message with the session that uses immediate responses
825+
return immediateSession.handle(message);
790826
}
791-
792-
// Handle the message with the session that uses immediate responses
793-
return immediateSession.handle(message).then(Mono.just(responseType)).onErrorReturn(ResponseType.IMMEDIATE);
794827
}
795828

796829
/**
797830
* Detects response type based on message characteristics.
798831
*/
799-
private ResponseType detectResponseType(McpSchema.JSONRPCMessage message) {
832+
public ResponseType detectResponseType(McpSchema.JSONRPCMessage message) {
800833
if (message instanceof McpSchema.JSONRPCRequest request) {
801834
// Initialize requests should return immediate JSON response
802835
if (McpSchema.METHOD_INITIALIZE.equals(request.method())) {
@@ -889,6 +922,7 @@ public void sendEvent(String eventType, String data) {
889922
String eventId = String.valueOf(++eventCounter);
890923
SseEvent event = new SseEvent(eventId, eventType, data);
891924
eventHistory.put(eventId, event);
925+
logger.debug("Sending SSE event {}: {}", eventId, data);
892926
eventSink.tryEmitNext(event);
893927
}
894928

@@ -963,6 +997,7 @@ public StreamableHttpServerTransport(StreamableHttpSseStream sseStream) {
963997
public Mono<Void> sendMessage(JSONRPCMessage message) {
964998
try {
965999
String jsonText = sseStream.objectMapper.writeValueAsString(message);
1000+
logger.debug("StreamableHttpServerTransport sending message: {}", jsonText);
9661001
sseStream.sendEvent(MESSAGE_EVENT_TYPE, jsonText);
9671002

9681003
// Complete stream after sending response to avoid hanging

mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ public final class McpSchema {
4040
private McpSchema() {
4141
}
4242

43-
public static final String LATEST_PROTOCOL_VERSION = "2025-06-18";
43+
public static final String LATEST_PROTOCOL_VERSION = "2024-11-05";
4444

4545
public static final String JSONRPC_VERSION = "2.0";
4646

mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ public class McpServerSession implements McpSession {
5656

5757
private final AtomicInteger state = new AtomicInteger(STATE_UNINITIALIZED);
5858

59+
private final boolean skipInitializationCheck;
60+
5961
/**
6062
* Creates a new server session with the given parameters and the transport to use.
6163
* @param id session id
@@ -72,13 +74,22 @@ public class McpServerSession implements McpSession {
7274
public McpServerSession(String id, Duration requestTimeout, McpServerTransport transport,
7375
InitRequestHandler initHandler, InitNotificationHandler initNotificationHandler,
7476
Map<String, RequestHandler<?>> requestHandlers, Map<String, NotificationHandler> notificationHandlers) {
77+
this(id, requestTimeout, transport, initHandler, initNotificationHandler, requestHandlers, notificationHandlers,
78+
false);
79+
}
80+
81+
public McpServerSession(String id, Duration requestTimeout, McpServerTransport transport,
82+
InitRequestHandler initHandler, InitNotificationHandler initNotificationHandler,
83+
Map<String, RequestHandler<?>> requestHandlers, Map<String, NotificationHandler> notificationHandlers,
84+
boolean skipInitializationCheck) {
7585
this.id = id;
7686
this.requestTimeout = requestTimeout;
7787
this.transport = transport;
7888
this.initRequestHandler = initHandler;
7989
this.initNotificationHandler = initNotificationHandler;
8090
this.requestHandlers = requestHandlers;
8191
this.notificationHandlers = notificationHandlers;
92+
this.skipInitializationCheck = skipInitializationCheck;
8293
}
8394

8495
/**
@@ -104,6 +115,14 @@ public void init(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Impl
104115
this.clientInfo.lazySet(clientInfo);
105116
}
106117

118+
public McpSchema.ClientCapabilities getClientCapabilities() {
119+
return this.clientCapabilities.get();
120+
}
121+
122+
public McpSchema.Implementation getClientInfo() {
123+
return this.clientInfo.get();
124+
}
125+
107126
private String generateRequestId() {
108127
return this.id + "-" + this.requestCounter.getAndIncrement();
109128
}
@@ -222,7 +241,16 @@ private Mono<McpSchema.JSONRPCResponse> handleIncomingRequest(McpSchema.JSONRPCR
222241
error.message(), error.data())));
223242
}
224243

225-
resultMono = this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, request.params()));
244+
if (skipInitializationCheck) {
245+
// Create a dummy exchange for sessions that skip initialization
246+
McpAsyncServerExchange dummyExchange = new McpAsyncServerExchange(this, clientCapabilities.get(),
247+
clientInfo.get());
248+
resultMono = handler.handle(dummyExchange, request.params());
249+
}
250+
else {
251+
resultMono = this.exchangeSink.asMono()
252+
.flatMap(exchange -> handler.handle(exchange, request.params()));
253+
}
226254
}
227255
return resultMono
228256
.map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null))

0 commit comments

Comments
 (0)