Skip to content

Commit 046903c

Browse files
authored
Don't conflate request and response IDs in Streamable HTTP transports (#475)
1 parent e0b058f commit 046903c

File tree

7 files changed

+147
-26
lines changed

7 files changed

+147
-26
lines changed

src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -96,30 +96,30 @@ internal async Task<HttpResponseMessage> SendHttpRequestAsync(JsonRpcMessage mes
9696
}
9797

9898
var rpcRequest = message as JsonRpcRequest;
99-
JsonRpcMessage? rpcResponseCandidate = null;
99+
JsonRpcMessageWithId? rpcResponseOrError = null;
100100

101101
if (response.Content.Headers.ContentType?.MediaType == "application/json")
102102
{
103103
var responseContent = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false);
104-
rpcResponseCandidate = await ProcessMessageAsync(responseContent, cancellationToken).ConfigureAwait(false);
104+
rpcResponseOrError = await ProcessMessageAsync(responseContent, rpcRequest, cancellationToken).ConfigureAwait(false);
105105
}
106106
else if (response.Content.Headers.ContentType?.MediaType == "text/event-stream")
107107
{
108108
using var responseBodyStream = await response.Content.ReadAsStreamAsync(cancellationToken);
109-
rpcResponseCandidate = await ProcessSseResponseAsync(responseBodyStream, rpcRequest, cancellationToken).ConfigureAwait(false);
109+
rpcResponseOrError = await ProcessSseResponseAsync(responseBodyStream, rpcRequest, cancellationToken).ConfigureAwait(false);
110110
}
111111

112112
if (rpcRequest is null)
113113
{
114114
return response;
115115
}
116116

117-
if (rpcResponseCandidate is not JsonRpcMessageWithId messageWithId || messageWithId.Id != rpcRequest.Id)
117+
if (rpcResponseOrError is null)
118118
{
119119
throw new McpException($"Streamable HTTP POST response completed without a reply to request with ID: {rpcRequest.Id}");
120120
}
121121

122-
if (rpcRequest.Method == RequestMethods.Initialize && rpcResponseCandidate is JsonRpcResponse)
122+
if (rpcRequest.Method == RequestMethods.Initialize && rpcResponseOrError is JsonRpcResponse)
123123
{
124124
// We've successfully initialized! Copy session-id and start GET request if any.
125125
if (response.Headers.TryGetValues("mcp-session-id", out var sessionIdValues))
@@ -193,20 +193,20 @@ private async Task ReceiveUnsolicitedMessagesAsync()
193193
continue;
194194
}
195195

196-
var message = await ProcessMessageAsync(sseEvent.Data, cancellationToken).ConfigureAwait(false);
196+
var rpcResponseOrError = await ProcessMessageAsync(sseEvent.Data, relatedRpcRequest, cancellationToken).ConfigureAwait(false);
197197

198-
// The server SHOULD end the response here anyway, but we won't leave it to chance. This transport makes
198+
// The server SHOULD end the HTTP response body here anyway, but we won't leave it to chance. This transport makes
199199
// a GET request for any notifications that might need to be sent after the completion of each POST.
200-
if (message is JsonRpcMessageWithId messageWithId && relatedRpcRequest?.Id == messageWithId.Id)
200+
if (rpcResponseOrError is not null)
201201
{
202-
return messageWithId;
202+
return rpcResponseOrError;
203203
}
204204
}
205205

206206
return null;
207207
}
208208

