diff --git a/mcp-spring/mcp-spring-webflux/pom.xml b/mcp-spring/mcp-spring-webflux/pom.xml index 26452fe95..c519b5580 100644 --- a/mcp-spring/mcp-spring-webflux/pom.xml +++ b/mcp-spring/mcp-spring-webflux/pom.xml @@ -127,6 +127,28 @@ test + + + org.apache.tomcat.embed + tomcat-embed-core + ${tomcat.version} + test + + + org.apache.tomcat.embed + tomcat-embed-websocket + ${tomcat.version} + test + + + + + jakarta.servlet + jakarta.servlet-api + ${jakarta.servlet.version} + test + + diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java index 42b91d14e..1cc7b2feb 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java @@ -12,6 +12,8 @@ import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; +import io.modelcontextprotocol.spec.McpSchema.McpId; + import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -161,7 +163,7 @@ void testBuilderPattern() { @Test void testMessageProcessing() { // Create a test message - JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", + JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", McpId.of("test-id"), Map.of("key", "value")); // Simulate receiving the message @@ -192,7 +194,7 @@ void testResponseMessageProcessing() { """); // Create and send a request message - JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", + JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", McpId.of("test-id"), Map.of("key", "value")); // Verify message handling @@ -216,7 +218,7 @@ void testErrorMessageProcessing() { """); // Create and send a request message - JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", + JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", McpId.of("test-id"), Map.of("key", "value")); // Verify message handling @@ -246,7 +248,7 @@ void testGracefulShutdown() { StepVerifier.create(transport.closeGracefully()).verifyComplete(); // Create a test message - JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", + JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", McpId.of("test-id"), Map.of("key", "value")); // Verify message is not processed after shutdown @@ -292,10 +294,10 @@ void testMultipleMessageProcessing() { """); // Create and send corresponding messages - JSONRPCRequest message1 = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method1", "id1", + JSONRPCRequest message1 = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method1", McpId.of("id1"), Map.of("key", "value1")); - JSONRPCRequest message2 = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method2", "id2", + JSONRPCRequest message2 = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method2", McpId.of("id2"), Map.of("key", "value2")); // Verify both messages are processed diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/StreamableHttpTransportIntegrationTest.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/StreamableHttpTransportIntegrationTest.java new file mode 100644 index 000000000..ec8b594e5 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/StreamableHttpTransportIntegrationTest.java @@ -0,0 +1,256 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Map; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.server.McpAsyncStreamableHttpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.transport.StreamableHttpServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; + +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for @link{StreamableHttpServerTransportProvider} with + * + * @link{WebClientStreamableHttpTransport}. + */ +class StreamableHttpTransportIntegrationTest { + + private static final int PORT = TomcatTestUtil.findAvailablePort(); + + private static final String ENDPOINT = "/mcp"; + + private StreamableHttpServerTransportProvider serverTransportProvider; + + private McpClient.AsyncSpec clientBuilder; + + private Tomcat tomcat; + + @BeforeEach + void setUp() { + serverTransportProvider = new StreamableHttpServerTransportProvider(new ObjectMapper(), ENDPOINT, null); + + // Set up session factory with proper server capabilities + McpSchema.ServerCapabilities serverCapabilities = new McpSchema.ServerCapabilities(null, null, null, null, null, + null); + serverTransportProvider + .setStreamableHttpSessionFactory(sessionId -> new io.modelcontextprotocol.spec.McpServerSession(sessionId, + java.time.Duration.ofSeconds(30), + request -> reactor.core.publisher.Mono.just(new McpSchema.InitializeResult("2024-11-05", + serverCapabilities, new McpSchema.Implementation("Test Server", "1.0.0"), null)), + () -> reactor.core.publisher.Mono.empty(), java.util.Map.of(), java.util.Map.of())); + + tomcat = TomcatTestUtil.createTomcatServer("", PORT, serverTransportProvider); + try { + tomcat.start(); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + WebClientStreamableHttpTransport clientTransport = WebClientStreamableHttpTransport + .builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) + .endpoint(ENDPOINT) + .objectMapper(new ObjectMapper()) + .build(); + + clientBuilder = McpClient.async(clientTransport) + .clientInfo(new McpSchema.Implementation("Test Client", "1.0.0")); + } + + @AfterEach + void tearDown() { + if (serverTransportProvider != null) { + serverTransportProvider.closeGracefully().block(); + } + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + @Test + void shouldInitializeSuccessfully() { + // The server is already configured via the session factory in setUp + var mcpClient = clientBuilder.build(); + try { + InitializeResult result = mcpClient.initialize().block(); + assertThat(result).isNotNull(); + assertThat(result.serverInfo().name()).isEqualTo("Test Server"); + } + finally { + mcpClient.close(); + } + } + + @Test + void shouldCallImmediateToolSuccessfully() { + var callResponse = new CallToolResult(List.of(new McpSchema.TextContent("Tool executed successfully")), null); + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("test-tool", "Test tool description", emptyJsonSchema), + (exchange, request) -> Mono.just(callResponse)); + + // Configure session factory with tool handler + McpSchema.ServerCapabilities serverCapabilities = new McpSchema.ServerCapabilities(null, null, null, null, null, + new McpSchema.ServerCapabilities.ToolCapabilities(true)); + serverTransportProvider + .setStreamableHttpSessionFactory(sessionId -> new io.modelcontextprotocol.spec.McpServerSession(sessionId, + java.time.Duration.ofSeconds(30), + request -> reactor.core.publisher.Mono.just(new McpSchema.InitializeResult("2024-11-05", + serverCapabilities, new McpSchema.Implementation("Test Server", "1.0.0"), null)), + () -> reactor.core.publisher.Mono.empty(), + java.util.Map.of("tools/call", + (io.modelcontextprotocol.spec.McpServerSession.RequestHandler) (exchange, + params) -> tool.call().apply(exchange, (Map) params)), + java.util.Map.of())); + + var mcpClient = clientBuilder.build(); + try { + mcpClient.initialize().block(); + CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())).block(); + assertThat(result).isNotNull(); + assertThat(result.content()).hasSize(1); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()) + .isEqualTo("Tool executed successfully"); + } + finally { + mcpClient.close(); + } + } + + @Test + void shouldCallStreamingToolSuccessfully() { + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + McpServerFeatures.AsyncStreamingToolSpecification streamingTool = new McpServerFeatures.AsyncStreamingToolSpecification( + new McpSchema.Tool("streaming-tool", "Streaming test tool", emptyJsonSchema), + (exchange, request) -> Flux.range(1, 3) + .map(i -> new CallToolResult(List.of(new McpSchema.TextContent("Step " + i)), null))); + + // Configure session factory with streaming tool handler + McpSchema.ServerCapabilities serverCapabilities = new McpSchema.ServerCapabilities(null, null, null, null, null, + new McpSchema.ServerCapabilities.ToolCapabilities(true)); + serverTransportProvider + .setStreamableHttpSessionFactory(sessionId -> new io.modelcontextprotocol.spec.McpServerSession(sessionId, + java.time.Duration.ofSeconds(30), + request -> reactor.core.publisher.Mono.just(new McpSchema.InitializeResult("2024-11-05", + serverCapabilities, new McpSchema.Implementation("Test Server", "1.0.0"), null)), + () -> reactor.core.publisher.Mono.empty(), java.util.Map.of("tools/call", + (io.modelcontextprotocol.spec.McpServerSession.StreamingRequestHandler) new io.modelcontextprotocol.spec.McpServerSession.StreamingRequestHandler() { + @Override + public Mono handle( + io.modelcontextprotocol.server.McpAsyncServerExchange exchange, Object params) { + return streamingTool.call().apply(exchange, (Map) params).next(); + } + + @Override + public Flux handleStreaming( + io.modelcontextprotocol.server.McpAsyncServerExchange exchange, Object params) { + return streamingTool.call().apply(exchange, (Map) params); + } + }), + java.util.Map.of())); + + var mcpClient = clientBuilder.build(); + try { + mcpClient.initialize().block(); + CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("streaming-tool", Map.of())) + .block(); + assertThat(result).isNotNull(); + assertThat(result.content()).hasSize(1); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).startsWith("Step"); + } + finally { + mcpClient.close(); + } + } + + @Test + void shouldReceiveNotificationThroughGetStream() throws InterruptedException { + CountDownLatch notificationLatch = new CountDownLatch(1); + AtomicReference receivedEvent = new AtomicReference<>(); + AtomicReference sessionId = new AtomicReference<>(); + + WebClient webClient = WebClient.create("http://localhost:" + PORT); + String initMessage = "{\"jsonrpc\":\"2.0\",\"id\":1,\"method\":\"initialize\",\"params\":{\"protocolVersion\":\"2024-11-05\",\"capabilities\":{},\"clientInfo\":{\"name\":\"Test\",\"version\":\"1.0\"}}}"; + + // Initialize and get session ID + webClient.post() + .uri(ENDPOINT) + .header("Accept", "application/json, text/event-stream") + .header("Content-Type", "application/json") + .bodyValue(initMessage) + .retrieve() + .toBodilessEntity() + .doOnNext(response -> sessionId.set(response.getHeaders().getFirst("Mcp-Session-Id"))) + .block(); + + // Establish SSE stream + webClient.get() + .uri(ENDPOINT) + .header("Accept", "text/event-stream") + .header("Mcp-Session-Id", sessionId.get()) + .retrieve() + .bodyToFlux(String.class) + .filter(line -> line.contains("test/notification")) + .doOnNext(event -> { + receivedEvent.set(event); + notificationLatch.countDown(); + }) + .subscribe(); + + // Send notification after delay + Mono.delay(Duration.ofMillis(200)) + .then(serverTransportProvider.notifyClients("test/notification", "test message")) + .subscribe(); + + assertThat(notificationLatch.await(5, TimeUnit.SECONDS)).isTrue(); + assertThat(receivedEvent.get()).isNotNull(); + assertThat(receivedEvent.get()).contains("test/notification"); + } + +} \ No newline at end of file diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java new file mode 100644 index 000000000..a9fa4d5bb --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java @@ -0,0 +1,63 @@ +/* +* Copyright 2025 - 2025 the original author or authors. +*/ +package io.modelcontextprotocol.server.transport; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.ServerSocket; + +import jakarta.servlet.Servlet; +import org.apache.catalina.Context; +import org.apache.catalina.startup.Tomcat; + +/** + * @author Christian Tzolov + */ +public class TomcatTestUtil { + + TomcatTestUtil() { + // Prevent instantiation + } + + public static Tomcat createTomcatServer(String contextPath, int port, Servlet servlet) { + + var tomcat = new Tomcat(); + tomcat.setPort(port); + + String baseDir = System.getProperty("java.io.tmpdir"); + tomcat.setBaseDir(baseDir); + + Context context = tomcat.addContext(contextPath, baseDir); + + // Add transport servlet to Tomcat + org.apache.catalina.Wrapper wrapper = context.createWrapper(); + wrapper.setName("mcpServlet"); + wrapper.setServlet(servlet); + wrapper.setLoadOnStartup(1); + wrapper.setAsyncSupported(true); + context.addChild(wrapper); + context.addServletMappingDecoded("/*", "mcpServlet"); + + var connector = tomcat.getConnector(); + connector.setAsyncTimeout(3000); + + return tomcat; + } + + /** + * Finds an available port on the local machine. + * @return an available port number + * @throws IllegalStateException if no available port can be found + */ + public static int findAvailablePort() { + try (final ServerSocket socket = new ServerSocket()) { + socket.bind(new InetSocketAddress(0)); + return socket.getLocalPort(); + } + catch (final IOException e) { + throw new IllegalStateException("Cannot bind to an available port!", e); + } + } + +} \ No newline at end of file diff --git a/mcp-spring/mcp-spring-webflux/src/test/resources/logback-test.xml b/mcp-spring/mcp-spring-webflux/src/test/resources/logback-test.xml new file mode 100644 index 000000000..37f43a17a --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/resources/logback-test.xml @@ -0,0 +1,15 @@ + + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n + + + + + + + + + + \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 02ad955b9..bc8623433 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -88,12 +88,16 @@ public class McpAsyncServer { private final McpSchema.ServerCapabilities serverCapabilities; + private final boolean isStreamableHttp; + private final McpSchema.Implementation serverInfo; private final String instructions; private final CopyOnWriteArrayList tools = new CopyOnWriteArrayList<>(); + private final CopyOnWriteArrayList streamTools = new CopyOnWriteArrayList<>(); + private final CopyOnWriteArrayList resourceTemplates = new CopyOnWriteArrayList<>(); private final ConcurrentHashMap resources = new ConcurrentHashMap<>(); @@ -119,7 +123,7 @@ public class McpAsyncServer { */ McpAsyncServer(McpServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, McpServerFeatures.Async features, Duration requestTimeout, - McpUriTemplateManagerFactory uriTemplateManagerFactory) { + McpUriTemplateManagerFactory uriTemplateManagerFactory, boolean isStreamableHttp) { this.mcpTransportProvider = mcpTransportProvider; this.objectMapper = objectMapper; this.serverInfo = features.serverInfo(); @@ -131,6 +135,7 @@ public class McpAsyncServer { this.prompts.putAll(features.prompts()); this.completions.putAll(features.completions()); this.uriTemplateManagerFactory = uriTemplateManagerFactory; + this.isStreamableHttp = isStreamableHttp; Map> requestHandlers = new HashMap<>(); @@ -183,9 +188,16 @@ public class McpAsyncServer { notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); - mcpTransportProvider.setSessionFactory( - transport -> new McpServerSession(UUID.randomUUID().toString(), requestTimeout, transport, - this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, notificationHandlers)); + mcpTransportProvider.setSessionFactory(listeningTransport -> new McpServerSession(UUID.randomUUID().toString(), + requestTimeout, listeningTransport, this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, + notificationHandlers)); + } + + // Alternate constructor for HTTP+SSE servers (past spec) + McpAsyncServer(McpServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, + McpServerFeatures.Async features, Duration requestTimeout, + McpUriTemplateManagerFactory uriTemplateManagerFactory) { + this(mcpTransportProvider, objectMapper, features, requestTimeout, uriTemplateManagerFactory, false); } // --------------------------------------- @@ -214,7 +226,6 @@ private Mono asyncInitializeRequestHandler( "Client requested unsupported protocol version: {}, so the server will suggest the {} version instead", initializeRequest.protocolVersion(), serverProtocolVersion); } - return Mono.just(new McpSchema.InitializeResult(serverProtocolVersion, this.serverCapabilities, this.serverInfo, this.instructions)); }); @@ -330,6 +341,69 @@ public Mono removeTool(String toolName) { }); } + /** + * Add a new tool specification at runtime. + * @param toolSpecification The tool specification to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addStreamTool(McpServerFeatures.AsyncStreamingToolSpecification toolSpecification) { + if (toolSpecification == null) { + return Mono.error(new McpError("Tool specification must not be null")); + } + if (toolSpecification.tool() == null) { + return Mono.error(new McpError("Tool must not be null")); + } + if (toolSpecification.call() == null) { + return Mono.error(new McpError("Tool call handler must not be null")); + } + if (this.serverCapabilities.tools() == null) { + return Mono.error(new McpError("Server must be configured with tool capabilities")); + } + + return Mono.defer(() -> { + // Check for duplicate tool names + if (this.streamTools.stream().anyMatch(th -> th.tool().name().equals(toolSpecification.tool().name()))) { + return Mono + .error(new McpError("Tool with name '" + toolSpecification.tool().name() + "' already exists")); + } + + this.streamTools.add(toolSpecification); + logger.debug("Added tool handler: {}", toolSpecification.tool().name()); + + if (this.serverCapabilities.tools().listChanged()) { + return notifyToolsListChanged(); + } + return Mono.empty(); + }); + } + + /** + * Remove a tool handler at runtime. + * @param toolName The name of the tool handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removeStreamTool(String toolName) { + if (toolName == null) { + return Mono.error(new McpError("Tool name must not be null")); + } + if (this.serverCapabilities.tools() == null) { + return Mono.error(new McpError("Server must be configured with tool capabilities")); + } + + return Mono.defer(() -> { + boolean removed = this.tools + .removeIf(toolSpecification -> toolSpecification.tool().name().equals(toolName)); + if (removed) { + logger.debug("Removed tool handler: {}", toolName); + if (this.serverCapabilities.tools().listChanged()) { + return notifyToolsListChanged(); + } + return Mono.empty(); + } + return Mono.error(new McpError("Tool with name '" + toolName + "' not found")); + }); + } + /** * Notifies clients that the list of available tools has changed. * @return A Mono that completes when all clients have been notified @@ -340,29 +414,97 @@ public Mono notifyToolsListChanged() { private McpServerSession.RequestHandler toolsListRequestHandler() { return (exchange, params) -> { - List tools = this.tools.stream().map(McpServerFeatures.AsyncToolSpecification::tool).toList(); + List tools = new ArrayList<>(); + tools.addAll(this.tools.stream().map(McpServerFeatures.AsyncToolSpecification::tool).toList()); + tools.addAll( + this.streamTools.stream().map(McpServerFeatures.AsyncStreamingToolSpecification::tool).toList()); return Mono.just(new McpSchema.ListToolsResult(tools, null)); }; } private McpServerSession.RequestHandler toolsCallRequestHandler() { - return (exchange, params) -> { - McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(params, - new TypeReference() { - }); + if (isStreamableHttp) { + return new McpServerSession.StreamingRequestHandler() { + @Override + public Mono handle(McpAsyncServerExchange exchange, Object params) { + var callToolRequest = objectMapper.convertValue(params, McpSchema.CallToolRequest.class); + + // Check regular tools first + var regularTool = tools.stream() + .filter(tool -> callToolRequest.name().equals(tool.tool().name())) + .findFirst(); + + if (regularTool.isPresent()) { + return regularTool.get().call().apply(exchange, callToolRequest.arguments()); + } + + // Check streaming tools (take first result) + var streamingTool = streamTools.stream() + .filter(tool -> callToolRequest.name().equals(tool.tool().name())) + .findFirst(); + + if (streamingTool.isPresent()) { + return streamingTool.get().call().apply(exchange, callToolRequest.arguments()).next(); + } + + return Mono.error(new McpError("Tool not found: " + callToolRequest.name())); + } - Optional toolSpecification = this.tools.stream() - .filter(tr -> callToolRequest.name().equals(tr.tool().name())) - .findAny(); + @Override + public Flux handleStreaming(McpAsyncServerExchange exchange, Object params) { + var callToolRequest = objectMapper.convertValue(params, McpSchema.CallToolRequest.class); - if (toolSpecification.isEmpty()) { - return Mono.error(new McpError("Tool not found: " + callToolRequest.name())); - } + // Check streaming tools first (preferred for streaming) + var streamingTool = streamTools.stream() + .filter(tool -> callToolRequest.name().equals(tool.tool().name())) + .findFirst(); - return toolSpecification.map(tool -> tool.call().apply(exchange, callToolRequest.arguments())) - .orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name()))); - }; + if (streamingTool.isPresent()) { + return streamingTool.get().call().apply(exchange, callToolRequest.arguments()); + } + + // Fallback to regular tools (convert Mono to Flux) + var regularTool = tools.stream() + .filter(tool -> callToolRequest.name().equals(tool.tool().name())) + .findFirst(); + + if (regularTool.isPresent()) { + return regularTool.get().call().apply(exchange, callToolRequest.arguments()).flux(); + } + + return Flux.error(new McpError("Tool not found: " + callToolRequest.name())); + } + }; + } + else { + return (exchange, params) -> { + McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(params, + new TypeReference() { + }); + + // Check regular tools first + Optional toolSpecification = this.tools.stream() + .filter(tr -> callToolRequest.name().equals(tr.tool().name())) + .findAny(); + + if (toolSpecification.isPresent()) { + return toolSpecification.get().call().apply(exchange, callToolRequest.arguments()); + } + + // Check streaming tools (take first result) + Optional streamToolSpecification = this.streamTools + .stream() + .filter(tr -> callToolRequest.name().equals(tr.tool().name())) + .findAny(); + + if (streamToolSpecification.isPresent()) { + return streamToolSpecification.get().call().apply(exchange, callToolRequest.arguments()).next(); + } + + return Mono.error(new McpError("Tool not found: " + callToolRequest.name())); + }; + } } // --------------------------------------- diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncStreamableHttpServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncStreamableHttpServer.java new file mode 100644 index 000000000..090aa1d09 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncStreamableHttpServer.java @@ -0,0 +1,650 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Supplier; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.server.transport.StreamableHttpServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.DeafaultMcpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * The Model Context Protocol (MCP) server implementation that provides asynchronous + * communication using Project Reactor's Mono and Flux types. + * + *

+ * This server implements the MCP specification, enabling AI models to expose tools, + * resources, and prompts through a standardized interface. Key features include: + *

    + *
  • Asynchronous communication using reactive programming patterns + *
  • Dynamic tool registration and management + *
  • Resource handling with URI-based addressing + *
  • Prompt template management + *
  • Real-time client notifications for state changes + *
  • Structured logging with configurable severity levels + *
  • Support for client-side AI model sampling + *
+ * + *

+ * The server follows a lifecycle: + *

    + *
  1. Initialization - Accepts client connections and negotiates capabilities + *
  2. Normal Operation - Handles client requests and sends notifications + *
  3. Graceful Shutdown - Ensures clean connection termination + *
+ * + *

+ * This implementation uses Project Reactor for non-blocking operations, making it + * suitable for high-throughput scenarios and reactive applications. All operations return + * Mono or Flux types that can be composed into reactive pipelines. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + * @author Jihoon Kim + * @author Zachary German + */ +public class McpAsyncStreamableHttpServer { + + private static final Logger logger = LoggerFactory.getLogger(McpAsyncStreamableHttpServer.class); + + private final StreamableHttpServerTransportProvider httpTransportProvider; + + private final ObjectMapper objectMapper; + + private final McpSchema.ServerCapabilities serverCapabilities; + + private final McpSchema.Implementation serverInfo; + + private final String instructions; + + private final Duration requestTimeout; + + private final McpUriTemplateManagerFactory uriTemplateManagerFactory; + + // Core server features + private final McpServerFeatures.Async features; + + /** + * Creates a new McpAsyncStreamableHttpServer. + */ + McpAsyncStreamableHttpServer(StreamableHttpServerTransportProvider httpTransportProvider, ObjectMapper objectMapper, + McpServerFeatures.Async features, Duration requestTimeout, + McpUriTemplateManagerFactory uriTemplateManagerFactory) { + this.httpTransportProvider = httpTransportProvider; + this.objectMapper = objectMapper; + this.features = features; + this.serverInfo = features.serverInfo(); + this.serverCapabilities = features.serverCapabilities(); + this.instructions = features.instructions(); + this.requestTimeout = requestTimeout; + this.uriTemplateManagerFactory = uriTemplateManagerFactory != null ? uriTemplateManagerFactory + : new DeafaultMcpUriTemplateManagerFactory(); + + setupRequestHandlers(); + setupSessionFactory(); + } + + public McpServerTransportProvider getTransportProvider() { + return this.httpTransportProvider; + } + + /** + * Sets up the request handlers for standard MCP methods. + */ + private void setupRequestHandlers() { + Map> requestHandlers = new HashMap<>(); + + // Ping handler + requestHandlers.put(McpSchema.METHOD_PING, (exchange, params) -> Mono.just(Map.of())); + + // Tool handlers + if (serverCapabilities.tools() != null) { + requestHandlers.put(McpSchema.METHOD_TOOLS_LIST, createToolsListHandler()); + requestHandlers.put(McpSchema.METHOD_TOOLS_CALL, createToolsCallHandler()); + } + + // Resource handlers + if (serverCapabilities.resources() != null) { + requestHandlers.put(McpSchema.METHOD_RESOURCES_LIST, createResourcesListHandler()); + requestHandlers.put(McpSchema.METHOD_RESOURCES_READ, createResourcesReadHandler()); + requestHandlers.put(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, createResourceTemplatesListHandler()); + } + + // Prompt handlers + if (serverCapabilities.prompts() != null) { + requestHandlers.put(McpSchema.METHOD_PROMPT_LIST, createPromptsListHandler()); + requestHandlers.put(McpSchema.METHOD_PROMPT_GET, createPromptsGetHandler()); + } + + // Logging handlers + if (serverCapabilities.logging() != null) { + requestHandlers.put(McpSchema.METHOD_LOGGING_SET_LEVEL, createLoggingSetLevelHandler()); + } + + // Completion handlers + if (serverCapabilities.completions() != null) { + requestHandlers.put(McpSchema.METHOD_COMPLETION_COMPLETE, createCompletionCompleteHandler()); + } + + this.requestHandlers = requestHandlers; + } + + private Map> requestHandlers; + + private Map notificationHandlers; + + /** + * Sets up notification handlers. + */ + private void setupNotificationHandlers() { + Map handlers = new HashMap<>(); + + handlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (exchange, params) -> { + logger.info("[INIT] Received initialized notification - initialization complete!"); + return Mono.empty(); + }); + + // Roots change notification handler + handlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, createRootsListChangedHandler()); + + this.notificationHandlers = handlers; + } + + /** + * Sets up the session factory for the HTTP transport provider. + */ + private void setupSessionFactory() { + setupNotificationHandlers(); + + httpTransportProvider.setStreamableHttpSessionFactory(sessionId -> new McpServerSession(sessionId, + requestTimeout, this::handleInitializeRequest, Mono::empty, requestHandlers, notificationHandlers)); + } + + /** + * Handles initialization requests from clients. + */ + private Mono handleInitializeRequest(McpSchema.InitializeRequest initializeRequest) { + return Mono.defer(() -> { + logger.info("[INIT] Client initialize request - Protocol: {}, Capabilities: {}, Info: {}", + initializeRequest.protocolVersion(), initializeRequest.capabilities(), + initializeRequest.clientInfo()); + + // Protocol version negotiation + String serverProtocolVersion = McpSchema.LATEST_PROTOCOL_VERSION; + if (!McpSchema.LATEST_PROTOCOL_VERSION.equals(initializeRequest.protocolVersion())) { + logger.warn("[INIT] Client requested protocol version: {}, server supports: {}", + initializeRequest.protocolVersion(), serverProtocolVersion); + } + + logger.debug("[INIT] Server capabilities: {}", serverCapabilities); + logger.debug("[INIT] Server info: {}", serverInfo); + logger.debug("[INIT] Instructions: {}", instructions); + McpSchema.InitializeResult result = new McpSchema.InitializeResult(serverProtocolVersion, + serverCapabilities, serverInfo, instructions); + logger.info("[INIT] Sending initialize response: {}", result); + return Mono.just(result); + }); + } + + // Request handler creation methods + private McpServerSession.RequestHandler createToolsListHandler() { + return (exchange, params) -> { + var regularTools = features.tools().stream().map(McpServerFeatures.AsyncToolSpecification::tool).toList(); + var streamingTools = features.streamTools() + .stream() + .map(McpServerFeatures.AsyncStreamingToolSpecification::tool) + .toList(); + var allTools = new ArrayList<>(regularTools); + allTools.addAll(streamingTools); + return Mono.just(new McpSchema.ListToolsResult(allTools, null)); + }; + } + + private McpServerSession.RequestHandler createToolsCallHandler() { + return new McpServerSession.StreamingRequestHandler() { + @Override + public Mono handle(McpAsyncServerExchange exchange, Object params) { + var callToolRequest = objectMapper.convertValue(params, McpSchema.CallToolRequest.class); + + // Check regular tools first + var regularTool = features.tools() + .stream() + .filter(tool -> callToolRequest.name().equals(tool.tool().name())) + .findFirst(); + + if (regularTool.isPresent()) { + return regularTool.get().call().apply(exchange, callToolRequest.arguments()); + } + + // Check streaming tools (take first result) + var streamingTool = features.streamTools() + .stream() + .filter(tool -> callToolRequest.name().equals(tool.tool().name())) + .findFirst(); + + if (streamingTool.isPresent()) { + return streamingTool.get().call().apply(exchange, callToolRequest.arguments()).next(); + } + + return Mono.error(new RuntimeException("Tool not found: " + callToolRequest.name())); + } + + @Override + public Flux handleStreaming(McpAsyncServerExchange exchange, Object params) { + var callToolRequest = objectMapper.convertValue(params, McpSchema.CallToolRequest.class); + + // Check streaming tools first (preferred for streaming) + var streamingTool = features.streamTools() + .stream() + .filter(tool -> callToolRequest.name().equals(tool.tool().name())) + .findFirst(); + + if (streamingTool.isPresent()) { + return streamingTool.get().call().apply(exchange, callToolRequest.arguments()); + } + + // Fallback to regular tools (convert Mono to Flux) + var regularTool = features.tools() + .stream() + .filter(tool -> callToolRequest.name().equals(tool.tool().name())) + .findFirst(); + + if (regularTool.isPresent()) { + return regularTool.get().call().apply(exchange, callToolRequest.arguments()).flux(); + } + + return Flux.error(new RuntimeException("Tool not found: " + callToolRequest.name())); + } + }; + } + + private McpServerSession.RequestHandler createResourcesListHandler() { + return (exchange, params) -> { + var resources = features.resources() + .values() + .stream() + .map(McpServerFeatures.AsyncResourceSpecification::resource) + .toList(); + return Mono.just(new McpSchema.ListResourcesResult(resources, null)); + }; + } + + private McpServerSession.RequestHandler createResourcesReadHandler() { + return (exchange, params) -> { + var resourceRequest = objectMapper.convertValue(params, McpSchema.ReadResourceRequest.class); + var resourceUri = resourceRequest.uri(); + + return features.resources() + .values() + .stream() + .filter(spec -> uriTemplateManagerFactory.create(spec.resource().uri()).matches(resourceUri)) + .findFirst() + .map(spec -> spec.readHandler().apply(exchange, resourceRequest)) + .orElse(Mono.error(new RuntimeException("Resource not found: " + resourceUri))); + }; + } + + private McpServerSession.RequestHandler createResourceTemplatesListHandler() { + return (exchange, params) -> Mono + .just(new McpSchema.ListResourceTemplatesResult(features.resourceTemplates(), null)); + } + + private McpServerSession.RequestHandler createPromptsListHandler() { + return (exchange, params) -> { + var prompts = features.prompts() + .values() + .stream() + .map(McpServerFeatures.AsyncPromptSpecification::prompt) + .toList(); + return Mono.just(new McpSchema.ListPromptsResult(prompts, null)); + }; + } + + private McpServerSession.RequestHandler createPromptsGetHandler() { + return (exchange, params) -> { + var promptRequest = objectMapper.convertValue(params, McpSchema.GetPromptRequest.class); + + return features.prompts() + .values() + .stream() + .filter(spec -> spec.prompt().name().equals(promptRequest.name())) + .findFirst() + .map(spec -> spec.promptHandler().apply(exchange, promptRequest)) + .orElse(Mono.error(new RuntimeException("Prompt not found: " + promptRequest.name()))); + }; + } + + private McpServerSession.RequestHandler createLoggingSetLevelHandler() { + return (exchange, params) -> { + var setLevelRequest = objectMapper.convertValue(params, McpSchema.SetLevelRequest.class); + exchange.setMinLoggingLevel(setLevelRequest.level()); + return Mono.just(Map.of()); + }; + } + + private McpServerSession.RequestHandler createCompletionCompleteHandler() { + return (exchange, params) -> { + var completeRequest = objectMapper.convertValue(params, McpSchema.CompleteRequest.class); + + return features.completions() + .values() + .stream() + .filter(spec -> spec.referenceKey().equals(completeRequest.ref())) + .findFirst() + .map(spec -> spec.completionHandler().apply(exchange, completeRequest)) + .orElse(Mono.error(new RuntimeException("Completion not found: " + completeRequest.ref()))); + }; + } + + private McpServerSession.NotificationHandler createRootsListChangedHandler() { + return (exchange, params) -> { + var rootsChangeConsumers = features.rootsChangeConsumers(); + if (rootsChangeConsumers.isEmpty()) { + return Mono + .fromRunnable(() -> logger.warn("Roots list changed notification, but no consumers provided")); + } + + return exchange.listRoots() + .flatMap(listRootsResult -> Flux.fromIterable(rootsChangeConsumers) + .flatMap(consumer -> consumer.apply(exchange, listRootsResult.roots())) + .onErrorResume(error -> { + logger.error("Error handling roots list change notification", error); + return Mono.empty(); + }) + .then()); + }; + } + + /** + * Get the server capabilities. + */ + public McpSchema.ServerCapabilities getServerCapabilities() { + return serverCapabilities; + } + + /** + * Get the server implementation information. + */ + public McpSchema.Implementation getServerInfo() { + return serverInfo; + } + + /** + * Gracefully closes the server. + */ + public Mono closeGracefully() { + return httpTransportProvider.closeGracefully(); + } + + /** + * Close the server immediately. + */ + public void close() { + httpTransportProvider.close(); + } + + /** + * Notifies clients that the list of available tools has changed. + */ + public Mono notifyToolsListChanged() { + return httpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, null); + } + + /** + * Notifies clients that the list of available resources has changed. + */ + public Mono notifyResourcesListChanged() { + return httpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, null); + } + + /** + * Notifies clients that resources have been updated. + */ + public Mono notifyResourcesUpdated(McpSchema.ResourcesUpdatedNotification notification) { + return httpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_RESOURCES_UPDATED, notification); + } + + /** + * Notifies clients that the list of available prompts has changed. + */ + public Mono notifyPromptsListChanged() { + return httpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED, null); + } + + /** + * Creates a new builder for configuring and creating McpAsyncStreamableHttpServer + * instances. + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating instances of McpAsyncServer + */ + public static class Builder { + + private McpSchema.Implementation serverInfo; + + private McpSchema.ServerCapabilities serverCapabilities; + + private String instructions; + + private Duration requestTimeout = Duration.ofSeconds(30); + + private ObjectMapper objectMapper = new ObjectMapper(); + + private String mcpEndpoint = "/mcp"; + + private Supplier sessionIdProvider; + + private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); + + private final List tools = new ArrayList<>(); + + private final List streamTools = new ArrayList<>(); + + private final Map resources = new HashMap<>(); + + private final List resourceTemplates = new ArrayList<>(); + + private final Map prompts = new HashMap<>(); + + private final Map completions = new HashMap<>(); + + private final List, Mono>> rootsChangeConsumers = new ArrayList<>(); + + /** + * Sets the server implementation information. + */ + public Builder serverInfo(String name, String version) { + return serverInfo(name, null, version); + } + + /** + * Sets the server implementation information. + */ + public Builder serverInfo(String name, String title, String version) { + Assert.hasText(name, "Server name must not be empty"); + Assert.hasText(version, "Server version must not be empty"); + this.serverInfo = new McpSchema.Implementation(name, version); + return this; + } + + /** + * Sets the server capabilities. + */ + public Builder serverCapabilities(McpSchema.ServerCapabilities capabilities) { + this.serverCapabilities = capabilities; + return this; + } + + /** + * Sets the server instructions. + */ + public Builder instructions(String instructions) { + this.instructions = instructions; + return this; + } + + /** + * Sets the request timeout duration. + */ + public Builder requestTimeout(Duration timeout) { + Assert.notNull(timeout, "Request timeout must not be null"); + this.requestTimeout = timeout; + return this; + } + + /** + * Sets the JSON object mapper. + */ + public Builder objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + /** + * Sets the MCP endpoint path. + */ + public Builder withMcpEndpoint(String endpoint) { + Assert.hasText(endpoint, "MCP endpoint must not be empty"); + this.mcpEndpoint = endpoint; + return this; + } + + /** + * Sets the session ID provider. + */ + public Builder withSessionIdProvider(Supplier provider) { + Assert.notNull(provider, "Session ID provider must not be null"); + this.sessionIdProvider = provider; + return this; + } + + /** + * Sets the URI template manager factory. + */ + public Builder withUriTemplateManagerFactory(McpUriTemplateManagerFactory factory) { + Assert.notNull(factory, "URI template manager factory must not be null"); + this.uriTemplateManagerFactory = factory; + return this; + } + + /** + * Adds a tool specification. + */ + public Builder withTool(McpServerFeatures.AsyncToolSpecification toolSpec) { + Assert.notNull(toolSpec, "Tool specification must not be null"); + this.tools.add(toolSpec); + return this; + } + + /** + * Adds a streaming tool specification. + */ + public Builder withStreamingTool(McpServerFeatures.AsyncStreamingToolSpecification toolSpec) { + Assert.notNull(toolSpec, "Streaming tool specification must not be null"); + this.streamTools.add(toolSpec); + return this; + } + + /** + * Adds a resource specification. + */ + public Builder withResource(String uri, McpServerFeatures.AsyncResourceSpecification resourceSpec) { + Assert.hasText(uri, "Resource URI must not be empty"); + Assert.notNull(resourceSpec, "Resource specification must not be null"); + this.resources.put(uri, resourceSpec); + return this; + } + + /** + * Adds a resource template. + */ + public Builder withResourceTemplate(McpSchema.ResourceTemplate template) { + Assert.notNull(template, "Resource template must not be null"); + this.resourceTemplates.add(template); + return this; + } + + /** + * Adds a prompt specification. + */ + public Builder withPrompt(String name, McpServerFeatures.AsyncPromptSpecification promptSpec) { + Assert.hasText(name, "Prompt name must not be empty"); + Assert.notNull(promptSpec, "Prompt specification must not be null"); + this.prompts.put(name, promptSpec); + return this; + } + + /** + * Adds a completion specification. + */ + public Builder withCompletion(McpSchema.CompleteReference reference, + McpServerFeatures.AsyncCompletionSpecification completionSpec) { + Assert.notNull(reference, "Completion reference must not be null"); + Assert.notNull(completionSpec, "Completion specification must not be null"); + this.completions.put(reference, completionSpec); + return this; + } + + /** + * Adds a roots change consumer. + */ + public Builder withRootsChangeConsumer( + BiFunction, Mono> consumer) { + Assert.notNull(consumer, "Roots change consumer must not be null"); + this.rootsChangeConsumers.add(consumer); + return this; + } + + /** + * Builds the McpAsyncStreamableHttpServer instance. + */ + public McpAsyncStreamableHttpServer build() { + Assert.notNull(serverInfo, "Server info must be set"); + + // Create Streamable HTTP transport provider + StreamableHttpServerTransportProvider.Builder transportBuilder = StreamableHttpServerTransportProvider + .builder() + .withObjectMapper(objectMapper) + .withMcpEndpoint(mcpEndpoint); + + if (sessionIdProvider != null) { + transportBuilder.withSessionIdProvider(sessionIdProvider); + } + + StreamableHttpServerTransportProvider httpTransportProvider = transportBuilder.build(); + + // Create server features + McpServerFeatures.Async features = new McpServerFeatures.Async(serverInfo, serverCapabilities, tools, + resources, resourceTemplates, prompts, completions, rootsChangeConsumers, instructions, + streamTools); + + return new McpAsyncStreamableHttpServer(httpTransportProvider, objectMapper, features, requestTimeout, + uriTemplateManagerFactory); + } + + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java index 8311f5d41..b1e533c4a 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java @@ -14,6 +14,7 @@ import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.Utils; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; @@ -44,7 +45,31 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s Map prompts, Map completions, List, Mono>> rootsChangeConsumers, - String instructions) { + String instructions, List streamTools) { + + /** + * Create an instance and validate the arguments (backward compatible + * constructor). + * @param serverInfo The server implementation details + * @param serverCapabilities The server capabilities + * @param tools The list of tool specifications + * @param resources The map of resource specifications + * @param resourceTemplates The list of resource templates + * @param prompts The map of prompt specifications + * @param rootsChangeConsumers The list of consumers that will be notified when + * the roots list changes + * @param instructions The server instructions text + */ + public Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, + List tools, Map resources, + List resourceTemplates, + Map prompts, + Map completions, + List, Mono>> rootsChangeConsumers, + String instructions) { + this(serverInfo, serverCapabilities, tools, resources, resourceTemplates, prompts, completions, + rootsChangeConsumers, instructions, List.of()); + } /** * Create an instance and validate the arguments. @@ -57,6 +82,7 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s * @param rootsChangeConsumers The list of consumers that will be notified when * the roots list changes * @param instructions The server instructions text + * @param streamTools The list of streaming tool specifications */ Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, List tools, Map resources, @@ -64,7 +90,7 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s Map prompts, Map completions, List, Mono>> rootsChangeConsumers, - String instructions) { + String instructions, List streamTools) { Assert.notNull(serverInfo, "Server info must not be null"); @@ -88,6 +114,7 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s this.completions = (completions != null) ? completions : Map.of(); this.rootsChangeConsumers = (rootsChangeConsumers != null) ? rootsChangeConsumers : List.of(); this.instructions = instructions; + this.streamTools = (streamTools != null) ? streamTools : List.of(); } /** @@ -128,7 +155,8 @@ static Async fromSync(Sync syncSpec) { } return new Async(syncSpec.serverInfo(), syncSpec.serverCapabilities(), tools, resources, - syncSpec.resourceTemplates(), prompts, completions, rootChangeConsumers, syncSpec.instructions()); + syncSpec.resourceTemplates(), prompts, completions, rootChangeConsumers, syncSpec.instructions(), + List.of()); } } @@ -251,6 +279,40 @@ static AsyncToolSpecification fromSync(SyncToolSpecification tool) { } } + /** + * Specification of a streaming tool with its asynchronous handler function that can + * return either a single result (Mono) or a stream of results (Flux). This enables + * tools to provide real-time streaming responses for long-running operations or + * progressive results. + * + *

+ * Example streaming tool specification:

{@code
+	 * new McpServerFeatures.AsyncStreamingToolSpecification(
+	 *     new Tool(
+	 *         "file_processor",
+	 *         "Processes files with streaming progress updates",
+	 *         new JsonSchemaObject()
+	 *             .required("file_path")
+	 *             .property("file_path", JsonSchemaType.STRING)
+	 *     ),
+	 *     (exchange, args) -> {
+	 *         String filePath = (String) args.get("file_path");
+	 *         return Flux.interval(Duration.ofSeconds(1))
+	 *             .take(10)
+	 *             .map(i -> new CallToolResult("Processing step " + i + " for " + filePath));
+	 *     }
+	 * )
+	 * }
+ * + * @param tool The tool definition including name, description, and parameter schema + * @param call The function that implements the tool's streaming logic, receiving + * arguments and returning a Flux of results that will be streamed to the client via + * SSE. + */ + public record AsyncStreamingToolSpecification(McpSchema.Tool tool, + BiFunction, Flux> call) { + } + /** * Specification of a resource with its asynchronous handler function. Resources * provide context to AI models by exposing data such as: diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/SessionHandler.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/SessionHandler.java new file mode 100644 index 000000000..5a4331cf5 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/SessionHandler.java @@ -0,0 +1,57 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +/** + * Handler interface for session lifecycle and runtime events in the Streamable HTTP + * transport. + * + *

+ * This interface provides hooks for monitoring and responding to various session-related + * events that occur during the operation of the HTTP-based MCP server transport. + * Implementations can use these callbacks to: + *

    + *
  • Log session activities
  • + *
  • Implement custom session management logic
  • + *
  • Handle error conditions
  • + *
  • Perform cleanup operations
  • + *
+ * + * @author Zachary German + */ +public interface SessionHandler { + + /** + * Called when a new session is created. + * @param sessionId The ID of the newly created session + * @param context Additional context information (may be null) + */ + void onSessionCreate(String sessionId, Object context); + + /** + * Called when a session is closed. + * @param sessionId The ID of the closed session + */ + void onSessionClose(String sessionId); + + /** + * Called when a session is not found for a given session ID. + * @param sessionId The session ID that was not found + * @param request The HTTP request that referenced the missing session + * @param response The HTTP response that will be sent to the client + */ + void onSessionNotFound(String sessionId, HttpServletRequest request, HttpServletResponse response); + + /** + * Called when an error occurs while sending a notification to a session. + * @param sessionId The ID of the session where the error occurred + * @param error The error that occurred + */ + void onSendNotificationError(String sessionId, Throwable error); + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProvider.java new file mode 100644 index 000000000..086957e99 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProvider.java @@ -0,0 +1,910 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.io.IOException; +import java.io.PrintWriter; +import java.util.ArrayList; +import java.util.Enumeration; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCResponse; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpServerTransport; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.SseEvent; +import io.modelcontextprotocol.util.Assert; +import jakarta.servlet.AsyncContext; +import jakarta.servlet.ReadListener; +import jakarta.servlet.ServletException; +import jakarta.servlet.ServletInputStream; +import jakarta.servlet.annotation.WebServlet; +import jakarta.servlet.http.HttpServlet; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import java.nio.charset.StandardCharsets; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.util.context.Context; + +import static java.util.Objects.requireNonNullElse; + +/** + * MCP Streamable HTTP transport provider that uses a single session class to manage all + * streams and transports. + * + *

+ * Key improvements over the original implementation: + *

    + *
  • Manages server-client sessions, including transport registration. + *
  • Handles HTTP requests and HTTP/SSE responses and streams. + *
  • Provides callbacks for session lifecycle and errors. + *
  • Supports graceful shutdown. + *
  • Enforces allowed 'Origin' header values if configured. + *
  • Provides a default session ID provider if none is configured. + *
+ * + * @author Zachary German + */ +@WebServlet(asyncSupported = true) +public class StreamableHttpServerTransportProvider extends HttpServlet implements McpServerTransportProvider { + + private static final Logger logger = LoggerFactory.getLogger(StreamableHttpServerTransportProvider.class); + + public static final String UTF_8 = "UTF-8"; + + public static final String APPLICATION_JSON = "application/json"; + + public static final String TEXT_EVENT_STREAM = "text/event-stream"; + + public static final String SESSION_ID_HEADER = "Mcp-Session-Id"; + + public static final String LAST_EVENT_ID_HEADER = "Last-Event-Id"; + + public static final String MESSAGE_EVENT_TYPE = "message"; + + public static final String ACCEPT_HEADER = "Accept"; + + public static final String ORIGIN_HEADER = "Origin"; + + public static final String ALLOW_ORIGIN_HEADER = "Access-Control-Allow-Origin"; + + public static final String ALLOW_ORIGIN_DEFAULT_VALUE = "*"; + + public static final String PROTOCOL_VERSION_HEADER = "MCP-Protocol-Version"; + + public static final String CACHE_CONTROL_HEADER = "Cache-Control"; + + public static final String CONNECTION_HEADER = "Connection"; + + public static final String CACHE_CONTROL_NO_CACHE = "no-cache"; + + public static final String CONNECTION_KEEP_ALIVE = "keep-alive"; + + public static final String MCP_SESSION_ID = "MCP-Session-ID"; + + public static final String DEFAULT_MCP_ENDPOINT = "/mcp"; + + /** com.fasterxml.jackson.databind.ObjectMapper */ + private static final ObjectMapper DEFAULT_OBJECT_MAPPER = new ObjectMapper(); + + /** UUID.randomUUID().toString() */ + private static final Supplier DEFAULT_SESSION_ID_PROVIDER = () -> UUID.randomUUID().toString(); + + /** JSON object mapper for serialization/deserialization */ + private final ObjectMapper objectMapper; + + /** The endpoint path for handling MCP requests */ + private final String mcpEndpoint; + + /** Supplier for generating unique session IDs */ + private final Supplier sessionIdProvider; + + /** Sessions map, keyed by Session ID */ + private static final Map sessions = new ConcurrentHashMap<>(); + + /** Flag indicating if the transport is in the process of shutting down */ + private final AtomicBoolean isClosing = new AtomicBoolean(false); + + /** Optional allowed 'Origin' header value list. Not enforced if empty. */ + private final List allowedOrigins = new ArrayList<>(); + + /** Callback interface for session lifecycle and errors */ + private SessionHandler sessionHandler; + + /** Factory for McpServerSession takes session IDs */ + private McpServerSession.StreamableHttpSessionFactory streamableHttpSessionFactory; + + /** + *
    + *
  • Manages server-client sessions, including transport registration. + *
  • Handles HTTP requests and HTTP/SSE responses and streams. + *
+ * @param objectMapper ObjectMapper - Default: + * com.fasterxml.jackson.databind.ObjectMapper + * @param mcpEndpoint String - Default: '/mcp' + * @param sessionIdProvider Supplier(String) - Default: UUID.randomUUID().toString() + */ + public StreamableHttpServerTransportProvider(ObjectMapper objectMapper, String mcpEndpoint, + Supplier sessionIdProvider) { + this.objectMapper = requireNonNullElse(objectMapper, DEFAULT_OBJECT_MAPPER); + this.mcpEndpoint = requireNonNullElse(mcpEndpoint, DEFAULT_MCP_ENDPOINT); + this.sessionIdProvider = requireNonNullElse(sessionIdProvider, DEFAULT_SESSION_ID_PROVIDER); + } + + /** + *
    + *
  • Manages server-client sessions, including transport registration. + *
  • Handles HTTP requests and HTTP/SSE responses and streams. + *
+ * @param objectMapper ObjectMapper - Default: + * com.fasterxml.jackson.databind.ObjectMapper + * @param mcpEndpoint String - Default: '/mcp' + * @param sessionIdProvider Supplier(String) - Default: UUID.randomUUID().toString() + */ + public StreamableHttpServerTransportProvider() { + this(null, null, null); + } + + @Override + public void setSessionFactory(McpServerSession.Factory sessionFactory) { + // Required but not used for this implementation + } + + public void setStreamableHttpSessionFactory(McpServerSession.StreamableHttpSessionFactory sessionFactory) { + this.streamableHttpSessionFactory = sessionFactory; + } + + public void setSessionHandler(SessionHandler sessionHandler) { + this.sessionHandler = sessionHandler; + } + + public void setAllowedOrigins(List allowedOrigins) { + this.allowedOrigins.clear(); + this.allowedOrigins.addAll(allowedOrigins); + } + + @Override + public Mono notifyClients(String method, Object params) { + if (sessions.isEmpty()) { + logger.debug("No active sessions to broadcast message to"); + return Mono.empty(); + } + + logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); + + return Flux.fromIterable(sessions.values()) + .flatMap(session -> session.sendNotification(method, params).doOnError(e -> { + logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage()); + if (sessionHandler != null) { + sessionHandler.onSendNotificationError(session.getId(), e); + } + }).onErrorComplete()) + .then(); + } + + @Override + public Mono closeGracefully() { + return Mono.defer(() -> { + isClosing.set(true); + logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size()); + return Flux.fromIterable(sessions.values()) + .flatMap(session -> session.closeGracefully() + .doOnError(e -> logger.error("Error closing session {}: {}", session.getId(), e.getMessage())) + .onErrorComplete()) + .then(); + }); + } + + @Override + protected void doGet(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + String requestURI = request.getRequestURI(); + logger.info("GET request received for URI: '{}' with headers: {}", requestURI, extractHeaders(request)); + + if (!validateOrigin(request, response) || !validateEndpoint(requestURI, response) + || !validateNotClosing(response)) { + return; + } + + String acceptHeader = request.getHeader(ACCEPT_HEADER); + if (acceptHeader == null || !acceptHeader.contains(TEXT_EVENT_STREAM)) { + logger.debug("Accept header missing or does not include {}", TEXT_EVENT_STREAM); + sendErrorResponse(response, "Accept header must include text/event-stream"); + return; + } + + String sessionId = request.getHeader(SESSION_ID_HEADER); + if (sessionId == null) { + sendErrorResponse(response, "Session ID missing in request header"); + return; + } + + McpServerSession session = sessions.get(sessionId); + if (session == null) { + handleSessionNotFound(sessionId, request, response); + return; + } + + // Delayed until version negotiation is implemented. + /* + * if (session.getState().equals(session.STATE_INITIALIZED) && + * request.getHeader(PROTOCOL_VERSION_HEADER) == null) { + * sendErrorResponse(response, "Protocol version missing in request header"); } + */ + + // Set up SSE connection + response.setContentType(TEXT_EVENT_STREAM); + response.setCharacterEncoding(UTF_8); + response.setHeader(CACHE_CONTROL_HEADER, CACHE_CONTROL_NO_CACHE); + response.setHeader(CONNECTION_HEADER, CONNECTION_KEEP_ALIVE); + response.setHeader(SESSION_ID_HEADER, sessionId); + + AsyncContext asyncContext = request.startAsync(); + asyncContext.setTimeout(0); + + String lastEventId = request.getHeader(LAST_EVENT_ID_HEADER); + + if (lastEventId == null) { // Just opening a listening stream + SseTransport sseTransport = new SseTransport(objectMapper, response, asyncContext, lastEventId, + session.LISTENING_TRANSPORT, sessionId); + session.registerTransport(session.LISTENING_TRANSPORT, sseTransport); + logger.debug("Registered SSE transport {} for session {}", session.LISTENING_TRANSPORT, sessionId); + } + else { // Asking for a stream to replay events from a previous request + SseTransport sseTransport = new SseTransport(objectMapper, response, asyncContext, lastEventId, + request.getRequestId(), sessionId); + session.registerTransport(request.getRequestId(), sseTransport); + logger.debug("Registered SSE transport {} for session {}", request.getRequestId(), sessionId); + } + } + + @Override + protected void doPost(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + String requestURI = request.getRequestURI(); + logger.info("POST request received for URI: '{}' with headers: {}", requestURI, extractHeaders(request)); + + if (!validateOrigin(request, response) || !validateEndpoint(requestURI, response) + || !validateNotClosing(response)) { + return; + } + + String acceptHeader = request.getHeader(ACCEPT_HEADER); + if (acceptHeader == null + || (!acceptHeader.contains(APPLICATION_JSON) || !acceptHeader.contains(TEXT_EVENT_STREAM))) { + logger.debug("Accept header validation failed. Header: {}", acceptHeader); + sendErrorResponse(response, "Accept header must include both application/json and text/event-stream"); + return; + } + + AsyncContext asyncContext = request.startAsync(); + asyncContext.setTimeout(0); + + StringBuilder body = new StringBuilder(); + ServletInputStream inputStream = request.getInputStream(); + + inputStream.setReadListener(new ReadListener() { + @Override + public void onDataAvailable() throws IOException { + int len; + byte[] buffer = new byte[1024]; + while (inputStream.isReady() && (len = inputStream.read(buffer)) != -1) { + body.append(new String(buffer, 0, len, StandardCharsets.UTF_8)); + } + } + + @Override + public void onAllDataRead() throws IOException { + try { + logger.debug("Parsing JSON-RPC message: {}", body.toString()); + JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body.toString()); + + boolean isInitializeRequest = false; + String sessionId = request.getHeader(SESSION_ID_HEADER); + + if (message instanceof McpSchema.JSONRPCRequest req + && McpSchema.METHOD_INITIALIZE.equals(req.method())) { + isInitializeRequest = true; + logger.debug("Detected initialize request"); + if (sessionId == null) { + sessionId = sessionIdProvider.get(); + logger.debug("Created new session ID for initialize request: {}", sessionId); + } + } + + if (!isInitializeRequest && sessionId == null) { + sendErrorResponse(response, "Session ID missing in request header"); + asyncContext.complete(); + return; + } + + McpServerSession session = getOrCreateSession(sessionId, isInitializeRequest); + if (session == null) { + logger.error("Failed to create session for sessionId: {}", sessionId); + handleSessionNotFound(sessionId, request, response); + asyncContext.complete(); + return; + } + + // Delayed until version negotiation is implemented. + /* + * if (session.getState().equals(session.STATE_INITIALIZED) && + * request.getHeader(PROTOCOL_VERSION_HEADER) == null) { + * sendErrorResponse(response, + * "Protocol version missing in request header"); } + */ + + logger.debug("Using session: {}", sessionId); + + response.setHeader(SESSION_ID_HEADER, sessionId); + + // Determine response type and create appropriate transport if needed + ResponseType responseType = detectResponseType(message, session); + final String transportId; + if (message instanceof JSONRPCRequest req) { + transportId = req.id().toString(); + } + else if (message instanceof JSONRPCResponse resp) { + transportId = resp.id().toString(); + } + else { + transportId = null; + } + + if (responseType == ResponseType.STREAM) { + logger.debug("Handling STREAM response type"); + response.setContentType(TEXT_EVENT_STREAM); + response.setCharacterEncoding(UTF_8); + response.setHeader(CACHE_CONTROL_HEADER, CACHE_CONTROL_NO_CACHE); + response.setHeader(CONNECTION_HEADER, CONNECTION_KEEP_ALIVE); + + SseTransport sseTransport = new SseTransport(objectMapper, response, asyncContext, null, + transportId, sessionId); + session.registerTransport(transportId, sseTransport); + } + else { + logger.debug("Handling IMMEDIATE response type"); + // Only set content type for requests, not notifications + if (message instanceof McpSchema.JSONRPCRequest) { + logger.debug("Setting content type to APPLICATION_JSON for request response"); + response.setContentType(APPLICATION_JSON); + } + else { + logger.debug("Not setting content type for notification (empty response expected)"); + } + + if (transportId != null) { // Not needed for notifications (null + // transportId) + HttpTransport httpTransport = new HttpTransport(objectMapper, response, asyncContext); + session.registerTransport(transportId, httpTransport); + } + } + + // Handle the message + logger.debug("About to handle message: {} with transport: {}", message.getClass().getSimpleName(), + transportId); + + // For notifications, we need to handle the HTTP response manually + // since no JSON response is sent + if (message instanceof McpSchema.JSONRPCNotification) { + session.handle(message).doOnSuccess(v -> { + logger.debug("Message handling completed successfully for transport: {}", transportId); + logger.debug("[NOTIFICATION] Sending empty HTTP response for notification"); + try { + if (!response.isCommitted()) { + response.setStatus(HttpServletResponse.SC_OK); + response.setCharacterEncoding("UTF-8"); + } + asyncContext.complete(); + } + catch (Exception e) { + logger.error("Failed to send notification response: {}", e.getMessage()); + asyncContext.complete(); + } + }).doOnError(e -> { + logger.error("Error in message handling: {}", e.getMessage(), e); + asyncContext.complete(); + }).doFinally(signalType -> { + logger.debug("Unregistering transport: {} with signal: {}", transportId, signalType); + session.unregisterTransport(transportId); + }).contextWrite(Context.of(MCP_SESSION_ID, sessionId)).subscribe(); + } + else { + // For requests, let the transport handle the response + session.handle(message) + .doOnSuccess(v -> logger.info("Message handling completed successfully for transport: {}", + transportId)) + .doOnError(e -> logger.error("Error in message handling: {}", e.getMessage(), e)) + .doFinally(signalType -> { + logger.debug("Unregistering transport: {} with signal: {}", transportId, signalType); + session.unregisterTransport(transportId); + }) + .contextWrite(Context.of(MCP_SESSION_ID, sessionId)) + .subscribe(null, error -> { + logger.error("Error in message handling chain: {}", error.getMessage(), error); + asyncContext.complete(); + }); + } + + } + catch (Exception e) { + logger.error("Error processing message: {}", e.getMessage()); + sendErrorResponse(response, "Invalid JSON-RPC message: " + e.getMessage()); + asyncContext.complete(); + } + } + + @Override + public void onError(Throwable t) { + logger.error("Error reading request body: {}", t.getMessage()); + try { + sendErrorResponse(response, "Error reading request: " + t.getMessage()); + } + catch (IOException e) { + logger.error("Failed to write error response", e); + } + asyncContext.complete(); + } + }); + } + + @Override + protected void doDelete(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + String requestURI = request.getRequestURI(); + if (!requestURI.endsWith(mcpEndpoint)) { + response.sendError(HttpServletResponse.SC_NOT_FOUND); + return; + } + + String sessionId = request.getHeader(SESSION_ID_HEADER); + if (sessionId == null) { + sendErrorResponse(response, "Session ID missing in request header"); + return; + } + + McpServerSession session = sessions.remove(sessionId); + if (session == null) { + handleSessionNotFound(sessionId, request, response); + return; + } + + session.closeGracefully().contextWrite(Context.of(MCP_SESSION_ID, sessionId)).subscribe(); + logger.debug("Session closed: {}", sessionId); + if (sessionHandler != null) { + sessionHandler.onSessionClose(sessionId); + } + + response.setStatus(HttpServletResponse.SC_OK); + } + + private boolean validateOrigin(HttpServletRequest request, HttpServletResponse response) throws IOException { + if (!allowedOrigins.isEmpty()) { + String origin = request.getHeader(ORIGIN_HEADER); + if (!allowedOrigins.contains(origin)) { + logger.debug("Origin header does not match allowed origins: '{}'", origin); + response.sendError(HttpServletResponse.SC_FORBIDDEN); + return false; + } + else { + response.setHeader(ALLOW_ORIGIN_HEADER, origin); + } + } + else { + response.setHeader(ALLOW_ORIGIN_HEADER, ALLOW_ORIGIN_DEFAULT_VALUE); + } + return true; + } + + private boolean validateEndpoint(String requestURI, HttpServletResponse response) throws IOException { + if (!requestURI.endsWith(mcpEndpoint)) { + logger.debug("URI does not match MCP endpoint: '{}'", mcpEndpoint); + response.sendError(HttpServletResponse.SC_NOT_FOUND); + return false; + } + return true; + } + + private boolean validateNotClosing(HttpServletResponse response) throws IOException { + if (isClosing.get()) { + logger.debug("Server is shutting down, rejecting request"); + response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down"); + return false; + } + return true; + } + + protected McpServerSession getOrCreateSession(String sessionId, boolean createIfMissing) { + McpServerSession session = sessions.get(sessionId); + logger.debug("Looking for session: {}, found: {}", sessionId, session != null); + if (session == null && createIfMissing) { + logger.debug("Creating new session: {}", sessionId); + session = streamableHttpSessionFactory.create(sessionId); + sessions.put(sessionId, session); + logger.debug("Created new session: {}", sessionId); + if (sessionHandler != null) { + sessionHandler.onSessionCreate(sessionId, null); + } + } + return session; + } + + private ResponseType detectResponseType(McpSchema.JSONRPCMessage message, McpServerSession session) { + if (message instanceof McpSchema.JSONRPCRequest request) { + if (McpSchema.METHOD_INITIALIZE.equals(request.method())) { + return ResponseType.IMMEDIATE; + } + + // Check if handler returns Flux (streaming) or Mono (immediate) + var handler = session.getRequestHandler(request.method()); + if (handler != null && handler instanceof McpServerSession.StreamingRequestHandler) { + return ResponseType.STREAM; + } + return ResponseType.IMMEDIATE; + } + else { + return ResponseType.IMMEDIATE; + } + } + + private void handleSessionNotFound(String sessionId, HttpServletRequest request, HttpServletResponse response) + throws IOException { + sendErrorResponse(response, "Session not found: " + sessionId); + if (sessionHandler != null) { + sessionHandler.onSessionNotFound(sessionId, request, response); + } + } + + private void sendErrorResponse(HttpServletResponse response, String message) throws IOException { + response.setContentType(APPLICATION_JSON); + response.setStatus(HttpServletResponse.SC_BAD_REQUEST); + response.getWriter().write(createErrorJson(message)); + } + + private String createErrorJson(String message) { + try { + return objectMapper.writeValueAsString(new McpError(message)); + } + catch (IOException e) { + logger.error("Failed to serialize error message", e); + return "{\"error\":\"" + message + "\"}"; + } + } + + @Override + public void destroy() { + closeGracefully().block(); + super.destroy(); + } + + private Map extractHeaders(HttpServletRequest request) { + Map headers = new HashMap<>(); + Enumeration headerNames = request.getHeaderNames(); + while (headerNames.hasMoreElements()) { + String name = headerNames.nextElement(); + headers.put(name, request.getHeader(name)); + } + return headers; + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private ObjectMapper objectMapper = DEFAULT_OBJECT_MAPPER; + + private String mcpEndpoint = DEFAULT_MCP_ENDPOINT; + + private Supplier sessionIdProvider = DEFAULT_SESSION_ID_PROVIDER; + + public Builder withObjectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + public Builder withMcpEndpoint(String mcpEndpoint) { + Assert.hasText(mcpEndpoint, "MCP endpoint must not be empty"); + this.mcpEndpoint = mcpEndpoint; + return this; + } + + public Builder withSessionIdProvider(Supplier sessionIdProvider) { + Assert.notNull(sessionIdProvider, "SessionIdProvider must not be null"); + this.sessionIdProvider = sessionIdProvider; + return this; + } + + public StreamableHttpServerTransportProvider build() { + return new StreamableHttpServerTransportProvider(objectMapper, mcpEndpoint, sessionIdProvider); + } + + } + + private enum ResponseType { + + IMMEDIATE, STREAM + + } + + /** + * SSE transport implementation. + */ + private static class SseTransport implements McpServerTransport { + + private static final Logger logger = LoggerFactory.getLogger(SseTransport.class); + + private final ObjectMapper objectMapper; + + private final HttpServletResponse response; + + private final AsyncContext asyncContext; + + private final Sinks.Many eventSink = Sinks.many().unicast().onBackpressureBuffer(); + + private final Map eventHistory = new ConcurrentHashMap<>(); + + private final String id; + + private final String sessionId; + + public SseTransport(ObjectMapper objectMapper, HttpServletResponse response, AsyncContext asyncContext, + String lastEventId, String transportId, String sessionId) { + this.objectMapper = objectMapper; + this.response = response; + this.asyncContext = asyncContext; + this.id = transportId; + this.sessionId = sessionId; + + setupSseStream(lastEventId); + } + + private void setupSseStream(String lastEventId) { + try { + PrintWriter writer = response.getWriter(); + + eventSink.asFlux().doOnNext(event -> { + try { + if (event.id() != null) { + writer.write("id: " + event.id() + "\n"); + } + if (event.event() != null) { + writer.write("event: " + event.event() + "\n"); + } + writer.write("data: " + event.data() + "\n\n"); + writer.flush(); + + if (writer.checkError()) { + throw new IOException("Client disconnected"); + } + } + catch (IOException e) { + logger.debug("Error writing to SSE stream: {}", e.getMessage()); + asyncContext.complete(); + } + }).doOnComplete(() -> { + try { + writer.close(); + } + finally { + asyncContext.complete(); + } + }).doOnError(e -> { + logger.error("Error in SSE stream: {}", e.getMessage()); + asyncContext.complete(); + }).contextWrite(Context.of(MCP_SESSION_ID, response.getHeader(SESSION_ID_HEADER))).subscribe(); + + // Replay events if requested + if (lastEventId != null) { + replayEventsAfter(lastEventId); + } + + } + catch (IOException e) { + logger.error("Failed to setup SSE stream: {}", e.getMessage()); + asyncContext.complete(); + } + } + + private void replayEventsAfter(String lastEventId) { + try { + McpServerSession session = sessions.get(sessionId); + String transportIdOfLastEventId = session.getTransportIdForEvent(lastEventId); + Map transportEventHistory = session + .getTransportEventHistory(transportIdOfLastEventId); + List eventIds = transportEventHistory.keySet() + .stream() + .map(Long::parseLong) + .filter(key -> key > Long.parseLong(lastEventId)) + .sorted() + .map(String::valueOf) + .collect(Collectors.toList()); + for (String eventId : eventIds) { + SseEvent event = transportEventHistory.get(eventId); + if (event != null) { + eventSink.tryEmitNext(event); + } + } + logger.debug("Completing SSE stream after replaying events"); + eventSink.tryEmitComplete(); + } + catch (NumberFormatException e) { + logger.warn("Invalid last event ID: {}", lastEventId); + } + } + + @Override + public Mono sendMessage(JSONRPCMessage message) { + try { + String jsonText = objectMapper.writeValueAsString(message); + String eventId = sessions.get(sessionId).incrementAndGetEventId(id); + SseEvent event = new SseEvent(eventId, MESSAGE_EVENT_TYPE, jsonText); + + eventHistory.put(eventId, event); + logger.debug("Sending SSE event {}: {}", eventId, jsonText); + eventSink.tryEmitNext(event); + + if (message instanceof McpSchema.JSONRPCResponse) { + logger.debug("Completing SSE stream after sending response"); + eventSink.tryEmitComplete(); + sessions.get(sessionId).setTransportEventHistory(id, eventHistory); + } + + return Mono.empty(); + } + catch (Exception e) { + logger.error("Failed to send message: {}", e.getMessage()); + return Mono.error(e); + } + } + + /** + * Sends a stream of messages for Flux responses. + */ + public Mono sendMessageStream(Flux messageStream) { + return messageStream.doOnNext(message -> { + try { + String jsonText = objectMapper.writeValueAsString(message); + String eventId = sessions.get(sessionId).incrementAndGetEventId(id); + SseEvent event = new SseEvent(eventId, MESSAGE_EVENT_TYPE, jsonText); + + eventHistory.put(eventId, event); + logger.debug("Sending SSE stream event {}: {}", eventId, jsonText); + eventSink.tryEmitNext(event); + } + catch (Exception e) { + logger.error("Failed to send stream message: {}", e.getMessage()); + eventSink.tryEmitError(e); + } + }).doOnComplete(() -> { + logger.debug("Completing SSE stream after sending all stream messages"); + eventSink.tryEmitComplete(); + sessions.get(sessionId).setTransportEventHistory(id, eventHistory); + }).then(); + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return objectMapper.convertValue(data, typeRef); + } + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + eventSink.tryEmitComplete(); + sessions.get(sessionId).setTransportEventHistory(id, eventHistory); + logger.debug("SSE transport closed gracefully"); + }); + } + + } + + /** + * HTTP transport implementation for immediate responses. + */ + private static class HttpTransport implements McpServerTransport { + + private static final Logger logger = LoggerFactory.getLogger(HttpTransport.class); + + private final ObjectMapper objectMapper; + + private final HttpServletResponse response; + + private final AsyncContext asyncContext; + + public HttpTransport(ObjectMapper objectMapper, HttpServletResponse response, AsyncContext asyncContext) { + this.objectMapper = objectMapper; + this.response = response; + this.asyncContext = asyncContext; + } + + @Override + public Mono sendMessage(JSONRPCMessage message) { + return Mono.fromRunnable(() -> { + try { + if (response.isCommitted()) { + logger.warn("Response already committed, cannot send message"); + return; + } + + response.setCharacterEncoding("UTF-8"); + response.setStatus(HttpServletResponse.SC_OK); + + // For notifications, don't write any content (empty response) + if (message instanceof McpSchema.JSONRPCNotification) { + logger.debug("Sending empty 200 response for notification"); + // Just complete the response with no content + } + else { + // For requests/responses, write JSON content + String jsonText = objectMapper.writeValueAsString(message); + PrintWriter writer = response.getWriter(); + writer.write(jsonText); + writer.flush(); + logger.debug("Successfully sent immediate response: {}", jsonText); + } + } + catch (Exception e) { + logger.error("Failed to send message: {}", e.getMessage(), e); + try { + if (!response.isCommitted()) { + response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR); + } + } + catch (Exception ignored) { + } + } + finally { + asyncContext.complete(); + } + }); + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return objectMapper.convertValue(data, typeRef); + } + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + try { + asyncContext.complete(); + } + catch (Exception e) { + logger.debug("Error completing async context: {}", e.getMessage()); + } + logger.debug("HTTP transport closed gracefully"); + }); + } + + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index cc7d2abf8..454f1fc4b 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -5,6 +5,8 @@ package io.modelcontextprotocol.spec; import com.fasterxml.jackson.core.type.TypeReference; + +import io.modelcontextprotocol.spec.McpSchema.McpId; import io.modelcontextprotocol.util.Assert; import org.reactivestreams.Publisher; import org.slf4j.Logger; @@ -47,7 +49,7 @@ public class McpClientSession implements McpSession { private final McpClientTransport transport; /** Map of pending responses keyed by request ID */ - private final ConcurrentHashMap> pendingResponses = new ConcurrentHashMap<>(); + private final ConcurrentHashMap> pendingResponses = new ConcurrentHashMap<>(); /** Map of request handlers keyed by method name */ private final ConcurrentHashMap> requestHandlers = new ConcurrentHashMap<>(); @@ -231,10 +233,10 @@ private Mono handleIncomingNotification(McpSchema.JSONRPCNotification noti /** * Generates a unique request ID in a non-blocking way. Combines a session-specific * prefix with an atomic counter to ensure uniqueness. - * @return A unique request ID string + * @return A unique request ID from String */ - private String generateRequestId() { - return this.sessionPrefix + "-" + this.requestCounter.getAndIncrement(); + private McpId generateRequestId() { + return McpId.of(this.sessionPrefix + "-" + this.requestCounter.getAndIncrement()); } /** @@ -247,7 +249,7 @@ private String generateRequestId() { */ @Override public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { - String requestId = this.generateRequestId(); + McpId requestId = this.generateRequestId(); return Mono.deferContextual(ctx -> Mono.create(pendingResponseSink -> { logger.debug("Sending message for method {}", method); diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpError.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpError.java index 13e43240b..6d177e0f9 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpError.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpError.java @@ -15,7 +15,7 @@ public McpError(JSONRPCError jsonRpcError) { } public McpError(Object error) { - super(error.toString()); + super(String.valueOf(error)); } public JSONRPCError getJsonRpcError() { diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index 9be585cea..f7414d0c9 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -18,12 +18,25 @@ import com.fasterxml.jackson.annotation.JsonSubTypes; import com.fasterxml.jackson.annotation.JsonTypeInfo; import com.fasterxml.jackson.annotation.JsonTypeInfo.As; +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.JsonToken; import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JsonDeserializer; +import com.fasterxml.jackson.databind.JsonMappingException; +import com.fasterxml.jackson.databind.JsonSerializer; import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializerProvider; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.annotation.JsonSerialize; + import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import static java.util.Objects.requireNonNull; + /** * Based on the JSON-RPC 2.0 * specification and the { + + @Override + public McpId deserialize(JsonParser p, DeserializationContext ctxt) throws IOException { + JsonToken t = p.getCurrentToken(); + if (t == JsonToken.VALUE_STRING) { + return new McpId(p.getText()); + } + else if (t.isNumeric()) { + return new McpId(p.getNumberValue()); + } + throw JsonMappingException.from(p, "MCP 'id' must be a non-null String or Number"); + } + + } + + public static class Serializer extends JsonSerializer { + + @Override + public void serialize(McpId id, JsonGenerator gen, SerializerProvider serializers) throws IOException { + if (id.isString()) { + gen.writeString(id.asString()); + } + else { + gen.writeNumber(id.asNumber().toString()); + } + } + + } + + } + public sealed interface Request permits InitializeRequest, CallToolRequest, CreateMessageRequest, ElicitRequest, CompleteRequest, GetPromptRequest, PaginatedRequest, ReadResourceRequest { @@ -200,18 +312,16 @@ public sealed interface JSONRPCMessage permits JSONRPCRequest, JSONRPCNotificati @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) - // TODO: batching support // @JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) public record JSONRPCRequest( // @formatter:off @JsonProperty("jsonrpc") String jsonrpc, @JsonProperty("method") String method, - @JsonProperty("id") Object id, + @JsonProperty("id") McpId id, @JsonProperty("params") Object params) implements JSONRPCMessage { } // @formatter:on @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) - // TODO: batching support // @JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) public record JSONRPCNotification( // @formatter:off @JsonProperty("jsonrpc") String jsonrpc, @@ -221,11 +331,10 @@ public record JSONRPCNotification( // @formatter:off @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) - // TODO: batching support // @JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) public record JSONRPCResponse( // @formatter:off @JsonProperty("jsonrpc") String jsonrpc, - @JsonProperty("id") Object id, + @JsonProperty("id") McpId id, @JsonProperty("result") Object result, @JsonProperty("error") JSONRPCError error) implements JSONRPCMessage { diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 86906d859..7f976621b 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -1,16 +1,23 @@ package io.modelcontextprotocol.spec; import java.time.Duration; +import java.util.ArrayList; +import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.server.McpAsyncServerExchange; +import io.modelcontextprotocol.spec.SseEvent; +import io.modelcontextprotocol.spec.McpSchema.McpId; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.publisher.MonoSink; import reactor.core.publisher.Sinks; @@ -23,7 +30,13 @@ public class McpServerSession implements McpSession { private static final Logger logger = LoggerFactory.getLogger(McpServerSession.class); - private final ConcurrentHashMap> pendingResponses = new ConcurrentHashMap<>(); + private final ConcurrentHashMap> pendingResponses = new ConcurrentHashMap<>(); + + private final ConcurrentHashMap transports = new ConcurrentHashMap<>(); + + private McpServerTransport listeningTransport; + + public static final String LISTENING_TRANSPORT = "listeningTransport"; private final String id; @@ -40,26 +53,29 @@ public class McpServerSession implements McpSession { private final Map notificationHandlers; - private final McpServerTransport transport; - private final Sinks.One exchangeSink = Sinks.one(); private final AtomicReference clientCapabilities = new AtomicReference<>(); private final AtomicReference clientInfo = new AtomicReference<>(); - private static final int STATE_UNINITIALIZED = 0; + public static final int STATE_UNINITIALIZED = 0; - private static final int STATE_INITIALIZING = 1; + public static final int STATE_INITIALIZING = 1; - private static final int STATE_INITIALIZED = 2; + public static final int STATE_INITIALIZED = 2; private final AtomicInteger state = new AtomicInteger(STATE_UNINITIALIZED); + private final AtomicLong eventCounter = new AtomicLong(0); + + private final Map eventTransports = new ConcurrentHashMap<>(); + + private final Map> transportEventHistories = new ConcurrentHashMap<>(); + /** * Creates a new server session with the given parameters and the transport to use. * @param id session id - * @param transport the transport to use * @param initHandler called when a * {@link io.modelcontextprotocol.spec.McpSchema.InitializeRequest} is received by the * server @@ -69,18 +85,33 @@ public class McpServerSession implements McpSession { * @param requestHandlers map of request handlers to use * @param notificationHandlers map of notification handlers to use */ - public McpServerSession(String id, Duration requestTimeout, McpServerTransport transport, + public McpServerSession(String id, Duration requestTimeout, McpServerTransport listeningTransport, InitRequestHandler initHandler, InitNotificationHandler initNotificationHandler, Map> requestHandlers, Map notificationHandlers) { this.id = id; this.requestTimeout = requestTimeout; - this.transport = transport; + this.listeningTransport = listeningTransport; this.initRequestHandler = initHandler; this.initNotificationHandler = initNotificationHandler; this.requestHandlers = requestHandlers; this.notificationHandlers = notificationHandlers; } + // Alternate constructor used by StreamableHttp servers + public McpServerSession(String id, Duration requestTimeout, InitRequestHandler initHandler, + InitNotificationHandler initNotificationHandler, Map> requestHandlers, + Map notificationHandlers) { + this(id, requestTimeout, null, initHandler, initNotificationHandler, requestHandlers, notificationHandlers); + } + + /** + * Retrieve the session initialization state + * @return session initialization state + */ + public int getState() { + return state.intValue(); + } + /** * Retrieve the session id. * @return session id @@ -89,6 +120,90 @@ public String getId() { return this.id; } + /** + * Increments the session-specific event counter, maps it to the given transport ID + * for replayability support, then returns the event ID + * @param transportId + * @return an event ID unique to the session + */ + public String incrementAndGetEventId(String transportId) { + final String eventId = String.valueOf(eventCounter.incrementAndGet()); + eventTransports.put(eventId, transportId); + return eventId; + } + + /** + * Used for replayability support to get the transport ID of a given event ID + * @param eventId + * @return The ID of the transport instance that the given event ID was sent over + */ + public String getTransportIdForEvent(String eventId) { + return eventTransports.get(eventId); + } + + /** + * Used for replayability support to set the event history of a given transport ID + * @param transportId + * @param eventHistory + */ + public void setTransportEventHistory(String transportId, Map eventHistory) { + transportEventHistories.put(transportId, eventHistory); + } + + /** + * Used for replayability support to retrieve the entire event history for a given + * transport ID + * @param transportId + * @return Map of SseEvent objects, keyed by event ID + */ + public Map getTransportEventHistory(String transportId) { + return transportEventHistories.get(transportId); + } + + /** + * Registers a transport for this session. + * @param transportId unique identifier for the transport + * @param transport the transport instance + */ + public void registerTransport(String transportId, McpServerTransport transport) { + if (transportId.equals(LISTENING_TRANSPORT)) { + this.listeningTransport = transport; + logger.debug("Registered listening transport for session {}", id); + return; + } + transports.put(transportId, transport); + logger.debug("Registered transport {} for session {}", transportId, id); + } + + /** + * Unregisters a transport from this session. + * @param transportId the transport identifier to remove + */ + public void unregisterTransport(String transportId) { + if (transportId.equals(LISTENING_TRANSPORT)) { + this.listeningTransport = null; + logger.debug("Unregistered listening transport for session {}", id); + return; + } + McpServerTransport removed = transports.remove(transportId); + if (removed != null) { + logger.debug("Unregistered transport {} from session {}", transportId, id); + } + } + + /** + * Gets a transport by its identifier. + * @param transportId the transport identifier + * @return the transport, or null if not found + */ + public McpServerTransport getTransport(String transportId) { + if (transportId.equals(LISTENING_TRANSPORT)) { + return this.listeningTransport; + } + logger.debug("Found transport {} in session {}", transportId, id); + return transports.get(transportId); + } + /** * Called upon successful initialization sequence between the client and the server * with the client capabilities and information. @@ -104,19 +219,35 @@ public void init(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Impl this.clientInfo.lazySet(clientInfo); } - private String generateRequestId() { - return this.id + "-" + this.requestCounter.getAndIncrement(); + public McpSchema.ClientCapabilities getClientCapabilities() { + return this.clientCapabilities.get(); + } + + public McpSchema.Implementation getClientInfo() { + return this.clientInfo.get(); + } + + private McpId generateRequestId() { + return McpId.of(this.id + "-" + this.requestCounter.getAndIncrement()); + } + + /** + * Gets a request handler by method name. + */ + public RequestHandler getRequestHandler(String method) { + return requestHandlers.get(method); } @Override public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { - String requestId = this.generateRequestId(); + McpId requestId = this.generateRequestId(); return Mono.create(sink -> { this.pendingResponses.put(requestId, sink); McpSchema.JSONRPCRequest jsonrpcRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, method, requestId, requestParams); - this.transport.sendMessage(jsonrpcRequest).subscribe(v -> { + + Flux.from(listeningTransport.sendMessage(jsonrpcRequest)).subscribe(v -> { }, error -> { this.pendingResponses.remove(requestId); sink.error(error); @@ -125,13 +256,12 @@ public Mono sendRequest(String method, Object requestParams, TypeReferenc if (jsonRpcResponse.error() != null) { sink.error(new McpError(jsonRpcResponse.error())); } + else if (typeRef.getType().equals(Void.class)) { + sink.complete(); + } else { - if (typeRef.getType().equals(Void.class)) { - sink.complete(); - } - else { - sink.next(this.transport.unmarshalFrom(jsonRpcResponse.result(), typeRef)); - } + T result = listeningTransport.unmarshalFrom(jsonRpcResponse.result(), typeRef); + sink.next(result); } }); } @@ -140,7 +270,7 @@ public Mono sendRequest(String method, Object requestParams, TypeReferenc public Mono sendNotification(String method, Object params) { McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, method, params); - return this.transport.sendMessage(jsonrpcNotification); + return this.listeningTransport.sendMessage(jsonrpcNotification); } /** @@ -170,13 +300,28 @@ public Mono handle(McpSchema.JSONRPCMessage message) { } else if (message instanceof McpSchema.JSONRPCRequest request) { logger.debug("Received request: {}", request); + final String transportId; + if (transports.isEmpty()) { + transportId = LISTENING_TRANSPORT; + } + else { + transportId = request.id().toString(); + } return handleIncomingRequest(request).onErrorResume(error -> { var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, error.getMessage(), null)); - // TODO: Should the error go to SSE or back as POST return? - return this.transport.sendMessage(errorResponse).then(Mono.empty()); - }).flatMap(this.transport::sendMessage); + McpServerTransport transport = getTransport(transportId); + return transport != null ? transport.sendMessage(errorResponse).then(Mono.empty()) : Mono.empty(); + }).flatMap(response -> { + McpServerTransport transport = getTransport(transportId); + if (transport != null) { + return transport.sendMessage(response); + } + else { + return Mono.error(new RuntimeException("Transport not found: " + transportId)); + } + }); } else if (message instanceof McpSchema.JSONRPCNotification notification) { // TODO handle errors for communication to without initialization @@ -203,8 +348,10 @@ private Mono handleIncomingRequest(McpSchema.JSONRPCR Mono resultMono; if (McpSchema.METHOD_INITIALIZE.equals(request.method())) { // TODO handle situation where already initialized! - McpSchema.InitializeRequest initializeRequest = transport.unmarshalFrom(request.params(), - new TypeReference() { + McpSchema.InitializeRequest initializeRequest = transports.isEmpty() ? listeningTransport + .unmarshalFrom(request.params(), new TypeReference() { + }) : transports.get(String.valueOf(request.id())) + .unmarshalFrom(request.params(), new TypeReference() { }); this.state.lazySet(STATE_INITIALIZING); @@ -222,6 +369,9 @@ private Mono handleIncomingRequest(McpSchema.JSONRPCR error.message(), error.data()))); } + // We would need to add request.id() as a parameter to handler.handle() if + // we want client-request-driven requests/notifications to go to the + // related stream resultMono = this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, request.params())); } return resultMono @@ -264,12 +414,28 @@ private MethodNotFoundError getMethodNotFoundError(String method) { @Override public Mono closeGracefully() { - return this.transport.closeGracefully(); + return Mono.defer(() -> { + List> closeTasks = new ArrayList<>(); + + // Add listening transport if it exists + if (listeningTransport != null) { + closeTasks.add(listeningTransport.closeGracefully()); + } + + // Add all transports from the map + closeTasks.addAll(transports.values().stream().map(McpServerTransport::closeGracefully).toList()); + + return Mono.when(closeTasks); + }); } @Override public void close() { - this.transport.close(); + if (listeningTransport != null) { + listeningTransport.close(); + } + transports.values().forEach(McpServerTransport::close); + transports.clear(); } /** @@ -334,6 +500,25 @@ public interface RequestHandler { } + /** + * A handler for client-initiated requests return Flux. + * + * @param the type of the response that is expected as a result of handling the + * request. + */ + public interface StreamingRequestHandler extends RequestHandler { + + /** + * Handles a request from the client which invokes a streamTool. + * @param exchange the exchange associated with the client that allows calling + * back to the connected client or inspecting its capabilities. + * @param params the parameters of the request. + * @return Flux that will emit the response to the request. + */ + Flux handleStreaming(McpAsyncServerExchange exchange, Object params); + + } + /** * Factory for creating server sessions which delegate to a provided 1:1 transport * with a connected client. @@ -350,4 +535,21 @@ public interface Factory { } + /** + * Factory for creating server sessions which delegate to a provided 1:1 transport + * with a connected client. + */ + @FunctionalInterface + public interface StreamableHttpSessionFactory { + + /** + * Creates a new 1:1 representation of the client-server interaction. + * @param transportId ID of the JSONRPCRequest/JSONRPCResponse the transport is + * serving. + * @return a new server session. + */ + McpServerSession create(String transportId); + + } + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/SseEvent.java b/mcp/src/main/java/io/modelcontextprotocol/spec/SseEvent.java new file mode 100644 index 000000000..f5f288cdd --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/SseEvent.java @@ -0,0 +1,4 @@ +package io.modelcontextprotocol.spec; + +public record SseEvent(String id, String event, String data) { +} \ No newline at end of file diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java index 3e89c8cef..cde86fe9c 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java @@ -17,6 +17,7 @@ import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.McpId; import io.modelcontextprotocol.spec.McpSchema.PaginatedRequest; import io.modelcontextprotocol.spec.McpSchema.Root; import org.junit.jupiter.api.Test; @@ -172,7 +173,7 @@ void testRootsListRequestHandling() { // Simulate incoming request McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, - McpSchema.METHOD_ROOTS_LIST, "test-id", null); + McpSchema.METHOD_ROOTS_LIST, McpId.of("test-id"), null); transport.simulateIncomingMessage(request); // Verify response @@ -180,7 +181,7 @@ void testRootsListRequestHandling() { assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; - assertThat(response.id()).isEqualTo("test-id"); + assertThat(response.id().toString()).isEqualTo("test-id"); assertThat(response.result()) .isEqualTo(new McpSchema.ListRootsResult(List.of(new Root("file:///test/path", "test-root")))); assertThat(response.error()).isNull(); @@ -309,7 +310,7 @@ void testSamplingCreateMessageRequestHandling() { // Simulate incoming request McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, - McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, "test-id", messageRequest); + McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, McpId.of("test-id"), messageRequest); transport.simulateIncomingMessage(request); // Verify response @@ -317,7 +318,7 @@ void testSamplingCreateMessageRequestHandling() { assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; - assertThat(response.id()).isEqualTo("test-id"); + assertThat(response.id().toString()).isEqualTo("test-id"); assertThat(response.error()).isNull(); McpSchema.CreateMessageResult result = transport.unmarshalFrom(response.result(), @@ -350,7 +351,7 @@ void testSamplingCreateMessageRequestHandlingWithoutCapability() { // Simulate incoming request McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, - McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, "test-id", messageRequest); + McpSchema.METHOD_SAMPLING_CREATE_MESSAGE, McpId.of("test-id"), messageRequest); transport.simulateIncomingMessage(request); // Verify error response @@ -358,7 +359,7 @@ void testSamplingCreateMessageRequestHandlingWithoutCapability() { assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; - assertThat(response.id()).isEqualTo("test-id"); + assertThat(response.id().toString()).isEqualTo("test-id"); assertThat(response.result()).isNull(); assertThat(response.error()).isNotNull(); assertThat(response.error().message()).contains("Method not found: sampling/createMessage"); @@ -414,7 +415,7 @@ void testElicitationCreateRequestHandling() { // Simulate incoming request McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, - McpSchema.METHOD_ELICITATION_CREATE, "test-id", elicitRequest); + McpSchema.METHOD_ELICITATION_CREATE, McpId.of("test-id"), elicitRequest); transport.simulateIncomingMessage(request); // Verify response @@ -422,7 +423,7 @@ void testElicitationCreateRequestHandling() { assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; - assertThat(response.id()).isEqualTo("test-id"); + assertThat(response.id().toString()).isEqualTo("test-id"); assertThat(response.error()).isNull(); McpSchema.ElicitResult result = transport.unmarshalFrom(response.result(), new TypeReference<>() { @@ -459,7 +460,7 @@ void testElicitationFailRequestHandling(McpSchema.ElicitResult.Action action) { // Simulate incoming request McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, - McpSchema.METHOD_ELICITATION_CREATE, "test-id", elicitRequest); + McpSchema.METHOD_ELICITATION_CREATE, McpId.of("test-id"), elicitRequest); transport.simulateIncomingMessage(request); // Verify response @@ -467,7 +468,7 @@ void testElicitationFailRequestHandling(McpSchema.ElicitResult.Action action) { assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; - assertThat(response.id()).isEqualTo("test-id"); + assertThat(response.id().toString()).isEqualTo("test-id"); assertThat(response.error()).isNull(); McpSchema.ElicitResult result = transport.unmarshalFrom(response.result(), new TypeReference<>() { @@ -498,7 +499,7 @@ void testElicitationCreateRequestHandlingWithoutCapability() { // Simulate incoming request McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, - McpSchema.METHOD_ELICITATION_CREATE, "test-id", elicitRequest); + McpSchema.METHOD_ELICITATION_CREATE, McpId.of("test-id"), elicitRequest); transport.simulateIncomingMessage(request); // Verify error response @@ -506,7 +507,7 @@ void testElicitationCreateRequestHandlingWithoutCapability() { assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; - assertThat(response.id()).isEqualTo("test-id"); + assertThat(response.id().toString()).isEqualTo("test-id"); assertThat(response.result()).isNull(); assertThat(response.error()).isNotNull(); assertThat(response.error().message()).contains("Method not found: elicitation/create"); @@ -535,7 +536,7 @@ void testPingMessageRequestHandling() { // Simulate incoming ping request from server McpSchema.JSONRPCRequest pingRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, - McpSchema.METHOD_PING, "ping-id", null); + McpSchema.METHOD_PING, McpId.of("ping-id"), null); transport.simulateIncomingMessage(pingRequest); // Verify response @@ -543,7 +544,7 @@ void testPingMessageRequestHandling() { assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCResponse.class); McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse) sentMessage; - assertThat(response.id()).isEqualTo("ping-id"); + assertThat(response.id().toString()).isEqualTo("ping-id"); assertThat(response.error()).isNull(); assertThat(response.result()).isInstanceOf(Map.class); assertThat(((Map) response.result())).isEmpty(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java index e4348be25..fdcb15933 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java @@ -17,6 +17,8 @@ import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; +import io.modelcontextprotocol.spec.McpSchema.McpId; + import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -106,7 +108,7 @@ void cleanup() { @Test void testMessageProcessing() { // Create a test message - JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", + JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", McpId.of("test-id"), Map.of("key", "value")); // Simulate receiving the message @@ -137,7 +139,7 @@ void testResponseMessageProcessing() { """); // Create and send a request message - JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", + JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", McpId.of("test-id"), Map.of("key", "value")); // Verify message handling @@ -161,7 +163,7 @@ void testErrorMessageProcessing() { """); // Create and send a request message - JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", + JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", McpId.of("test-id"), Map.of("key", "value")); // Verify message handling @@ -191,7 +193,7 @@ void testGracefulShutdown() { StepVerifier.create(transport.closeGracefully()).verifyComplete(); // Create a test message - JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", + JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", McpId.of("test-id"), Map.of("key", "value")); // Verify message is not processed after shutdown @@ -236,10 +238,10 @@ void testMultipleMessageProcessing() { """); // Create and send corresponding messages - JSONRPCRequest message1 = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method1", "id1", + JSONRPCRequest message1 = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method1", McpId.of("id1"), Map.of("key", "value1")); - JSONRPCRequest message2 = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method2", "id2", + JSONRPCRequest message2 = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method2", McpId.of("id2"), Map.of("key", "value2")); // Verify both messages are processed diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/McpServerProtocolVersionTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/McpServerProtocolVersionTests.java index f643f1ba3..2acce4d40 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/McpServerProtocolVersionTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/McpServerProtocolVersionTests.java @@ -10,6 +10,8 @@ import io.modelcontextprotocol.MockMcpServerTransport; import io.modelcontextprotocol.MockMcpServerTransportProvider; import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.McpId; + import org.junit.jupiter.api.Test; import static org.assertj.core.api.Assertions.assertThat; @@ -23,7 +25,7 @@ class McpServerProtocolVersionTests { private static final McpSchema.Implementation CLIENT_INFO = new McpSchema.Implementation("test-client", "1.0.0"); - private McpSchema.JSONRPCRequest jsonRpcInitializeRequest(String requestId, String protocolVersion) { + private McpSchema.JSONRPCRequest jsonRpcInitializeRequest(McpId requestId, String protocolVersion) { return new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, requestId, new McpSchema.InitializeRequest(protocolVersion, null, CLIENT_INFO)); } @@ -34,7 +36,7 @@ void shouldUseLatestVersionByDefault() { var transportProvider = new MockMcpServerTransportProvider(serverTransport); McpAsyncServer server = McpServer.async(transportProvider).serverInfo(SERVER_INFO).build(); - String requestId = UUID.randomUUID().toString(); + McpId requestId = McpId.of(UUID.randomUUID().toString()); transportProvider .simulateIncomingMessage(jsonRpcInitializeRequest(requestId, McpSchema.LATEST_PROTOCOL_VERSION)); @@ -60,7 +62,7 @@ void shouldNegotiateSpecificVersion() { server.setProtocolVersions(List.of(oldVersion, McpSchema.LATEST_PROTOCOL_VERSION)); - String requestId = UUID.randomUUID().toString(); + McpId requestId = McpId.of(UUID.randomUUID().toString()); transportProvider.simulateIncomingMessage(jsonRpcInitializeRequest(requestId, oldVersion)); @@ -83,7 +85,7 @@ void shouldSuggestLatestVersionForUnsupportedVersion() { McpAsyncServer server = McpServer.async(transportProvider).serverInfo(SERVER_INFO).build(); - String requestId = UUID.randomUUID().toString(); + McpId requestId = McpId.of(UUID.randomUUID().toString()); transportProvider.simulateIncomingMessage(jsonRpcInitializeRequest(requestId, unsupportedVersion)); @@ -111,7 +113,8 @@ void shouldUseHighestVersionWhenMultipleSupported() { server.setProtocolVersions(List.of(oldVersion, middleVersion, latestVersion)); - String requestId = UUID.randomUUID().toString(); + McpId requestId = McpId.of(UUID.randomUUID().toString()); + transportProvider.simulateIncomingMessage(jsonRpcInitializeRequest(requestId, latestVersion)); McpSchema.JSONRPCMessage response = serverTransport.getLastSentMessage(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/StreamableHttpMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/StreamableHttpMcpAsyncServerTests.java new file mode 100644 index 000000000..13114e5c9 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/StreamableHttpMcpAsyncServerTests.java @@ -0,0 +1,22 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.server.transport.StreamableHttpServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import org.junit.jupiter.api.Timeout; + +/** + * Tests for {@link McpAsyncServer} using {@link StreamableHttpServerTransportProvider}. + */ +@Timeout(15) // Giving extra time beyond the client timeout +class StreamableHttpMcpAsyncServerTests extends AbstractMcpAsyncServerTests { + + @Override + protected McpServerTransportProvider createMcpTransportProvider() { + return new StreamableHttpServerTransportProvider(); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/StreamableHttpMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/StreamableHttpMcpSyncServerTests.java new file mode 100644 index 000000000..568abd741 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/StreamableHttpMcpSyncServerTests.java @@ -0,0 +1,22 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.server.transport.StreamableHttpServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import org.junit.jupiter.api.Timeout; + +/** + * Tests for {@link McpSyncServer} using {@link StreamableHttpServerTransportProvider}. + */ +@Timeout(15) // Giving extra time beyond the client timeout +class StreamableHttpMcpSyncServerTests extends AbstractMcpSyncServerTests { + + @Override + protected McpServerTransportProvider createMcpTransportProvider() { + return new StreamableHttpServerTransportProvider(); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProviderTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProviderTests.java new file mode 100644 index 000000000..78a1ab15a --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProviderTests.java @@ -0,0 +1,349 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.io.IOException; +import java.io.PrintWriter; +import java.io.StringWriter; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerSession; +import jakarta.servlet.AsyncContext; +import jakarta.servlet.ServletException; +import jakarta.servlet.ServletInputStream; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Unit tests for {@link StreamableHttpServerTransportProvider}. + */ +class StreamableHttpServerTransportProviderTests { + + private StreamableHttpServerTransportProvider transportProvider; + + private ObjectMapper objectMapper; + + private McpServerSession.StreamableHttpSessionFactory sessionFactory; + + private McpServerSession mockSession; + + @BeforeEach + void setUp() { + objectMapper = new ObjectMapper(); + mockSession = mock(McpServerSession.class); + sessionFactory = mock(McpServerSession.StreamableHttpSessionFactory.class); + + when(sessionFactory.create(anyString())).thenReturn(mockSession); + when(mockSession.getId()).thenReturn("test-session-id"); + when(mockSession.closeGracefully()).thenReturn(Mono.empty()); + when(mockSession.sendNotification(anyString(), any())).thenReturn(Mono.empty()); + + transportProvider = new StreamableHttpServerTransportProvider(objectMapper, "/mcp", null); + transportProvider.setStreamableHttpSessionFactory(sessionFactory); + } + + @Test + void shouldCreateSessionOnFirstRequest() { + // Test session creation directly through the getOrCreateSession method + String sessionId = "test-session-1"; + + McpServerSession session = transportProvider.getOrCreateSession(sessionId, true); + + assertThat(session).isNotNull(); + verify(sessionFactory).create(sessionId); + } + + @Test + void shouldHandleSSERequest() throws IOException, ServletException { + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletResponse response = mock(HttpServletResponse.class); + AsyncContext asyncContext = mock(AsyncContext.class); + StringWriter stringWriter = new StringWriter(); + PrintWriter printWriter = new PrintWriter(stringWriter); + + String sessionId = "test-session-2"; + when(request.getRequestURI()).thenReturn("/mcp"); + when(request.getMethod()).thenReturn("GET"); + when(request.getHeader("Accept")).thenReturn("text/event-stream"); + when(request.getHeader("Mcp-Session-Id")).thenReturn(sessionId); + when(request.getHeaderNames()).thenReturn(Collections.enumeration(Collections.emptyList())); + when(request.startAsync()).thenReturn(asyncContext); + when(response.getWriter()).thenReturn(printWriter); + when(response.getHeader("Mcp-Session-Id")).thenReturn(sessionId); + + // First create a session + transportProvider.getOrCreateSession(sessionId, true); + + transportProvider.doGet(request, response); + + verify(response).setContentType("text/event-stream"); + verify(response).setCharacterEncoding("UTF-8"); + verify(response).setHeader("Cache-Control", "no-cache"); + verify(response).setHeader("Connection", "keep-alive"); + } + + @Test + void shouldNotifyClients() { + String sessionId = "test-session-3"; + transportProvider.getOrCreateSession(sessionId, true); + + String method = "test/notification"; + String params = "test message"; + + StepVerifier.create(transportProvider.notifyClients(method, params)).verifyComplete(); + + // Verify that the session was created + assertThat(transportProvider.getOrCreateSession(sessionId, false)).isNotNull(); + } + + @Test + void shouldCloseGracefully() { + String sessionId = "test-session-4"; + transportProvider.getOrCreateSession(sessionId, true); + + StepVerifier.create(transportProvider.closeGracefully()).verifyComplete(); + + verify(mockSession).closeGracefully(); + } + + @Test + void shouldHandleInvalidRequestURI() throws IOException, ServletException { + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletResponse response = mock(HttpServletResponse.class); + + when(request.getRequestURI()).thenReturn("/wrong-path"); + when(request.getHeaderNames()).thenReturn(Collections.enumeration(Collections.emptyList())); + + transportProvider.doGet(request, response); + transportProvider.doPost(request, response); + transportProvider.doDelete(request, response); + + verify(response, times(3)).sendError(HttpServletResponse.SC_NOT_FOUND); + } + + @Test + void shouldRejectNonJSONContentType() throws IOException, ServletException { + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletResponse response = mock(HttpServletResponse.class); + StringWriter stringWriter = new StringWriter(); + PrintWriter printWriter = new PrintWriter(stringWriter); + + when(request.getRequestURI()).thenReturn("/mcp"); + when(request.getMethod()).thenReturn("POST"); + when(request.getHeader("Content-Type")).thenReturn("text/plain"); + when(request.getHeaderNames()).thenReturn(Collections.enumeration(Collections.emptyList())); + when(response.getWriter()).thenReturn(printWriter); + + transportProvider.doPost(request, response); + + // The implementation uses sendErrorResponse which sets status to 400, not + // sendError with 415 + verify(response).setStatus(HttpServletResponse.SC_BAD_REQUEST); + verify(response).setContentType("application/json"); + } + + @Test + void shouldRejectInvalidAcceptHeader() throws IOException, ServletException { + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletResponse response = mock(HttpServletResponse.class); + StringWriter stringWriter = new StringWriter(); + PrintWriter printWriter = new PrintWriter(stringWriter); + + when(request.getRequestURI()).thenReturn("/mcp"); + when(request.getMethod()).thenReturn("GET"); + when(request.getHeader("Accept")).thenReturn("text/html"); + when(request.getHeaderNames()).thenReturn(Collections.enumeration(Collections.emptyList())); + when(response.getWriter()).thenReturn(printWriter); + + transportProvider.doGet(request, response); + + // The implementation uses sendErrorResponse which sets status to 400, not + // sendError with 406 + verify(response).setStatus(HttpServletResponse.SC_BAD_REQUEST); + verify(response).setContentType("application/json"); + } + + @Test + void shouldRequireSessionIdForSSE() throws IOException, ServletException { + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletResponse response = mock(HttpServletResponse.class); + StringWriter stringWriter = new StringWriter(); + PrintWriter printWriter = new PrintWriter(stringWriter); + + when(request.getRequestURI()).thenReturn("/mcp"); + when(request.getMethod()).thenReturn("GET"); + when(request.getHeader("Accept")).thenReturn("text/event-stream"); + when(request.getHeader("Mcp-Session-Id")).thenReturn(null); + when(request.getHeaderNames()).thenReturn(Collections.enumeration(Collections.emptyList())); + when(response.getWriter()).thenReturn(printWriter); + + transportProvider.doGet(request, response); + + // The implementation uses sendErrorResponse which sets status to 400 + verify(response).setStatus(HttpServletResponse.SC_BAD_REQUEST); + verify(response).setContentType("application/json"); + } + + @Test + void shouldHandleSessionCleanup() throws IOException, ServletException { + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletResponse response = mock(HttpServletResponse.class); + + String sessionId = "test-session-5"; + when(request.getRequestURI()).thenReturn("/mcp"); + when(request.getMethod()).thenReturn("DELETE"); + when(request.getHeader("Mcp-Session-Id")).thenReturn(sessionId); + when(request.getHeaderNames()).thenReturn(Collections.enumeration(Collections.emptyList())); + + // Create a session first + transportProvider.getOrCreateSession(sessionId, true); + + transportProvider.doDelete(request, response); + + verify(response).setStatus(HttpServletResponse.SC_OK); + verify(mockSession).closeGracefully(); + } + + @Test + void shouldHandleDeleteNonExistentSession() throws IOException, ServletException { + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletResponse response = mock(HttpServletResponse.class); + StringWriter stringWriter = new StringWriter(); + PrintWriter printWriter = new PrintWriter(stringWriter); + + when(request.getRequestURI()).thenReturn("/mcp"); + when(request.getMethod()).thenReturn("DELETE"); + when(request.getHeader("Mcp-Session-Id")).thenReturn("non-existent-session"); + when(request.getHeaderNames()).thenReturn(Collections.enumeration(Collections.emptyList())); + when(response.getWriter()).thenReturn(printWriter); + + transportProvider.doDelete(request, response); + + // The implementation uses sendErrorResponse which sets status to 400, not + // sendError with 404 + verify(response).setStatus(HttpServletResponse.SC_BAD_REQUEST); + verify(response).setContentType("application/json"); + } + + @Test + void shouldHandleMultipleSessions() { + String sessionId1 = "session-1"; + String sessionId2 = "session-2"; + + // Create separate mock sessions for each ID + McpServerSession mockSession1 = mock(McpServerSession.class); + McpServerSession mockSession2 = mock(McpServerSession.class); + when(mockSession1.getId()).thenReturn(sessionId1); + when(mockSession2.getId()).thenReturn(sessionId2); + when(mockSession1.closeGracefully()).thenReturn(Mono.empty()); + when(mockSession2.closeGracefully()).thenReturn(Mono.empty()); + when(mockSession1.sendNotification(anyString(), any())).thenReturn(Mono.empty()); + when(mockSession2.sendNotification(anyString(), any())).thenReturn(Mono.empty()); + + // Configure factory to return different sessions for different IDs + when(sessionFactory.create(sessionId1)).thenReturn(mockSession1); + when(sessionFactory.create(sessionId2)).thenReturn(mockSession2); + + McpServerSession session1 = transportProvider.getOrCreateSession(sessionId1, true); + McpServerSession session2 = transportProvider.getOrCreateSession(sessionId2, true); + + assertThat(session1).isNotNull(); + assertThat(session2).isNotNull(); + assertThat(session1).isNotSameAs(session2); + + // Verify both sessions are created with different IDs + verify(sessionFactory, times(2)).create(anyString()); + } + + @Test + void shouldReuseExistingSession() { + String sessionId = "test-session-6"; + + McpServerSession session1 = transportProvider.getOrCreateSession(sessionId, true); + McpServerSession session2 = transportProvider.getOrCreateSession(sessionId, false); + + assertThat(session1).isSameAs(session2); + verify(sessionFactory, times(1)).create(sessionId); + } + + @Test + void shouldHandleAsyncTimeout() throws IOException, ServletException { + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletResponse response = mock(HttpServletResponse.class); + AsyncContext asyncContext = mock(AsyncContext.class); + StringWriter stringWriter = new StringWriter(); + PrintWriter printWriter = new PrintWriter(stringWriter); + + when(request.getRequestURI()).thenReturn("/mcp"); + when(request.getMethod()).thenReturn("GET"); + when(request.getHeader("Accept")).thenReturn("text/event-stream"); + when(request.getHeader("Mcp-Session-Id")).thenReturn("test-session"); + when(request.getHeaderNames()).thenReturn(Collections.enumeration(Collections.emptyList())); + when(request.startAsync()).thenReturn(asyncContext); + when(response.getWriter()).thenReturn(printWriter); + when(response.getHeader("Mcp-Session-Id")).thenReturn("test-session"); + + transportProvider.getOrCreateSession("test-session", true); + transportProvider.doGet(request, response); + + verify(asyncContext).setTimeout(0L); // Updated to match actual implementation + } + + @Test + void shouldBuildWithCustomConfiguration() { + ObjectMapper customMapper = new ObjectMapper(); + String customEndpoint = "/custom-mcp"; + + StreamableHttpServerTransportProvider provider = StreamableHttpServerTransportProvider.builder() + .withObjectMapper(customMapper) + .withMcpEndpoint(customEndpoint) + .withSessionIdProvider(() -> "custom-session-id") + .build(); + + assertThat(provider).isNotNull(); + } + + @Test + void shouldHandleBuilderValidation() { + try { + StreamableHttpServerTransportProvider.builder().withObjectMapper(null).build(); + } + catch (IllegalArgumentException e) { + assertThat(e.getMessage()).contains("ObjectMapper must not be null"); + } + + try { + StreamableHttpServerTransportProvider.builder().withMcpEndpoint("").build(); + } + catch (IllegalArgumentException e) { + assertThat(e.getMessage()).contains("MCP endpoint must not be empty"); + } + } + +} \ No newline at end of file diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java b/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java index f72be43e0..4c7bfbc0c 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java @@ -9,6 +9,8 @@ import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.MockMcpClientTransport; +import io.modelcontextprotocol.spec.McpSchema.McpId; + import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -144,7 +146,7 @@ void testRequestHandling() { // Simulate incoming request McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, ECHO_METHOD, - "test-id", echoMessage); + McpId.of("test-id"), echoMessage); transport.simulateIncomingMessage(request); // Verify response @@ -179,7 +181,7 @@ void testNotificationHandling() { void testUnknownMethodHandling() { // Simulate incoming request for unknown method McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, "unknown.method", - "test-id", null); + McpId.of("test-id"), null); transport.simulateIncomingMessage(request); // Verify error response diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java b/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java index ea063e4e3..782a4cf4c 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java @@ -19,6 +19,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.exc.InvalidTypeIdException; +import io.modelcontextprotocol.spec.McpSchema.McpId; import io.modelcontextprotocol.spec.McpSchema.TextResourceContents; import net.javacrumbs.jsonunit.core.Option; @@ -240,8 +241,8 @@ void testJSONRPCRequest() throws Exception { Map params = new HashMap<>(); params.put("key", "value"); - McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method_name", 1, - params); + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method_name", + McpId.of(1), params); String value = mapper.writeValueAsString(request); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) @@ -272,7 +273,8 @@ void testJSONRPCResponse() throws Exception { Map result = new HashMap<>(); result.put("result_key", "result_value"); - McpSchema.JSONRPCResponse response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, 1, result, null); + McpSchema.JSONRPCResponse response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, McpId.of(1), + result, null); String value = mapper.writeValueAsString(response); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) @@ -287,7 +289,8 @@ void testJSONRPCResponseWithError() throws Exception { McpSchema.JSONRPCResponse.JSONRPCError error = new McpSchema.JSONRPCResponse.JSONRPCError( McpSchema.ErrorCodes.INVALID_REQUEST, "Invalid request", null); - McpSchema.JSONRPCResponse response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, 1, null, error); + McpSchema.JSONRPCResponse response = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, McpId.of(1), null, + error); String value = mapper.writeValueAsString(response); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER)