@@ -20,7 +20,13 @@ internal sealed class McpSession : IDisposable
20
20
private readonly RequestHandlers _requestHandlers ;
21
21
private readonly NotificationHandlers _notificationHandlers ;
22
22
23
+ /// <summary>Collection of requests sent on this session and waiting for responses.</summary>
23
24
private readonly ConcurrentDictionary < RequestId , TaskCompletionSource < IJsonRpcMessage > > _pendingRequests = [ ] ;
25
+ /// <summary>
26
+ /// Collection of requests received on this session and currently being handled. The value provides a <see cref="CancellationTokenSource"/>
27
+ /// that can be used to request cancellation of the in-flight handler.
28
+ /// </summary>
29
+ private readonly ConcurrentDictionary < RequestId , CancellationTokenSource > _handlingRequests = new ( ) ;
24
30
private readonly JsonSerializerOptions _jsonOptions ;
25
31
private readonly ILogger _logger ;
26
32
@@ -69,25 +75,70 @@ public async Task ProcessMessagesAsync(CancellationToken cancellationToken)
69
75
{
70
76
_logger . TransportMessageRead ( EndpointName , message . GetType ( ) . Name ) ;
71
77
72
- // Fire and forget the message handling task to avoid blocking the transport
73
- // If awaiting the task, the transport will not be able to read more messages,
74
- // which could lead to a deadlock if the handler sends a message back
75
78
_ = ProcessMessageAsync ( ) ;
76
79
async Task ProcessMessageAsync ( )
77
80
{
81
+ IJsonRpcMessageWithId ? messageWithId = message as IJsonRpcMessageWithId ;
82
+ CancellationTokenSource ? combinedCts = null ;
83
+ try
84
+ {
85
+ // Register before we yield, so that the tracking is guaranteed to be there
86
+ // when subsequent messages arrive, even if the asynchronous processing happens
87
+ // out of order.
88
+ if ( messageWithId is not null )
89
+ {
90
+ combinedCts = CancellationTokenSource . CreateLinkedTokenSource ( cancellationToken ) ;
91
+ _handlingRequests [ messageWithId . Id ] = combinedCts ;
92
+ }
93
+
94
+ // Fire and forget the message handling to avoid blocking the transport
95
+ // If awaiting the task, the transport will not be able to read more messages,
96
+ // which could lead to a deadlock if the handler sends a message back
97
+
78
98
#if NET
79
- await Task . CompletedTask . ConfigureAwait ( ConfigureAwaitOptions . ForceYielding ) ;
99
+ await Task . CompletedTask . ConfigureAwait ( ConfigureAwaitOptions . ForceYielding ) ;
80
100
#else
81
- await default ( ForceYielding ) ;
101
+ await default ( ForceYielding ) ;
82
102
#endif
83
- try
84
- {
85
- await HandleMessageAsync ( message , cancellationToken ) . ConfigureAwait ( false ) ;
103
+
104
+ // Handle the message.
105
+ await HandleMessageAsync ( message , combinedCts ? . Token ?? cancellationToken ) . ConfigureAwait ( false ) ;
86
106
}
87
107
catch ( Exception ex )
88
108
{
89
- var payload = JsonSerializer . Serialize ( message , _jsonOptions . GetTypeInfo < IJsonRpcMessage > ( ) ) ;
90
- _logger . MessageHandlerError ( EndpointName , message . GetType ( ) . Name , payload , ex ) ;
109
+ // Only send responses for request errors that aren't user-initiated cancellation.
110
+ bool isUserCancellation =
111
+ ex is OperationCanceledException &&
112
+ ! cancellationToken . IsCancellationRequested &&
113
+ combinedCts ? . IsCancellationRequested is true ;
114
+
115
+ if ( ! isUserCancellation && message is JsonRpcRequest request )
116
+ {
117
+ _logger . RequestHandlerError ( EndpointName , request . Method , ex ) ;
118
+ await _transport . SendMessageAsync ( new JsonRpcError
119
+ {
120
+ Id = request . Id ,
121
+ JsonRpc = "2.0" ,
122
+ Error = new JsonRpcErrorDetail
123
+ {
124
+ Code = ErrorCodes . InternalError ,
125
+ Message = ex . Message
126
+ }
127
+ } , cancellationToken ) . ConfigureAwait ( false ) ;
128
+ }
129
+ else if ( ex is not OperationCanceledException )
130
+ {
131
+ var payload = JsonSerializer . Serialize ( message , _jsonOptions . GetTypeInfo < IJsonRpcMessage > ( ) ) ;
132
+ _logger . MessageHandlerError ( EndpointName , message . GetType ( ) . Name , payload , ex ) ;
133
+ }
134
+ }
135
+ finally
136
+ {
137
+ if ( messageWithId is not null )
138
+ {
139
+ _handlingRequests . TryRemove ( messageWithId . Id , out _ ) ;
140
+ combinedCts ! . Dispose ( ) ;
141
+ }
91
142
}
92
143
}
93
144
}
@@ -123,6 +174,25 @@ private async Task HandleMessageAsync(IJsonRpcMessage message, CancellationToken
123
174
124
175
private async Task HandleNotification ( JsonRpcNotification notification )
125
176
{
177
+ // Special-case cancellation to cancel a pending operation. (We'll still subsequently invoke a user-specified handler if one exists.)
178
+ if ( notification . Method == NotificationMethods . CancelledNotification )
179
+ {
180
+ try
181
+ {
182
+ if ( GetCancelledNotificationParams ( notification . Params ) is CancelledNotification cn &&
183
+ _handlingRequests . TryGetValue ( cn . RequestId , out var cts ) )
184
+ {
185
+ await cts . CancelAsync ( ) . ConfigureAwait ( false ) ;
186
+ _logger . RequestCanceled ( cn . RequestId , cn . Reason ) ;
187
+ }
188
+ }
189
+ catch
190
+ {
191
+ // "Invalid cancellation notifications SHOULD be ignored"
192
+ }
193
+ }
194
+
195
+ // Handle user-defined notifications.
126
196
if ( _notificationHandlers . TryGetValue ( notification . Method , out var handlers ) )
127
197
{
128
198
foreach ( var notificationHandler in handlers )
@@ -161,33 +231,15 @@ private async Task HandleRequest(JsonRpcRequest request, CancellationToken cance
161
231
{
162
232
if ( _requestHandlers . TryGetValue ( request . Method , out var handler ) )
163
233
{
164
- try
165
- {
166
- _logger . RequestHandlerCalled ( EndpointName , request . Method ) ;
167
- var result = await handler ( request , cancellationToken ) . ConfigureAwait ( false ) ;
168
- _logger . RequestHandlerCompleted ( EndpointName , request . Method ) ;
169
- await _transport . SendMessageAsync ( new JsonRpcResponse
170
- {
171
- Id = request . Id ,
172
- JsonRpc = "2.0" ,
173
- Result = result
174
- } , cancellationToken ) . ConfigureAwait ( false ) ;
175
- }
176
- catch ( Exception ex )
234
+ _logger . RequestHandlerCalled ( EndpointName , request . Method ) ;
235
+ var result = await handler ( request , cancellationToken ) . ConfigureAwait ( false ) ;
236
+ _logger . RequestHandlerCompleted ( EndpointName , request . Method ) ;
237
+ await _transport . SendMessageAsync ( new JsonRpcResponse
177
238
{
178
- _logger . RequestHandlerError ( EndpointName , request . Method , ex ) ;
179
- // Send error response
180
- await _transport . SendMessageAsync ( new JsonRpcError
181
- {
182
- Id = request . Id ,
183
- JsonRpc = "2.0" ,
184
- Error = new JsonRpcErrorDetail
185
- {
186
- Code = - 32000 , // Implementation defined error
187
- Message = ex . Message
188
- }
189
- } , cancellationToken ) . ConfigureAwait ( false ) ;
190
- }
239
+ Id = request . Id ,
240
+ JsonRpc = "2.0" ,
241
+ Result = result
242
+ } , cancellationToken ) . ConfigureAwait ( false ) ;
191
243
}
192
244
else
193
245
{
@@ -273,7 +325,7 @@ public async Task<TResult> SendRequestAsync<TResult>(JsonRpcRequest request, Can
273
325
}
274
326
}
275
327
276
- public Task SendMessageAsync ( IJsonRpcMessage message , CancellationToken cancellationToken = default )
328
+ public async Task SendMessageAsync ( IJsonRpcMessage message , CancellationToken cancellationToken = default )
277
329
{
278
330
Throw . IfNull ( message ) ;
279
331
@@ -288,7 +340,44 @@ public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancella
288
340
_logger . SendingMessage ( EndpointName , JsonSerializer . Serialize ( message , _jsonOptions . GetTypeInfo < IJsonRpcMessage > ( ) ) ) ;
289
341
}
290
342
291
- return _transport . SendMessageAsync ( message , cancellationToken ) ;
343
+ await _transport . SendMessageAsync ( message , cancellationToken ) . ConfigureAwait ( false ) ;
344
+
345
+ // If the sent notification was a cancellation notification, cancel the pending request's await, as either the
346
+ // server won't be sending a response, or per the specification, the response should be ignored. There are inherent
347
+ // race conditions here, so it's possible and allowed for the operation to complete before we get to this point.
348
+ if ( message is JsonRpcNotification { Method : NotificationMethods . CancelledNotification } notification &&
349
+ GetCancelledNotificationParams ( notification . Params ) is CancelledNotification cn &&
350
+ _pendingRequests . TryRemove ( cn . RequestId , out var tcs ) )
351
+ {
352
+ tcs . TrySetCanceled ( default ) ;
353
+ }
354
+ }
355
+
356
+ private static CancelledNotification ? GetCancelledNotificationParams ( object ? notificationParams )
357
+ {
358
+ try
359
+ {
360
+ switch ( notificationParams )
361
+ {
362
+ case null :
363
+ return null ;
364
+
365
+ case CancelledNotification cn :
366
+ return cn ;
367
+
368
+ case JsonElement je :
369
+ return JsonSerializer . Deserialize ( je , McpJsonUtilities . DefaultOptions . GetTypeInfo < CancelledNotification > ( ) ) ;
370
+
371
+ default :
372
+ return JsonSerializer . Deserialize (
373
+ JsonSerializer . Serialize ( notificationParams , McpJsonUtilities . DefaultOptions . GetTypeInfo < object ? > ( ) ) ,
374
+ McpJsonUtilities . DefaultOptions . GetTypeInfo < CancelledNotification > ( ) ) ;
375
+ }
376
+ }
377
+ catch
378
+ {
379
+ return null ;
380
+ }
292
381
}
293
382
294
383
public void Dispose ( )
0 commit comments