From 27852d357859dfee981633e56a41730e9490b4ad Mon Sep 17 00:00:00 2001 From: bzsurbhi Date: Wed, 2 Jul 2025 13:31:18 -0700 Subject: [PATCH 1/2] feat: Add completions context --- .../WebFluxSseIntegrationTests.java | 2 +- .../server/McpAsyncServer.java | 9 +- .../modelcontextprotocol/spec/McpSchema.java | 7 +- .../server/McpCompletionTests.java | 301 ++++++++++++++++++ 4 files changed, 316 insertions(+), 3 deletions(-) create mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/McpCompletionTests.java diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 2f85654e8..994047945 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -1011,7 +1011,7 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) { assertThat(initResult).isNotNull(); CompleteRequest request = new CompleteRequest(new PromptReference("ref/prompt", "code_review"), - new CompleteRequest.CompleteArgument("language", "py")); + new CompleteRequest.CompleteArgument("language", "py"), null); CompleteResult result = mcpClient.completeCompletion(request); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 02ad955b9..6e290a997 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -721,6 +721,7 @@ private McpSchema.CompleteRequest parseCompletionParams(Object object) { Map params = (Map) object; Map refMap = (Map) params.get("ref"); Map argMap = (Map) params.get("argument"); + Map contextMap = (Map) params.get("context"); String refType = (String) refMap.get("type"); @@ -735,7 +736,13 @@ private McpSchema.CompleteRequest parseCompletionParams(Object object) { McpSchema.CompleteRequest.CompleteArgument argument = new McpSchema.CompleteRequest.CompleteArgument(argName, argValue); - return new McpSchema.CompleteRequest(ref, argument); + McpSchema.CompleteRequest.CompleteContext context = null; + if (contextMap != null) { + Map arguments = (Map) contextMap.get("arguments"); + context = new McpSchema.CompleteRequest.CompleteContext(arguments); + } + + return new McpSchema.CompleteRequest(ref, argument, context); } /** diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index 597130946..2d15f1f45 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -1447,11 +1447,16 @@ public String identifier() { @JsonIgnoreProperties(ignoreUnknown = true) public record CompleteRequest(// @formatter:off @JsonProperty("ref") McpSchema.CompleteReference ref, - @JsonProperty("argument") CompleteArgument argument) implements Request { + @JsonProperty("argument") CompleteArgument argument, + @JsonProperty("context") CompleteContext context) implements Request { public record CompleteArgument( @JsonProperty("name") String name, @JsonProperty("value") String value) { + } + + public record CompleteContext( + @JsonProperty("arguments") Map arguments) { }// @formatter:on } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/McpCompletionTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/McpCompletionTests.java new file mode 100644 index 000000000..d6b995d6f --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/McpCompletionTests.java @@ -0,0 +1,301 @@ +package io.modelcontextprotocol.server; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; + +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import static org.assertj.core.api.Assertions.assertThat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider; +import io.modelcontextprotocol.server.transport.TomcatTestUtil; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CompleteRequest; +import io.modelcontextprotocol.spec.McpSchema.CompleteResult; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptArgument; +import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult; +import io.modelcontextprotocol.spec.McpSchema.ResourceReference; +import io.modelcontextprotocol.spec.McpSchema.PromptReference; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpError; + +/** + * Tests for completion functionality with context support. + */ +class McpCompletionTests { + + private HttpServletSseServerTransportProvider mcpServerTransportProvider; + + private static final String CUSTOM_MESSAGE_ENDPOINT = "/otherPath/mcp/message"; + + McpClient.SyncSpec clientBuilder; + + private Tomcat tomcat; + + @BeforeEach + public void before() { + // Create and con figure the transport provider + mcpServerTransportProvider = HttpServletSseServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) + .build(); + + tomcat = TomcatTestUtil.createTomcatServer("", 3400, mcpServerTransportProvider); + try { + tomcat.start(); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + this.clientBuilder = McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + 3400).build()); + } + + @AfterEach + public void after() { + if (mcpServerTransportProvider != null) { + mcpServerTransportProvider.closeGracefully().block(); + } + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + @Test + void testCompletionHandlerReceivesContext() { + AtomicReference receivedRequest = new AtomicReference<>(); + BiFunction completionHandler = (exchange, request) -> { + receivedRequest.set(request); + return new CompleteResult(new CompleteResult.CompleteCompletion(List.of("test-completion"), 1, false)); + }; + + ResourceReference resourceRef = new ResourceReference("ref/resource", "test://resource/{param}"); + + McpSchema.Resource resource = new McpSchema.Resource("test://resource/{param}", "Test Resource", + "A resource for testing", "text/plain", 123L, null); + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .capabilities(ServerCapabilities.builder().completions().build()) + .resources(new McpServerFeatures.SyncResourceSpecification(resource, + (exchange, req) -> new ReadResourceResult(List.of()))) + .completions(new McpServerFeatures.SyncCompletionSpecification(resourceRef, completionHandler)) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) + .build();) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Test with context + CompleteRequest request = new CompleteRequest(resourceRef, + new CompleteRequest.CompleteArgument("param", "test"), + new CompleteRequest.CompleteContext(Map.of("previous", "value"))); + + CompleteResult result = mcpClient.completeCompletion(request); + + // Verify handler received the context + assertThat(receivedRequest.get().context()).isNotNull(); + assertThat(receivedRequest.get().context().arguments()).containsEntry("previous", "value"); + assertThat(result.completion().values()).containsExactly("test-completion"); + } + + mcpServer.close(); + } + + @Test + void testCompletionBackwardCompatibility() { + AtomicReference contextWasNull = new AtomicReference<>(false); + BiFunction completionHandler = (exchange, request) -> { + contextWasNull.set(request.context() == null); + return new CompleteResult( + new CompleteResult.CompleteCompletion(List.of("no-context-completion"), 1, false)); + }; + + McpSchema.Prompt prompt = new Prompt("test-prompt", "this is a test prompt", + List.of(new PromptArgument("arg", "string", false))); + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .capabilities(ServerCapabilities.builder().completions().build()) + .prompts(new McpServerFeatures.SyncPromptSpecification(prompt, + (mcpSyncServerExchange, getPromptRequest) -> null)) + .completions(new McpServerFeatures.SyncCompletionSpecification( + new PromptReference("ref/prompt", "test-prompt"), completionHandler)) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) + .build();) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Test without context + CompleteRequest request = new CompleteRequest(new PromptReference("ref/prompt", "test-prompt"), + new CompleteRequest.CompleteArgument("arg", "val"), null); + + CompleteResult result = mcpClient.completeCompletion(request); + + // Verify context was null + assertThat(contextWasNull.get()).isTrue(); + assertThat(result.completion().values()).containsExactly("no-context-completion"); + } + + mcpServer.close(); + } + + @Test + void testDependentCompletionScenario() { + BiFunction completionHandler = (exchange, request) -> { + // Simulate database/table completion scenario + if (request.ref() instanceof ResourceReference resourceRef) { + if ("db://{database}/{table}".equals(resourceRef.uri())) { + if ("database".equals(request.argument().name())) { + // Complete database names + return new CompleteResult(new CompleteResult.CompleteCompletion( + List.of("users_db", "products_db", "analytics_db"), 3, false)); + } + else if ("table".equals(request.argument().name())) { + // Complete table names based on selected database + if (request.context() != null && request.context().arguments() != null) { + String db = request.context().arguments().get("database"); + if ("users_db".equals(db)) { + return new CompleteResult(new CompleteResult.CompleteCompletion( + List.of("users", "sessions", "permissions"), 3, false)); + } + else if ("products_db".equals(db)) { + return new CompleteResult(new CompleteResult.CompleteCompletion( + List.of("products", "categories", "inventory"), 3, false)); + } + } + } + } + } + return new CompleteResult(new CompleteResult.CompleteCompletion(List.of(), 0, false)); + }; + + McpSchema.Resource resource = new McpSchema.Resource("db://{database}/{table}", "Database Table", + "Resource representing a table in a database", "application/json", 456L, null); + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .capabilities(ServerCapabilities.builder().completions().build()) + .resources(new McpServerFeatures.SyncResourceSpecification(resource, + (exchange, req) -> new ReadResourceResult(List.of()))) + .completions(new McpServerFeatures.SyncCompletionSpecification( + new ResourceReference("ref/resource", "db://{database}/{table}"), completionHandler)) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) + .build();) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // First, complete database + CompleteRequest dbRequest = new CompleteRequest( + new ResourceReference("ref/resource", "db://{database}/{table}"), + new CompleteRequest.CompleteArgument("database", ""), null); + + CompleteResult dbResult = mcpClient.completeCompletion(dbRequest); + assertThat(dbResult.completion().values()).contains("users_db", "products_db"); + + // Then complete table with database context + CompleteRequest tableRequest = new CompleteRequest( + new ResourceReference("ref/resource", "db://{database}/{table}"), + new CompleteRequest.CompleteArgument("table", ""), + new CompleteRequest.CompleteContext(Map.of("database", "users_db"))); + + CompleteResult tableResult = mcpClient.completeCompletion(tableRequest); + assertThat(tableResult.completion().values()).containsExactly("users", "sessions", "permissions"); + + // Different database gives different tables + CompleteRequest tableRequest2 = new CompleteRequest( + new ResourceReference("ref/resource", "db://{database}/{table}"), + new CompleteRequest.CompleteArgument("table", ""), + new CompleteRequest.CompleteContext(Map.of("database", "products_db"))); + + CompleteResult tableResult2 = mcpClient.completeCompletion(tableRequest2); + assertThat(tableResult2.completion().values()).containsExactly("products", "categories", "inventory"); + } + + mcpServer.close(); + } + + @Test + void testCompletionErrorOnMissingContext() { + BiFunction completionHandler = (exchange, request) -> { + if (request.ref() instanceof ResourceReference resourceRef) { + if ("db://{database}/{table}".equals(resourceRef.uri())) { + if ("table".equals(request.argument().name())) { + // Check if database context is provided + if (request.context() == null || request.context().arguments() == null + || !request.context().arguments().containsKey("database")) { + throw new McpError("Please select a database first to see available tables"); + } + // Normal completion if context is provided + String db = request.context().arguments().get("database"); + if ("test_db".equals(db)) { + return new CompleteResult(new CompleteResult.CompleteCompletion( + List.of("users", "orders", "products"), 3, false)); + } + } + } + } + return new CompleteResult(new CompleteResult.CompleteCompletion(List.of(), 0, false)); + }; + + McpSchema.Resource resource = new McpSchema.Resource("db://{database}/{table}", "Database Table", + "Resource representing a table in a database", "application/json", 456L, null); + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .capabilities(ServerCapabilities.builder().completions().build()) + .resources(new McpServerFeatures.SyncResourceSpecification(resource, + (exchange, req) -> new ReadResourceResult(List.of()))) + .completions(new McpServerFeatures.SyncCompletionSpecification( + new ResourceReference("ref/resource", "db://{database}/{table}"), completionHandler)) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample" + "client", "0.0.0")) + .build();) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Try to complete table without database context - should raise error + CompleteRequest requestWithoutContext = new CompleteRequest( + new ResourceReference("ref/resource", "db://{database}/{table}"), + new CompleteRequest.CompleteArgument("table", ""), null); + + assertThatExceptionOfType(McpError.class) + .isThrownBy(() -> mcpClient.completeCompletion(requestWithoutContext)) + .withMessageContaining("Please select a database first"); + + // Now complete with proper context - should work normally + CompleteRequest requestWithContext = new CompleteRequest( + new ResourceReference("ref/resource", "db://{database}/{table}"), + new CompleteRequest.CompleteArgument("table", ""), + new CompleteRequest.CompleteContext(Map.of("database", "test_db"))); + + CompleteResult resultWithContext = mcpClient.completeCompletion(requestWithContext); + assertThat(resultWithContext.completion().values()).containsExactly("users", "orders", "products"); + } + + mcpServer.close(); + } + +} \ No newline at end of file From cd82731c50b63bc2eed16651ed0afdf53d81543f Mon Sep 17 00:00:00 2001 From: bzsurbhi Date: Wed, 2 Jul 2025 13:39:25 -0700 Subject: [PATCH 2/2] feat: Add completions context --- mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java | 1 + .../java/io/modelcontextprotocol/server/McpCompletionTests.java | 1 + 2 files changed, 2 insertions(+) diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index 2d15f1f45..bfffee723 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -32,6 +32,7 @@ * * @author Christian Tzolov * @author Luca Chang + * @author Surbhi Bansal */ public final class McpSchema { diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/McpCompletionTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/McpCompletionTests.java index d6b995d6f..828d7511f 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/McpCompletionTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/McpCompletionTests.java @@ -34,6 +34,7 @@ /** * Tests for completion functionality with context support. + * @author Surbhi Bansal */ class McpCompletionTests {