From 2c07ee8efc6909761bdb9acbbb1162cb25393a3d Mon Sep 17 00:00:00 2001 From: Zachary German Date: Mon, 2 Jun 2025 20:39:12 +0000 Subject: [PATCH 1/5] Adding StreamableHttp server support via HTTPServlet with async support --- mcp-spring/mcp-spring-webflux/pom.xml | 22 + ...treamableHttpTransportIntegrationTest.java | 255 ++++++ .../server/transport/TomcatTestUtil.java | 63 ++ .../src/test/resources/logback-test.xml | 15 + .../server/McpAsyncServer.java | 1 - .../server/McpAsyncServerExchange.java | 18 +- .../server/McpAsyncStreamableHttpServer.java | 639 +++++++++++++ .../server/McpServerFeatures.java | 68 +- .../server/transport/SessionHandler.java | 57 ++ ...StreamableHttpServerTransportProvider.java | 850 ++++++++++++++++++ .../modelcontextprotocol/spec/McpSchema.java | 4 +- .../spec/McpServerSession.java | 9 + .../spec/McpStreamableHttpServerSession.java | 383 ++++++++ .../StreamableHttpMcpAsyncServerTests.java | 22 + .../StreamableHttpMcpSyncServerTests.java | 22 + ...mableHttpServerTransportProviderTests.java | 349 +++++++ 16 files changed, 2770 insertions(+), 7 deletions(-) create mode 100644 mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/StreamableHttpTransportIntegrationTest.java create mode 100644 mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java create mode 100644 mcp-spring/mcp-spring-webflux/src/test/resources/logback-test.xml create mode 100644 mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncStreamableHttpServer.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/server/transport/SessionHandler.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProvider.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableHttpServerSession.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/StreamableHttpMcpAsyncServerTests.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/StreamableHttpMcpSyncServerTests.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProviderTests.java diff --git a/mcp-spring/mcp-spring-webflux/pom.xml b/mcp-spring/mcp-spring-webflux/pom.xml index 26452fe95..c519b5580 100644 --- a/mcp-spring/mcp-spring-webflux/pom.xml +++ b/mcp-spring/mcp-spring-webflux/pom.xml @@ -127,6 +127,28 @@ test + + + org.apache.tomcat.embed + tomcat-embed-core + ${tomcat.version} + test + + + org.apache.tomcat.embed + tomcat-embed-websocket + ${tomcat.version} + test + + + + + jakarta.servlet + jakarta.servlet-api + ${jakarta.servlet.version} + test + + diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/StreamableHttpTransportIntegrationTest.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/StreamableHttpTransportIntegrationTest.java new file mode 100644 index 000000000..42eec8db0 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/StreamableHttpTransportIntegrationTest.java @@ -0,0 +1,255 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.time.Duration; +import java.util.Map; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.server.McpAsyncStreamableHttpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.transport.StreamableHttpServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; + +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for @link{StreamableHttpServerTransportProvider} with + * + * @link{WebClientStreamableHttpTransport}. + */ +class StreamableHttpTransportIntegrationTest { + + private static final int PORT = TomcatTestUtil.findAvailablePort(); + + private static final String ENDPOINT = "/mcp"; + + private StreamableHttpServerTransportProvider serverTransportProvider; + + private McpClient.AsyncSpec clientBuilder; + + private Tomcat tomcat; + + @BeforeEach + void setUp() { + serverTransportProvider = new StreamableHttpServerTransportProvider(new ObjectMapper(), ENDPOINT, null); + + // Set up session factory with proper server capabilities + McpSchema.ServerCapabilities serverCapabilities = new McpSchema.ServerCapabilities(null, null, null, null, null, + null); + serverTransportProvider.setStreamableHttpSessionFactory( + sessionId -> new io.modelcontextprotocol.spec.McpStreamableHttpServerSession(sessionId, + java.time.Duration.ofSeconds(30), + request -> reactor.core.publisher.Mono.just(new McpSchema.InitializeResult("2025-06-18", + serverCapabilities, new McpSchema.Implementation("Test Server", "1.0.0"), null)), + () -> reactor.core.publisher.Mono.empty(), java.util.Map.of(), java.util.Map.of())); + + tomcat = TomcatTestUtil.createTomcatServer("", PORT, serverTransportProvider); + try { + tomcat.start(); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + WebClientStreamableHttpTransport clientTransport = WebClientStreamableHttpTransport + .builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) + .endpoint(ENDPOINT) + .objectMapper(new ObjectMapper()) + .build(); + + clientBuilder = McpClient.async(clientTransport) + .clientInfo(new McpSchema.Implementation("Test Client", "1.0.0")); + } + + @AfterEach + void tearDown() { + if (serverTransportProvider != null) { + serverTransportProvider.closeGracefully().block(); + } + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + @Test + void shouldInitializeSuccessfully() { + // The server is already configured via the session factory in setUp + var mcpClient = clientBuilder.build(); + try { + InitializeResult result = mcpClient.initialize().block(); + assertThat(result).isNotNull(); + assertThat(result.serverInfo().name()).isEqualTo("Test Server"); + } + finally { + mcpClient.close(); + } + } + + @Test + void shouldCallImmediateToolSuccessfully() { + var callResponse = new CallToolResult(List.of(new McpSchema.TextContent("Tool executed successfully")), null); + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("test-tool", "Test tool description", emptyJsonSchema), + (exchange, request) -> Mono.just(callResponse)); + + // Configure session factory with tool handler + McpSchema.ServerCapabilities serverCapabilities = new McpSchema.ServerCapabilities(null, null, null, null, null, + new McpSchema.ServerCapabilities.ToolCapabilities(true)); + serverTransportProvider + .setStreamableHttpSessionFactory(sessionId -> new io.modelcontextprotocol.spec.McpStreamableHttpServerSession( + sessionId, java.time.Duration.ofSeconds(30), + request -> reactor.core.publisher.Mono.just(new McpSchema.InitializeResult("2025-06-18", + serverCapabilities, new McpSchema.Implementation("Test Server", "1.0.0"), null)), + () -> reactor.core.publisher.Mono.empty(), + java.util.Map.of("tools/call", + (io.modelcontextprotocol.spec.McpStreamableHttpServerSession.RequestHandler) ( + exchange, params) -> tool.call().apply(exchange, (Map) params)), + java.util.Map.of())); + + var mcpClient = clientBuilder.build(); + try { + mcpClient.initialize().block(); + CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())).block(); + assertThat(result).isNotNull(); + assertThat(result.content()).hasSize(1); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()) + .isEqualTo("Tool executed successfully"); + } + finally { + mcpClient.close(); + } + } + + @Test + void shouldCallStreamingToolSuccessfully() { + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + McpServerFeatures.AsyncStreamingToolSpecification streamingTool = new McpServerFeatures.AsyncStreamingToolSpecification( + new McpSchema.Tool("streaming-tool", "Streaming test tool", emptyJsonSchema), + (exchange, request) -> Flux.range(1, 3) + .map(i -> new CallToolResult(List.of(new McpSchema.TextContent("Step " + i)), null))); + + // Configure session factory with streaming tool handler + McpSchema.ServerCapabilities serverCapabilities = new McpSchema.ServerCapabilities(null, null, null, null, null, + new McpSchema.ServerCapabilities.ToolCapabilities(true)); + serverTransportProvider + .setStreamableHttpSessionFactory(sessionId -> new io.modelcontextprotocol.spec.McpStreamableHttpServerSession( + sessionId, java.time.Duration.ofSeconds(30), + request -> reactor.core.publisher.Mono.just(new McpSchema.InitializeResult("2025-06-18", + serverCapabilities, new McpSchema.Implementation("Test Server", "1.0.0"), null)), + () -> reactor.core.publisher.Mono.empty(), java.util.Map.of("tools/call", + (io.modelcontextprotocol.spec.McpStreamableHttpServerSession.StreamingRequestHandler) new io.modelcontextprotocol.spec.McpStreamableHttpServerSession.StreamingRequestHandler() { + @Override + public Mono handle( + io.modelcontextprotocol.server.McpAsyncServerExchange exchange, Object params) { + return streamingTool.call().apply(exchange, (Map) params).next(); + } + + @Override + public Flux handleStreaming( + io.modelcontextprotocol.server.McpAsyncServerExchange exchange, Object params) { + return streamingTool.call().apply(exchange, (Map) params); + } + }), + java.util.Map.of())); + + var mcpClient = clientBuilder.build(); + try { + mcpClient.initialize().block(); + CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("streaming-tool", Map.of())) + .block(); + assertThat(result).isNotNull(); + assertThat(result.content()).hasSize(1); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).startsWith("Step"); + } + finally { + mcpClient.close(); + } + } + + @Test + void shouldReceiveNotificationThroughGetStream() throws InterruptedException { + CountDownLatch notificationLatch = new CountDownLatch(1); + AtomicReference receivedEvent = new AtomicReference<>(); + AtomicReference sessionId = new AtomicReference<>(); + + WebClient webClient = WebClient.create("http://localhost:" + PORT); + String initMessage = "{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"initialize\",\"params\":{\"protocolVersion\":\"2025-06-18\",\"capabilities\":{},\"clientInfo\":{\"name\":\"Test\",\"version\":\"1.0\"}}}"; + + // Initialize and get session ID + webClient.post() + .uri(ENDPOINT) + .header("Accept", "application/json, text/event-stream") + .header("Content-Type", "application/json") + .bodyValue(initMessage) + .retrieve() + .toBodilessEntity() + .doOnNext(response -> sessionId.set(response.getHeaders().getFirst("Mcp-Session-Id"))) + .block(); + + // Establish SSE stream + webClient.get() + .uri(ENDPOINT) + .header("Accept", "text/event-stream") + .header("Mcp-Session-Id", sessionId.get()) + .retrieve() + .bodyToFlux(String.class) + .filter(line -> line.contains("test/notification")) + .doOnNext(event -> { + receivedEvent.set(event); + notificationLatch.countDown(); + }) + .subscribe(); + + // Send notification after delay + Mono.delay(Duration.ofMillis(200)) + .then(serverTransportProvider.notifyClients("test/notification", "test message")) + .subscribe(); + + assertThat(notificationLatch.await(5, TimeUnit.SECONDS)).isTrue(); + assertThat(receivedEvent.get()).isNotNull(); + assertThat(receivedEvent.get()).contains("test/notification"); + } + +} \ No newline at end of file diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java new file mode 100644 index 000000000..a9fa4d5bb --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java @@ -0,0 +1,63 @@ +/* +* Copyright 2025 - 2025 the original author or authors. +*/ +package io.modelcontextprotocol.server.transport; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.ServerSocket; + +import jakarta.servlet.Servlet; +import org.apache.catalina.Context; +import org.apache.catalina.startup.Tomcat; + +/** + * @author Christian Tzolov + */ +public class TomcatTestUtil { + + TomcatTestUtil() { + // Prevent instantiation + } + + public static Tomcat createTomcatServer(String contextPath, int port, Servlet servlet) { + + var tomcat = new Tomcat(); + tomcat.setPort(port); + + String baseDir = System.getProperty("java.io.tmpdir"); + tomcat.setBaseDir(baseDir); + + Context context = tomcat.addContext(contextPath, baseDir); + + // Add transport servlet to Tomcat + org.apache.catalina.Wrapper wrapper = context.createWrapper(); + wrapper.setName("mcpServlet"); + wrapper.setServlet(servlet); + wrapper.setLoadOnStartup(1); + wrapper.setAsyncSupported(true); + context.addChild(wrapper); + context.addServletMappingDecoded("/*", "mcpServlet"); + + var connector = tomcat.getConnector(); + connector.setAsyncTimeout(3000); + + return tomcat; + } + + /** + * Finds an available port on the local machine. + * @return an available port number + * @throws IllegalStateException if no available port can be found + */ + public static int findAvailablePort() { + try (final ServerSocket socket = new ServerSocket()) { + socket.bind(new InetSocketAddress(0)); + return socket.getLocalPort(); + } + catch (final IOException e) { + throw new IllegalStateException("Cannot bind to an available port!", e); + } + } + +} \ No newline at end of file diff --git a/mcp-spring/mcp-spring-webflux/src/test/resources/logback-test.xml b/mcp-spring/mcp-spring-webflux/src/test/resources/logback-test.xml new file mode 100644 index 000000000..37f43a17a --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/resources/logback-test.xml @@ -0,0 +1,15 @@ + + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n + + + + + + + + + + \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 02ad955b9..59b1afca3 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -214,7 +214,6 @@ private Mono asyncInitializeRequestHandler( "Client requested unsupported protocol version: {}, so the server will suggest the {} version instead", initializeRequest.protocolVersion(), serverProtocolVersion); } - return Mono.just(new McpSchema.InitializeResult(serverProtocolVersion, this.serverCapabilities, this.serverInfo, this.instructions)); }); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java index e56c695fa..412875ab3 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java @@ -12,6 +12,8 @@ import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; +import io.modelcontextprotocol.spec.McpSession; +import io.modelcontextprotocol.spec.McpStreamableHttpServerSession; import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.util.Assert; import reactor.core.publisher.Mono; @@ -25,7 +27,7 @@ */ public class McpAsyncServerExchange { - private final McpServerSession session; + private final McpSession session; private final McpSchema.ClientCapabilities clientCapabilities; @@ -59,6 +61,20 @@ public McpAsyncServerExchange(McpServerSession session, McpSchema.ClientCapabili this.clientInfo = clientInfo; } + /** + * Create a new asynchronous exchange with the client. + * @param session The server session representing a 1-1 interaction. + * @param clientCapabilities The client capabilities that define the supported + * features and functionality. + * @param clientInfo The client implementation information. + */ + public McpAsyncServerExchange(McpStreamableHttpServerSession session, + McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo) { + this.session = session; + this.clientCapabilities = clientCapabilities; + this.clientInfo = clientInfo; + } + /** * Get the client capabilities that define the supported features and functionality. * @return The client capabilities diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncStreamableHttpServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncStreamableHttpServer.java new file mode 100644 index 000000000..7e13f93ca --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncStreamableHttpServer.java @@ -0,0 +1,639 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Supplier; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.server.transport.StreamableHttpServerTransportProvider; +import io.modelcontextprotocol.spec.McpStreamableHttpServerSession; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.DeafaultMcpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * Streamable HTTP-based MCP server implementation that uses a single session class to + * manage all streams and transports efficiently. + * + *

+ * Featuring: + *

    + *
  • Single session class manages all transport streams
  • + *
  • Simplified transport registration and management
  • + *
  • Clear separation of concerns between session and transport
  • + *
  • Efficient resource management
  • + *
  • API for handling both immediate and streaming responses
  • + *
+ * + * @author Zachary German + */ +public class McpAsyncStreamableHttpServer { + + private static final Logger logger = LoggerFactory.getLogger(McpAsyncStreamableHttpServer.class); + + private final StreamableHttpServerTransportProvider httpTransportProvider; + + private final ObjectMapper objectMapper; + + private final McpSchema.ServerCapabilities serverCapabilities; + + private final McpSchema.Implementation serverInfo; + + private final String instructions; + + private final Duration requestTimeout; + + private final McpUriTemplateManagerFactory uriTemplateManagerFactory; + + // Core server features + private final McpServerFeatures.Async features; + + /** + * Creates a new McpAsyncStreamableHttpServer. + */ + McpAsyncStreamableHttpServer(StreamableHttpServerTransportProvider httpTransportProvider, ObjectMapper objectMapper, + McpServerFeatures.Async features, Duration requestTimeout, + McpUriTemplateManagerFactory uriTemplateManagerFactory) { + this.httpTransportProvider = httpTransportProvider; + this.objectMapper = objectMapper; + this.features = features; + this.serverInfo = features.serverInfo(); + this.serverCapabilities = features.serverCapabilities(); + this.instructions = features.instructions(); + this.requestTimeout = requestTimeout; + this.uriTemplateManagerFactory = uriTemplateManagerFactory != null ? uriTemplateManagerFactory + : new DeafaultMcpUriTemplateManagerFactory(); + + setupRequestHandlers(); + setupSessionFactory(); + } + + /** + * Sets up the request handlers for standard MCP methods. + */ + private void setupRequestHandlers() { + Map> requestHandlers = new HashMap<>(); + + // Ping handler + requestHandlers.put(McpSchema.METHOD_PING, (exchange, params) -> Mono.just(Map.of())); + + // Tool handlers + if (serverCapabilities.tools() != null) { + requestHandlers.put(McpSchema.METHOD_TOOLS_LIST, createToolsListHandler()); + requestHandlers.put(McpSchema.METHOD_TOOLS_CALL, createToolsCallHandler()); + } + + // Resource handlers + if (serverCapabilities.resources() != null) { + requestHandlers.put(McpSchema.METHOD_RESOURCES_LIST, createResourcesListHandler()); + requestHandlers.put(McpSchema.METHOD_RESOURCES_READ, createResourcesReadHandler()); + requestHandlers.put(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, createResourceTemplatesListHandler()); + } + + // Prompt handlers + if (serverCapabilities.prompts() != null) { + requestHandlers.put(McpSchema.METHOD_PROMPT_LIST, createPromptsListHandler()); + requestHandlers.put(McpSchema.METHOD_PROMPT_GET, createPromptsGetHandler()); + } + + // Logging handlers + if (serverCapabilities.logging() != null) { + requestHandlers.put(McpSchema.METHOD_LOGGING_SET_LEVEL, createLoggingSetLevelHandler()); + } + + // Completion handlers + if (serverCapabilities.completions() != null) { + requestHandlers.put(McpSchema.METHOD_COMPLETION_COMPLETE, createCompletionCompleteHandler()); + } + + this.requestHandlers = requestHandlers; + } + + private Map> requestHandlers; + + private Map notificationHandlers; + + /** + * Sets up notification handlers. + */ + private void setupNotificationHandlers() { + Map handlers = new HashMap<>(); + + handlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (exchange, params) -> { + logger.info("[INIT] Received initialized notification - initialization complete!"); + return Mono.empty(); + }); + + // Roots change notification handler + handlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, createRootsListChangedHandler()); + + this.notificationHandlers = handlers; + } + + /** + * Sets up the session factory for the HTTP transport provider. + */ + private void setupSessionFactory() { + setupNotificationHandlers(); + + httpTransportProvider.setStreamableHttpSessionFactory(sessionId -> new McpStreamableHttpServerSession(sessionId, + requestTimeout, this::handleInitializeRequest, Mono::empty, requestHandlers, notificationHandlers)); + } + + /** + * Handles initialization requests from clients. + */ + private Mono handleInitializeRequest(McpSchema.InitializeRequest initializeRequest) { + return Mono.defer(() -> { + logger.info("[INIT] Client initialize request - Protocol: {}, Capabilities: {}, Info: {}", + initializeRequest.protocolVersion(), initializeRequest.capabilities(), + initializeRequest.clientInfo()); + + // Protocol version negotiation + String serverProtocolVersion = McpSchema.LATEST_PROTOCOL_VERSION; + if (!McpSchema.LATEST_PROTOCOL_VERSION.equals(initializeRequest.protocolVersion())) { + logger.warn("[INIT] Client requested protocol version: {}, server supports: {}", + initializeRequest.protocolVersion(), serverProtocolVersion); + } + + logger.debug("[INIT] Server capabilities: {}", serverCapabilities); + logger.debug("[INIT] Server info: {}", serverInfo); + logger.debug("[INIT] Instructions: {}", instructions); + McpSchema.InitializeResult result = new McpSchema.InitializeResult(serverProtocolVersion, + serverCapabilities, serverInfo, instructions); + logger.info("[INIT] Sending initialize response: {}", result); + return Mono.just(result); + }); + } + + // Request handler creation methods + private McpStreamableHttpServerSession.RequestHandler createToolsListHandler() { + return (exchange, params) -> { + var regularTools = features.tools().stream().map(McpServerFeatures.AsyncToolSpecification::tool).toList(); + var streamingTools = features.streamTools() + .stream() + .map(McpServerFeatures.AsyncStreamingToolSpecification::tool) + .toList(); + var allTools = new ArrayList<>(regularTools); + allTools.addAll(streamingTools); + return Mono.just(new McpSchema.ListToolsResult(allTools, null)); + }; + } + + private McpStreamableHttpServerSession.RequestHandler createToolsCallHandler() { + return new McpStreamableHttpServerSession.StreamingRequestHandler() { + @Override + public Mono handle(McpAsyncServerExchange exchange, Object params) { + var callToolRequest = objectMapper.convertValue(params, McpSchema.CallToolRequest.class); + + // Check regular tools first + var regularTool = features.tools() + .stream() + .filter(tool -> callToolRequest.name().equals(tool.tool().name())) + .findFirst(); + + if (regularTool.isPresent()) { + return regularTool.get().call().apply(exchange, callToolRequest.arguments()); + } + + // Check streaming tools (take first result) + var streamingTool = features.streamTools() + .stream() + .filter(tool -> callToolRequest.name().equals(tool.tool().name())) + .findFirst(); + + if (streamingTool.isPresent()) { + return streamingTool.get().call().apply(exchange, callToolRequest.arguments()).next(); + } + + return Mono.error(new RuntimeException("Tool not found: " + callToolRequest.name())); + } + + @Override + public Flux handleStreaming(McpAsyncServerExchange exchange, Object params) { + var callToolRequest = objectMapper.convertValue(params, McpSchema.CallToolRequest.class); + + // Check streaming tools first (preferred for streaming) + var streamingTool = features.streamTools() + .stream() + .filter(tool -> callToolRequest.name().equals(tool.tool().name())) + .findFirst(); + + if (streamingTool.isPresent()) { + return streamingTool.get().call().apply(exchange, callToolRequest.arguments()); + } + + // Fallback to regular tools (convert Mono to Flux) + var regularTool = features.tools() + .stream() + .filter(tool -> callToolRequest.name().equals(tool.tool().name())) + .findFirst(); + + if (regularTool.isPresent()) { + return regularTool.get().call().apply(exchange, callToolRequest.arguments()).flux(); + } + + return Flux.error(new RuntimeException("Tool not found: " + callToolRequest.name())); + } + }; + } + + private McpStreamableHttpServerSession.RequestHandler createResourcesListHandler() { + return (exchange, params) -> { + var resources = features.resources() + .values() + .stream() + .map(McpServerFeatures.AsyncResourceSpecification::resource) + .toList(); + return Mono.just(new McpSchema.ListResourcesResult(resources, null)); + }; + } + + private McpStreamableHttpServerSession.RequestHandler createResourcesReadHandler() { + return (exchange, params) -> { + var resourceRequest = objectMapper.convertValue(params, McpSchema.ReadResourceRequest.class); + var resourceUri = resourceRequest.uri(); + + return features.resources() + .values() + .stream() + .filter(spec -> uriTemplateManagerFactory.create(spec.resource().uri()).matches(resourceUri)) + .findFirst() + .map(spec -> spec.readHandler().apply(exchange, resourceRequest)) + .orElse(Mono.error(new RuntimeException("Resource not found: " + resourceUri))); + }; + } + + private McpStreamableHttpServerSession.RequestHandler createResourceTemplatesListHandler() { + return (exchange, params) -> Mono + .just(new McpSchema.ListResourceTemplatesResult(features.resourceTemplates(), null)); + } + + private McpStreamableHttpServerSession.RequestHandler createPromptsListHandler() { + return (exchange, params) -> { + var prompts = features.prompts() + .values() + .stream() + .map(McpServerFeatures.AsyncPromptSpecification::prompt) + .toList(); + return Mono.just(new McpSchema.ListPromptsResult(prompts, null)); + }; + } + + private McpStreamableHttpServerSession.RequestHandler createPromptsGetHandler() { + return (exchange, params) -> { + var promptRequest = objectMapper.convertValue(params, McpSchema.GetPromptRequest.class); + + return features.prompts() + .values() + .stream() + .filter(spec -> spec.prompt().name().equals(promptRequest.name())) + .findFirst() + .map(spec -> spec.promptHandler().apply(exchange, promptRequest)) + .orElse(Mono.error(new RuntimeException("Prompt not found: " + promptRequest.name()))); + }; + } + + private McpStreamableHttpServerSession.RequestHandler createLoggingSetLevelHandler() { + return (exchange, params) -> { + var setLevelRequest = objectMapper.convertValue(params, McpSchema.SetLevelRequest.class); + exchange.setMinLoggingLevel(setLevelRequest.level()); + return Mono.just(Map.of()); + }; + } + + private McpStreamableHttpServerSession.RequestHandler createCompletionCompleteHandler() { + return (exchange, params) -> { + var completeRequest = objectMapper.convertValue(params, McpSchema.CompleteRequest.class); + + return features.completions() + .values() + .stream() + .filter(spec -> spec.referenceKey().equals(completeRequest.ref())) + .findFirst() + .map(spec -> spec.completionHandler().apply(exchange, completeRequest)) + .orElse(Mono.error(new RuntimeException("Completion not found: " + completeRequest.ref()))); + }; + } + + private McpStreamableHttpServerSession.NotificationHandler createRootsListChangedHandler() { + return (exchange, params) -> { + var rootsChangeConsumers = features.rootsChangeConsumers(); + if (rootsChangeConsumers.isEmpty()) { + return Mono + .fromRunnable(() -> logger.warn("Roots list changed notification, but no consumers provided")); + } + + return exchange.listRoots() + .flatMap(listRootsResult -> Flux.fromIterable(rootsChangeConsumers) + .flatMap(consumer -> consumer.apply(exchange, listRootsResult.roots())) + .onErrorResume(error -> { + logger.error("Error handling roots list change notification", error); + return Mono.empty(); + }) + .then()); + }; + } + + /** + * Get the server capabilities. + */ + public McpSchema.ServerCapabilities getServerCapabilities() { + return serverCapabilities; + } + + /** + * Get the server implementation information. + */ + public McpSchema.Implementation getServerInfo() { + return serverInfo; + } + + /** + * Gracefully closes the server. + */ + public Mono closeGracefully() { + return httpTransportProvider.closeGracefully(); + } + + /** + * Close the server immediately. + */ + public void close() { + httpTransportProvider.close(); + } + + /** + * Notifies clients that the list of available tools has changed. + */ + public Mono notifyToolsListChanged() { + return httpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, null); + } + + /** + * Notifies clients that the list of available resources has changed. + */ + public Mono notifyResourcesListChanged() { + return httpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, null); + } + + /** + * Notifies clients that resources have been updated. + */ + public Mono notifyResourcesUpdated(McpSchema.ResourcesUpdatedNotification notification) { + return httpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_RESOURCES_UPDATED, notification); + } + + /** + * Notifies clients that the list of available prompts has changed. + */ + public Mono notifyPromptsListChanged() { + return httpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED, null); + } + + /** + * Creates a new builder for configuring and creating McpAsyncStreamableHttpServer + * instances. + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating instances of McpAsyncStreamableHttpServer with Streamable HTTP + * transport. + * + *

+ * This builder provides a fluent API for configuring Streamable HTTP MCP + * servers with enhanced features: + *

    + *
  • Single session class managing all transport streams
  • + *
  • Resource management and lifecycle handling
  • + *
  • Clean separation between session and transport concerns
  • + *
  • Support for both immediate and streaming responses
  • + *
+ * + * @author Zachary German + */ + public static class Builder { + + private McpSchema.Implementation serverInfo; + + private McpSchema.ServerCapabilities serverCapabilities; + + private String instructions; + + private Duration requestTimeout = Duration.ofSeconds(30); + + private ObjectMapper objectMapper = new ObjectMapper(); + + private String mcpEndpoint = "/mcp"; + + private Supplier sessionIdProvider; + + private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); + + private final List tools = new ArrayList<>(); + + private final List streamTools = new ArrayList<>(); + + private final Map resources = new HashMap<>(); + + private final List resourceTemplates = new ArrayList<>(); + + private final Map prompts = new HashMap<>(); + + private final Map completions = new HashMap<>(); + + private final List, Mono>> rootsChangeConsumers = new ArrayList<>(); + + /** + * Sets the server implementation information. + */ + public Builder serverInfo(String name, String version) { + return serverInfo(name, null, version); + } + + /** + * Sets the server implementation information. + */ + public Builder serverInfo(String name, String title, String version) { + Assert.hasText(name, "Server name must not be empty"); + Assert.hasText(version, "Server version must not be empty"); + this.serverInfo = new McpSchema.Implementation(name, version); + return this; + } + + /** + * Sets the server capabilities. + */ + public Builder serverCapabilities(McpSchema.ServerCapabilities capabilities) { + this.serverCapabilities = capabilities; + return this; + } + + /** + * Sets the server instructions. + */ + public Builder instructions(String instructions) { + this.instructions = instructions; + return this; + } + + /** + * Sets the request timeout duration. + */ + public Builder requestTimeout(Duration timeout) { + Assert.notNull(timeout, "Request timeout must not be null"); + this.requestTimeout = timeout; + return this; + } + + /** + * Sets the JSON object mapper. + */ + public Builder objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + /** + * Sets the MCP endpoint path. + */ + public Builder withMcpEndpoint(String endpoint) { + Assert.hasText(endpoint, "MCP endpoint must not be empty"); + this.mcpEndpoint = endpoint; + return this; + } + + /** + * Sets the session ID provider. + */ + public Builder withSessionIdProvider(Supplier provider) { + Assert.notNull(provider, "Session ID provider must not be null"); + this.sessionIdProvider = provider; + return this; + } + + /** + * Sets the URI template manager factory. + */ + public Builder withUriTemplateManagerFactory(McpUriTemplateManagerFactory factory) { + Assert.notNull(factory, "URI template manager factory must not be null"); + this.uriTemplateManagerFactory = factory; + return this; + } + + /** + * Adds a tool specification. + */ + public Builder withTool(McpServerFeatures.AsyncToolSpecification toolSpec) { + Assert.notNull(toolSpec, "Tool specification must not be null"); + this.tools.add(toolSpec); + return this; + } + + /** + * Adds a streaming tool specification. + */ + public Builder withStreamingTool(McpServerFeatures.AsyncStreamingToolSpecification toolSpec) { + Assert.notNull(toolSpec, "Streaming tool specification must not be null"); + this.streamTools.add(toolSpec); + return this; + } + + /** + * Adds a resource specification. + */ + public Builder withResource(String uri, McpServerFeatures.AsyncResourceSpecification resourceSpec) { + Assert.hasText(uri, "Resource URI must not be empty"); + Assert.notNull(resourceSpec, "Resource specification must not be null"); + this.resources.put(uri, resourceSpec); + return this; + } + + /** + * Adds a resource template. + */ + public Builder withResourceTemplate(McpSchema.ResourceTemplate template) { + Assert.notNull(template, "Resource template must not be null"); + this.resourceTemplates.add(template); + return this; + } + + /** + * Adds a prompt specification. + */ + public Builder withPrompt(String name, McpServerFeatures.AsyncPromptSpecification promptSpec) { + Assert.hasText(name, "Prompt name must not be empty"); + Assert.notNull(promptSpec, "Prompt specification must not be null"); + this.prompts.put(name, promptSpec); + return this; + } + + /** + * Adds a completion specification. + */ + public Builder withCompletion(McpSchema.CompleteReference reference, + McpServerFeatures.AsyncCompletionSpecification completionSpec) { + Assert.notNull(reference, "Completion reference must not be null"); + Assert.notNull(completionSpec, "Completion specification must not be null"); + this.completions.put(reference, completionSpec); + return this; + } + + /** + * Adds a roots change consumer. + */ + public Builder withRootsChangeConsumer( + BiFunction, Mono> consumer) { + Assert.notNull(consumer, "Roots change consumer must not be null"); + this.rootsChangeConsumers.add(consumer); + return this; + } + + /** + * Builds the McpAsyncStreamableHttpServer instance. + */ + public McpAsyncStreamableHttpServer build() { + Assert.notNull(serverInfo, "Server info must be set"); + + // Create Streamable HTTP transport provider + StreamableHttpServerTransportProvider.Builder transportBuilder = StreamableHttpServerTransportProvider + .builder() + .withObjectMapper(objectMapper) + .withMcpEndpoint(mcpEndpoint); + + if (sessionIdProvider != null) { + transportBuilder.withSessionIdProvider(sessionIdProvider); + } + + StreamableHttpServerTransportProvider httpTransportProvider = transportBuilder.build(); + + // Create server features + McpServerFeatures.Async features = new McpServerFeatures.Async(serverInfo, serverCapabilities, tools, + resources, resourceTemplates, prompts, completions, rootsChangeConsumers, instructions, + streamTools); + + return new McpAsyncStreamableHttpServer(httpTransportProvider, objectMapper, features, requestTimeout, + uriTemplateManagerFactory); + } + + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java index 8311f5d41..b1e533c4a 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java @@ -14,6 +14,7 @@ import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.Utils; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; @@ -44,7 +45,31 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s Map prompts, Map completions, List, Mono>> rootsChangeConsumers, - String instructions) { + String instructions, List streamTools) { + + /** + * Create an instance and validate the arguments (backward compatible + * constructor). + * @param serverInfo The server implementation details + * @param serverCapabilities The server capabilities + * @param tools The list of tool specifications + * @param resources The map of resource specifications + * @param resourceTemplates The list of resource templates + * @param prompts The map of prompt specifications + * @param rootsChangeConsumers The list of consumers that will be notified when + * the roots list changes + * @param instructions The server instructions text + */ + public Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, + List tools, Map resources, + List resourceTemplates, + Map prompts, + Map completions, + List, Mono>> rootsChangeConsumers, + String instructions) { + this(serverInfo, serverCapabilities, tools, resources, resourceTemplates, prompts, completions, + rootsChangeConsumers, instructions, List.of()); + } /** * Create an instance and validate the arguments. @@ -57,6 +82,7 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s * @param rootsChangeConsumers The list of consumers that will be notified when * the roots list changes * @param instructions The server instructions text + * @param streamTools The list of streaming tool specifications */ Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, List tools, Map resources, @@ -64,7 +90,7 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s Map prompts, Map completions, List, Mono>> rootsChangeConsumers, - String instructions) { + String instructions, List streamTools) { Assert.notNull(serverInfo, "Server info must not be null"); @@ -88,6 +114,7 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s this.completions = (completions != null) ? completions : Map.of(); this.rootsChangeConsumers = (rootsChangeConsumers != null) ? rootsChangeConsumers : List.of(); this.instructions = instructions; + this.streamTools = (streamTools != null) ? streamTools : List.of(); } /** @@ -128,7 +155,8 @@ static Async fromSync(Sync syncSpec) { } return new Async(syncSpec.serverInfo(), syncSpec.serverCapabilities(), tools, resources, - syncSpec.resourceTemplates(), prompts, completions, rootChangeConsumers, syncSpec.instructions()); + syncSpec.resourceTemplates(), prompts, completions, rootChangeConsumers, syncSpec.instructions(), + List.of()); } } @@ -251,6 +279,40 @@ static AsyncToolSpecification fromSync(SyncToolSpecification tool) { } } + /** + * Specification of a streaming tool with its asynchronous handler function that can + * return either a single result (Mono) or a stream of results (Flux). This enables + * tools to provide real-time streaming responses for long-running operations or + * progressive results. + * + *

+ * Example streaming tool specification:

{@code
+	 * new McpServerFeatures.AsyncStreamingToolSpecification(
+	 *     new Tool(
+	 *         "file_processor",
+	 *         "Processes files with streaming progress updates",
+	 *         new JsonSchemaObject()
+	 *             .required("file_path")
+	 *             .property("file_path", JsonSchemaType.STRING)
+	 *     ),
+	 *     (exchange, args) -> {
+	 *         String filePath = (String) args.get("file_path");
+	 *         return Flux.interval(Duration.ofSeconds(1))
+	 *             .take(10)
+	 *             .map(i -> new CallToolResult("Processing step " + i + " for " + filePath));
+	 *     }
+	 * )
+	 * }
+ * + * @param tool The tool definition including name, description, and parameter schema + * @param call The function that implements the tool's streaming logic, receiving + * arguments and returning a Flux of results that will be streamed to the client via + * SSE. + */ + public record AsyncStreamingToolSpecification(McpSchema.Tool tool, + BiFunction, Flux> call) { + } + /** * Specification of a resource with its asynchronous handler function. Resources * provide context to AI models by exposing data such as: diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/SessionHandler.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/SessionHandler.java new file mode 100644 index 000000000..5a4331cf5 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/SessionHandler.java @@ -0,0 +1,57 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +/** + * Handler interface for session lifecycle and runtime events in the Streamable HTTP + * transport. + * + *

+ * This interface provides hooks for monitoring and responding to various session-related + * events that occur during the operation of the HTTP-based MCP server transport. + * Implementations can use these callbacks to: + *

    + *
  • Log session activities
  • + *
  • Implement custom session management logic
  • + *
  • Handle error conditions
  • + *
  • Perform cleanup operations
  • + *
+ * + * @author Zachary German + */ +public interface SessionHandler { + + /** + * Called when a new session is created. + * @param sessionId The ID of the newly created session + * @param context Additional context information (may be null) + */ + void onSessionCreate(String sessionId, Object context); + + /** + * Called when a session is closed. + * @param sessionId The ID of the closed session + */ + void onSessionClose(String sessionId); + + /** + * Called when a session is not found for a given session ID. + * @param sessionId The session ID that was not found + * @param request The HTTP request that referenced the missing session + * @param response The HTTP response that will be sent to the client + */ + void onSessionNotFound(String sessionId, HttpServletRequest request, HttpServletResponse response); + + /** + * Called when an error occurs while sending a notification to a session. + * @param sessionId The ID of the session where the error occurred + * @param error The error that occurred + */ + void onSendNotificationError(String sessionId, Throwable error); + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProvider.java new file mode 100644 index 000000000..87b614848 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProvider.java @@ -0,0 +1,850 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.io.IOException; +import java.io.PrintWriter; +import java.util.ArrayList; +import java.util.Enumeration; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Supplier; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.spec.McpStreamableHttpServerSession; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpServerTransport; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.util.Assert; +import jakarta.servlet.AsyncContext; +import jakarta.servlet.ReadListener; +import jakarta.servlet.ServletException; +import jakarta.servlet.ServletInputStream; +import jakarta.servlet.annotation.WebServlet; +import jakarta.servlet.http.HttpServlet; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import java.nio.charset.StandardCharsets; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.util.context.Context; + +/** + * MCP Streamable HTTP transport provider that uses a single session class to manage all + * streams and transports. + * + *

+ * Key improvements over the original implementation: + *

    + *
  • Manages server-client sessions, including transport registration. + *
  • Handles HTTP requests and HTTP/SSE responses and streams. + *
  • Provides callbacks for session lifecycle and errors. + *
  • Supports graceful shutdown. + *
  • Enforces allowed 'Origin' header values if configured. + *
  • Provides a default session ID provider if none is configured. + *
+ * + * @author Zachary German + */ +@WebServlet(asyncSupported = true) +public class StreamableHttpServerTransportProvider extends HttpServlet implements McpServerTransportProvider { + + private static final Logger logger = LoggerFactory.getLogger(StreamableHttpServerTransportProvider.class); + + public static final String UTF_8 = "UTF-8"; + + public static final String APPLICATION_JSON = "application/json"; + + public static final String TEXT_EVENT_STREAM = "text/event-stream"; + + public static final String SESSION_ID_HEADER = "Mcp-Session-Id"; + + public static final String LAST_EVENT_ID_HEADER = "Last-Event-Id"; + + public static final String MESSAGE_EVENT_TYPE = "message"; + + public static final String ACCEPT_HEADER = "Accept"; + + public static final String ORIGIN_HEADER = "Origin"; + + public static final String ALLOW_ORIGIN_HEADER = "Access-Control-Allow-Origin"; + + public static final String ALLOW_ORIGIN_DEFAULT_VALUE = "*"; + + public static final String CACHE_CONTROL_HEADER = "Cache-Control"; + + public static final String CONNECTION_HEADER = "Connection"; + + public static final String CACHE_CONTROL_NO_CACHE = "no-cache"; + + public static final String CONNECTION_KEEP_ALIVE = "keep-alive"; + + public static final String MCP_SESSION_ID = "MCP-Session-ID"; + + public static final String DEFAULT_MCP_ENDPOINT = "/mcp"; + + /** com.fasterxml.jackson.databind.ObjectMapper */ + private static final ObjectMapper DEFAULT_OBJECT_MAPPER = new ObjectMapper(); + + /** UUID.randomUUID().toString() */ + private static final Supplier DEFAULT_SESSION_ID_PROVIDER = () -> UUID.randomUUID().toString(); + + /** JSON object mapper for serialization/deserialization */ + private final ObjectMapper objectMapper; + + /** The endpoint path for handling MCP requests */ + private final String mcpEndpoint; + + /** Supplier for generating unique session IDs */ + private final Supplier sessionIdProvider; + + /** Sessions map, keyed by Session ID */ + private final Map sessions = new ConcurrentHashMap<>(); + + /** Flag indicating if the transport is in the process of shutting down */ + private final AtomicBoolean isClosing = new AtomicBoolean(false); + + /** Optional allowed 'Origin' header value list. Not enforced if empty. */ + private final List allowedOrigins = new ArrayList<>(); + + /** Callback interface for session lifecycle and errors */ + private SessionHandler sessionHandler; + + private McpStreamableHttpServerSession.Factory streamableHttpSessionFactory; + + /** + *
    + *
  • Manages server-client sessions, including transport registration. + *
  • Handles HTTP requests and HTTP/SSE responses and streams. + *
+ * @param objectMapper ObjectMapper - Default: + * com.fasterxml.jackson.databind.ObjectMapper + * @param mcpEndpoint String - Default: '/mcp' + * @param sessionIdProvider Supplier(String) - Default: UUID.randomUUID().toString() + */ + public StreamableHttpServerTransportProvider(ObjectMapper objectMapper, String mcpEndpoint, + Supplier sessionIdProvider) { + this.objectMapper = Objects.requireNonNullElse(objectMapper, DEFAULT_OBJECT_MAPPER); + this.mcpEndpoint = Objects.requireNonNullElse(mcpEndpoint, DEFAULT_MCP_ENDPOINT); + this.sessionIdProvider = Objects.requireNonNullElse(sessionIdProvider, DEFAULT_SESSION_ID_PROVIDER); + } + + /** + *
    + *
  • Manages server-client sessions, including transport registration. + *
  • Handles HTTP requests and HTTP/SSE responses and streams. + *
+ * @param objectMapper ObjectMapper - Default: + * com.fasterxml.jackson.databind.ObjectMapper + * @param mcpEndpoint String - Default: '/mcp' + * @param sessionIdProvider Supplier(String) - Default: UUID.randomUUID().toString() + */ + public StreamableHttpServerTransportProvider() { + this(null, null, null); + } + + @Override + public void setSessionFactory(McpServerSession.Factory sessionFactory) { + // This method is required by the interface but not used in this implementation + } + + public void setStreamableHttpSessionFactory(McpStreamableHttpServerSession.Factory sessionFactory) { + this.streamableHttpSessionFactory = sessionFactory; + } + + public void setSessionHandler(SessionHandler sessionHandler) { + this.sessionHandler = sessionHandler; + } + + public void setAllowedOrigins(List allowedOrigins) { + this.allowedOrigins.clear(); + this.allowedOrigins.addAll(allowedOrigins); + } + + @Override + public Mono notifyClients(String method, Object params) { + if (sessions.isEmpty()) { + logger.debug("No active sessions to broadcast message to"); + return Mono.empty(); + } + + logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); + + return Flux.fromIterable(sessions.values()) + .flatMap(session -> session.sendNotification(method, params).doOnError(e -> { + logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage()); + if (sessionHandler != null) { + sessionHandler.onSendNotificationError(session.getId(), e); + } + }).onErrorComplete()) + .then(); + } + + @Override + public Mono closeGracefully() { + return Mono.defer(() -> { + isClosing.set(true); + logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size()); + return Flux.fromIterable(sessions.values()) + .flatMap(session -> session.closeGracefully() + .doOnError(e -> logger.error("Error closing session {}: {}", session.getId(), e.getMessage())) + .onErrorComplete()) + .then(); + }); + } + + @Override + protected void doGet(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + String requestURI = request.getRequestURI(); + logger.info("GET request received for URI: '{}' with headers: {}", requestURI, extractHeaders(request)); + + if (!validateOrigin(request, response) || !validateEndpoint(requestURI, response) + || !validateNotClosing(response)) { + return; + } + + String acceptHeader = request.getHeader(ACCEPT_HEADER); + if (acceptHeader == null || !acceptHeader.contains(TEXT_EVENT_STREAM)) { + logger.debug("Accept header missing or does not include {}", TEXT_EVENT_STREAM); + sendErrorResponse(response, "Accept header must include text/event-stream"); + return; + } + + String sessionId = request.getHeader(SESSION_ID_HEADER); + if (sessionId == null) { + sendErrorResponse(response, "Session ID missing in request header"); + return; + } + + McpStreamableHttpServerSession session = sessions.get(sessionId); + if (session == null) { + handleSessionNotFound(sessionId, request, response); + return; + } + + // Set up SSE connection + response.setContentType(TEXT_EVENT_STREAM); + response.setCharacterEncoding(UTF_8); + response.setHeader(CACHE_CONTROL_HEADER, CACHE_CONTROL_NO_CACHE); + response.setHeader(CONNECTION_HEADER, CONNECTION_KEEP_ALIVE); + response.setHeader(SESSION_ID_HEADER, sessionId); + + AsyncContext asyncContext = request.startAsync(); + asyncContext.setTimeout(0); + + String lastEventId = request.getHeader(LAST_EVENT_ID_HEADER); + String transportId = "sse-" + request.getRequestId(); + + SseTransport sseTransport = new SseTransport(objectMapper, response, asyncContext, lastEventId); + session.registerTransport(transportId, sseTransport); + + logger.debug("Registered SSE transport {} for session {}", transportId, sessionId); + } + + @Override + protected void doPost(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + String requestURI = request.getRequestURI(); + logger.info("POST request received for URI: '{}' with headers: {}", requestURI, extractHeaders(request)); + + if (!validateOrigin(request, response) || !validateEndpoint(requestURI, response) + || !validateNotClosing(response)) { + return; + } + + String acceptHeader = request.getHeader(ACCEPT_HEADER); + if (acceptHeader == null + || (!acceptHeader.contains(APPLICATION_JSON) || !acceptHeader.contains(TEXT_EVENT_STREAM))) { + logger.debug("Accept header validation failed. Header: {}", acceptHeader); + sendErrorResponse(response, "Accept header must include both application/json and text/event-stream"); + return; + } + + AsyncContext asyncContext = request.startAsync(); + asyncContext.setTimeout(0); + + StringBuilder body = new StringBuilder(); + ServletInputStream inputStream = request.getInputStream(); + + inputStream.setReadListener(new ReadListener() { + @Override + public void onDataAvailable() throws IOException { + int len; + byte[] buffer = new byte[1024]; + while (inputStream.isReady() && (len = inputStream.read(buffer)) != -1) { + body.append(new String(buffer, 0, len, StandardCharsets.UTF_8)); + } + } + + @Override + public void onAllDataRead() throws IOException { + try { + logger.debug("Parsing JSON-RPC message: {}", body.toString()); + JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body.toString()); + + boolean isInitializeRequest = false; + String sessionId = request.getHeader(SESSION_ID_HEADER); + + if (message instanceof McpSchema.JSONRPCRequest req + && McpSchema.METHOD_INITIALIZE.equals(req.method())) { + isInitializeRequest = true; + logger.debug("Detected initialize request"); + if (sessionId == null) { + sessionId = sessionIdProvider.get(); + logger.debug("Created new session ID for initialize request: {}", sessionId); + } + } + + if (!isInitializeRequest && sessionId == null) { + sendErrorResponse(response, "Session ID missing in request header"); + asyncContext.complete(); + return; + } + + McpStreamableHttpServerSession session = getOrCreateSession(sessionId, isInitializeRequest); + if (session == null) { + logger.error("Failed to create session for sessionId: {}", sessionId); + handleSessionNotFound(sessionId, request, response); + asyncContext.complete(); + return; + } + logger.debug("Using session: {}", sessionId); + + response.setHeader(SESSION_ID_HEADER, sessionId); + + // Determine response type and create appropriate transport + ResponseType responseType = detectResponseType(message, session); + String transportId = "req-" + request.getRequestId(); + + if (responseType == ResponseType.STREAM) { + logger.debug("Handling STREAM response type"); + response.setContentType(TEXT_EVENT_STREAM); + response.setCharacterEncoding(UTF_8); + response.setHeader(CACHE_CONTROL_HEADER, CACHE_CONTROL_NO_CACHE); + response.setHeader(CONNECTION_HEADER, CONNECTION_KEEP_ALIVE); + + SseTransport sseTransport = new SseTransport(objectMapper, response, asyncContext, null); + session.registerTransport(transportId, sseTransport); + } + else { + logger.debug("Handling IMMEDIATE response type"); + // Only set content type for requests, not notifications + if (message instanceof McpSchema.JSONRPCRequest) { + logger.debug("Setting content type to APPLICATION_JSON for request response"); + response.setContentType(APPLICATION_JSON); + } + else { + logger.debug("Not setting content type for notification (empty response expected)"); + } + + HttpTransport httpTransport = new HttpTransport(objectMapper, response, asyncContext); + session.registerTransport(transportId, httpTransport); + } + + // Handle the message + logger.debug("About to handle message: {} with transport: {}", message.getClass().getSimpleName(), + transportId); + + // For notifications, we need to handle the HTTP response manually + // since no JSON response is sent + if (message instanceof McpSchema.JSONRPCNotification) { + session.handleMessage(message, transportId).doOnSuccess(v -> { + logger.debug("Message handling completed successfully for transport: {}", transportId); + logger.debug("[NOTIFICATION] Sending empty HTTP response for notification"); + try { + if (!response.isCommitted()) { + response.setStatus(HttpServletResponse.SC_OK); + response.setCharacterEncoding("UTF-8"); + } + asyncContext.complete(); + } + catch (Exception e) { + logger.error("Failed to send notification response: {}", e.getMessage()); + asyncContext.complete(); + } + }).doOnError(e -> { + logger.error("Error in message handling: {}", e.getMessage(), e); + asyncContext.complete(); + }).doFinally(signalType -> { + logger.debug("Unregistering transport: {} with signal: {}", transportId, signalType); + session.unregisterTransport(transportId); + }).contextWrite(Context.of(MCP_SESSION_ID, sessionId)).subscribe(); + } + else { + // For requests, let the transport handle the response + session.handleMessage(message, transportId) + .doOnSuccess(v -> logger.info("Message handling completed successfully for transport: {}", + transportId)) + .doOnError(e -> logger.error("Error in message handling: {}", e.getMessage(), e)) + .doFinally(signalType -> { + logger.debug("Unregistering transport: {} with signal: {}", transportId, signalType); + session.unregisterTransport(transportId); + }) + .contextWrite(Context.of(MCP_SESSION_ID, sessionId)) + .subscribe(null, error -> { + logger.error("Error in message handling chain: {}", error.getMessage(), error); + asyncContext.complete(); + }); + } + + } + catch (Exception e) { + logger.error("Error processing message: {}", e.getMessage()); + sendErrorResponse(response, "Invalid JSON-RPC message: " + e.getMessage()); + asyncContext.complete(); + } + } + + @Override + public void onError(Throwable t) { + logger.error("Error reading request body: {}", t.getMessage()); + try { + sendErrorResponse(response, "Error reading request: " + t.getMessage()); + } + catch (IOException e) { + logger.error("Failed to write error response", e); + } + asyncContext.complete(); + } + }); + } + + @Override + protected void doDelete(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + String requestURI = request.getRequestURI(); + if (!requestURI.endsWith(mcpEndpoint)) { + response.sendError(HttpServletResponse.SC_NOT_FOUND); + return; + } + + String sessionId = request.getHeader(SESSION_ID_HEADER); + if (sessionId == null) { + sendErrorResponse(response, "Session ID missing in request header"); + return; + } + + McpStreamableHttpServerSession session = sessions.remove(sessionId); + if (session == null) { + handleSessionNotFound(sessionId, request, response); + return; + } + + session.closeGracefully().contextWrite(Context.of(MCP_SESSION_ID, sessionId)).subscribe(); + logger.debug("Session closed: {}", sessionId); + if (sessionHandler != null) { + sessionHandler.onSessionClose(sessionId); + } + + response.setStatus(HttpServletResponse.SC_OK); + } + + private boolean validateOrigin(HttpServletRequest request, HttpServletResponse response) throws IOException { + if (!allowedOrigins.isEmpty()) { + String origin = request.getHeader(ORIGIN_HEADER); + if (!allowedOrigins.contains(origin)) { + logger.debug("Origin header does not match allowed origins: '{}'", origin); + response.sendError(HttpServletResponse.SC_FORBIDDEN); + return false; + } + else { + response.setHeader(ALLOW_ORIGIN_HEADER, origin); + } + } + else { + response.setHeader(ALLOW_ORIGIN_HEADER, ALLOW_ORIGIN_DEFAULT_VALUE); + } + return true; + } + + private boolean validateEndpoint(String requestURI, HttpServletResponse response) throws IOException { + if (!requestURI.endsWith(mcpEndpoint)) { + logger.debug("URI does not match MCP endpoint: '{}'", mcpEndpoint); + response.sendError(HttpServletResponse.SC_NOT_FOUND); + return false; + } + return true; + } + + private boolean validateNotClosing(HttpServletResponse response) throws IOException { + if (isClosing.get()) { + logger.debug("Server is shutting down, rejecting request"); + response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down"); + return false; + } + return true; + } + + protected McpStreamableHttpServerSession getOrCreateSession(String sessionId, boolean createIfMissing) { + McpStreamableHttpServerSession session = sessions.get(sessionId); + logger.debug("Looking for session: {}, found: {}", sessionId, session != null); + if (session == null && createIfMissing) { + logger.debug("Creating new session: {}", sessionId); + session = streamableHttpSessionFactory.create(sessionId); + sessions.put(sessionId, session); + logger.debug("Created new session: {}", sessionId); + if (sessionHandler != null) { + sessionHandler.onSessionCreate(sessionId, null); + } + } + return session; + } + + private ResponseType detectResponseType(McpSchema.JSONRPCMessage message, McpStreamableHttpServerSession session) { + if (message instanceof McpSchema.JSONRPCRequest request) { + if (McpSchema.METHOD_INITIALIZE.equals(request.method())) { + return ResponseType.IMMEDIATE; + } + + // Check if handler returns Flux (streaming) or Mono (immediate) + var handler = session.getRequestHandler(request.method()); + if (handler != null && handler instanceof McpStreamableHttpServerSession.StreamingRequestHandler) { + return ResponseType.STREAM; + } + return ResponseType.IMMEDIATE; + } + else { + return ResponseType.IMMEDIATE; + } + } + + private void handleSessionNotFound(String sessionId, HttpServletRequest request, HttpServletResponse response) + throws IOException { + sendErrorResponse(response, "Session not found: " + sessionId); + if (sessionHandler != null) { + sessionHandler.onSessionNotFound(sessionId, request, response); + } + } + + private void sendErrorResponse(HttpServletResponse response, String message) throws IOException { + response.setContentType(APPLICATION_JSON); + response.setStatus(HttpServletResponse.SC_BAD_REQUEST); + response.getWriter().write(createErrorJson(message)); + } + + private String createErrorJson(String message) { + try { + return objectMapper.writeValueAsString(new McpError(message)); + } + catch (IOException e) { + logger.error("Failed to serialize error message", e); + return "{\"error\":\"" + message + "\"}"; + } + } + + @Override + public void destroy() { + closeGracefully().block(); + super.destroy(); + } + + private Map extractHeaders(HttpServletRequest request) { + Map headers = new HashMap<>(); + Enumeration headerNames = request.getHeaderNames(); + while (headerNames.hasMoreElements()) { + String name = headerNames.nextElement(); + headers.put(name, request.getHeader(name)); + } + return headers; + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private ObjectMapper objectMapper = DEFAULT_OBJECT_MAPPER; + + private String mcpEndpoint = DEFAULT_MCP_ENDPOINT; + + private Supplier sessionIdProvider = DEFAULT_SESSION_ID_PROVIDER; + + public Builder withObjectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + public Builder withMcpEndpoint(String mcpEndpoint) { + Assert.hasText(mcpEndpoint, "MCP endpoint must not be empty"); + this.mcpEndpoint = mcpEndpoint; + return this; + } + + public Builder withSessionIdProvider(Supplier sessionIdProvider) { + Assert.notNull(sessionIdProvider, "SessionIdProvider must not be null"); + this.sessionIdProvider = sessionIdProvider; + return this; + } + + public StreamableHttpServerTransportProvider build() { + return new StreamableHttpServerTransportProvider(objectMapper, mcpEndpoint, sessionIdProvider); + } + + } + + private enum ResponseType { + + IMMEDIATE, STREAM + + } + + /** + * SSE transport implementation. + */ + private static class SseTransport implements McpServerTransport { + + private static final Logger logger = LoggerFactory.getLogger(SseTransport.class); + + private final ObjectMapper objectMapper; + + private final HttpServletResponse response; + + private final AsyncContext asyncContext; + + private final Sinks.Many eventSink = Sinks.many().unicast().onBackpressureBuffer(); + + private final Map eventHistory = new ConcurrentHashMap<>(); + + private final AtomicLong eventCounter = new AtomicLong(0); + + public SseTransport(ObjectMapper objectMapper, HttpServletResponse response, AsyncContext asyncContext, + String lastEventId) { + this.objectMapper = objectMapper; + this.response = response; + this.asyncContext = asyncContext; + + setupSseStream(lastEventId); + } + + private void setupSseStream(String lastEventId) { + try { + PrintWriter writer = response.getWriter(); + + eventSink.asFlux().doOnNext(event -> { + try { + if (event.id() != null) { + writer.write("id: " + event.id() + "\n"); + } + if (event.event() != null) { + writer.write("event: " + event.event() + "\n"); + } + writer.write("data: " + event.data() + "\n\n"); + writer.flush(); + + if (writer.checkError()) { + throw new IOException("Client disconnected"); + } + } + catch (IOException e) { + logger.debug("Error writing to SSE stream: {}", e.getMessage()); + asyncContext.complete(); + } + }).doOnComplete(() -> { + try { + writer.close(); + } + finally { + asyncContext.complete(); + } + }).doOnError(e -> { + logger.error("Error in SSE stream: {}", e.getMessage()); + asyncContext.complete(); + }).contextWrite(Context.of(MCP_SESSION_ID, response.getHeader(SESSION_ID_HEADER))).subscribe(); + + // Replay events if requested + if (lastEventId != null) { + replayEventsAfter(lastEventId); + } + + } + catch (IOException e) { + logger.error("Failed to setup SSE stream: {}", e.getMessage()); + asyncContext.complete(); + } + } + + private void replayEventsAfter(String lastEventId) { + try { + long lastId = Long.parseLong(lastEventId); + for (long i = lastId + 1; i <= eventCounter.get(); i++) { + SseEvent event = eventHistory.get(String.valueOf(i)); + if (event != null) { + eventSink.tryEmitNext(event); + } + } + } + catch (NumberFormatException e) { + logger.warn("Invalid last event ID: {}", lastEventId); + } + } + + @Override + public Mono sendMessage(JSONRPCMessage message) { + try { + String jsonText = objectMapper.writeValueAsString(message); + String eventId = String.valueOf(eventCounter.incrementAndGet()); + SseEvent event = new SseEvent(eventId, MESSAGE_EVENT_TYPE, jsonText); + + eventHistory.put(eventId, event); + logger.debug("Sending SSE event {}: {}", eventId, jsonText); + eventSink.tryEmitNext(event); + + if (message instanceof McpSchema.JSONRPCResponse) { + logger.debug("Completing SSE stream after sending response"); + eventSink.tryEmitComplete(); + } + + return Mono.empty(); + } + catch (Exception e) { + logger.error("Failed to send message: {}", e.getMessage()); + return Mono.error(e); + } + } + + /** + * Sends a stream of messages for Flux responses. + */ + public Mono sendMessageStream(Flux messageStream) { + return messageStream.doOnNext(message -> { + try { + String jsonText = objectMapper.writeValueAsString(message); + String eventId = String.valueOf(eventCounter.incrementAndGet()); + SseEvent event = new SseEvent(eventId, MESSAGE_EVENT_TYPE, jsonText); + + eventHistory.put(eventId, event); + logger.debug("Sending SSE stream event {}: {}", eventId, jsonText); + eventSink.tryEmitNext(event); + } + catch (Exception e) { + logger.error("Failed to send stream message: {}", e.getMessage()); + eventSink.tryEmitError(e); + } + }).doOnComplete(() -> { + logger.debug("Completing SSE stream after sending all stream messages"); + eventSink.tryEmitComplete(); + }).then(); + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return objectMapper.convertValue(data, typeRef); + } + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + eventSink.tryEmitComplete(); + logger.debug("SSE transport closed gracefully"); + }); + } + + private record SseEvent(String id, String event, String data) { + } + + } + + /** + * HTTP transport implementation for immediate responses. + */ + private static class HttpTransport implements McpServerTransport { + + private static final Logger logger = LoggerFactory.getLogger(HttpTransport.class); + + private final ObjectMapper objectMapper; + + private final HttpServletResponse response; + + private final AsyncContext asyncContext; + + public HttpTransport(ObjectMapper objectMapper, HttpServletResponse response, AsyncContext asyncContext) { + this.objectMapper = objectMapper; + this.response = response; + this.asyncContext = asyncContext; + } + + @Override + public Mono sendMessage(JSONRPCMessage message) { + return Mono.fromRunnable(() -> { + try { + if (response.isCommitted()) { + logger.warn("Response already committed, cannot send message"); + return; + } + + response.setCharacterEncoding("UTF-8"); + response.setStatus(HttpServletResponse.SC_OK); + + // For notifications, don't write any content (empty response) + if (message instanceof McpSchema.JSONRPCNotification) { + logger.debug("Sending empty 200 response for notification"); + // Just complete the response with no content + } + else { + // For requests/responses, write JSON content + String jsonText = objectMapper.writeValueAsString(message); + PrintWriter writer = response.getWriter(); + writer.write(jsonText); + writer.flush(); + logger.debug("Successfully sent immediate response: {}", jsonText); + } + } + catch (Exception e) { + logger.error("Failed to send message: {}", e.getMessage(), e); + try { + if (!response.isCommitted()) { + response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR); + } + } + catch (Exception ignored) { + } + } + finally { + asyncContext.complete(); + } + }); + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return objectMapper.convertValue(data, typeRef); + } + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + try { + asyncContext.complete(); + } + catch (Exception e) { + logger.debug("Error completing async context: {}", e.getMessage()); + } + logger.debug("HTTP transport closed gracefully"); + }); + } + + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index 9be585cea..efc8425e6 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -27,7 +27,7 @@ /** * Based on the JSON-RPC 2.0 * specification and the Model + * "https://github.com/modelcontextprotocol/specification/blob/main/schema/2025-06-18/schema.ts">Model * Context Protocol Schema. * * @author Christian Tzolov @@ -40,7 +40,7 @@ public final class McpSchema { private McpSchema() { } - public static final String LATEST_PROTOCOL_VERSION = "2024-11-05"; + public static final String LATEST_PROTOCOL_VERSION = "2025-06-18"; public static final String JSONRPC_VERSION = "2.0"; diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 86906d859..f38df44b9 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -3,6 +3,7 @@ import java.time.Duration; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; @@ -104,6 +105,14 @@ public void init(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Impl this.clientInfo.lazySet(clientInfo); } + public McpSchema.ClientCapabilities getClientCapabilities() { + return this.clientCapabilities.get(); + } + + public McpSchema.Implementation getClientInfo() { + return this.clientInfo.get(); + } + private String generateRequestId() { return this.id + "-" + this.requestCounter.getAndIncrement(); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableHttpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableHttpServerSession.java new file mode 100644 index 000000000..05960cca7 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableHttpServerSession.java @@ -0,0 +1,383 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import java.time.Duration; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; + +import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.server.McpAsyncServerExchange; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoSink; +import reactor.core.publisher.Sinks; + +/** + * Streamable HTTP MCP server session that manages multiple transport streams and handles both + * immediate and streaming responses through a unified interface. + * + *

+ * This session implementation provides: + *

    + *
  • Unified management of multiple transport streams per session
  • + *
  • Support for both immediate JSON responses and SSE streaming
  • + *
  • Automatic response type detection based on handler return types
  • + *
  • Proper lifecycle management for all associated transports
  • + *
+ */ +public class McpStreamableHttpServerSession implements McpSession { + + private static final Logger logger = LoggerFactory.getLogger(McpStreamableHttpServerSession.class); + + private final ConcurrentHashMap> pendingResponses = new ConcurrentHashMap<>(); + + private final ConcurrentHashMap transports = new ConcurrentHashMap<>(); + + private final String id; + + private final Duration requestTimeout; + + private final AtomicLong requestCounter = new AtomicLong(0); + + private final InitRequestHandler initRequestHandler; + + private final InitNotificationHandler initNotificationHandler; + + private final Map> requestHandlers; + + private final Map notificationHandlers; + + private final Sinks.One exchangeSink = Sinks.one(); + + private final AtomicReference clientCapabilities = new AtomicReference<>(); + + private final AtomicReference clientInfo = new AtomicReference<>(); + + private static final int STATE_UNINITIALIZED = 0; + + private static final int STATE_INITIALIZING = 1; + + private static final int STATE_INITIALIZED = 2; + + private final AtomicInteger state = new AtomicInteger(STATE_UNINITIALIZED); + + /** + * Creates a new Streamable HTTP MCP server session. + * @param id session id + * @param requestTimeout timeout for requests + * @param initHandler initialization request handler + * @param initNotificationHandler initialization notification handler + * @param requestHandlers map of request handlers + * @param notificationHandlers map of notification handlers + */ + public McpStreamableHttpServerSession(String id, Duration requestTimeout, InitRequestHandler initHandler, + InitNotificationHandler initNotificationHandler, Map> requestHandlers, + Map notificationHandlers) { + this.id = id; + this.requestTimeout = requestTimeout; + this.initRequestHandler = initHandler; + this.initNotificationHandler = initNotificationHandler; + this.requestHandlers = requestHandlers; + this.notificationHandlers = notificationHandlers; + } + + /** + * Registers a transport for this session. + * @param transportId unique identifier for the transport + * @param transport the transport instance + */ + public void registerTransport(String transportId, McpServerTransport transport) { + transports.put(transportId, transport); + logger.debug("Registered transport {} for session {}", transportId, id); + } + + /** + * Unregisters a transport from this session. + * @param transportId the transport identifier to remove + */ + public void unregisterTransport(String transportId) { + McpServerTransport removed = transports.remove(transportId); + if (removed != null) { + logger.debug("Unregistered transport {} from session {}", transportId, id); + } + } + + /** + * Gets a transport by its identifier. + * @param transportId the transport identifier + * @return the transport, or null if not found + */ + public McpServerTransport getTransport(String transportId) { + return transports.get(transportId); + } + + /** + * Handles a message using the specified transport. + * @param message the JSON-RPC message + * @param transportId the transport to use for responses + * @return a Mono that completes when the message is processed + */ + public Mono handleMessage(McpSchema.JSONRPCMessage message, String transportId) { + McpServerTransport transport = transports.get(transportId); + if (transport == null) { + return Mono.error(new RuntimeException("Transport not found: " + transportId)); + } + + return Mono.defer(() -> { + if (message instanceof McpSchema.JSONRPCResponse response) { + logger.debug("Received Response: {}", response); + var sink = pendingResponses.remove(response.id()); + if (sink == null) { + logger.warn("Unexpected response for unknown id {}", response.id()); + } + else { + sink.success(response); + } + return Mono.empty(); + } + else if (message instanceof McpSchema.JSONRPCRequest request) { + logger.debug("Received request: {}", request); + return handleIncomingRequest(request, transport).onErrorResume(error -> { + var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, + new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + error.getMessage(), null)); + return transport.sendMessage(errorResponse).then(Mono.empty()); + }).flatMap(transport::sendMessage); + } + else if (message instanceof McpSchema.JSONRPCNotification notification) { + logger.debug("Received notification: {}", notification); + return handleIncomingNotification(notification) + .doOnError(error -> logger.error("Error handling notification: {}", error.getMessage())); + } + else { + logger.warn("Received unknown message type: {}", message); + return Mono.empty(); + } + }); + } + + /** + * Handles incoming JSON-RPC requests. + */ + private Mono handleIncomingRequest(McpSchema.JSONRPCRequest request, + McpServerTransport transport) { + return Mono.defer(() -> { + Mono resultMono; + if (McpSchema.METHOD_INITIALIZE.equals(request.method())) { + logger.info("[INIT] Processing initialize request for session {}", id); + McpSchema.InitializeRequest initializeRequest = transport.unmarshalFrom(request.params(), + new TypeReference() { + }); + + this.state.lazySet(STATE_INITIALIZING); + logger.debug("[INIT] Session {} state set to INITIALIZING", id); + this.init(initializeRequest.capabilities(), initializeRequest.clientInfo()); + logger.debug("[INIT] Session {} client info stored", id); + resultMono = this.initRequestHandler.handle(initializeRequest); + } + else { + var handler = this.requestHandlers.get(request.method()); + if (handler == null) { + MethodNotFoundError error = getMethodNotFoundError(request.method()); + return Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, + new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.METHOD_NOT_FOUND, + error.message(), error.data()))); + } + + // Wait for initialization to complete, then handle the request + resultMono = this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, request.params())); + } + + return resultMono + .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)) + .onErrorResume(error -> Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), + null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + error.getMessage(), null)))); + }); + } + + /** + * Handles incoming JSON-RPC notifications. + */ + private Mono handleIncomingNotification(McpSchema.JSONRPCNotification notification) { + return Mono.defer(() -> { + if (McpSchema.METHOD_NOTIFICATION_INITIALIZED.equals(notification.method())) { + logger.info("[INIT] Received initialized notification for session {}", id); + this.state.lazySet(STATE_INITIALIZED); + logger.debug("[INIT] Session {} state set to INITIALIZED", id); + McpAsyncServerExchange exchange = new McpAsyncServerExchange(this, clientCapabilities.get(), + clientInfo.get()); + logger.debug("[INIT] Created exchange for session {}: {}", id, exchange); + exchangeSink.tryEmitValue(exchange); + logger.info("[INIT] Session {} initialization complete - exchange emitted", id); + return this.initNotificationHandler.handle(); + } + + var handler = notificationHandlers.get(notification.method()); + if (handler == null) { + logger.error("No handler registered for notification method: {}", notification.method()); + return Mono.empty(); + } + return this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, notification.params())); + }); + } + + /** + * Retrieve the session id. + */ + public String getId() { + return this.id; + } + + /** + * Called upon successful initialization sequence. + */ + public void init(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo) { + this.clientCapabilities.lazySet(clientCapabilities); + this.clientInfo.lazySet(clientInfo); + } + + public McpSchema.ClientCapabilities getClientCapabilities() { + return this.clientCapabilities.get(); + } + + public McpSchema.Implementation getClientInfo() { + return this.clientInfo.get(); + } + + record MethodNotFoundError(String method, String message, Object data) { + } + + private MethodNotFoundError getMethodNotFoundError(String method) { + return new MethodNotFoundError(method, "Method not found: " + method, null); + } + + @Override + public Mono closeGracefully() { + return Flux.fromIterable(transports.values()).flatMap(McpServerTransport::closeGracefully).then(); + } + + @Override + public void close() { + transports.values().forEach(McpServerTransport::close); + transports.clear(); + } + + @Override + public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { + String requestId = this.generateRequestId(); + + return Mono.create(sink -> { + this.pendingResponses.put(requestId, sink); + McpSchema.JSONRPCRequest jsonrpcRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, method, + requestId, requestParams); + + Flux.fromIterable(transports.values()) + .flatMap(transport -> transport.sendMessage(jsonrpcRequest)) + .subscribe(v -> { + }, error -> { + this.pendingResponses.remove(requestId); + sink.error(error); + }); + }).timeout(requestTimeout).handle((jsonRpcResponse, sink) -> { + if (jsonRpcResponse.error() != null) { + sink.error(new McpError(jsonRpcResponse.error())); + } + else { + if (typeRef.getType().equals(Void.class)) { + sink.complete(); + } + else { + McpServerTransport transport = transports.values().iterator().next(); + sink.next(transport.unmarshalFrom(jsonRpcResponse.result(), typeRef)); + } + } + }); + } + + @Override + public Mono sendNotification(String method, Object params) { + McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, + method, params); + + return Flux.fromIterable(transports.values()) + .flatMap(transport -> transport.sendMessage(jsonrpcNotification)) + .then(); + } + + /** + * Gets a request handler by method name. + */ + public RequestHandler getRequestHandler(String method) { + return requestHandlers.get(method); + } + + private String generateRequestId() { + return this.id + "-" + this.requestCounter.getAndIncrement(); + } + + /** + * Request handler for the initialization request. + */ + public interface InitRequestHandler { + + Mono handle(McpSchema.InitializeRequest initializeRequest); + + } + + /** + * Notification handler for the initialization notification from the client. + */ + public interface InitNotificationHandler { + + Mono handle(); + + } + + /** + * A handler for client-initiated notifications. + */ + public interface NotificationHandler { + + Mono handle(McpAsyncServerExchange exchange, Object params); + + } + + /** + * A handler for client-initiated requests. + */ + public interface RequestHandler { + + Mono handle(McpAsyncServerExchange exchange, Object params); + + } + + /** + * A handler for streaming requests that return Flux. + */ + public interface StreamingRequestHandler extends RequestHandler { + + Flux handleStreaming(McpAsyncServerExchange exchange, Object params); + + } + + /** + * Factory for creating Streamable HTTP MCP server sessions. + */ + @FunctionalInterface + public interface Factory { + + McpStreamableHttpServerSession create(String sessionId); + + } + +} \ No newline at end of file diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/StreamableHttpMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/StreamableHttpMcpAsyncServerTests.java new file mode 100644 index 000000000..13114e5c9 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/StreamableHttpMcpAsyncServerTests.java @@ -0,0 +1,22 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.server.transport.StreamableHttpServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import org.junit.jupiter.api.Timeout; + +/** + * Tests for {@link McpAsyncServer} using {@link StreamableHttpServerTransportProvider}. + */ +@Timeout(15) // Giving extra time beyond the client timeout +class StreamableHttpMcpAsyncServerTests extends AbstractMcpAsyncServerTests { + + @Override + protected McpServerTransportProvider createMcpTransportProvider() { + return new StreamableHttpServerTransportProvider(); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/StreamableHttpMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/StreamableHttpMcpSyncServerTests.java new file mode 100644 index 000000000..568abd741 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/StreamableHttpMcpSyncServerTests.java @@ -0,0 +1,22 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.server.transport.StreamableHttpServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import org.junit.jupiter.api.Timeout; + +/** + * Tests for {@link McpSyncServer} using {@link StreamableHttpServerTransportProvider}. + */ +@Timeout(15) // Giving extra time beyond the client timeout +class StreamableHttpMcpSyncServerTests extends AbstractMcpSyncServerTests { + + @Override + protected McpServerTransportProvider createMcpTransportProvider() { + return new StreamableHttpServerTransportProvider(); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProviderTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProviderTests.java new file mode 100644 index 000000000..07980cf97 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProviderTests.java @@ -0,0 +1,349 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.io.IOException; +import java.io.PrintWriter; +import java.io.StringWriter; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpStreamableHttpServerSession; +import jakarta.servlet.AsyncContext; +import jakarta.servlet.ServletException; +import jakarta.servlet.ServletInputStream; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Unit tests for {@link StreamableHttpServerTransportProvider}. + */ +class StreamableHttpServerTransportProviderTests { + + private StreamableHttpServerTransportProvider transportProvider; + + private ObjectMapper objectMapper; + + private McpStreamableHttpServerSession.Factory sessionFactory; + + private McpStreamableHttpServerSession mockSession; + + @BeforeEach + void setUp() { + objectMapper = new ObjectMapper(); + mockSession = mock(McpStreamableHttpServerSession.class); + sessionFactory = mock(McpStreamableHttpServerSession.Factory.class); + + when(sessionFactory.create(anyString())).thenReturn(mockSession); + when(mockSession.getId()).thenReturn("test-session-id"); + when(mockSession.closeGracefully()).thenReturn(Mono.empty()); + when(mockSession.sendNotification(anyString(), any())).thenReturn(Mono.empty()); + + transportProvider = new StreamableHttpServerTransportProvider(objectMapper, "/mcp", null); + transportProvider.setStreamableHttpSessionFactory(sessionFactory); + } + + @Test + void shouldCreateSessionOnFirstRequest() { + // Test session creation directly through the getOrCreateSession method + String sessionId = "test-session-123"; + + McpStreamableHttpServerSession session = transportProvider.getOrCreateSession(sessionId, true); + + assertThat(session).isNotNull(); + verify(sessionFactory).create(sessionId); + } + + @Test + void shouldHandleSSERequest() throws IOException, ServletException { + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletResponse response = mock(HttpServletResponse.class); + AsyncContext asyncContext = mock(AsyncContext.class); + StringWriter stringWriter = new StringWriter(); + PrintWriter printWriter = new PrintWriter(stringWriter); + + String sessionId = "test-session-123"; + when(request.getRequestURI()).thenReturn("/mcp"); + when(request.getMethod()).thenReturn("GET"); + when(request.getHeader("Accept")).thenReturn("text/event-stream"); + when(request.getHeader("Mcp-Session-Id")).thenReturn(sessionId); + when(request.getHeaderNames()).thenReturn(Collections.enumeration(Collections.emptyList())); + when(request.startAsync()).thenReturn(asyncContext); + when(response.getWriter()).thenReturn(printWriter); + when(response.getHeader("Mcp-Session-Id")).thenReturn(sessionId); + + // First create a session + transportProvider.getOrCreateSession(sessionId, true); + + transportProvider.doGet(request, response); + + verify(response).setContentType("text/event-stream"); + verify(response).setCharacterEncoding("UTF-8"); + verify(response).setHeader("Cache-Control", "no-cache"); + verify(response).setHeader("Connection", "keep-alive"); + } + + @Test + void shouldNotifyClients() { + String sessionId = "test-session-123"; + transportProvider.getOrCreateSession(sessionId, true); + + String method = "test/notification"; + String params = "test message"; + + StepVerifier.create(transportProvider.notifyClients(method, params)).verifyComplete(); + + // Verify that the session was created + assertThat(transportProvider.getOrCreateSession(sessionId, false)).isNotNull(); + } + + @Test + void shouldCloseGracefully() { + String sessionId = "test-session-123"; + transportProvider.getOrCreateSession(sessionId, true); + + StepVerifier.create(transportProvider.closeGracefully()).verifyComplete(); + + verify(mockSession).closeGracefully(); + } + + @Test + void shouldHandleInvalidRequestURI() throws IOException, ServletException { + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletResponse response = mock(HttpServletResponse.class); + + when(request.getRequestURI()).thenReturn("/wrong-path"); + when(request.getHeaderNames()).thenReturn(Collections.enumeration(Collections.emptyList())); + + transportProvider.doGet(request, response); + transportProvider.doPost(request, response); + transportProvider.doDelete(request, response); + + verify(response, times(3)).sendError(HttpServletResponse.SC_NOT_FOUND); + } + + @Test + void shouldRejectNonJSONContentType() throws IOException, ServletException { + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletResponse response = mock(HttpServletResponse.class); + StringWriter stringWriter = new StringWriter(); + PrintWriter printWriter = new PrintWriter(stringWriter); + + when(request.getRequestURI()).thenReturn("/mcp"); + when(request.getMethod()).thenReturn("POST"); + when(request.getHeader("Content-Type")).thenReturn("text/plain"); + when(request.getHeaderNames()).thenReturn(Collections.enumeration(Collections.emptyList())); + when(response.getWriter()).thenReturn(printWriter); + + transportProvider.doPost(request, response); + + // The implementation uses sendErrorResponse which sets status to 400, not + // sendError with 415 + verify(response).setStatus(HttpServletResponse.SC_BAD_REQUEST); + verify(response).setContentType("application/json"); + } + + @Test + void shouldRejectInvalidAcceptHeader() throws IOException, ServletException { + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletResponse response = mock(HttpServletResponse.class); + StringWriter stringWriter = new StringWriter(); + PrintWriter printWriter = new PrintWriter(stringWriter); + + when(request.getRequestURI()).thenReturn("/mcp"); + when(request.getMethod()).thenReturn("GET"); + when(request.getHeader("Accept")).thenReturn("text/html"); + when(request.getHeaderNames()).thenReturn(Collections.enumeration(Collections.emptyList())); + when(response.getWriter()).thenReturn(printWriter); + + transportProvider.doGet(request, response); + + // The implementation uses sendErrorResponse which sets status to 400, not + // sendError with 406 + verify(response).setStatus(HttpServletResponse.SC_BAD_REQUEST); + verify(response).setContentType("application/json"); + } + + @Test + void shouldRequireSessionIdForSSE() throws IOException, ServletException { + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletResponse response = mock(HttpServletResponse.class); + StringWriter stringWriter = new StringWriter(); + PrintWriter printWriter = new PrintWriter(stringWriter); + + when(request.getRequestURI()).thenReturn("/mcp"); + when(request.getMethod()).thenReturn("GET"); + when(request.getHeader("Accept")).thenReturn("text/event-stream"); + when(request.getHeader("Mcp-Session-Id")).thenReturn(null); + when(request.getHeaderNames()).thenReturn(Collections.enumeration(Collections.emptyList())); + when(response.getWriter()).thenReturn(printWriter); + + transportProvider.doGet(request, response); + + // The implementation uses sendErrorResponse which sets status to 400 + verify(response).setStatus(HttpServletResponse.SC_BAD_REQUEST); + verify(response).setContentType("application/json"); + } + + @Test + void shouldHandleSessionCleanup() throws IOException, ServletException { + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletResponse response = mock(HttpServletResponse.class); + + String sessionId = "test-session-123"; + when(request.getRequestURI()).thenReturn("/mcp"); + when(request.getMethod()).thenReturn("DELETE"); + when(request.getHeader("Mcp-Session-Id")).thenReturn(sessionId); + when(request.getHeaderNames()).thenReturn(Collections.enumeration(Collections.emptyList())); + + // Create a session first + transportProvider.getOrCreateSession(sessionId, true); + + transportProvider.doDelete(request, response); + + verify(response).setStatus(HttpServletResponse.SC_OK); + verify(mockSession).closeGracefully(); + } + + @Test + void shouldHandleDeleteNonExistentSession() throws IOException, ServletException { + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletResponse response = mock(HttpServletResponse.class); + StringWriter stringWriter = new StringWriter(); + PrintWriter printWriter = new PrintWriter(stringWriter); + + when(request.getRequestURI()).thenReturn("/mcp"); + when(request.getMethod()).thenReturn("DELETE"); + when(request.getHeader("Mcp-Session-Id")).thenReturn("non-existent-session"); + when(request.getHeaderNames()).thenReturn(Collections.enumeration(Collections.emptyList())); + when(response.getWriter()).thenReturn(printWriter); + + transportProvider.doDelete(request, response); + + // The implementation uses sendErrorResponse which sets status to 400, not + // sendError with 404 + verify(response).setStatus(HttpServletResponse.SC_BAD_REQUEST); + verify(response).setContentType("application/json"); + } + + @Test + void shouldHandleMultipleSessions() { + String sessionId1 = "session-1"; + String sessionId2 = "session-2"; + + // Create separate mock sessions for each ID + McpStreamableHttpServerSession mockSession1 = mock(McpStreamableHttpServerSession.class); + McpStreamableHttpServerSession mockSession2 = mock(McpStreamableHttpServerSession.class); + when(mockSession1.getId()).thenReturn(sessionId1); + when(mockSession2.getId()).thenReturn(sessionId2); + when(mockSession1.closeGracefully()).thenReturn(Mono.empty()); + when(mockSession2.closeGracefully()).thenReturn(Mono.empty()); + when(mockSession1.sendNotification(anyString(), any())).thenReturn(Mono.empty()); + when(mockSession2.sendNotification(anyString(), any())).thenReturn(Mono.empty()); + + // Configure factory to return different sessions for different IDs + when(sessionFactory.create(sessionId1)).thenReturn(mockSession1); + when(sessionFactory.create(sessionId2)).thenReturn(mockSession2); + + McpStreamableHttpServerSession session1 = transportProvider.getOrCreateSession(sessionId1, true); + McpStreamableHttpServerSession session2 = transportProvider.getOrCreateSession(sessionId2, true); + + assertThat(session1).isNotNull(); + assertThat(session2).isNotNull(); + assertThat(session1).isNotSameAs(session2); + + // Verify both sessions are created with different IDs + verify(sessionFactory, times(2)).create(anyString()); + } + + @Test + void shouldReuseExistingSession() { + String sessionId = "test-session-123"; + + McpStreamableHttpServerSession session1 = transportProvider.getOrCreateSession(sessionId, true); + McpStreamableHttpServerSession session2 = transportProvider.getOrCreateSession(sessionId, false); + + assertThat(session1).isSameAs(session2); + verify(sessionFactory, times(1)).create(sessionId); + } + + @Test + void shouldHandleAsyncTimeout() throws IOException, ServletException { + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletResponse response = mock(HttpServletResponse.class); + AsyncContext asyncContext = mock(AsyncContext.class); + StringWriter stringWriter = new StringWriter(); + PrintWriter printWriter = new PrintWriter(stringWriter); + + when(request.getRequestURI()).thenReturn("/mcp"); + when(request.getMethod()).thenReturn("GET"); + when(request.getHeader("Accept")).thenReturn("text/event-stream"); + when(request.getHeader("Mcp-Session-Id")).thenReturn("test-session"); + when(request.getHeaderNames()).thenReturn(Collections.enumeration(Collections.emptyList())); + when(request.startAsync()).thenReturn(asyncContext); + when(response.getWriter()).thenReturn(printWriter); + when(response.getHeader("Mcp-Session-Id")).thenReturn("test-session"); + + transportProvider.getOrCreateSession("test-session", true); + transportProvider.doGet(request, response); + + verify(asyncContext).setTimeout(0L); // Updated to match actual implementation + } + + @Test + void shouldBuildWithCustomConfiguration() { + ObjectMapper customMapper = new ObjectMapper(); + String customEndpoint = "/custom-mcp"; + + StreamableHttpServerTransportProvider provider = StreamableHttpServerTransportProvider.builder() + .withObjectMapper(customMapper) + .withMcpEndpoint(customEndpoint) + .withSessionIdProvider(() -> "custom-session-id") + .build(); + + assertThat(provider).isNotNull(); + } + + @Test + void shouldHandleBuilderValidation() { + try { + StreamableHttpServerTransportProvider.builder().withObjectMapper(null).build(); + } + catch (IllegalArgumentException e) { + assertThat(e.getMessage()).contains("ObjectMapper must not be null"); + } + + try { + StreamableHttpServerTransportProvider.builder().withMcpEndpoint("").build(); + } + catch (IllegalArgumentException e) { + assertThat(e.getMessage()).contains("MCP endpoint must not be empty"); + } + } + +} \ No newline at end of file From a01872d624ff0a7d2c47597b51b0c520a8394d44 Mon Sep 17 00:00:00 2001 From: Zachary German Date: Sun, 29 Jun 2025 21:06:33 +0000 Subject: [PATCH 2/5] Integrated dedicated GET/listening stream and multi-transport management --- ...treamableHttpTransportIntegrationTest.java | 26 +- .../server/McpAsyncServer.java | 6 +- .../server/McpAsyncServerExchange.java | 15 - .../server/McpAsyncStreamableHttpServer.java | 38 +- ...StreamableHttpServerTransportProvider.java | 65 ++- .../spec/McpServerSession.java | 165 +++++++- .../spec/McpStreamableHttpServerSession.java | 383 ------------------ ...mableHttpServerTransportProviderTests.java | 24 +- 8 files changed, 235 insertions(+), 487 deletions(-) delete mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableHttpServerSession.java diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/StreamableHttpTransportIntegrationTest.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/StreamableHttpTransportIntegrationTest.java index 42eec8db0..d99f34959 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/StreamableHttpTransportIntegrationTest.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/StreamableHttpTransportIntegrationTest.java @@ -58,12 +58,12 @@ void setUp() { // Set up session factory with proper server capabilities McpSchema.ServerCapabilities serverCapabilities = new McpSchema.ServerCapabilities(null, null, null, null, null, null); - serverTransportProvider.setStreamableHttpSessionFactory( - sessionId -> new io.modelcontextprotocol.spec.McpStreamableHttpServerSession(sessionId, - java.time.Duration.ofSeconds(30), - request -> reactor.core.publisher.Mono.just(new McpSchema.InitializeResult("2025-06-18", - serverCapabilities, new McpSchema.Implementation("Test Server", "1.0.0"), null)), - () -> reactor.core.publisher.Mono.empty(), java.util.Map.of(), java.util.Map.of())); + serverTransportProvider + .setStreamableHttpSessionFactory(sessionId -> new io.modelcontextprotocol.spec.McpServerSession(sessionId, + java.time.Duration.ofSeconds(30), + request -> reactor.core.publisher.Mono.just(new McpSchema.InitializeResult("2025-06-18", + serverCapabilities, new McpSchema.Implementation("Test Server", "1.0.0"), null)), + () -> reactor.core.publisher.Mono.empty(), java.util.Map.of(), java.util.Map.of())); tomcat = TomcatTestUtil.createTomcatServer("", PORT, serverTransportProvider); try { @@ -132,14 +132,14 @@ void shouldCallImmediateToolSuccessfully() { McpSchema.ServerCapabilities serverCapabilities = new McpSchema.ServerCapabilities(null, null, null, null, null, new McpSchema.ServerCapabilities.ToolCapabilities(true)); serverTransportProvider - .setStreamableHttpSessionFactory(sessionId -> new io.modelcontextprotocol.spec.McpStreamableHttpServerSession( - sessionId, java.time.Duration.ofSeconds(30), + .setStreamableHttpSessionFactory(sessionId -> new io.modelcontextprotocol.spec.McpServerSession(sessionId, + java.time.Duration.ofSeconds(30), request -> reactor.core.publisher.Mono.just(new McpSchema.InitializeResult("2025-06-18", serverCapabilities, new McpSchema.Implementation("Test Server", "1.0.0"), null)), () -> reactor.core.publisher.Mono.empty(), java.util.Map.of("tools/call", - (io.modelcontextprotocol.spec.McpStreamableHttpServerSession.RequestHandler) ( - exchange, params) -> tool.call().apply(exchange, (Map) params)), + (io.modelcontextprotocol.spec.McpServerSession.RequestHandler) (exchange, + params) -> tool.call().apply(exchange, (Map) params)), java.util.Map.of())); var mcpClient = clientBuilder.build(); @@ -174,12 +174,12 @@ void shouldCallStreamingToolSuccessfully() { McpSchema.ServerCapabilities serverCapabilities = new McpSchema.ServerCapabilities(null, null, null, null, null, new McpSchema.ServerCapabilities.ToolCapabilities(true)); serverTransportProvider - .setStreamableHttpSessionFactory(sessionId -> new io.modelcontextprotocol.spec.McpStreamableHttpServerSession( - sessionId, java.time.Duration.ofSeconds(30), + .setStreamableHttpSessionFactory(sessionId -> new io.modelcontextprotocol.spec.McpServerSession(sessionId, + java.time.Duration.ofSeconds(30), request -> reactor.core.publisher.Mono.just(new McpSchema.InitializeResult("2025-06-18", serverCapabilities, new McpSchema.Implementation("Test Server", "1.0.0"), null)), () -> reactor.core.publisher.Mono.empty(), java.util.Map.of("tools/call", - (io.modelcontextprotocol.spec.McpStreamableHttpServerSession.StreamingRequestHandler) new io.modelcontextprotocol.spec.McpStreamableHttpServerSession.StreamingRequestHandler() { + (io.modelcontextprotocol.spec.McpServerSession.StreamingRequestHandler) new io.modelcontextprotocol.spec.McpServerSession.StreamingRequestHandler() { @Override public Mono handle( io.modelcontextprotocol.server.McpAsyncServerExchange exchange, Object params) { diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 59b1afca3..b63880360 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -183,9 +183,9 @@ public class McpAsyncServer { notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); - mcpTransportProvider.setSessionFactory( - transport -> new McpServerSession(UUID.randomUUID().toString(), requestTimeout, transport, - this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, notificationHandlers)); + mcpTransportProvider.setSessionFactory(listeningTransport -> new McpServerSession(UUID.randomUUID().toString(), + requestTimeout, listeningTransport, this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, + notificationHandlers)); } // --------------------------------------- diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java index 412875ab3..893a2f812 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java @@ -13,7 +13,6 @@ import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; import io.modelcontextprotocol.spec.McpSession; -import io.modelcontextprotocol.spec.McpStreamableHttpServerSession; import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.util.Assert; import reactor.core.publisher.Mono; @@ -61,20 +60,6 @@ public McpAsyncServerExchange(McpServerSession session, McpSchema.ClientCapabili this.clientInfo = clientInfo; } - /** - * Create a new asynchronous exchange with the client. - * @param session The server session representing a 1-1 interaction. - * @param clientCapabilities The client capabilities that define the supported - * features and functionality. - * @param clientInfo The client implementation information. - */ - public McpAsyncServerExchange(McpStreamableHttpServerSession session, - McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo) { - this.session = session; - this.clientCapabilities = clientCapabilities; - this.clientInfo = clientInfo; - } - /** * Get the client capabilities that define the supported features and functionality. * @return The client capabilities diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncStreamableHttpServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncStreamableHttpServer.java index 7e13f93ca..8fc96b001 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncStreamableHttpServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncStreamableHttpServer.java @@ -15,7 +15,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.server.transport.StreamableHttpServerTransportProvider; -import io.modelcontextprotocol.spec.McpStreamableHttpServerSession; +import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.DeafaultMcpUriTemplateManagerFactory; @@ -86,7 +86,7 @@ public class McpAsyncStreamableHttpServer { * Sets up the request handlers for standard MCP methods. */ private void setupRequestHandlers() { - Map> requestHandlers = new HashMap<>(); + Map> requestHandlers = new HashMap<>(); // Ping handler requestHandlers.put(McpSchema.METHOD_PING, (exchange, params) -> Mono.just(Map.of())); @@ -123,15 +123,15 @@ private void setupRequestHandlers() { this.requestHandlers = requestHandlers; } - private Map> requestHandlers; + private Map> requestHandlers; - private Map notificationHandlers; + private Map notificationHandlers; /** * Sets up notification handlers. */ private void setupNotificationHandlers() { - Map handlers = new HashMap<>(); + Map handlers = new HashMap<>(); handlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (exchange, params) -> { logger.info("[INIT] Received initialized notification - initialization complete!"); @@ -150,7 +150,7 @@ private void setupNotificationHandlers() { private void setupSessionFactory() { setupNotificationHandlers(); - httpTransportProvider.setStreamableHttpSessionFactory(sessionId -> new McpStreamableHttpServerSession(sessionId, + httpTransportProvider.setStreamableHttpSessionFactory(sessionId -> new McpServerSession(sessionId, requestTimeout, this::handleInitializeRequest, Mono::empty, requestHandlers, notificationHandlers)); } @@ -181,7 +181,7 @@ private Mono handleInitializeRequest(McpSchema.Initi } // Request handler creation methods - private McpStreamableHttpServerSession.RequestHandler createToolsListHandler() { + private McpServerSession.RequestHandler createToolsListHandler() { return (exchange, params) -> { var regularTools = features.tools().stream().map(McpServerFeatures.AsyncToolSpecification::tool).toList(); var streamingTools = features.streamTools() @@ -194,8 +194,8 @@ private McpStreamableHttpServerSession.RequestHandler }; } - private McpStreamableHttpServerSession.RequestHandler createToolsCallHandler() { - return new McpStreamableHttpServerSession.StreamingRequestHandler() { + private McpServerSession.RequestHandler createToolsCallHandler() { + return new McpServerSession.StreamingRequestHandler() { @Override public Mono handle(McpAsyncServerExchange exchange, Object params) { var callToolRequest = objectMapper.convertValue(params, McpSchema.CallToolRequest.class); @@ -252,7 +252,7 @@ public Flux handleStreaming(McpAsyncServerExchange exc }; } - private McpStreamableHttpServerSession.RequestHandler createResourcesListHandler() { + private McpServerSession.RequestHandler createResourcesListHandler() { return (exchange, params) -> { var resources = features.resources() .values() @@ -263,7 +263,7 @@ private McpStreamableHttpServerSession.RequestHandler createResourcesReadHandler() { + private McpServerSession.RequestHandler createResourcesReadHandler() { return (exchange, params) -> { var resourceRequest = objectMapper.convertValue(params, McpSchema.ReadResourceRequest.class); var resourceUri = resourceRequest.uri(); @@ -278,12 +278,12 @@ private McpStreamableHttpServerSession.RequestHandler createResourceTemplatesListHandler() { + private McpServerSession.RequestHandler createResourceTemplatesListHandler() { return (exchange, params) -> Mono .just(new McpSchema.ListResourceTemplatesResult(features.resourceTemplates(), null)); } - private McpStreamableHttpServerSession.RequestHandler createPromptsListHandler() { + private McpServerSession.RequestHandler createPromptsListHandler() { return (exchange, params) -> { var prompts = features.prompts() .values() @@ -294,7 +294,7 @@ private McpStreamableHttpServerSession.RequestHandler createPromptsGetHandler() { + private McpServerSession.RequestHandler createPromptsGetHandler() { return (exchange, params) -> { var promptRequest = objectMapper.convertValue(params, McpSchema.GetPromptRequest.class); @@ -308,7 +308,7 @@ private McpStreamableHttpServerSession.RequestHandler }; } - private McpStreamableHttpServerSession.RequestHandler createLoggingSetLevelHandler() { + private McpServerSession.RequestHandler createLoggingSetLevelHandler() { return (exchange, params) -> { var setLevelRequest = objectMapper.convertValue(params, McpSchema.SetLevelRequest.class); exchange.setMinLoggingLevel(setLevelRequest.level()); @@ -316,7 +316,7 @@ private McpStreamableHttpServerSession.RequestHandler createLoggingSetLe }; } - private McpStreamableHttpServerSession.RequestHandler createCompletionCompleteHandler() { + private McpServerSession.RequestHandler createCompletionCompleteHandler() { return (exchange, params) -> { var completeRequest = objectMapper.convertValue(params, McpSchema.CompleteRequest.class); @@ -330,7 +330,7 @@ private McpStreamableHttpServerSession.RequestHandler }; } - private McpStreamableHttpServerSession.NotificationHandler createRootsListChangedHandler() { + private McpServerSession.NotificationHandler createRootsListChangedHandler() { return (exchange, params) -> { var rootsChangeConsumers = features.rootsChangeConsumers(); if (rootsChangeConsumers.isEmpty()) { @@ -418,8 +418,8 @@ public static Builder builder() { * transport. * *

- * This builder provides a fluent API for configuring Streamable HTTP MCP - * servers with enhanced features: + * This builder provides a fluent API for configuring Streamable HTTP MCP servers with + * enhanced features: *