diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java index e7b7c8ee9..e0e1094cc 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java @@ -125,13 +125,14 @@ public Mono connect(Function, Mono> onClose = () -> { - DefaultMcpTransportSession transportSession = this.activeSession.get(); - return transportSession.sessionId().isEmpty() ? Mono.empty() - : webClient.delete().uri(this.endpoint).headers(httpHeaders -> { - httpHeaders.add("mcp-session-id", transportSession.sessionId().get()); - }).retrieve().toBodilessEntity().doOnError(e -> logger.info("Got response {}", e)).then(); - }; + Function> onClose = sessionId -> sessionId == null ? Mono.empty() + : webClient.delete().uri(this.endpoint).headers(httpHeaders -> { + httpHeaders.add("mcp-session-id", sessionId); + }) + .retrieve() + .toBodilessEntity() + .doOnError(e -> logger.warn("Got error when closing transport", e)) + .then(); return new DefaultMcpTransportSession(onClose); } @@ -192,6 +193,7 @@ private Mono reconnect(McpTransportStream stream) { }) .exchangeToFlux(response -> { if (isEventStream(response)) { + logger.debug("Established SSE stream via GET"); return eventStream(stream, response); } else if (isNotAllowed(response)) { @@ -208,6 +210,7 @@ else if (isNotFound(response)) { }).flux(); } }) + .flatMap(jsonrpcMessage -> this.handler.get().apply(Mono.just(jsonrpcMessage))) .onErrorComplete(t -> { this.handleException(t); return true; @@ -274,6 +277,7 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { else { MediaType mediaType = contentType.get(); if (mediaType.isCompatibleWith(MediaType.TEXT_EVENT_STREAM)) { + logger.debug("Established SSE stream via POST"); // communicate to caller that the message was delivered sink.success(); // starting a stream diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 049bea008..460bc0195 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -19,6 +19,8 @@ import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; import io.modelcontextprotocol.spec.McpSchema.GetPromptRequest; import io.modelcontextprotocol.spec.McpSchema.Prompt; import io.modelcontextprotocol.spec.McpSchema.Resource; @@ -77,7 +79,9 @@ McpAsyncClient client(McpClientTransport transport, Function Mono.just(new CreateMessageResult(McpSchema.Role.USER, + new McpSchema.TextContent("Oh, hi!"), "modelId", CreateMessageResult.StopReason.END_TURN))) + .capabilities(ClientCapabilities.builder().roots(true).sampling().build()); builder = customizer.apply(builder); client.set(builder.build()); }).doesNotThrowAnyException(); @@ -189,6 +193,22 @@ void testCallTool() { }); } + @Test + void testSampling() { + withClient(createMcpTransport(), mcpAsyncClient -> { + CallToolRequest callToolRequest = new CallToolRequest("sampleLLM", + Map.of("prompt", "Hello MCP Spring AI!")); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(callToolRequest))) + .consumeNextWith(callToolResult -> { + assertThat(callToolResult).isNotNull().satisfies(result -> { + assertThat(result.content()).isNotNull(); + assertThat(result.isError()).isNull(); + }); + }) + .verifyComplete(); + }); + } + @Test void testCallToolWithInvalidTool() { withClient(createMcpTransport(), mcpAsyncClient -> { @@ -424,6 +444,20 @@ void testInitializeWithSamplingCapability() { }); } + @Test + void testInitializeWithElicitationCapability() { + ClientCapabilities capabilities = ClientCapabilities.builder().elicitation().build(); + ElicitResult elicitResult = ElicitResult.builder() + .message(ElicitResult.Action.ACCEPT) + .content(Map.of("foo", "bar")) + .build(); + withClient(createMcpTransport(), + builder -> builder.capabilities(capabilities).elicitation(request -> Mono.just(elicitResult)), + client -> { + StepVerifier.create(client.initialize()).expectNextMatches(Objects::nonNull).verifyComplete(); + }); + } + @Test void testInitializeWithAllCapabilities() { var capabilities = ClientCapabilities.builder() @@ -435,7 +469,11 @@ void testInitializeWithAllCapabilities() { Function> samplingHandler = request -> Mono .just(CreateMessageResult.builder().message("test").model("test-model").build()); - withClient(createMcpTransport(), builder -> builder.capabilities(capabilities).sampling(samplingHandler), + Function> elicitationHandler = request -> Mono + .just(ElicitResult.builder().message(ElicitResult.Action.ACCEPT).content(Map.of("foo", "bar")).build()); + + withClient(createMcpTransport(), + builder -> builder.capabilities(capabilities).sampling(samplingHandler).elicitation(elicitationHandler), client -> StepVerifier.create(client.initialize()).assertNext(result -> { diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java index d06d5b325..56cdeaf7f 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java @@ -10,7 +10,7 @@ import java.util.Optional; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Supplier; +import java.util.function.Function; /** * Default implementation of {@link McpTransportSession} which manages the open @@ -29,9 +29,9 @@ public class DefaultMcpTransportSession implements McpTransportSession sessionId = new AtomicReference<>(); - private final Supplier> onClose; + private final Function> onClose; - public DefaultMcpTransportSession(Supplier> onClose) { + public DefaultMcpTransportSession(Function> onClose) { this.onClose = onClose; } @@ -73,7 +73,8 @@ public void close() { @Override public Mono closeGracefully() { - return Mono.from(this.onClose.get()).then(Mono.fromRunnable(this.openConnections::dispose)); + return Mono.from(this.onClose.apply(this.sessionId.get())) + .then(Mono.fromRunnable(this.openConnections::dispose)); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java index 37f9e71a7..1b7c77c57 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientTests.java @@ -80,7 +80,9 @@ McpAsyncClient client(McpClientTransport transport, Function Mono.just(new CreateMessageResult(McpSchema.Role.USER, + new McpSchema.TextContent("Oh, hi!"), "modelId", CreateMessageResult.StopReason.END_TURN))) + .capabilities(ClientCapabilities.builder().roots(true).sampling().build()); builder = customizer.apply(builder); client.set(builder.build()); }).doesNotThrowAnyException(); @@ -192,6 +194,22 @@ void testCallTool() { }); } + @Test + void testSampling() { + withClient(createMcpTransport(), mcpAsyncClient -> { + CallToolRequest callToolRequest = new CallToolRequest("sampleLLM", + Map.of("prompt", "Hello MCP Spring AI!")); + StepVerifier.create(mcpAsyncClient.initialize().then(mcpAsyncClient.callTool(callToolRequest))) + .consumeNextWith(callToolResult -> { + assertThat(callToolResult).isNotNull().satisfies(result -> { + assertThat(result.content()).isNotNull(); + assertThat(result.isError()).isNull(); + }); + }) + .verifyComplete(); + }); + } + @Test void testCallToolWithInvalidTool() { withClient(createMcpTransport(), mcpAsyncClient -> {