1
1
using Microsoft . AspNetCore . Builder ;
2
2
using Microsoft . AspNetCore . Http ;
3
3
using Microsoft . Extensions . DependencyInjection ;
4
+ using Microsoft . Extensions . Logging ;
4
5
using ModelContextProtocol . AspNetCore . Tests . Utils ;
5
6
using ModelContextProtocol . Client ;
7
+ using ModelContextProtocol . Protocol ;
6
8
using ModelContextProtocol . Server ;
9
+ using ModelContextProtocol . Tests . Utils ;
7
10
using System . ComponentModel ;
8
11
using System . Net ;
9
12
using System . Security . Claims ;
@@ -20,18 +23,21 @@ protected void ConfigureStateless(HttpServerTransportOptions options)
20
23
options . Stateless = Stateless ;
21
24
}
22
25
23
- protected async Task < IMcpClient > ConnectAsync ( string ? path = null , SseClientTransportOptions ? options = null )
26
+ protected async Task < IMcpClient > ConnectAsync (
27
+ string ? path = null ,
28
+ SseClientTransportOptions ? transportOptions = null ,
29
+ McpClientOptions ? clientOptions = null )
24
30
{
25
31
// Default behavior when no options are provided
26
32
path ??= UseStreamableHttp ? "/" : "/sse" ;
27
33
28
- await using var transport = new SseClientTransport ( options ?? new SseClientTransportOptions ( )
34
+ await using var transport = new SseClientTransport ( transportOptions ?? new SseClientTransportOptions ( )
29
35
{
30
36
Endpoint = new Uri ( $ "http://localhost{ path } ") ,
31
37
TransportMode = UseStreamableHttp ? HttpTransportMode . StreamableHttp : HttpTransportMode . Sse ,
32
38
} , HttpClient , LoggerFactory ) ;
33
39
34
- return await McpClientFactory . CreateAsync ( transport , loggerFactory : LoggerFactory , cancellationToken : TestContext . Current . CancellationToken ) ;
40
+ return await McpClientFactory . CreateAsync ( transport , clientOptions , LoggerFactory , TestContext . Current . CancellationToken ) ;
35
41
}
36
42
37
43
[ Fact ]
@@ -71,7 +77,7 @@ IHttpContextAccessor is not currently supported with non-stateless Streamable HT
71
77
72
78
await app . StartAsync ( TestContext . Current . CancellationToken ) ;
73
79
74
- var mcpClient = await ConnectAsync ( ) ;
80
+ await using var mcpClient = await ConnectAsync ( ) ;
75
81
76
82
var response = await mcpClient . CallToolAsync (
77
83
"EchoWithUserName" ,
@@ -111,13 +117,90 @@ public async Task Messages_FromNewUser_AreRejected()
111
117
Assert . Equal ( HttpStatusCode . Forbidden , httpRequestException . StatusCode ) ;
112
118
}
113
119
114
- protected ClaimsPrincipal CreateUser ( string name )
120
+ [ Fact ]
121
+ public async Task Sampling_DoesNotCloseStream_Prematurely ( )
122
+ {
123
+ Assert . SkipWhen ( Stateless , "Sampling is not supported in stateless mode." ) ;
124
+
125
+ Builder . Services . AddMcpServer ( ) . WithHttpTransport ( ConfigureStateless ) . WithTools < SamplingRegressionTools > ( ) ;
126
+
127
+ var mockLoggerProvider = new MockLoggerProvider ( ) ;
128
+ Builder . Logging . AddProvider ( mockLoggerProvider ) ;
129
+ Builder . Logging . SetMinimumLevel ( LogLevel . Debug ) ;
130
+
131
+ await using var app = Builder . Build ( ) ;
132
+
133
+ // Reset the LoggerFactory used by the client to use the MockLoggerProvider as well.
134
+ LoggerFactory = app . Services . GetRequiredService < ILoggerFactory > ( ) ;
135
+
136
+ app . MapMcp ( ) ;
137
+
138
+ await app . StartAsync ( TestContext . Current . CancellationToken ) ;
139
+
140
+ var sampleCount = 0 ;
141
+ var clientOptions = new McpClientOptions
142
+ {
143
+ Capabilities = new ( )
144
+ {
145
+ Sampling = new ( )
146
+ {
147
+ SamplingHandler = async ( parameters , _ , _ ) =>
148
+ {
149
+ Assert . NotNull ( parameters ? . Messages ) ;
150
+ var message = Assert . Single ( parameters . Messages ) ;
151
+ Assert . Equal ( Role . User , message . Role ) ;
152
+ Assert . Equal ( "text" , message . Content . Type ) ;
153
+ Assert . Equal ( "Test prompt for sampling" , message . Content . Text ) ;
154
+
155
+ sampleCount ++ ;
156
+ return new CreateMessageResult
157
+ {
158
+ Model = "test-model" ,
159
+ Role = Role . Assistant ,
160
+ Content = new Content
161
+ {
162
+ Type = "text" ,
163
+ Text = "Sampling response from client"
164
+ }
165
+ } ;
166
+ } ,
167
+ } ,
168
+ } ,
169
+ } ;
170
+
171
+ await using var mcpClient = await ConnectAsync ( clientOptions : clientOptions ) ;
172
+
173
+ var result = await mcpClient . CallToolAsync ( "sampling-tool" , new Dictionary < string , object ? >
174
+ {
175
+ [ "prompt" ] = "Test prompt for sampling"
176
+ } , cancellationToken : TestContext . Current . CancellationToken ) ;
177
+
178
+ Assert . NotNull ( result ) ;
179
+ Assert . False ( result . IsError ) ;
180
+ var textContent = Assert . Single ( result . Content ) ;
181
+ Assert . Equal ( "text" , textContent . Type ) ;
182
+ Assert . Equal ( "Sampling completed successfully. Client responded: Sampling response from client" , textContent . Text ) ;
183
+
184
+ Assert . Equal ( 2 , sampleCount ) ;
185
+
186
+ // Verify that the tool call and the sampling request both used the same ID to ensure we cover against regressions.
187
+ // https://github.com/modelcontextprotocol/csharp-sdk/issues/464
188
+ Assert . Single ( mockLoggerProvider . LogMessages , m =>
189
+ m . Category == "ModelContextProtocol.Client.McpClient" &&
190
+ m . Message . Contains ( "request '2' for method 'tools/call'" ) ) ;
191
+
192
+ Assert . Single ( mockLoggerProvider . LogMessages , m =>
193
+ m . Category == "ModelContextProtocol.Server.McpServer" &&
194
+ m . Message . Contains ( "request '2' for method 'sampling/createMessage'" ) ) ;
195
+ }
196
+
197
+ private ClaimsPrincipal CreateUser ( string name )
115
198
=> new ClaimsPrincipal ( new ClaimsIdentity (
116
199
[ new Claim ( "name" , name ) , new Claim ( ClaimTypes . NameIdentifier , name ) ] ,
117
200
"TestAuthType" , "name" , "role" ) ) ;
118
201
119
202
[ McpServerToolType ]
120
- protected class EchoHttpContextUserTools ( IHttpContextAccessor contextAccessor )
203
+ private class EchoHttpContextUserTools ( IHttpContextAccessor contextAccessor )
121
204
{
122
205
[ McpServerTool , Description ( "Echoes the input back to the client with their user name." ) ]
123
206
public string EchoWithUserName ( string message )
@@ -127,4 +210,37 @@ public string EchoWithUserName(string message)
127
210
return $ "{ userName } : { message } ";
128
211
}
129
212
}
213
+
214
+ [ McpServerToolType ]
215
+ private class SamplingRegressionTools
216
+ {
217
+ [ McpServerTool ( Name = "sampling-tool" ) ]
218
+ public static async Task < string > SamplingToolAsync ( IMcpServer server , string prompt , CancellationToken cancellationToken )
219
+ {
220
+ // This tool reproduces the scenario described in https://github.com/modelcontextprotocol/csharp-sdk/issues/464
221
+ // 1. The client calls tool with request ID 2, because it's the first request after the initialize request.
222
+ // 2. This tool makes two sampling requests which use IDs 1 and 2.
223
+ // 3. In the old buggy Streamable HTTP transport code, this would close the SSE response stream,
224
+ // because the second sampling request used an ID matching the tool call.
225
+ var samplingRequest = new CreateMessageRequestParams
226
+ {
227
+ Messages = [
228
+ new SamplingMessage
229
+ {
230
+ Role = Role . User ,
231
+ Content = new Content
232
+ {
233
+ Type = "text" ,
234
+ Text = prompt
235
+ } ,
236
+ }
237
+ ] ,
238
+ } ;
239
+
240
+ await server . SampleAsync ( samplingRequest , cancellationToken ) ;
241
+ var samplingResult = await server . SampleAsync ( samplingRequest , cancellationToken ) ;
242
+
243
+ return $ "Sampling completed successfully. Client responded: { samplingResult . Content . Text } ";
244
+ }
245
+ }
130
246
}
0 commit comments