Skip to content

Commit 0db4c0f

Browse files
minguncletzolov
andcommitted
feat(webmvc): Add support for custom context paths in WebMvcSseServerTransportProvider
Adds the ability to specify a base URL for message endpoints in WebMvcSseServerTransportProvider, enabling proper handling of custom servlet context paths in Spring WebMVC applications. This ensures that clients receive the correct full endpoint URL when connecting through SSE. - Add messageBaseUrl field to WebMvcSseServerTransportProvider - Create new constructor that accepts messageBaseUrl parameter - Update endpoint event to include base URL in the message endpoint - Add TomcatTestUtil class to simplify test server creation - Add WebMvcSseCustomContextPathTests to verify custom context path functionality - Refactor WebMvcSseIntegrationTests to use the new TomcatTestUtil Co-authored-by: Christian Tzolov <christian.tzolov@broadcom.com> Signed-off-by: Christian Tzolov <christian.tzolov@broadcom.com>
1 parent 734153a commit 0db4c0f

File tree

5 files changed

+216
-63
lines changed

5 files changed

+216
-63
lines changed

mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi
9191

9292
private final String sseEndpoint;
9393

94+
private final String messageBaseUrl;
95+
9496
private final RouterFunction<ServerResponse> routerFunction;
9597

9698
private McpServerSession.Factory sessionFactory;
@@ -105,6 +107,20 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi
105107
*/
106108
private volatile boolean isClosing = false;
107109

110+
/**
111+
* Constructs a new WebMvcSseServerTransportProvider instance with the default SSE
112+
* endpoint.
113+
* @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
114+
* of messages.
115+
* @param messageEndpoint The endpoint URI where clients should send their JSON-RPC
116+
* messages via HTTP POST. This endpoint will be communicated to clients through the
117+
* SSE connection's initial endpoint event.
118+
* @throws IllegalArgumentException if either objectMapper or messageEndpoint is null
119+
*/
120+
public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) {
121+
this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT);
122+
}
123+
108124
/**
109125
* Constructs a new WebMvcSseServerTransportProvider instance.
110126
* @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
@@ -116,11 +132,30 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi
116132
* @throws IllegalArgumentException if any parameter is null
117133
*/
118134
public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) {
135+
this(objectMapper, "", messageEndpoint, sseEndpoint);
136+
}
137+
138+
/**
139+
* Constructs a new WebMvcSseServerTransportProvider instance.
140+
* @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
141+
* of messages.
142+
* @param messageBaseUrl The base URL for the message endpoint, used to construct the
143+
* full endpoint URL for clients.
144+
* @param messageEndpoint The endpoint URI where clients should send their JSON-RPC
145+
* messages via HTTP POST. This endpoint will be communicated to clients through the
146+
* SSE connection's initial endpoint event.
147+
* @param sseEndpoint The endpoint URI where clients establish their SSE connections.
148+
* @throws IllegalArgumentException if any parameter is null
149+
*/
150+
public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageBaseUrl, String messageEndpoint,
151+
String sseEndpoint) {
119152
Assert.notNull(objectMapper, "ObjectMapper must not be null");
153+
Assert.notNull(messageBaseUrl, "Message base URL must not be null");
120154
Assert.notNull(messageEndpoint, "Message endpoint must not be null");
121155
Assert.notNull(sseEndpoint, "SSE endpoint must not be null");
122156

123157
this.objectMapper = objectMapper;
158+
this.messageBaseUrl = messageBaseUrl;
124159
this.messageEndpoint = messageEndpoint;
125160
this.sseEndpoint = sseEndpoint;
126161
this.routerFunction = RouterFunctions.route()
@@ -129,20 +164,6 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messag
129164
.build();
130165
}
131166

