Skip to content

Commit 8fe96bf

Browse files
authored
Implement cancellation notifications (#146)
* Implement cancellation notifications * Address PR feedback * Address feedback
1 parent 5d3fb65 commit 8fe96bf

File tree

5 files changed

+201
-43
lines changed

5 files changed

+201
-43
lines changed

src/ModelContextProtocol/Logging/Log.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ internal static partial class Log
6868
[LoggerMessage(Level = LogLevel.Error, Message = "Request failed for {endpointName} with method {method}: {message} ({code})")]
6969
internal static partial void RequestFailed(this ILogger logger, string endpointName, string method, string message, int code);
7070

71+
[LoggerMessage(Level = LogLevel.Information, Message = "Request '{requestId}' canceled via client notification with reason '{Reason}'.")]
72+
internal static partial void RequestCanceled(this ILogger logger, RequestId requestId, string? reason);
73+
7174
[LoggerMessage(Level = LogLevel.Information, Message = "Request response received payload for {endpointName}: {payload}")]
7275
internal static partial void RequestResponseReceivedPayload(this ILogger logger, string endpointName, string payload);
7376

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
using System.Text.Json.Serialization;
2+
3+
namespace ModelContextProtocol.Protocol.Messages;
4+
5+
/// <summary>
6+
/// This notification indicates that the result will be unused, so any associated processing SHOULD cease.
7+
/// </summary>
8+
public sealed class CancelledNotification
9+
{
10+
/// <summary>
11+
/// The ID of the request to cancel.
12+
/// </summary>
13+
[JsonPropertyName("requestId")]
14+
public RequestId RequestId { get; set; }
15+
16+
/// <summary>
17+
/// An optional string describing the reason for the cancellation.
18+
/// </summary>
19+
[JsonPropertyName("reason")]
20+
public string? Reason { get; set; }
21+
}

src/ModelContextProtocol/Shared/McpSession.cs

Lines changed: 127 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,13 @@ internal sealed class McpSession : IDisposable
2020
private readonly RequestHandlers _requestHandlers;
2121
private readonly NotificationHandlers _notificationHandlers;
2222

23+
/// <summary>Collection of requests sent on this session and waiting for responses.</summary>
2324
private readonly ConcurrentDictionary<RequestId, TaskCompletionSource<IJsonRpcMessage>> _pendingRequests = [];
25+
/// <summary>
26+
/// Collection of requests received on this session and currently being handled. The value provides a <see cref="CancellationTokenSource"/>
27+
/// that can be used to request cancellation of the in-flight handler.
28+
/// </summary>
29+
private readonly ConcurrentDictionary<RequestId, CancellationTokenSource> _handlingRequests = new();
2430
private readonly JsonSerializerOptions _jsonOptions;
2531
private readonly ILogger _logger;
2632

@@ -69,25 +75,70 @@ public async Task ProcessMessagesAsync(CancellationToken cancellationToken)
6975
{
7076
_logger.TransportMessageRead(EndpointName, message.GetType().Name);
7177

72-
// Fire and forget the message handling task to avoid blocking the transport
73-
// If awaiting the task, the transport will not be able to read more messages,
74-
// which could lead to a deadlock if the handler sends a message back
7578
_ = ProcessMessageAsync();
7679
async Task ProcessMessageAsync()
7780
{
81+
IJsonRpcMessageWithId? messageWithId = message as IJsonRpcMessageWithId;
82+
CancellationTokenSource? combinedCts = null;
83+
try
84+
{
85+
// Register before we yield, so that the tracking is guaranteed to be there
86+
// when subsequent messages arrive, even if the asynchronous processing happens
87+
// out of order.
88+
if (messageWithId is not null)
89+
{
90+
combinedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
91+
_handlingRequests[messageWithId.Id] = combinedCts;
92+
}
93+
94+
// Fire and forget the message handling to avoid blocking the transport
95+
// If awaiting the task, the transport will not be able to read more messages,
96+
// which could lead to a deadlock if the handler sends a message back
97+
7898
#if NET
79-
await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding);
99+
await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding);
80100
#else
81-
await default(ForceYielding);
101+
await default(ForceYielding);
82102
#endif
83-
try
84-
{
85-
await HandleMessageAsync(message, cancellationToken).ConfigureAwait(false);
103+
104+
// Handle the message.
105+
await HandleMessageAsync(message, combinedCts?.Token ?? cancellationToken).ConfigureAwait(false);
86106
}
87107
catch (Exception ex)
88108
{
89-
var payload = JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo<IJsonRpcMessage>());
90-
_logger.MessageHandlerError(EndpointName, message.GetType().Name, payload, ex);
109+
// Only send responses for request errors that aren't user-initiated cancellation.
110+
bool isUserCancellation =
111+
ex is OperationCanceledException &&
112+
!cancellationToken.IsCancellationRequested &&
113+
combinedCts?.IsCancellationRequested is true;
114+
115+
if (!isUserCancellation && message is JsonRpcRequest request)
116+
{
117+
_logger.RequestHandlerError(EndpointName, request.Method, ex);
118+
await _transport.SendMessageAsync(new JsonRpcError
119+
{
120+
Id = request.Id,
121+
JsonRpc = "2.0",
122+
Error = new JsonRpcErrorDetail
123+
{
124+
Code = ErrorCodes.InternalError,
125+
Message = ex.Message
126+
}
127+
}, cancellationToken).ConfigureAwait(false);
128+
}
129+
else if (ex is not OperationCanceledException)
130+
{
131+
var payload = JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo<IJsonRpcMessage>());
132+
_logger.MessageHandlerError(EndpointName, message.GetType().Name, payload, ex);
133+
}
134+
}
135+
finally
136+
{
137+
if (messageWithId is not null)
138+
{
139+
_handlingRequests.TryRemove(messageWithId.Id, out _);
140+
combinedCts!.Dispose();
141+
}
91142
}
92143
}
93144
}
@@ -123,6 +174,25 @@ private async Task HandleMessageAsync(IJsonRpcMessage message, CancellationToken
123174

124175
private async Task HandleNotification(JsonRpcNotification notification)
125176
{
177+
// Special-case cancellation to cancel a pending operation. (We'll still subsequently invoke a user-specified handler if one exists.)
178+
if (notification.Method == NotificationMethods.CancelledNotification)
179+
{
180+
try
181+
{
182+
if (GetCancelledNotificationParams(notification.Params) is CancelledNotification cn &&
183+
_handlingRequests.TryGetValue(cn.RequestId, out var cts))
184+
{
185+
await cts.CancelAsync().ConfigureAwait(false);
186+
_logger.RequestCanceled(cn.RequestId, cn.Reason);
187+
}
188+
}
189+
catch
190+
{
191+
// "Invalid cancellation notifications SHOULD be ignored"
192+
}
193+
}
194+
195+
// Handle user-defined notifications.
126196
if (_notificationHandlers.TryGetValue(notification.Method, out var handlers))
127197
{
128198
foreach (var notificationHandler in handlers)
@@ -161,33 +231,15 @@ private async Task HandleRequest(JsonRpcRequest request, CancellationToken cance
161231
{
162232
if (_requestHandlers.TryGetValue(request.Method, out var handler))
163233
{
164-
try
165-
{
166-
_logger.RequestHandlerCalled(EndpointName, request.Method);
167-
var result = await handler(request, cancellationToken).ConfigureAwait(false);
168-
_logger.RequestHandlerCompleted(EndpointName, request.Method);
169-
await _transport.SendMessageAsync(new JsonRpcResponse
170-
{
171-
Id = request.Id,
172-
JsonRpc = "2.0",
173-
Result = result
174-
}, cancellationToken).ConfigureAwait(false);
175-
}
176-
catch (Exception ex)
234+
_logger.RequestHandlerCalled(EndpointName, request.Method);
235+
var result = await handler(request, cancellationToken).ConfigureAwait(false);
236+
_logger.RequestHandlerCompleted(EndpointName, request.Method);
237+
await _transport.SendMessageAsync(new JsonRpcResponse
177238
{
178-
_logger.RequestHandlerError(EndpointName, request.Method, ex);
179-
// Send error response
180-
await _transport.SendMessageAsync(new JsonRpcError
181-
{
182-
Id = request.Id,
183-
JsonRpc = "2.0",
184-
Error = new JsonRpcErrorDetail
185-
{
186-
Code = -32000, // Implementation defined error
187-
Message = ex.Message
188-
}
189-
}, cancellationToken).ConfigureAwait(false);
190-
}
239+
Id = request.Id,
240+
JsonRpc = "2.0",
241+
Result = result
242+
}, cancellationToken).ConfigureAwait(false);
191243
}
192244
else
193245
{
@@ -273,7 +325,7 @@ public async Task<TResult> SendRequestAsync<TResult>(JsonRpcRequest request, Can
273325
}
274326
}
275327

276-
public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
328+
public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
277329
{
278330
Throw.IfNull(message);
279331

@@ -288,7 +340,44 @@ public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancella
288340
_logger.SendingMessage(EndpointName, JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo<IJsonRpcMessage>()));
289341
}
290342

291-
return _transport.SendMessageAsync(message, cancellationToken);
343+
await _transport.SendMessageAsync(message, cancellationToken).ConfigureAwait(false);
344+
345+
// If the sent notification was a cancellation notification, cancel the pending request's await, as either the
346+
// server won't be sending a response, or per the specification, the response should be ignored. There are inherent
347+
// race conditions here, so it's possible and allowed for the operation to complete before we get to this point.
348+
if (message is JsonRpcNotification { Method: NotificationMethods.CancelledNotification } notification &&
349+
GetCancelledNotificationParams(notification.Params) is CancelledNotification cn &&
350+
_pendingRequests.TryRemove(cn.RequestId, out var tcs))
351+
{
352+
tcs.TrySetCanceled(default);
353+
}
354+
}
355+
356+
private static CancelledNotification? GetCancelledNotificationParams(object? notificationParams)
357+
{
358+
try
359+
{
360+
switch (notificationParams)
361+
{
362+
case null:
363+
return null;
364+
365+
case CancelledNotification cn:
366+
return cn;
367+
368+
case JsonElement je:
369+
return JsonSerializer.Deserialize(je, McpJsonUtilities.DefaultOptions.GetTypeInfo<CancelledNotification>());
370+
371+
default:
372+
return JsonSerializer.Deserialize(
373+
JsonSerializer.Serialize(notificationParams, McpJsonUtilities.DefaultOptions.GetTypeInfo<object?>()),
374+
McpJsonUtilities.DefaultOptions.GetTypeInfo<CancelledNotification>());
375+
}
376+
}
377+
catch
378+
{
379+
return null;
380+
}
292381
}
293382

294383
public void Dispose()

src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ internal static bool IsValidMcpToolSchema(JsonElement element)
121121
// MCP Request Params / Results
122122
[JsonSerializable(typeof(CallToolRequestParams))]
123123
[JsonSerializable(typeof(CallToolResponse))]
124+
[JsonSerializable(typeof(CancelledNotification))]
124125
[JsonSerializable(typeof(CompleteRequestParams))]
125126
[JsonSerializable(typeof(CompleteResult))]
126127
[JsonSerializable(typeof(CreateMessageRequestParams))]

tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ public async Task Can_List_Registered_Tools()
9191
IMcpClient client = await CreateMcpClientForServer();
9292

9393
var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken);
94-
Assert.Equal(12, tools.Count);
94+
Assert.Equal(13, tools.Count);
9595

9696
McpClientTool echoTool = tools.First(t => t.Name == "Echo");
9797
Assert.Equal("Echo", echoTool.Name);
@@ -138,7 +138,7 @@ public async Task Can_Create_Multiple_Servers_From_Options_And_List_Registered_T
138138
cancellationToken: TestContext.Current.CancellationToken))
139139
{
140140
var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken);
141-
Assert.Equal(12, tools.Count);
141+
Assert.Equal(13, tools.Count);
142142

