Skip to content

Commit 77333d1

Browse files
author
Zachary German
committed
Enforce "String | Number" for JSONRPC message IDs
1 parent 590e68c commit 77333d1

File tree

9 files changed

+99
-65
lines changed

9 files changed

+99
-65
lines changed

mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import com.fasterxml.jackson.databind.ObjectMapper;
1313
import io.modelcontextprotocol.spec.McpSchema;
1414
import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest;
15+
import io.modelcontextprotocol.spec.McpSchema.McpId;
16+
1517
import org.junit.jupiter.api.AfterEach;
1618
import org.junit.jupiter.api.BeforeEach;
1719
import org.junit.jupiter.api.Test;
@@ -161,7 +163,7 @@ void testBuilderPattern() {
161163
@Test
162164
void testMessageProcessing() {
163165
// Create a test message
164-
JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id",
166+
JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", McpId.of("test-id"),
165167
Map.of("key", "value"));
166168

167169
// Simulate receiving the message
@@ -192,7 +194,7 @@ void testResponseMessageProcessing() {
192194
""");
193195

194196
// Create and send a request message
195-
JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id",
197+
JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", McpId.of("test-id"),
196198
Map.of("key", "value"));
197199

198200
// Verify message handling
@@ -216,7 +218,7 @@ void testErrorMessageProcessing() {
216218
""");
217219

218220
// Create and send a request message
219-
JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id",
221+
JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", McpId.of("test-id"),
220222
Map.of("key", "value"));
221223

222224
// Verify message handling
@@ -246,7 +248,7 @@ void testGracefulShutdown() {
246248
StepVerifier.create(transport.closeGracefully()).verifyComplete();
247249

248250
// Create a test message
249-
JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id",
251+
JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", McpId.of("test-id"),
250252
Map.of("key", "value"));
251253

