Skip to content

Commit c2e7150

Browse files
committed
feat(mcp): webflux support filter
1 parent 261554b commit c2e7150

File tree

4 files changed

+313
-7
lines changed

4 files changed

+313
-7
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
package io.modelcontextprotocol.server.filter;
2+
3+
import com.fasterxml.jackson.core.JsonProcessingException;
4+
import com.fasterxml.jackson.databind.ObjectMapper;
5+
import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider;
6+
import io.modelcontextprotocol.spec.McpSchema;
7+
import io.modelcontextprotocol.spec.McpServerSession;
8+
import org.slf4j.Logger;
9+
import org.slf4j.LoggerFactory;
10+
import org.springframework.web.reactive.function.server.HandlerFilterFunction;
11+
import org.springframework.web.reactive.function.server.HandlerFunction;
12+
import org.springframework.web.reactive.function.server.ServerRequest;
13+
import org.springframework.web.reactive.function.server.ServerResponse;
14+
import reactor.core.publisher.Mono;
15+
16+
import java.util.Objects;
17+
import java.util.function.Function;
18+
19+
public abstract class CallToolHandlerFilter implements HandlerFilterFunction<ServerResponse, ServerResponse> {
20+
21+
private static final Logger logger = LoggerFactory.getLogger(CallToolHandlerFilter.class);
22+
23+
private final ObjectMapper objectMapper = new ObjectMapper();
24+
25+
private static final String DEFAULT_MESSAGE_PATH = "/mcp/message";
26+
27+
public final static McpSchema.CallToolResult PASS = null;
28+
29+
private Function<String, McpServerSession> sessionFunction;
30+
31+
/**
32+
* Filter incoming requests to handle tool calls.
33+
* Processes {@linkplain io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest JSONRPCRequest} that match
34+
* the configured path and method.
35+
*
36+
* @param request The incoming server request
37+
* @param next The next handler in the chain
38+
* @return The filtered response
39+
*/
40+
@Override
41+
public Mono<ServerResponse> filter(ServerRequest request, HandlerFunction<ServerResponse> next) {
42+
if (!Objects.equals(request.path(), matchPath())) {
43+
return next.handle(request);
44+
}
45+
46+
return request.bodyToMono(McpSchema.JSONRPCRequest.class)
47+
.flatMap(jsonrpcRequest -> handleJsonRpcRequest(request, jsonrpcRequest, next));
48+
}
49+
50+
private Mono<ServerResponse> handleJsonRpcRequest(ServerRequest request, McpSchema.JSONRPCRequest jsonrpcRequest,
51+
HandlerFunction<ServerResponse> next) {
52+
ServerRequest newRequest;
53+
try {
54+
newRequest = ServerRequest.from(request)
55+
.body(objectMapper.writeValueAsString(jsonrpcRequest))
56+
.build();
57+
} catch (JsonProcessingException e) {
58+
return Mono.error(e);
59+
}
60+
61+
if (skipFilter(jsonrpcRequest)) {
62+
return next.handle(newRequest);
63+
}
64+
65+
return handleToolCallRequest(newRequest, jsonrpcRequest, next);
66+
}
67+
68+
private Mono<ServerResponse> handleToolCallRequest(ServerRequest newRequest, McpSchema.JSONRPCRequest jsonrpcRequest,
69+
HandlerFunction<ServerResponse> next) {
70+
McpServerSession session = newRequest.queryParam("sessionId")
71+
.map(sessionId -> sessionFunction.apply(sessionId))
72+
.orElse(null);
73+
74+
if (Objects.isNull(session)) {
75+
return next.handle(newRequest);
76+
}
77+
78+
McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(jsonrpcRequest.params(), McpSchema.CallToolRequest.class);
79+
McpSchema.CallToolResult callToolResult = doFilter(newRequest, callToolRequest);
80+
if (Objects.equals(PASS, callToolResult)) {
81+
return next.handle(newRequest);
82+
} else {
83+
return session.sendResponse(jsonrpcRequest.id(), callToolResult, null).then(ServerResponse.ok().build());
84+
}
85+
}
86+
87+
private boolean skipFilter(McpSchema.JSONRPCRequest jsonrpcRequest) {
88+
if (!Objects.equals(jsonrpcRequest.method(), matchMethod())) {
89+
return true;
90+
}
91+
92+
if (Objects.isNull(sessionFunction)) {
93+
logger.error("No session function provided, skip CallToolRequest filter");
94+
return true;
95+
}
96+
97+
return false;
98+
}
99+
100+
/**
101+
* Abstract method to be implemented by subclasses to handle tool call requests.
102+
*
103+
* @param request The incoming server request. Contains HTTP information such as: request path,
104+
* request headers, request parameters.
105+
* Note that the request body has already been extracted and deserialized into the callToolRequest
106+
* parameter, so there's no need to extract the body from the ServerRequest again.
107+
* @param callToolRequest The deserialized call tool request object
108+
* @return A CallToolResult object if the current filter handles the request (subsequent filters will not be executed),
109+
* or {@linkplain CallToolHandlerFilter#PASS PASS} if the current filter does not handle the request
110+
* (execution will continue to subsequent filters in the chain).
111+
*/
112+
public abstract McpSchema.CallToolResult doFilter(ServerRequest request, McpSchema.CallToolRequest callToolRequest);
113+
114+
/**
115+
* Returns the method name to match for handling tool calls.
116+
*
117+
* @return The method name to match
118+
*/
119+
public String matchMethod() {
120+
return McpSchema.METHOD_TOOLS_CALL;
121+
}
122+
123+
/**
124+
* Returns the path to match for handling tool calls.
125+
*
126+
* @return The path to match
127+
*/
128+
public String matchPath() {
129+
return DEFAULT_MESSAGE_PATH;
130+
}
131+
132+
/**
133+
* Set the session provider function.
134+
*
135+
* @param transportProvider The SSE server transport provider used to obtain sessions
136+
*/
137+
public void applySession(WebFluxSseServerTransportProvider transportProvider) {
138+
this.sessionFunction = transportProvider::getSession;
139+
}
140+
}

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
package io.modelcontextprotocol.server.transport;
22

