Skip to content

Commit 606d21b

Browse files
ZachGermanZachary German
authored andcommitted
Adding StreamableHttpServerTransportProvider class and unit tests
1 parent 9ebff0c commit 606d21b

File tree

10 files changed

+1532
-3
lines changed

10 files changed

+1532
-3
lines changed

mcp-spring/mcp-spring-webflux/pom.xml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,28 @@
127127
<scope>test</scope>
128128
</dependency>
129129

130+
<!-- Tomcat dependencies for testing -->
131+
<dependency>
132+
<groupId>org.apache.tomcat.embed</groupId>
133+
<artifactId>tomcat-embed-core</artifactId>
134+
<version>${tomcat.version}</version>
135+
<scope>test</scope>
136+
</dependency>
137+
<dependency>
138+
<groupId>org.apache.tomcat.embed</groupId>
139+
<artifactId>tomcat-embed-websocket</artifactId>
140+
<version>${tomcat.version}</version>
141+
<scope>test</scope>
142+
</dependency>
143+
144+
<!-- Used by the StreamableHttpServerTransportProvider -->
145+
<dependency>
146+
<groupId>jakarta.servlet</groupId>
147+
<artifactId>jakarta.servlet-api</artifactId>
148+
<version>${jakarta.servlet.version}</version>
149+
<scope>test</scope>
150+
</dependency>
151+
130152
</dependencies>
131153

132154

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
/*
2+
* Copyright 2024-2024 the original author or authors.
3+
*/
4+
5+
package io.modelcontextprotocol.server.transport;
6+
7+
import java.util.Map;
8+
import java.util.List;
9+
10+
import com.fasterxml.jackson.databind.ObjectMapper;
11+
import io.modelcontextprotocol.client.McpClient;
12+
import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport;
13+
import io.modelcontextprotocol.server.McpServer;
14+
import io.modelcontextprotocol.server.McpServerFeatures;
15+
import io.modelcontextprotocol.spec.McpSchema;
16+
import io.modelcontextprotocol.spec.McpSchema.CallToolResult;
17+
import io.modelcontextprotocol.spec.McpSchema.InitializeResult;
18+
import org.apache.catalina.LifecycleException;
19+
import org.apache.catalina.LifecycleState;
20+
import org.apache.catalina.startup.Tomcat;
21+
import org.junit.jupiter.api.AfterEach;
22+
import org.junit.jupiter.api.BeforeEach;
23+
import org.junit.jupiter.api.Test;
24+
import org.springframework.web.reactive.function.client.WebClient;
25+
26+
import static org.assertj.core.api.Assertions.assertThat;
27+
28+
/**
29+
* Integration tests for {@link StreamableHttpServerTransportProvider} with
30+
* {@link WebClientStreamableHttpTransport}.
31+
*/
32+
class StreamableHttpServerTransportProviderIntegrationTests {
33+
34+
private static final int PORT = TomcatTestUtil.findAvailablePort();
35+
36+
private static final String ENDPOINT = "/mcp";
37+
38+
private StreamableHttpServerTransportProvider serverTransportProvider;
39+
40+
private McpClient.SyncSpec clientBuilder;
41+
42+
private Tomcat tomcat;
43+
44+
@BeforeEach
45+
void setUp() {
46+
serverTransportProvider = new StreamableHttpServerTransportProvider(new ObjectMapper(), ENDPOINT, null);
47+
48+
tomcat = TomcatTestUtil.createTomcatServer("", PORT, serverTransportProvider);
49+
try {
50+
tomcat.start();
51+
assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED);
52+
}
53+
catch (Exception e) {
54+
throw new RuntimeException("Failed to start Tomcat", e);
55+
}
56+
57+
WebClientStreamableHttpTransport clientTransport = WebClientStreamableHttpTransport
58+
.builder(WebClient.builder().baseUrl("http://localhost:" + PORT))
59+
.endpoint(ENDPOINT)
60+
.objectMapper(new ObjectMapper())
61+
.build();
62+
63+
clientBuilder = McpClient.sync(clientTransport)
64+
.clientInfo(new McpSchema.Implementation("Test Client", "1.0.0"));
65+
}
66+
67+
@AfterEach
68+
void tearDown() {
69+
if (serverTransportProvider != null) {
70+
serverTransportProvider.closeGracefully().block();
71+
}
72+
if (tomcat != null) {
73+
try {
74+
tomcat.stop();
75+
tomcat.destroy();
76+
}
77+
catch (LifecycleException e) {
78+
throw new RuntimeException("Failed to stop Tomcat", e);
79+
}
80+
}
81+
}
82+
83+
@Test
84+
void shouldInitializeSuccessfully() {
85+
var mcpServer = McpServer.sync(serverTransportProvider).serverInfo("Test Server", "1.0.0").build();
86+
87+
try (var mcpClient = clientBuilder.build()) {
88+
InitializeResult result = mcpClient.initialize();
89+
assertThat(result).isNotNull();
90+
assertThat(result.serverInfo().name()).isEqualTo("Test Server");
91+
}
92+
93+
mcpServer.close();
94+
}
95+
96+
@Test
97+
void shouldCallToolSuccessfully() {
98+
String emptyJsonSchema = """
99+
{
100+
"$schema": "http://json-schema.org/draft-07/schema#",
101+
"type": "object",
102+
"properties": {}
103+
}
104+
""";
105+
106+
var callResponse = new CallToolResult(List.of(new McpSchema.TextContent("Tool executed successfully")), null);
107+
McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification(
108+
new McpSchema.Tool("test-tool", "Test tool description", emptyJsonSchema),
109+
(exchange, request) -> callResponse);
110+
111+
var mcpServer = McpServer.sync(serverTransportProvider).serverInfo("Test Server", "1.0.0").tools(tool).build();
112+
113+
try (var mcpClient = clientBuilder.build()) {
114+
mcpClient.initialize();
115+
116+
CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of()));
117+
118+
assertThat(response).isNotNull();
119+
assertThat(response).isEqualTo(callResponse);
120+
}
121+
122+
mcpServer.close();
123+
}
124+
125+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/*
2+
* Copyright 2025 - 2025 the original author or authors.
3+
*/
4+
package io.modelcontextprotocol.server.transport;
5+
6+
import java.io.IOException;
7+
import java.net.InetSocketAddress;
8+
import java.net.ServerSocket;
9+
10+
import jakarta.servlet.Servlet;
11+
import org.apache.catalina.Context;
12+
import org.apache.catalina.startup.Tomcat;
13+
14+
/**
15+
* @author Christian Tzolov
16+
*/
17+
public class TomcatTestUtil {
18+
19+
TomcatTestUtil() {
20+
// Prevent instantiation
21+
}
22+
23+
public static Tomcat createTomcatServer(String contextPath, int port, Servlet servlet) {
24+
25+
var tomcat = new Tomcat();
26+
tomcat.setPort(port);
27+
28+
String baseDir = System.getProperty("java.io.tmpdir");
29+
tomcat.setBaseDir(baseDir);
30+
31+
Context context = tomcat.addContext(contextPath, baseDir);
32+
33+
// Add transport servlet to Tomcat
34+
org.apache.catalina.Wrapper wrapper = context.createWrapper();
35+
wrapper.setName("mcpServlet");
36+
wrapper.setServlet(servlet);
37+
wrapper.setLoadOnStartup(1);
38+
wrapper.setAsyncSupported(true);
39+
context.addChild(wrapper);
40+
context.addServletMappingDecoded("/*", "mcpServlet");
41+
42+
var connector = tomcat.getConnector();
43+
connector.setAsyncTimeout(3000);
44+
45+
return tomcat;
46+
}
47+
48+
/**
49+
* Finds an available port on the local machine.
50+
* @return an available port number
51+
* @throws IllegalStateException if no available port can be found
52+
*/
53+
public static int findAvailablePort() {
54+
try (final ServerSocket socket = new ServerSocket()) {
55+
socket.bind(new InetSocketAddress(0));
56+
return socket.getLocalPort();
57+
}
58+
catch (final IOException e) {
59+
throw new IllegalStateException("Cannot bind to an available port!", e);
60+
}
61+
}
62+
63+
}

mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ private Mono<McpSchema.InitializeResult> asyncInitializeRequestHandler(
214214
"Client requested unsupported protocol version: {}, so the server will suggest the {} version instead",
215215
initializeRequest.protocolVersion(), serverProtocolVersion);
216216
}
217-
217+
System.out.println("---------------Server sending initalize response-----------");
218218
return Mono.just(new McpSchema.InitializeResult(serverProtocolVersion, this.serverCapabilities,
219219
this.serverInfo, this.instructions));
220220
});
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package io.modelcontextprotocol.server.transport;
2+
3+
import io.modelcontextprotocol.spec.McpServerSession;
4+
import jakarta.servlet.http.HttpServletRequest;
5+
import jakarta.servlet.http.HttpServletResponse;
6+
7+
/**
8+
* Handler interface for session lifecycle and runtime events in MCP transport providers.
9+
* Allows users to hook into session creation, closing, and error scenarios.
10+
*/
11+
public interface SessionHandler {
12+
13+
/**
14+
* Called when a new session is created.
15+
* @param sessionId The unique session identifier
16+
* @param session The created session instance
17+
*/
18+
default void onSessionCreate(String sessionId, McpServerSession session) {
19+
}
20+
21+
/**
22+
* Called when a session is destroyed or removed.
23+
* @param sessionId The unique session identifier
24+
*/
25+
default void onSessionClose(String sessionId) {
26+
}
27+
28+
/**
29+
* Called when a client attempts to use a session that doesn't exist.
30+
* @param sessionId The requested session identifier
31+
* @param request The HTTP request that referenced the missing session
32+
* @param response The (default) HTTP response that will be sent to the client
33+
*/
34+
default void onSessionNotFound(String sessionId, HttpServletRequest request, HttpServletResponse response) {
35+
}
36+
37+
/**
38+
* Called when an error occurs during the sending of a notification.
39+
* @param sessionId The unique session identifier
40+
* @param error The error that occurred during the notification sending
41+
*/
42+
default void onSendNotificationError(String sessionId, Throwable error) {
43+
}
44+
45+
}

0 commit comments

Comments
 (0)