Skip to content

Commit 25ff3dc

Browse files
author
Zachary German
committed
Enforce "String | Number" for JSONRPC message IDs
1 parent 09d47b2 commit 25ff3dc

File tree

10 files changed

+213
-77
lines changed

10 files changed

+213
-77
lines changed

mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import com.fasterxml.jackson.databind.ObjectMapper;
1313
import io.modelcontextprotocol.spec.McpSchema;
1414
import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest;
15+
import io.modelcontextprotocol.spec.McpSchema.McpId;
16+
1517
import org.junit.jupiter.api.AfterEach;
1618
import org.junit.jupiter.api.BeforeEach;
1719
import org.junit.jupiter.api.Test;
@@ -161,7 +163,7 @@ void testBuilderPattern() {
161163
@Test
162164
void testMessageProcessing() {
163165
// Create a test message
164-
JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id",
166+
JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", McpId.of("test-id"),
165167
Map.of("key", "value"));
166168

167169
// Simulate receiving the message
@@ -192,7 +194,7 @@ void testResponseMessageProcessing() {
192194
""");
193195

194196
// Create and send a request message
195-
JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id",
197+
JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", McpId.of("test-id"),
196198
Map.of("key", "value"));
197199

198200
// Verify message handling
@@ -216,7 +218,7 @@ void testErrorMessageProcessing() {
216218
""");
217219

218220
// Create and send a request message
219-
JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id",
221+
JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", McpId.of("test-id"),
220222
Map.of("key", "value"));
221223

222224
// Verify message handling
@@ -246,7 +248,7 @@ void testGracefulShutdown() {
246248
StepVerifier.create(transport.closeGracefully()).verifyComplete();
247249

248250
// Create a test message
249-
JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id",
251+
JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", McpId.of("test-id"),
250252
Map.of("key", "value"));
251253