33
import java.io.IOException;
4-
import java.util.Map;
54
import java.util.concurrent.ConcurrentHashMap;
65

76
import com.fasterxml.jackson.core.type.TypeReference;
87
import com.fasterxml.jackson.databind.ObjectMapper;
8+
import io.modelcontextprotocol.server.filter.CallToolHandlerFilter;
99
import io.modelcontextprotocol.spec.McpError;
1010
import io.modelcontextprotocol.spec.McpSchema;
1111
import io.modelcontextprotocol.spec.McpServerSession;
@@ -14,6 +14,7 @@
1414
import io.modelcontextprotocol.util.Assert;
1515
import org.slf4j.Logger;
1616
import org.slf4j.LoggerFactory;
17+
import org.springframework.web.reactive.function.server.*;
1718
import reactor.core.Exceptions;
1819
import reactor.core.publisher.Flux;
1920
import reactor.core.publisher.FluxSink;
@@ -22,10 +23,6 @@
2223
import org.springframework.http.HttpStatus;
2324
import org.springframework.http.MediaType;
2425
import org.springframework.http.codec.ServerSentEvent;
25-
import org.springframework.web.reactive.function.server.RouterFunction;
26-
import org.springframework.web.reactive.function.server.RouterFunctions;
27-
import org.springframework.web.reactive.function.server.ServerRequest;
28-
import org.springframework.web.reactive.function.server.ServerResponse;
2926