209-
private async Task<JsonRpcMessage?> ProcessMessageAsync(string data, CancellationToken cancellationToken)
209+
private async Task<JsonRpcMessageWithId?> ProcessMessageAsync(string data, JsonRpcRequest? relatedRpcRequest, CancellationToken cancellationToken)
210210
{
211211
try
212212
{
@@ -218,7 +218,12 @@ private async Task ReceiveUnsolicitedMessagesAsync()
218218
}
219219

220220
await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false);
221-
return message;
221+
if (message is JsonRpcResponse or JsonRpcError &&
222+
message is JsonRpcMessageWithId rpcResponseOrError &&
223+
rpcResponseOrError.Id == relatedRpcRequest?.Id)
224+
{
225+
return rpcResponseOrError;
226+
}
222227
}
223228
catch (JsonException ex)
224229
{

src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ public async ValueTask DisposeAsync()
6060
{
6161
yield return message;
6262

63-
if (message.Data is JsonRpcMessageWithId response && response.Id == _pendingRequest)
63+
if (message.Data is JsonRpcResponse or JsonRpcError && ((JsonRpcMessageWithId)message.Data).Id == _pendingRequest)
6464
{
6565
// Complete the SSE response stream now that all pending requests have been processed.
6666
break;

tests/Common/Utils/LoggedTest.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,16 @@ public LoggedTest(ITestOutputHelper testOutputHelper)
1212
{
1313
CurrentTestOutputHelper = testOutputHelper,
1414
};
15-
LoggerProvider = new XunitLoggerProvider(_delegatingTestOutputHelper);
15+
XunitLoggerProvider = new XunitLoggerProvider(_delegatingTestOutputHelper);
1616
LoggerFactory = Microsoft.Extensions.Logging.LoggerFactory.Create(builder =>
1717
{
18-
builder.AddProvider(LoggerProvider);
18+
builder.AddProvider(XunitLoggerProvider);
1919
});
2020
}
2121

2222
public ITestOutputHelper TestOutputHelper => _delegatingTestOutputHelper;
23-
public ILoggerFactory LoggerFactory { get; }
24-
public ILoggerProvider LoggerProvider { get; }
23+
public ILoggerFactory LoggerFactory { get; set; }
24+
public ILoggerProvider XunitLoggerProvider { get; }
2525

2626
public virtual void Dispose()
2727
{

tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public async Task CanConnect_WithMcpClient_AfterCustomizingRoute(string routePat
5252

5353
await app.StartAsync(TestContext.Current.CancellationToken);
5454

55-
var mcpClient = await ConnectAsync(requestPath);
55+
await using var mcpClient = await ConnectAsync(requestPath);
5656

5757
Assert.Equal("TestCustomRouteServer", mcpClient.ServerInfo.Name);
5858
}

tests/ModelContextProtocol.AspNetCore.Tests/MapMcpStreamableHttpTests.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ public async Task CanConnect_WithMcpClient_AfterCustomizingRoute(string routePat
3131

3232
await app.StartAsync(TestContext.Current.CancellationToken);
3333

34-
var mcpClient = await ConnectAsync(requestPath);
34+
await using var mcpClient = await ConnectAsync(requestPath);
3535

3636
Assert.Equal("TestCustomRouteServer", mcpClient.ServerInfo.Name);
3737
}
@@ -135,7 +135,7 @@ public async Task SseMode_Works_WithSseEndpoint()
135135

136136
await app.StartAsync(TestContext.Current.CancellationToken);
137137

138-
await using var mcpClient = await ConnectAsync(options: new()
138+
await using var mcpClient = await ConnectAsync(transportOptions: new()
139139
{
140140
Endpoint = new Uri("http://localhost/sse"),
141141
TransportMode = HttpTransportMode.Sse

tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs

Lines changed: 122 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
using Microsoft.AspNetCore.Builder;
22
using Microsoft.AspNetCore.Http;
33
using Microsoft.Extensions.DependencyInjection;
4+
using Microsoft.Extensions.Logging;
45
using ModelContextProtocol.AspNetCore.Tests.Utils;
56
using ModelContextProtocol.Client;
7+
using ModelContextProtocol.Protocol;
68
using ModelContextProtocol.Server;
9+
using ModelContextProtocol.Tests.Utils;
710
using System.ComponentModel;
811
using System.Net;
912
using System.Security.Claims;
@@ -20,18 +23,21 @@ protected void ConfigureStateless(HttpServerTransportOptions options)
2023
options.Stateless = Stateless;
2124
}
2225

23-
protected async Task<IMcpClient> ConnectAsync(string? path = null, SseClientTransportOptions? options = null)
26+
protected async Task<IMcpClient> ConnectAsync(
27+
string? path = null,
28+
SseClientTransportOptions? transportOptions = null,
29+
McpClientOptions? clientOptions = null)
2430
{
2531
// Default behavior when no options are provided
2632
path ??= UseStreamableHttp ? "/" : "/sse";
2733

28-
await using var transport = new SseClientTransport(options ?? new SseClientTransportOptions()
34+
await using var transport = new SseClientTransport(transportOptions ?? new SseClientTransportOptions()
2935
{
3036
Endpoint = new Uri($"http://localhost{path}"),
3137
TransportMode = UseStreamableHttp ? HttpTransportMode.StreamableHttp : HttpTransportMode.Sse,
3238
}, HttpClient, LoggerFactory);
3339

34-
return await McpClientFactory.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken);
40+
return await McpClientFactory.CreateAsync(transport, clientOptions, LoggerFactory, TestContext.Current.CancellationToken);
3541
}
3642

3743
[Fact]
@@ -71,7 +77,7 @@ IHttpContextAccessor is not currently supported with non-stateless Streamable HT
7177

7278
await app.StartAsync(TestContext.Current.CancellationToken);
7379

74-
var mcpClient = await ConnectAsync();
80+
await using var mcpClient = await ConnectAsync();
7581

7682
var response = await mcpClient.CallToolAsync(
7783
"EchoWithUserName",
@@ -111,13 +117,90 @@ public async Task Messages_FromNewUser_AreRejected()
111117
Assert.Equal(HttpStatusCode.Forbidden, httpRequestException.StatusCode);
112118
}
113119

114-
protected ClaimsPrincipal CreateUser(string name)
120+
[Fact]
121+
public async Task Sampling_DoesNotCloseStream_Prematurely()
122+
{
123+
Assert.SkipWhen(Stateless, "Sampling is not supported in stateless mode.");
124+
125+
Builder.Services.AddMcpServer().WithHttpTransport(ConfigureStateless).WithTools<SamplingRegressionTools>();
126+
127+
var mockLoggerProvider = new MockLoggerProvider();
128+
Builder.Logging.AddProvider(mockLoggerProvider);
129+
Builder.Logging.SetMinimumLevel(LogLevel.Debug);
130+
131+
await using var app = Builder.Build();
132+
133+
// Reset the LoggerFactory used by the client to use the MockLoggerProvider as well.
134+
LoggerFactory = app.Services.GetRequiredService<ILoggerFactory>();
135+
136+
app.MapMcp();
137+
138+
await app.StartAsync(TestContext.Current.CancellationToken);
139+
140+
var sampleCount = 0;
141+
var clientOptions = new McpClientOptions
142+
{
143+
Capabilities = new()
144+
{
145+
Sampling = new()
146+
{
147+
SamplingHandler = async (parameters, _, _) =>
148+
{
149+
Assert.NotNull(parameters?.Messages);
150+
var message = Assert.Single(parameters.Messages);
151+
Assert.Equal(Role.User, message.Role);
152+
Assert.Equal("text", message.Content.Type);
153+
Assert.Equal("Test prompt for sampling", message.Content.Text);
154+
155+
sampleCount++;
156+
return new CreateMessageResult
157+
{
158+
Model = "test-model",
159+
Role = Role.Assistant,
160+
Content = new Content
161+
{
162+
Type = "text",
163+
Text = "Sampling response from client"
164+
}
165+
};
166+
},
167+
},
168+
},
169+
};
170+
171+
await using var mcpClient = await ConnectAsync(clientOptions: clientOptions);
172+
173+
var result = await mcpClient.CallToolAsync("sampling-tool", new Dictionary<string, object?>
174+
{
175+
["prompt"] = "Test prompt for sampling"
176+
}, cancellationToken: TestContext.Current.CancellationToken);
177+
178+
Assert.NotNull(result);
179+
Assert.False(result.IsError);
180+
var textContent = Assert.Single(result.Content);
181+
Assert.Equal("text", textContent.Type);
182+
Assert.Equal("Sampling completed successfully. Client responded: Sampling response from client", textContent.Text);
183+
184+
Assert.Equal(2, sampleCount);
185+
186+
// Verify that the tool call and the sampling request both used the same ID to ensure we cover against regressions.
187+
// https://github.com/modelcontextprotocol/csharp-sdk/issues/464
188+
Assert.Single(mockLoggerProvider.LogMessages, m =>
189+
m.Category == "ModelContextProtocol.Client.McpClient" &&
190+
m.Message.Contains("request '2' for method 'tools/call'"));
191+
192+
Assert.Single(mockLoggerProvider.LogMessages, m =>
193+
m.Category == "ModelContextProtocol.Server.McpServer" &&
194+
m.Message.Contains("request '2' for method 'sampling/createMessage'"));
195+
}
196+
197+
private ClaimsPrincipal CreateUser(string name)
115198
=> new ClaimsPrincipal(new ClaimsIdentity(
116199
[new Claim("name", name), new Claim(ClaimTypes.NameIdentifier, name)],
117200
"TestAuthType", "name", "role"));
118201

119202
[McpServerToolType]
120-
protected class EchoHttpContextUserTools(IHttpContextAccessor contextAccessor)
203+
private class EchoHttpContextUserTools(IHttpContextAccessor contextAccessor)
121204
{
122205
[McpServerTool, Description("Echoes the input back to the client with their user name.")]
123206
public string EchoWithUserName(string message)
@@ -127,4 +210,37 @@ public string EchoWithUserName(string message)
127210
return $"{userName}: {message}";
128211
}
129212
}
213+
214+
[McpServerToolType]
215+
private class SamplingRegressionTools
216+
{
217+
[McpServerTool(Name = "sampling-tool")]
218+
public static async Task<string> SamplingToolAsync(IMcpServer server, string prompt, CancellationToken cancellationToken)
219+
{
220+
// This tool reproduces the scenario described in https://github.com/modelcontextprotocol/csharp-sdk/issues/464
221+
// 1. The client calls tool with request ID 2, because it's the first request after the initialize request.
222+
// 2. This tool makes two sampling requests which use IDs 1 and 2.
223+
// 3. In the old buggy Streamable HTTP transport code, this would close the SSE response stream,
224+
// because the second sampling request used an ID matching the tool call.
225+
var samplingRequest = new CreateMessageRequestParams
226+
{
227+
Messages = [
228+
new SamplingMessage
229+
{
230+
Role = Role.User,
231+
Content = new Content
232+
{
233+
Type = "text",
234+
Text = prompt
235+
},
236+
}
237+
],
238+
};
239+
240+
await server.SampleAsync(samplingRequest, cancellationToken);
241+
var samplingResult = await server.SampleAsync(samplingRequest, cancellationToken);
242+
243+
return $"Sampling completed successfully. Client responded: {samplingResult.Content.Text}";
244+
}
245+
}
130246
}

tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ public KestrelInMemoryTest(ITestOutputHelper testOutputHelper)
1818
Builder = WebApplication.CreateSlimBuilder();
1919
Builder.Services.RemoveAll<IConnectionListenerFactory>();
2020
Builder.Services.AddSingleton<IConnectionListenerFactory>(_inMemoryTransport);
21-
Builder.Services.AddSingleton(LoggerProvider);
21+
Builder.Services.AddSingleton(XunitLoggerProvider);
2222

2323
HttpClient = new HttpClient(new SocketsHttpHandler()
2424
{

0 commit comments

Comments
 (0)