143143
McpClientTool echoTool = tools.First(t => t.Name == "Echo");
144144
Assert.Equal("Echo", echoTool.Name);
@@ -165,7 +165,7 @@ public async Task Can_Be_Notified_Of_Tool_Changes()
165165
IMcpClient client = await CreateMcpClientForServer();
166166

167167
var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken);
168-
Assert.Equal(12, tools.Count);
168+
Assert.Equal(13, tools.Count);
169169

170170
Channel<JsonRpcNotification> listChanged = Channel.CreateUnbounded<JsonRpcNotification>();
171171
client.AddNotificationHandler(NotificationMethods.ToolListChangedNotification, notification =>
@@ -186,7 +186,7 @@ public async Task Can_Be_Notified_Of_Tool_Changes()
186186
await notificationRead;
187187

188188
tools = await client.ListToolsAsync(TestContext.Current.CancellationToken);
189-
Assert.Equal(13, tools.Count);
189+
Assert.Equal(14, tools.Count);
190190
Assert.Contains(tools, t => t.Name == "NewTool");
191191

192192
notificationRead = listChanged.Reader.ReadAsync(TestContext.Current.CancellationToken);
@@ -195,7 +195,7 @@ public async Task Can_Be_Notified_Of_Tool_Changes()
195195
await notificationRead;
196196

197197
tools = await client.ListToolsAsync(TestContext.Current.CancellationToken);
198-
Assert.Equal(12, tools.Count);
198+
Assert.Equal(13, tools.Count);
199199
Assert.DoesNotContain(tools, t => t.Name == "NewTool");
200200
}
201201