3027
/**
3128
* Server-side implementation of the MCP (Model Context Protocol) HTTP transport using
@@ -84,6 +81,11 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv
8481

8582
public static final String DEFAULT_BASE_URL = "";
8683

84+
/**
85+
* Default filter function for handling requests, do nothing
86+
*/
87+
public static final HandlerFilterFunction<ServerResponse, ServerResponse> DEFAULT_REQUEST_FILTER = ((request, next) -> next.handle(request));
88+
8789
private final ObjectMapper objectMapper;
8890

8991
/**
@@ -149,10 +151,28 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messa
149151
*/
150152
public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
151153
String sseEndpoint) {
154+
this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, DEFAULT_REQUEST_FILTER);
155+
}
156+
157+
/**
158+
* Constructs a new WebFlux SSE server transport provider instance.
159+
* @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
160+
* of MCP messages. Must not be null.
161+
* @param baseUrl webflux message base path
162+
* @param messageEndpoint The endpoint URI where clients should send their JSON-RPC
163+
* messages. This endpoint will be communicated to clients during SSE connection
164+
* setup. Must not be null.
165+
* @param requestFilter The filter function to apply to incoming requests, which may
166+
* be sse or message request.
167+
* @throws IllegalArgumentException if either parameter is null
168+
*/
169+
public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
170+
String sseEndpoint, HandlerFilterFunction<ServerResponse, ServerResponse> requestFilter) {
152171
Assert.notNull(objectMapper, "ObjectMapper must not be null");
153172
Assert.notNull(baseUrl, "Message base path must not be null");
154173
Assert.notNull(messageEndpoint, "Message endpoint must not be null");
155174
Assert.notNull(sseEndpoint, "SSE endpoint must not be null");
175+
Assert.notNull(requestFilter, "Request filter must not be null");
156176

157177
this.objectMapper = objectMapper;
158178
this.baseUrl = baseUrl;
@@ -161,6 +181,7 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseU
161181
this.routerFunction = RouterFunctions.route()
162182
.GET(this.sseEndpoint, this::handleSseConnection)
163183
.POST(this.messageEndpoint, this::handleMessage)
184+
.filter(requestFilter)
164185
.build();
165186
}
166187

@@ -245,6 +266,14 @@ public RouterFunction<?> getRouterFunction() {
245266
return this.routerFunction;
246267
}
247268

269+
/**
270+
* Returns the McpServerSession associated with the given session ID.
271+
* @return session The McpServerSession associated with the given session ID, or null
272+
*/
273+
public McpServerSession getSession(String sessionId) {
274+
return sessions.get(sessionId);
275+
}
276+
248277
/**
249278
* Handles new SSE connection requests from clients. Creates a new session for each
250279
* connection and sets up the SSE event stream.
@@ -397,6 +426,8 @@ public static class Builder {
397426

398427
private String sseEndpoint = DEFAULT_SSE_ENDPOINT;
399428

429+
private HandlerFilterFunction<ServerResponse, ServerResponse> requestFilter = DEFAULT_REQUEST_FILTER;
430+
400431
/**
401432
* Sets the ObjectMapper to use for JSON serialization/deserialization of MCP
402433
* messages.
@@ -447,6 +478,12 @@ public Builder sseEndpoint(String sseEndpoint) {
447478
return this;
448479
}
449480

481+
public Builder requestFilter(CallToolHandlerFilter requestFilter) {
482+
Assert.notNull(requestFilter, "requestFilter must not be null");
483+
this.requestFilter = requestFilter;
484+
return this;
485+
}
486+
450487
/**
451488
* Builds a new instance of {@link WebFluxSseServerTransportProvider} with the
452489
* configured settings.
@@ -457,7 +494,8 @@ public WebFluxSseServerTransportProvider build() {
457494
Assert.notNull(objectMapper, "ObjectMapper must be set");
458495
Assert.notNull(messageEndpoint, "Message endpoint must be set");
459496

460-
return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint);
497+
return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint,
498+
requestFilter);
461499
}
462500

463501
}

0 commit comments

Comments
 (0)