132-
/**
133-
* Constructs a new WebMvcSseServerTransportProvider instance with the default SSE
134-
* endpoint.
135-
* @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
136-
* of messages.
137-
* @param messageEndpoint The endpoint URI where clients should send their JSON-RPC
138-
* messages via HTTP POST. This endpoint will be communicated to clients through the
139-
* SSE connection's initial endpoint event.
140-
* @throws IllegalArgumentException if either objectMapper or messageEndpoint is null
141-
*/
142-
public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) {
143-
this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT);
144-
}
145-
146167
@Override
147168
public void setSessionFactory(McpServerSession.Factory sessionFactory) {
148169
this.sessionFactory = sessionFactory;
@@ -248,7 +269,7 @@ private ServerResponse handleSseConnection(ServerRequest request) {
248269
try {
249270
sseBuilder.id(sessionId)
250271
.event(ENDPOINT_EVENT_TYPE)
251-
.data(messageEndpoint + "?sessionId=" + sessionId);
272+
.data(this.messageBaseUrl + this.messageEndpoint + "?sessionId=" + sessionId);
252273
}
253274
catch (Exception e) {
254275
logger.error("Failed to send initial endpoint event: {}", e.getMessage());
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Copyright 2025 - 2025 the original author or authors.
3+
*/
4+
package io.modelcontextprotocol.server;
5+
6+
import org.apache.catalina.Context;
7+
import org.apache.catalina.startup.Tomcat;
8+
9+
import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
10+
import org.springframework.web.servlet.DispatcherServlet;
11+
12+
/**
13+
* @author Christian Tzolov
14+
*/
15+
public class TomcatTestUtil {
16+
17+
public record TomcatServer(Tomcat tomcat, AnnotationConfigWebApplicationContext appContext) {
18+
}
19+
20+
public TomcatServer createTomcatServer(String contextPath, int port, Class<?> componentClass) {
21+
22+
// Set up Tomcat first
23+
var tomcat = new Tomcat();
24+
tomcat.setPort(port);
25+
26+
// Set Tomcat base directory to java.io.tmpdir to avoid permission issues
27+
String baseDir = System.getProperty("java.io.tmpdir");
28+
tomcat.setBaseDir(baseDir);
29+
30+
// Use the same directory for document base
31+
Context context = tomcat.addContext(contextPath, baseDir);
32+
33+
// Create and configure Spring WebMvc context
34+
var appContext = new AnnotationConfigWebApplicationContext();
35+
appContext.register(componentClass);
36+
appContext.setServletContext(context.getServletContext());
37+
appContext.refresh();
38+
39+
// Create DispatcherServlet with our Spring context
40+
DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext);
41+
42+
// Add servlet to Tomcat and get the wrapper
43+
var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet);
44+
wrapper.setLoadOnStartup(1);
45+
wrapper.setAsyncSupported(true);
46+
context.addServletMappingDecoded("/*", "dispatcherServlet");
47+
48+
try {
49+
// Configure and start the connector with async support
50+
var connector = tomcat.getConnector();
51+
connector.setAsyncTimeout(3000); // 3 seconds timeout for async requests
52+
}
53+
catch (Exception e) {
54+
throw new RuntimeException("Failed to start Tomcat", e);
55+
}
56+
57+
return new TomcatServer(tomcat, appContext);
58+
}
59+
60+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
/*
2+
* Copyright 2024 - 2024 the original author or authors.
3+
*/
4+
package io.modelcontextprotocol.server;
5+
6+
import com.fasterxml.jackson.databind.ObjectMapper;
7+
import io.modelcontextprotocol.client.McpClient;
8+
import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
9+
import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider;
10+
import io.modelcontextprotocol.spec.McpSchema;
11+
import org.apache.catalina.LifecycleException;
12+
import org.apache.catalina.LifecycleState;
13+
import org.junit.jupiter.api.AfterEach;
14+
import org.junit.jupiter.api.BeforeEach;
15+
import org.junit.jupiter.api.Test;
16+
17+
import org.springframework.context.annotation.Bean;
18+
import org.springframework.context.annotation.Configuration;
19+
import org.springframework.web.servlet.config.annotation.EnableWebMvc;
20+
import org.springframework.web.servlet.function.RouterFunction;
21+
import org.springframework.web.servlet.function.ServerResponse;
22+
23+
import static org.assertj.core.api.Assertions.assertThat;
24+
25+
public class WebMvcSseCustomContextPathTests {
26+
27+
private static final String CUSTOM_CONTEXT_PATH = "/app/1";
28+
29+
private static final int PORT = 8183;
30+
31+
private static final String MESSAGE_ENDPOINT = "/mcp/message";
32+
33+
private WebMvcSseServerTransportProvider mcpServerTransportProvider;
34+
35+
McpClient.SyncSpec clientBuilder;
36+
37+
private TomcatTestUtil.TomcatServer tomcatServer;
38+
39+
@BeforeEach
40+
public void before() {
41+
42+
tomcatServer = new TomcatTestUtil().createTomcatServer(CUSTOM_CONTEXT_PATH, PORT, TestConfig.class);
43+
44+
try {
45+
tomcatServer.tomcat().start();
46+
assertThat(tomcatServer.tomcat().getServer().getState() == LifecycleState.STARTED);
47+
}
48+
catch (Exception e) {
49+
throw new RuntimeException("Failed to start Tomcat", e);
50+
}
51+
52+
var clientTransport = HttpClientSseClientTransport.builder("http://localhost:" + PORT)
53+
.sseEndpoint(CUSTOM_CONTEXT_PATH + WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT)
54+
.build();
55+
56+
clientBuilder = McpClient.sync(clientTransport);
57+
58+
mcpServerTransportProvider = tomcatServer.appContext().getBean(WebMvcSseServerTransportProvider.class);
59+
}
60+
61+
@AfterEach
62+
public void after() {
63+
if (mcpServerTransportProvider != null) {
64+
mcpServerTransportProvider.closeGracefully().block();
65+
}
66+
if (tomcatServer.appContext() != null) {
67+
tomcatServer.appContext().close();
68+
}
69+
if (tomcatServer.tomcat() != null) {
70+
try {
71+
tomcatServer.tomcat().stop();
72+
tomcatServer.tomcat().destroy();
73+
}
74+
catch (LifecycleException e) {
75+
throw new RuntimeException("Failed to stop Tomcat", e);
76+
}
77+
}
78+
}
79+
80+
@Test
81+
void testCustomContextPath() {
82+
McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").build();
83+
var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")).build();
84+
assertThat(client.initialize()).isNotNull();
85+
}
86+
87+
@Configuration
88+
@EnableWebMvc
89+
static class TestConfig {
90+
91+
@Bean
92+
public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() {
93+
94+
return new WebMvcSseServerTransportProvider(new ObjectMapper(), CUSTOM_CONTEXT_PATH, MESSAGE_ENDPOINT,
95+
WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT);
96+
}
97+
98+
@Bean
99+
public RouterFunction<ServerResponse> routerFunction(WebMvcSseServerTransportProvider transportProvider) {
100+
return transportProvider.getRouterFunction();
101+
}
102+
103+
}
104+
105+
}

mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java

Lines changed: 14 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,8 @@
2525
import io.modelcontextprotocol.spec.McpSchema.Root;
2626
import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities;
2727
import io.modelcontextprotocol.spec.McpSchema.Tool;
28-
import org.apache.catalina.Context;
2928
import org.apache.catalina.LifecycleException;
3029
import org.apache.catalina.LifecycleState;
31-
import org.apache.catalina.startup.Tomcat;
3230
import org.junit.jupiter.api.AfterEach;
3331
import org.junit.jupiter.api.BeforeEach;
3432
import org.junit.jupiter.api.Test;
@@ -38,15 +36,12 @@
3836
import org.springframework.context.annotation.Bean;
3937
import org.springframework.context.annotation.Configuration;
4038
import org.springframework.web.client.RestClient;
41-
import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
42-
import org.springframework.web.servlet.DispatcherServlet;
4339
import org.springframework.web.servlet.config.annotation.EnableWebMvc;
4440
import org.springframework.web.servlet.function.RouterFunction;
4541
import org.springframework.web.servlet.function.ServerResponse;
4642

4743
import static org.assertj.core.api.Assertions.assertThat;
4844
import static org.awaitility.Awaitility.await;
49-
import static org.junit.Assert.assertThat;
5045
import static org.mockito.Mockito.mock;
5146

5247
public class WebMvcSseIntegrationTests {
@@ -75,69 +70,40 @@ public RouterFunction<ServerResponse> routerFunction(WebMvcSseServerTransportPro
7570

7671
}
7772

78-
private Tomcat tomcat;
79-
80-
private AnnotationConfigWebApplicationContext appContext;
73+
private TomcatTestUtil.TomcatServer tomcatServer;
8174

8275
@BeforeEach
8376
public void before() {
8477

85-
// Set up Tomcat first
86-
tomcat = new Tomcat();
87-
tomcat.setPort(PORT);
88-
89-
// Set Tomcat base directory to java.io.tmpdir to avoid permission issues
90-
String baseDir = System.getProperty("java.io.tmpdir");
91-
tomcat.setBaseDir(baseDir);
92-
93-
// Use the same directory for document base
94-
Context context = tomcat.addContext("", baseDir);
95-
96-
// Create and configure Spring WebMvc context
97-
appContext = new AnnotationConfigWebApplicationContext();
98-
appContext.register(TestConfig.class);
99-
appContext.setServletContext(context.getServletContext());
100-
appContext.refresh();
101-
102-
// Get the transport from Spring context
103-
mcpServerTransportProvider = appContext.getBean(WebMvcSseServerTransportProvider.class);
104-
105-
// Create DispatcherServlet with our Spring context
106-
DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext);
107-
// dispatcherServlet.setThrowExceptionIfNoHandlerFound(true);
108-
109-
// Add servlet to Tomcat and get the wrapper
110-
var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet);
111-
wrapper.setLoadOnStartup(1);
112-
wrapper.setAsyncSupported(true);
113-
context.addServletMappingDecoded("/*", "dispatcherServlet");
78+
tomcatServer = new TomcatTestUtil().createTomcatServer("", PORT, TestConfig.class);
11479

11580
try {
116-
// Configure and start the connector with async support
117-
var connector = tomcat.getConnector();
118-
connector.setAsyncTimeout(3000); // 3 seconds timeout for async requests
119-
tomcat.start();
120-
assertThat(tomcat.getServer().getState() == LifecycleState.STARTED);
81+
tomcatServer.tomcat().start();
82+
assertThat(tomcatServer.tomcat().getServer().getState() == LifecycleState.STARTED);
12183
}
12284
catch (Exception e) {
12385
throw new RuntimeException("Failed to start Tomcat", e);
12486
}
12587

126-
this.clientBuilder = McpClient.sync(new HttpClientSseClientTransport("http://localhost:" + PORT));
88+
clientBuilder = McpClient.sync(new HttpClientSseClientTransport("http://localhost:" + PORT));
89+
90+
// Get the transport from Spring context
91+
mcpServerTransportProvider = tomcatServer.appContext().getBean(WebMvcSseServerTransportProvider.class);
92+
12793
}
12894

12995
@AfterEach
13096
public void after() {
13197
if (mcpServerTransportProvider != null) {
13298
mcpServerTransportProvider.closeGracefully().block();
13399
}
134-
if (appContext != null) {
135-
appContext.close();
100+
if (tomcatServer.appContext() != null) {
101+
tomcatServer.appContext().close();
136102
}
137-
if (tomcat != null) {
103+
if (tomcatServer.tomcat() != null) {
138104
try {
139-
tomcat.stop();
140-
tomcat.destroy();
105+
tomcatServer.tomcat().stop();
106+
tomcatServer.tomcat().destroy();
141107
}
142108
catch (LifecycleException e) {
143109
throw new RuntimeException("Failed to stop Tomcat", e);

mcp-test/pom.xml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
<artifactId>logback-classic</artifactId>
8181
<version>${logback.version}</version>
8282
</dependency>
83+
8384
</dependencies>
8485

8586

0 commit comments

Comments
 (0)