17
17
import java .util .concurrent .atomic .AtomicBoolean ;
18
18
import java .util .concurrent .atomic .AtomicLong ;
19
19
import java .util .function .Supplier ;
20
+ import java .util .stream .Collectors ;
20
21
21
22
import com .fasterxml .jackson .core .type .TypeReference ;
22
23
import com .fasterxml .jackson .databind .ObjectMapper ;
30
31
import io .modelcontextprotocol .spec .McpServerSession ;
31
32
import io .modelcontextprotocol .spec .McpServerTransport ;
32
33
import io .modelcontextprotocol .spec .McpServerTransportProvider ;
34
+ import io .modelcontextprotocol .spec .SseEvent ;
33
35
import io .modelcontextprotocol .util .Assert ;
34
36
import jakarta .servlet .AsyncContext ;
35
37
import jakarta .servlet .ReadListener ;
@@ -89,6 +91,8 @@ public class StreamableHttpServerTransportProvider extends HttpServlet implement
89
91
90
92
public static final String ALLOW_ORIGIN_DEFAULT_VALUE = "*" ;
91
93
94
+ public static final String PROTOCOL_VERSION_HEADER = "MCP-Protocol-Version" ;
95
+
92
96
public static final String CACHE_CONTROL_HEADER = "Cache-Control" ;
93
97
94
98
public static final String CONNECTION_HEADER = "Connection" ;
@@ -117,7 +121,7 @@ public class StreamableHttpServerTransportProvider extends HttpServlet implement
117
121
private final Supplier <String > sessionIdProvider ;
118
122
119
123
/** Sessions map, keyed by Session ID */
120
- private final Map <String , McpServerSession > sessions = new ConcurrentHashMap <>();
124
+ private static final Map <String , McpServerSession > sessions = new ConcurrentHashMap <>();
121
125
122
126
/** Flag indicating if the transport is in the process of shutting down */
123
127
private final AtomicBoolean isClosing = new AtomicBoolean (false );
@@ -128,6 +132,7 @@ public class StreamableHttpServerTransportProvider extends HttpServlet implement
128
132
/** Callback interface for session lifecycle and errors */
129
133
private SessionHandler sessionHandler ;
130
134
135
+ /** Factory for McpServerSession takes session IDs */
131
136
private McpServerSession .StreamableHttpSessionFactory streamableHttpSessionFactory ;
132
137
133
138
/**
@@ -242,6 +247,13 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
242
247
return ;
243
248
}
244
249
250
+ // Delayed until version negotiation is implemented.
251
+ /*
252
+ * if (session.getState().equals(session.STATE_INITIALIZED) &&
253
+ * request.getHeader(PROTOCOL_VERSION_HEADER) == null) {
254
+ * sendErrorResponse(response, "Protocol version missing in request header"); }
255
+ */
256
+
245
257
// Set up SSE connection
246
258
response .setContentType (TEXT_EVENT_STREAM );
247
259
response .setCharacterEncoding (UTF_8 );
@@ -254,10 +266,18 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
254
266
255
267
String lastEventId = request .getHeader (LAST_EVENT_ID_HEADER );
256
268
257
- SseTransport sseTransport = new SseTransport (objectMapper , response , asyncContext , lastEventId );
258
- session .registerTransport (session .LISTENING_TRANSPORT , sseTransport );
259
-
260
- logger .debug ("Registered SSE transport {} for session {}" , session .LISTENING_TRANSPORT , sessionId );
269
+ if (lastEventId == null ) { // Just opening a listening stream
270
+ SseTransport sseTransport = new SseTransport (objectMapper , response , asyncContext , lastEventId ,
271
+ session .LISTENING_TRANSPORT , sessionId );
272
+ session .registerTransport (session .LISTENING_TRANSPORT , sseTransport );
273
+ logger .debug ("Registered SSE transport {} for session {}" , session .LISTENING_TRANSPORT , sessionId );
274
+ }
275
+ else { // Asking for a stream to replay events from a previous request
276
+ SseTransport sseTransport = new SseTransport (objectMapper , response , asyncContext , lastEventId ,
277
+ request .getRequestId (), sessionId );
278
+ session .registerTransport (request .getRequestId (), sseTransport );
279
+ logger .debug ("Registered SSE transport {} for session {}" , session .LISTENING_TRANSPORT , sessionId );
280
+ }
261
281
}
262
282
263
283
@ Override
@@ -328,6 +348,15 @@ public void onAllDataRead() throws IOException {
328
348
asyncContext .complete ();
329
349
return ;
330
350
}
351
+
352
+ // Delayed until version negotiation is implemented.
353
+ /*
354
+ * if (session.getState().equals(session.STATE_INITIALIZED) &&
355
+ * request.getHeader(PROTOCOL_VERSION_HEADER) == null) {
356
+ * sendErrorResponse(response,
357
+ * "Protocol version missing in request header"); }
358
+ */
359
+
331
360
logger .debug ("Using session: {}" , sessionId );
332
361
333
362
response .setHeader (SESSION_ID_HEADER , sessionId );
@@ -362,7 +391,8 @@ else if (id instanceof Integer) {
362
391
response .setHeader (CACHE_CONTROL_HEADER , CACHE_CONTROL_NO_CACHE );
363
392
response .setHeader (CONNECTION_HEADER , CONNECTION_KEEP_ALIVE );
364
393
365
- SseTransport sseTransport = new SseTransport (objectMapper , response , asyncContext , null );
394
+ SseTransport sseTransport = new SseTransport (objectMapper , response , asyncContext , null ,
395
+ transportId , sessionId );
366
396
session .registerTransport (transportId , sseTransport );
367
397
}
368
398
else {
@@ -650,13 +680,17 @@ private static class SseTransport implements McpServerTransport {
650
680
651
681
private final Map <String , SseEvent > eventHistory = new ConcurrentHashMap <>();
652
682
653
- private final AtomicLong eventCounter = new AtomicLong (0 );
683
+ private final String id ;
684
+
685
+ private final String sessionId ;
654
686
655
687
public SseTransport (ObjectMapper objectMapper , HttpServletResponse response , AsyncContext asyncContext ,
656
- String lastEventId ) {
688
+ String lastEventId , String transportId , String sessionId ) {
657
689
this .objectMapper = objectMapper ;
658
690
this .response = response ;
659
691
this .asyncContext = asyncContext ;
692
+ this .id = transportId ;
693
+ this .sessionId = sessionId ;
660
694
661
695
setupSseStream (lastEventId );
662
696
}
@@ -710,9 +744,19 @@ private void setupSseStream(String lastEventId) {
710
744
711
745
private void replayEventsAfter (String lastEventId ) {
712
746
try {
713
- long lastId = Long .parseLong (lastEventId );
714
- for (long i = lastId + 1 ; i <= eventCounter .get (); i ++) {
715
- SseEvent event = eventHistory .get (String .valueOf (i ));
747
+ McpServerSession session = sessions .get (sessionId );
748
+ String transportIdOfLastEventId = session .getTransportIdForEvent (lastEventId );
749
+ Map <String , SseEvent > transportEventHistory = session
750
+ .getTransportEventHistory (transportIdOfLastEventId );
751
+ List <String > eventIds = transportEventHistory .keySet ()
752
+ .stream ()
753
+ .map (Long ::parseLong )
754
+ .filter (key -> key > Long .parseLong (lastEventId ))
755
+ .sorted ()
756
+ .map (String ::valueOf )
757
+ .collect (Collectors .toList ());
758
+ for (String eventId : eventIds ) {
759
+ SseEvent event = transportEventHistory .get (eventId );
716
760
if (event != null ) {
717
761
eventSink .tryEmitNext (event );
718
762
}
@@ -727,7 +771,7 @@ private void replayEventsAfter(String lastEventId) {
727
771
public Mono <Void > sendMessage (JSONRPCMessage message ) {
728
772
try {
729
773
String jsonText = objectMapper .writeValueAsString (message );
730
- String eventId = String . valueOf ( eventCounter . incrementAndGet () );
774
+ String eventId = sessions . get ( sessionId ). incrementAndGetEventId ( id );
731
775
SseEvent event = new SseEvent (eventId , MESSAGE_EVENT_TYPE , jsonText );
732
776
733
777
eventHistory .put (eventId , event );
@@ -737,6 +781,7 @@ public Mono<Void> sendMessage(JSONRPCMessage message) {
737
781
if (message instanceof McpSchema .JSONRPCResponse ) {
738
782
logger .debug ("Completing SSE stream after sending response" );
739
783
eventSink .tryEmitComplete ();
784
+ sessions .get (sessionId ).setTransportEventHistory (id , eventHistory );
740
785
}
741
786
742
787
return Mono .empty ();
@@ -754,7 +799,7 @@ public Mono<Void> sendMessageStream(Flux<JSONRPCMessage> messageStream) {
754
799
return messageStream .doOnNext (message -> {
755
800
try {
756
801
String jsonText = objectMapper .writeValueAsString (message );
757
- String eventId = String . valueOf ( eventCounter . incrementAndGet () );
802
+ String eventId = sessions . get ( sessionId ). incrementAndGetEventId ( id );
758
803
SseEvent event = new SseEvent (eventId , MESSAGE_EVENT_TYPE , jsonText );
759
804
760
805
eventHistory .put (eventId , event );
@@ -768,6 +813,7 @@ public Mono<Void> sendMessageStream(Flux<JSONRPCMessage> messageStream) {
768
813
}).doOnComplete (() -> {
769
814
logger .debug ("Completing SSE stream after sending all stream messages" );
770
815
eventSink .tryEmitComplete ();
816
+ sessions .get (sessionId ).setTransportEventHistory (id , eventHistory );
771
817
}).then ();
772
818
}
773
819
@@ -780,13 +826,11 @@ public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) {
780
826
public Mono <Void > closeGracefully () {
781
827
return Mono .fromRunnable (() -> {
782
828
eventSink .tryEmitComplete ();
829
+ sessions .get (sessionId ).setTransportEventHistory (id , eventHistory );
783
830
logger .debug ("SSE transport closed gracefully" );
784
831
});
785
832
}
786
833
787
- private record SseEvent (String id , String event , String data ) {
788
- }
789
-
790
834
}
791
835
792
836
/**
0 commit comments