@@ -560,6 +560,35 @@ public async Task HandlesIProgressParameter()
560560
}
561561
}
562562

563+
[Fact]
564+
public async Task CancellationNotificationsPropagateToToolTokens()
565+
{
566+
IMcpClient client = await CreateMcpClientForServer();
567+
568+
var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken);
569+
Assert.NotNull(tools);
570+
Assert.NotEmpty(tools);
571+
McpClientTool cancelableTool = tools.First(t => t.Name == nameof(EchoTool.InfiniteCancelableOperation));
572+
573+
var requestId = new RequestId(Guid.NewGuid().ToString());
574+
var invokeTask = client.SendRequestAsync<CallToolResponse>(new JsonRpcRequest()
575+
{
576+
Method = RequestMethods.ToolsCall,
577+
Id = requestId,
578+
Params = new CallToolRequestParams() { Name = cancelableTool.ProtocolTool.Name },
579+
}, TestContext.Current.CancellationToken);
580+
581+
await client.SendNotificationAsync(
582+
NotificationMethods.CancelledNotification,
583+
parameters: new CancelledNotification()
584+
{
585+
RequestId = requestId,
586+
},
587+
cancellationToken: TestContext.Current.CancellationToken);
588+
589+
await Assert.ThrowsAnyAsync<OperationCanceledException>(() => invokeTask);
590+
}
591+
563592
[McpServerToolType]
564593
public sealed class EchoTool(ObjectWithId objectFromDI)
565594
{
@@ -625,6 +654,21 @@ public static string EchoComplex(ComplexObject complex)
625654
return complex.Name!;
626655
}
627656

657+
[McpServerTool]
658+
public static async Task<string> InfiniteCancelableOperation(CancellationToken cancellationToken)
659+
{
660+
try
661+
{
662+
await Task.Delay(Timeout.Infinite, cancellationToken);
663+
}
664+
catch (Exception)
665+
{
666+
return "canceled";
667+
}
668+
669+
return "unreachable";
670+
}
671+
628672
[McpServerTool]
629673
public string GetCtorParameter() => $"{_randomValue}:{objectFromDI.Id}";
630674

0 commit comments

Comments
 (0)