252254
// Verify message is not processed after shutdown
@@ -292,10 +294,10 @@ void testMultipleMessageProcessing() {
292294
""");
293295

294296
// Create and send corresponding messages
295-
JSONRPCRequest message1 = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method1", "id1",
297+
JSONRPCRequest message1 = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method1", McpId.of("id1"),
296298
Map.of("key", "value1"));
297299

298-
JSONRPCRequest message2 = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method2", "id2",
300+
JSONRPCRequest message2 = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method2", McpId.of("id2"),
299301
Map.of("key", "value2"));
300302

301303
// Verify both messages are processed

mcp/src/main/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProvider.java

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import java.util.HashMap;
1212
import java.util.List;
1313
import java.util.Map;
14-
import java.util.Objects;
1514
import java.util.UUID;
1615
import java.util.concurrent.ConcurrentHashMap;
1716
import java.util.concurrent.atomic.AtomicBoolean;
@@ -49,6 +48,8 @@
4948
import reactor.core.publisher.Sinks;
5049
import reactor.util.context.Context;
5150

51+
import static java.util.Objects.requireNonNullElse;
52+
5253
/**
5354
* MCP Streamable HTTP transport provider that uses a single session class to manage all
5455
* streams and transports.
@@ -147,9 +148,9 @@ public class StreamableHttpServerTransportProvider extends HttpServlet implement
147148
*/
148149
public StreamableHttpServerTransportProvider(ObjectMapper objectMapper, String mcpEndpoint,
149150
Supplier<String> sessionIdProvider) {
150-
this.objectMapper = Objects.requireNonNullElse(objectMapper, DEFAULT_OBJECT_MAPPER);
151-
this.mcpEndpoint = Objects.requireNonNullElse(mcpEndpoint, DEFAULT_MCP_ENDPOINT);
152-
this.sessionIdProvider = Objects.requireNonNullElse(sessionIdProvider, DEFAULT_SESSION_ID_PROVIDER);
151+
this.objectMapper = requireNonNullElse(objectMapper, DEFAULT_OBJECT_MAPPER);
152+
this.mcpEndpoint = requireNonNullElse(mcpEndpoint, DEFAULT_MCP_ENDPOINT);
153+
this.sessionIdProvider = requireNonNullElse(sessionIdProvider, DEFAULT_SESSION_ID_PROVIDER);
153154
}
154155

155156
/**
@@ -276,7 +277,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
276277
SseTransport sseTransport = new SseTransport(objectMapper, response, asyncContext, lastEventId,
277278
request.getRequestId(), sessionId);
278279
session.registerTransport(request.getRequestId(), sseTransport);
279-
logger.debug("Registered SSE transport {} for session {}", session.LISTENING_TRANSPORT, sessionId);
280+
logger.debug("Registered SSE transport {} for session {}", request.getRequestId(), sessionId);
280281
}
281282
}
282283

@@ -364,21 +365,11 @@ public void onAllDataRead() throws IOException {
364365
// Determine response type and create appropriate transport if needed
365366
ResponseType responseType = detectResponseType(message, session);
366367
final String transportId;
367-
final Object id;
368368
if (message instanceof JSONRPCRequest req) {
369-
id = req.id();
369+
transportId = req.id().toString();
370370
}
371371
else if (message instanceof JSONRPCResponse resp) {
372-
id = resp.id();
373-
}
374-
else {
375-
id = null;
376-
}
377-
if (id instanceof String) {
378-
transportId = (String) id;
379-
}
380-
else if (id instanceof Integer) {
381-
transportId = id.toString();
372+
transportId = resp.id().toString();
382373
}
383374
else {
384375
transportId = null;
@@ -761,6 +752,8 @@ private void replayEventsAfter(String lastEventId) {
761752
eventSink.tryEmitNext(event);
762753
}
763754
}
755+
logger.debug("Completing SSE stream after replaying events");
756+
eventSink.tryEmitComplete();
764757
}
765758
catch (NumberFormatException e) {
766759
logger.warn("Invalid last event ID: {}", lastEventId);

mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
package io.modelcontextprotocol.spec;
66

77
import com.fasterxml.jackson.core.type.TypeReference;
8+
9+
import io.modelcontextprotocol.spec.McpSchema.McpId;
810
import io.modelcontextprotocol.util.Assert;
911
import org.reactivestreams.Publisher;
1012
import org.slf4j.Logger;
@@ -47,7 +49,7 @@ public class McpClientSession implements McpSession {
4749
private final McpClientTransport transport;
4850

4951
/** Map of pending responses keyed by request ID */
50-
private final ConcurrentHashMap<Object, MonoSink<McpSchema.JSONRPCResponse>> pendingResponses = new ConcurrentHashMap<>();
52+
private final ConcurrentHashMap<McpId, MonoSink<McpSchema.JSONRPCResponse>> pendingResponses = new ConcurrentHashMap<>();
5153

5254
/** Map of request handlers keyed by method name */
5355
private final ConcurrentHashMap<String, RequestHandler<?>> requestHandlers = new ConcurrentHashMap<>();
@@ -231,10 +233,10 @@ private Mono<Void> handleIncomingNotification(McpSchema.JSONRPCNotification noti
231233
/**
232234
* Generates a unique request ID in a non-blocking way. Combines a session-specific
233235
* prefix with an atomic counter to ensure uniqueness.
234-
* @return A unique request ID string
236+
* @return A unique request ID from String
235237
*/
236-
private String generateRequestId() {
237-
return this.sessionPrefix + "-" + this.requestCounter.getAndIncrement();
238+
private McpId generateRequestId() {
239+
return McpId.of(this.sessionPrefix + "-" + this.requestCounter.getAndIncrement());
238240
}
239241

240242
/**
@@ -247,7 +249,7 @@ private String generateRequestId() {
247249
*/
248250
@Override
249251
public <T> Mono<T> sendRequest(String method, Object requestParams, TypeReference<T> typeRef) {
250-
String requestId = this.generateRequestId();
252+
McpId requestId = this.generateRequestId();
251253

252254
return Mono.deferContextual(ctx -> Mono.<McpSchema.JSONRPCResponse>create(pendingResponseSink -> {
253255
logger.debug("Sending message for method {}", method);

mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java

Lines changed: 113 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,25 @@
1818
import com.fasterxml.jackson.annotation.JsonSubTypes;
1919
import com.fasterxml.jackson.annotation.JsonTypeInfo;
2020
import com.fasterxml.jackson.annotation.JsonTypeInfo.As;
21+
import com.fasterxml.jackson.core.JsonGenerator;
22+
import com.fasterxml.jackson.core.JsonParser;
23+
import com.fasterxml.jackson.core.JsonToken;
2124
import com.fasterxml.jackson.core.type.TypeReference;
25+
import com.fasterxml.jackson.databind.DeserializationContext;
26+
import com.fasterxml.jackson.databind.JsonDeserializer;
27+
import com.fasterxml.jackson.databind.JsonMappingException;
28+
import com.fasterxml.jackson.databind.JsonSerializer;
2229
import com.fasterxml.jackson.databind.ObjectMapper;
30+
import com.fasterxml.jackson.databind.SerializerProvider;
31+
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
32+
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
33+
2334
import io.modelcontextprotocol.util.Assert;
2435
import org.slf4j.Logger;
2536
import org.slf4j.LoggerFactory;
2637

38+
import static java.util.Objects.requireNonNull;
39+
2740
/**
2841
* Based on the <a href="http://www.jsonrpc.org/specification">JSON-RPC 2.0
2942
* specification</a> and the <a href=
@@ -32,6 +45,7 @@
3245
*
3346
* @author Christian Tzolov
3447
* @author Luca Chang
48+
* @author Zachary German
3549
*/
3650
public final class McpSchema {
3751

@@ -142,6 +156,103 @@ public static final class ErrorCodes {
142156

143157
}
144158

159+
/**
160+
* MCP ID wrapper: MUST be non-null String or Number.
161+
*/
162+
@JsonSerialize(using = McpId.Serializer.class)
163+
@JsonDeserialize(using = McpId.Deserializer.class)
164+
public static final class McpId {
165+
166+
private final Object value;
167+
168+
public McpId(String value) {
169+
this.value = requireNonNull(value, "'id' must not be null");
170+
}
171+
172+
public McpId(Number value) {
173+
this.value = requireNonNull(value, "'id' must not be null");
174+
}
175+
176+
public static McpId of(Object raw) {
177+
if (raw instanceof String s)
178+
return new McpId(s);
179+
if (raw instanceof Number n)
180+
return new McpId(n);
181+
throw new IllegalArgumentException("MCP 'id' must be String or Number");
182+
}
183+
184+
public boolean isString() {
185+
return value instanceof String;
186+
}
187+
188+
public boolean isNumber() {
189+
return value instanceof Number;
190+
}
191+
192+
public String asString() {
193+
return (String) value;
194+
}
195+
196+
public Number asNumber() {
197+
return (Number) value;
198+
}
199+
200+
public Object raw() {
201+
return value;
202+
}
203+
204+
@Override
205+
public String toString() {
206+
return String.valueOf(value);
207+
}
208+
209+
@Override
210+
public boolean equals(Object o) {
211+
if (this == o)
212+
return true;
213+
if (o == null || getClass() != o.getClass())
214+
return false;
215+
McpId mcpId = (McpId) o;
216+
return value.equals(mcpId.value);
217+
}
218+
219+
@Override
220+
public int hashCode() {
221+
return value.hashCode();
222+
}
223+
224+
public static class Deserializer extends JsonDeserializer<McpId> {
225+
226+
@Override
227+
public McpId deserialize(JsonParser p, DeserializationContext ctxt) throws IOException {
228+
JsonToken t = p.getCurrentToken();
229+
if (t == JsonToken.VALUE_STRING) {
230+
return new McpId(p.getText());
231+
}
232+
else if (t.isNumeric()) {
233+
return new McpId(p.getNumberValue());
234+
}
235+
throw JsonMappingException.from(p, "MCP 'id' must be a non-null String or Number");
236+
}
237+
238+
}
239+
240+
public static class Serializer extends JsonSerializer<McpId> {
241+
242+
@Override
243+
public void serialize(McpId id, JsonGenerator gen, SerializerProvider serializers) throws IOException {
244+
if (id.isString()) {
245+
gen.writeString(id.asString());
246+
}
247+
else {
248+
gen.writeNumber(id.asNumber().toString());
249+
}
250+
}
251+
252+
}
253+
254+
}
255+
145256
public sealed interface Request permits InitializeRequest, CallToolRequest, CreateMessageRequest, ElicitRequest,
146257
CompleteRequest, GetPromptRequest, PaginatedRequest, ReadResourceRequest {
147258

@@ -205,7 +316,7 @@ public sealed interface JSONRPCMessage permits JSONRPCRequest, JSONRPCNotificati
205316
public record JSONRPCRequest( // @formatter:off
206317
@JsonProperty("jsonrpc") String jsonrpc,
207318
@JsonProperty("method") String method,
208-
@JsonProperty("id") Object id,
319+
@JsonProperty("id") McpId id,
209320
@JsonProperty("params") Object params) implements JSONRPCMessage {
210321
} // @formatter:on
211322

@@ -223,7 +334,7 @@ public record JSONRPCNotification( // @formatter:off
223334
// @JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY)
224335
public record JSONRPCResponse( // @formatter:off
225336
@JsonProperty("jsonrpc") String jsonrpc,
226-
@JsonProperty("id") Object id,
337+
@JsonProperty("id") McpId id,
227338
@JsonProperty("result") Object result,
228339
@JsonProperty("error") JSONRPCError error) implements JSONRPCMessage {
229340

0 commit comments

Comments
 (0)