252254
// Verify message is not processed after shutdown
@@ -292,10 +294,10 @@ void testMultipleMessageProcessing() {
292294
""");
293295

294296
// Create and send corresponding messages
295-
JSONRPCRequest message1 = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method1", "id1",
297+
JSONRPCRequest message1 = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method1", McpId.of("id1"),
296298
Map.of("key", "value1"));
297299

298-
JSONRPCRequest message2 = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method2", "id2",
300+
JSONRPCRequest message2 = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method2", McpId.of("id2"),
299301
Map.of("key", "value2"));
300302

301303
// Verify both messages are processed

mcp/src/main/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProvider.java

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import java.util.HashMap;
1212
import java.util.List;
1313
import java.util.Map;
14-
import java.util.Objects;
1514
import java.util.UUID;
1615
import java.util.concurrent.ConcurrentHashMap;
1716
import java.util.concurrent.atomic.AtomicBoolean;
@@ -49,6 +48,8 @@
4948
import reactor.core.publisher.Sinks;
5049
import reactor.util.context.Context;
5150

51+
import static java.util.Objects.requireNonNullElse;
52+
5253
/**
5354
* MCP Streamable HTTP transport provider that uses a single session class to manage all
5455
* streams and transports.
@@ -147,9 +148,9 @@ public class StreamableHttpServerTransportProvider extends HttpServlet implement
147148
*/
148149
public StreamableHttpServerTransportProvider(ObjectMapper objectMapper, String mcpEndpoint,
149150
Supplier<String> sessionIdProvider) {
150-
this.objectMapper = Objects.requireNonNullElse(objectMapper, DEFAULT_OBJECT_MAPPER);
151-
this.mcpEndpoint = Objects.requireNonNullElse(mcpEndpoint, DEFAULT_MCP_ENDPOINT);
152-
this.sessionIdProvider = Objects.requireNonNullElse(sessionIdProvider, DEFAULT_SESSION_ID_PROVIDER);
151+
this.objectMapper = requireNonNullElse(objectMapper, DEFAULT_OBJECT_MAPPER);
152+
this.mcpEndpoint = requireNonNullElse(mcpEndpoint, DEFAULT_MCP_ENDPOINT);
153+
this.sessionIdProvider = requireNonNullElse(sessionIdProvider, DEFAULT_SESSION_ID_PROVIDER);
153154
}
154155

155156
/**
@@ -276,7 +277,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
276277
SseTransport sseTransport = new SseTransport(objectMapper, response, asyncContext, lastEventId,
277278
request.getRequestId(), sessionId);
278279
session.registerTransport(request.getRequestId(), sseTransport);
279-
logger.debug("Registered SSE transport {} for session {}", session.LISTENING_TRANSPORT, sessionId);
280+
logger.debug("Registered SSE transport {} for session {}", request.getRequestId(), sessionId);
280281
}
281282
}
282283

@@ -364,21 +365,11 @@ public void onAllDataRead() throws IOException {
364365
// Determine response type and create appropriate transport if needed
365366
ResponseType responseType = detectResponseType(message, session);
366367
final String transportId;
367-
final Object id;
368368
if (message instanceof JSONRPCRequest req) {
369-
id = req.id();
369+
transportId = req.id().toString();
370370
}
371371
else if (message instanceof JSONRPCResponse resp) {
372-
id = resp.id();
373-
}
374-
else {
375-
id = null;
376-
}
377-
if (id instanceof String) {
378-
transportId = (String) id;
379-
}
380-
else if (id instanceof Integer) {
381-
transportId = id.toString();
372+
transportId = resp.id().toString();
382373
}
383374
else {
384375
transportId = null;
@@ -761,6 +752,8 @@ private void replayEventsAfter(String lastEventId) {
761752
eventSink.tryEmitNext(event);
762753
}
763754
}
755+
logger.debug("Completing SSE stream after replaying events");
756+
eventSink.tryEmitComplete();
764757
}
765758
catch (NumberFormatException e) {
766759
logger.warn("Invalid last event ID: {}", lastEventId);

mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
package io.modelcontextprotocol.spec;
66

77
import com.fasterxml.jackson.core.type.TypeReference;
8+
9+
import io.modelcontextprotocol.spec.McpSchema.McpId;
810
import io.modelcontextprotocol.util.Assert;
911
import org.reactivestreams.Publisher;
1012
import org.slf4j.Logger;
@@ -47,7 +49,7 @@ public class McpClientSession implements McpSession {
4749
private final McpClientTransport transport;
4850

4951
/** Map of pending responses keyed by request ID */
50-
private final ConcurrentHashMap<Object, MonoSink<McpSchema.JSONRPCResponse>> pendingResponses = new ConcurrentHashMap<>();
52+
private final ConcurrentHashMap<McpId, MonoSink<McpSchema.JSONRPCResponse>> pendingResponses = new ConcurrentHashMap<>();
5153

5254
/** Map of request handlers keyed by method name */
5355
private final ConcurrentHashMap<String, RequestHandler<?>> requestHandlers = new ConcurrentHashMap<>();
@@ -231,10 +233,10 @@ private Mono<Void> handleIncomingNotification(McpSchema.JSONRPCNotification noti
231233
/**
232234
* Generates a unique request ID in a non-blocking way. Combines a session-specific
233235
* prefix with an atomic counter to ensure uniqueness.
234-
* @return A unique request ID string
236+
* @return A unique request ID from String
235237
*/
236-
private String generateRequestId() {
237-
return this.sessionPrefix + "-" + this.requestCounter.getAndIncrement();
238+
private McpId generateRequestId() {
239+
return McpId.of(this.sessionPrefix + "-" + this.requestCounter.getAndIncrement());
238240
}
239241

240242
/**
@@ -247,7 +249,7 @@ private String generateRequestId() {
247249
*/
248250
@Override
249251
public <T> Mono<T> sendRequest(String method, Object requestParams, TypeReference<T> typeRef) {
250-
String requestId = this.generateRequestId();
252+
McpId requestId = this.generateRequestId();
251253

252254
return Mono.deferContextual(ctx -> Mono.<McpSchema.JSONRPCResponse>create(pendingResponseSink -> {
253255
logger.debug("Sending message for method {}", method);

mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import com.fasterxml.jackson.core.type.TypeReference;
1414
import io.modelcontextprotocol.server.McpAsyncServerExchange;
1515
import io.modelcontextprotocol.spec.SseEvent;
16+
import io.modelcontextprotocol.spec.McpSchema.McpId;
17+
1618
import org.slf4j.Logger;
1719
import org.slf4j.LoggerFactory;
1820
import reactor.core.publisher.Flux;
@@ -28,7 +30,7 @@ public class McpServerSession implements McpSession {
2830

2931
private static final Logger logger = LoggerFactory.getLogger(McpServerSession.class);
3032

31-
private final ConcurrentHashMap<Object, MonoSink<McpSchema.JSONRPCResponse>> pendingResponses = new ConcurrentHashMap<>();
33+
private final ConcurrentHashMap<McpId, MonoSink<McpSchema.JSONRPCResponse>> pendingResponses = new ConcurrentHashMap<>();
3234

3335
private final ConcurrentHashMap<String, McpServerTransport> transports = new ConcurrentHashMap<>();
3436

@@ -118,20 +120,42 @@ public String getId() {
118120
return this.id;
119121
}
120122

123+
/**
124+
* Increments the session-specific event counter, maps it to the given transport ID
125+
* for replayability support, then returns the event ID
126+
* @param transportId
127+
* @return an event ID unique to the session
128+
*/
121129
public String incrementAndGetEventId(String transportId) {
122130
final String eventId = String.valueOf(eventCounter.incrementAndGet());
123131
eventTransports.put(eventId, transportId);
124132
return eventId;
125133
}
126134

135+
/**
136+
* Used for replayability support to get the transport ID of a given event ID
137+
* @param eventId
138+
* @return The ID of the transport instance that the given event ID was sent over
139+
*/
127140
public String getTransportIdForEvent(String eventId) {
128141
return eventTransports.get(eventId);
129142
}
130143

144+
/**
145+
* Used for replayability support to set the event history of a given transport ID
146+
* @param transportId
147+
* @param eventHistory
148+
*/
131149
public void setTransportEventHistory(String transportId, Map<String, SseEvent> eventHistory) {
132150
transportEventHistories.put(transportId, eventHistory);
133151
}
134152

153+
/**
154+
* Used for replayability support to retrieve the entire event history for a given
155+
* transport ID
156+
* @param transportId
157+
* @return Map of SseEvent objects, keyed by event ID
158+
*/
135159
public Map<String, SseEvent> getTransportEventHistory(String transportId) {
136160
return transportEventHistories.get(transportId);
137161
}
@@ -203,8 +227,8 @@ public McpSchema.Implementation getClientInfo() {
203227
return this.clientInfo.get();
204228
}
205229

206-
private String generateRequestId() {
207-
return this.id + "-" + this.requestCounter.getAndIncrement();
230+
private McpId generateRequestId() {
231+
return McpId.of(this.id + "-" + this.requestCounter.getAndIncrement());
208232
}
209233

210234
/**
@@ -216,7 +240,7 @@ public RequestHandler<?> getRequestHandler(String method) {
216240

217241
@Override
218242
public <T> Mono<T> sendRequest(String method, Object requestParams, TypeReference<T> typeRef) {
219-
String requestId = this.generateRequestId();
243+
McpId requestId = this.generateRequestId();
220244

221245
return Mono.<McpSchema.JSONRPCResponse>create(sink -> {
222246
this.pendingResponses.put(requestId, sink);
@@ -276,7 +300,6 @@ public Mono<Void> handle(McpSchema.JSONRPCMessage message) {
276300
}
277301
else if (message instanceof McpSchema.JSONRPCRequest request) {
278302
logger.debug("Received request: {}", request);
279-
final McpServerTransport finalListeningTransport = listeningTransport;
280303
final String transportId;
281304
if (transports.isEmpty()) {
282305
transportId = LISTENING_TRANSPORT;
@@ -327,7 +350,7 @@ private Mono<McpSchema.JSONRPCResponse> handleIncomingRequest(McpSchema.JSONRPCR
327350
// TODO handle situation where already initialized!
328351
McpSchema.InitializeRequest initializeRequest = transports.isEmpty() ? listeningTransport
329352
.unmarshalFrom(request.params(), new TypeReference<McpSchema.InitializeRequest>() {
330-
}) : transports.get(request.id())
353+
}) : transports.get(String.valueOf(request.id()))
331354
.unmarshalFrom(request.params(), new TypeReference<McpSchema.InitializeRequest>() {
332355
});
333356

@@ -346,6 +369,9 @@ private Mono<McpSchema.JSONRPCResponse> handleIncomingRequest(McpSchema.JSONRPCR
346369
error.message(), error.data())));
347370
}
348371

372+
// We would need to add request.id() as a parameter to handler.handle() if
373+
// we want client-request-driven requests/notifications to go to the
374+
// related stream
349375
resultMono = this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, request.params()));
350376
}
351377
return resultMono

0 commit comments

Comments
 (0)