Skip to content

Commit 1b1c7fa

Browse files
authored
Move the SseResponseStreamTransport out of sample (#47)
1 parent e1a5828 commit 1b1c7fa

File tree

6 files changed

+103
-125
lines changed

6 files changed

+103
-125
lines changed

Directory.Packages.props

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
<PackageVersion Include="Microsoft.Extensions.AI.Abstractions" Version="$(MicrosoftExtensionsAIVersion)" />
1414
<PackageVersion Include="Microsoft.Extensions.Hosting.Abstractions" Version="$(MicrosoftExtensionsVersion)" />
1515
<PackageVersion Include="Microsoft.Extensions.Logging.Abstractions" Version="$(MicrosoftExtensionsVersion)" />
16-
<PackageVersion Include="System.Net.ServerSentEvents" Version="$(SystemVersion)" />
16+
<PackageVersion Include="System.Net.ServerSentEvents" Version="$(System10Version)" />
1717
<PackageVersion Include="System.Text.Json" Version="$(SystemVersion)" />
1818
<PackageVersion Include="System.Threading.Channels" Version="$(SystemVersion)" />
1919

samples/AspNetCoreSseServer/AspNetCoreSseServer.csproj

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,4 @@
1010
<ProjectReference Include="..\..\src\ModelContextProtocol\ModelContextProtocol.csproj" />
1111
</ItemGroup>
1212

13-
<ItemGroup>
14-
<PackageReference Include="System.Net.ServerSentEvents" VersionOverride="10.0.0-preview.1.25080.5" />
15-
</ItemGroup>
16-
1713
</Project>

samples/AspNetCoreSseServer/McpEndpointRouteBuilderExtensions.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using ModelContextProtocol.Server;
33
using ModelContextProtocol.Utils.Json;
44
using Microsoft.Extensions.Options;
5+
using ModelContextProtocol.Protocol.Transport;
56

67
namespace AspNetCoreSseServer;
78

@@ -10,15 +11,15 @@ public static class McpEndpointRouteBuilderExtensions
1011
public static IEndpointConventionBuilder MapMcpSse(this IEndpointRouteBuilder endpoints)
1112
{
1213
IMcpServer? server = null;
13-
SseServerStreamTransport? transport = null;
14+
SseResponseStreamTransport? transport = null;
1415
var loggerFactory = endpoints.ServiceProvider.GetRequiredService<ILoggerFactory>();
1516
var mcpServerOptions = endpoints.ServiceProvider.GetRequiredService<IOptions<McpServerOptions>>();
1617

1718
var routeGroup = endpoints.MapGroup("");
1819

1920
routeGroup.MapGet("/sse", async (HttpResponse response, CancellationToken requestAborted) =>
2021
{
21-
await using var localTransport = transport = new SseServerStreamTransport(response.Body);
22+
await using var localTransport = transport = new SseResponseStreamTransport(response.Body);
2223
await using var localServer = server = McpServerFactory.Create(transport, mcpServerOptions.Value, loggerFactory, endpoints.ServiceProvider);
2324

2425
await localServer.StartAsync(requestAborted);
@@ -37,7 +38,7 @@ public static IEndpointConventionBuilder MapMcpSse(this IEndpointRouteBuilder en
3738
}
3839
});
3940

40-
routeGroup.MapPost("/message", async (HttpContext context) =>
41+
routeGroup.MapPost("/message", async context =>
4142
{
4243
if (transport is null)
4344
{

src/ModelContextProtocol/Protocol/Transport/HttpListenerServerProvider.cs

Lines changed: 7 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
1-
using ModelContextProtocol.Server;
2-
using System.Diagnostics.CodeAnalysis;
3-
using System.Net;
4-
using System.Text;
1+
using System.Net;
2+
using ModelContextProtocol.Server;
53

64
namespace ModelContextProtocol.Protocol.Transport;
75

86
/// <summary>
97
/// HTTP server provider using HttpListener.
108
/// </summary>
11-
[ExcludeFromCodeCoverage]
129
internal class HttpListenerServerProvider : IDisposable
1310
{
1411
private static readonly byte[] s_accepted = "Accepted"u8.ToArray();
@@ -19,8 +16,6 @@ internal class HttpListenerServerProvider : IDisposable
1916
private readonly int _port;
2017
private HttpListener? _listener;
2118
private CancellationTokenSource? _cts;
22-
private Func<string, CancellationToken, bool>? _messageHandler;
23-
private StreamWriter? _streamWriter;
2419
private bool _isRunning;
2520

2621
/// <summary>
@@ -32,39 +27,16 @@ public HttpListenerServerProvider(int port)
3227
_port = port;
3328
}
3429

35-
public Task<string> GetSseEndpointUri()
36-
{
37-
return Task.FromResult($"http://localhost:{_port}{SseEndpoint}");
38-
}
39-
40-
public Task InitializeMessageHandler(Func<string, CancellationToken, bool> messageHandler)
41-
{
42-
_messageHandler = messageHandler;
43-
return Task.CompletedTask;
44-
}
45-
46-
public async Task SendEvent(string data, string eventId)
47-
{
48-
if (_streamWriter == null)
49-
{
50-
throw new McpServerException("Stream writer not initialized");
51-
}
52-
if (eventId != null)
53-
{
54-
await _streamWriter.WriteLineAsync($"id: {eventId}").ConfigureAwait(false);
55-
}
56-
await _streamWriter.WriteLineAsync($"data: {data}").ConfigureAwait(false);
57-
await _streamWriter.WriteLineAsync().ConfigureAwait(false); // Empty line to finish the event
58-
await _streamWriter.FlushAsync().ConfigureAwait(false);
59-
}
30+
public required Func<Stream, CancellationToken, Task> OnSseConnectionAsync { get; set; }
31+
public required Func<Stream, CancellationToken, Task<bool>> OnMessageAsync { get; set; }
6032

6133
/// <inheritdoc/>
6234
public Task StartAsync(CancellationToken cancellationToken = default)
6335
{
6436
if (_isRunning)
6537
return Task.CompletedTask;
6638

67-
_cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
39+
_cts = new CancellationTokenSource();
6840
_listener = new HttpListener();
6941
_listener.Prefixes.Add($"http://localhost:{_port}/");
7042
_listener.Start();
@@ -84,8 +56,6 @@ public Task StopAsync(CancellationToken cancellationToken = default)
8456
_cts?.Cancel();
8557
_listener?.Stop();
8658

87-
_streamWriter?.Close();
88-
8959
_isRunning = false;
9060
return Task.CompletedTask;
9161
}
@@ -170,28 +140,10 @@ private async Task HandleSseConnectionAsync(HttpListenerContext context, Cancell
170140
response.Headers.Add("Cache-Control", "no-cache");
171141
response.Headers.Add("Connection", "keep-alive");
172142

173-
// Get the output stream and create a StreamWriter
174-
var outputStream = response.OutputStream;
175-
_streamWriter = new StreamWriter(outputStream, Encoding.UTF8) { AutoFlush = true };
176-
177143
// Keep the connection open until cancelled
178144
try
179145
{
180-
// Immediately send the "endpoint" event with the POST URL
181-
await _streamWriter.WriteLineAsync("event: endpoint").ConfigureAwait(false);
182-
await _streamWriter.WriteLineAsync($"data: {MessageEndpoint}").ConfigureAwait(false);
183-
await _streamWriter.WriteLineAsync().ConfigureAwait(false); // blank line to end an SSE message
184-
await _streamWriter.FlushAsync(cancellationToken).ConfigureAwait(false);
185-
186-
// Keep the connection open by "pinging" or just waiting
187-
// until the client disconnects or the server is canceled.
188-
while (!cancellationToken.IsCancellationRequested && response.OutputStream.CanWrite)
189-
{
190-
// Do a periodic no-op to keep connection alive:
191-
await _streamWriter.WriteLineAsync(": keep-alive").ConfigureAwait(false);
192-
await _streamWriter.FlushAsync(cancellationToken).ConfigureAwait(false);
193-
await Task.Delay(TimeSpan.FromSeconds(5), cancellationToken).ConfigureAwait(false);
194-
}
146+
await OnSseConnectionAsync(response.OutputStream, cancellationToken).ConfigureAwait(false);
195147
}
196148
catch (TaskCanceledException)
197149
{
@@ -206,7 +158,6 @@ private async Task HandleSseConnectionAsync(HttpListenerContext context, Cancell
206158
// Remove client on disconnect
207159
try
208160
{
209-
_streamWriter.Close();
210161
response.Close();
211162
}
212163
catch { /* Ignore errors during cleanup */ }
@@ -218,15 +169,8 @@ private async Task HandleMessageAsync(HttpListenerContext context, CancellationT
218169
var request = context.Request;
219170
var response = context.Response;
220171

221-
// Read the request body
222-
string requestBody;
223-
using (var reader = new StreamReader(request.InputStream, request.ContentEncoding))
224-
{
225-
requestBody = await reader.ReadToEndAsync(cancellationToken).ConfigureAwait(false);
226-
}
227-
228172
// Process the message asynchronously
229-
if (_messageHandler != null && _messageHandler(requestBody, cancellationToken))
173+
if (await OnMessageAsync(request.InputStream, cancellationToken))
230174
{
231175
// Return 202 Accepted
232176
response.StatusCode = 202;

src/ModelContextProtocol/Protocol/Transport/HttpListenerSseServerTransport.cs

Lines changed: 54 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using System.Text.Json;
1+
using System.Net;
2+
using System.Text.Json;
23
using Microsoft.Extensions.Logging;
34
using ModelContextProtocol.Protocol.Messages;
45
using ModelContextProtocol.Utils.Json;
@@ -9,16 +10,15 @@
910
namespace ModelContextProtocol.Protocol.Transport;
1011

1112
/// <summary>
12-
/// Implements the MCP transport protocol over standard input/output streams.
13+
/// Implements the MCP transport protocol using <see cref="HttpListener"/>.
1314
/// </summary>
1415
public sealed class HttpListenerSseServerTransport : TransportBase, IServerTransport
1516
{
1617
private readonly string _serverName;
1718
private readonly HttpListenerServerProvider _httpServerProvider;
1819
private readonly ILogger<HttpListenerSseServerTransport> _logger;
19-
private readonly JsonSerializerOptions _jsonOptions;
20-
private CancellationTokenSource? _shutdownCts;
21-
20+
private SseResponseStreamTransport? _sseResponseStreamTransport;
21+
2222
private string EndpointName => $"Server (SSE) ({_serverName})";
2323

2424
/// <summary>
@@ -43,28 +43,23 @@ public HttpListenerSseServerTransport(string serverName, int port, ILoggerFactor
4343
{
4444
_serverName = serverName;
4545
_logger = loggerFactory.CreateLogger<HttpListenerSseServerTransport>();
46-
_jsonOptions = McpJsonUtilities.DefaultOptions;
47-
_httpServerProvider = new HttpListenerServerProvider(port);
46+
_httpServerProvider = new HttpListenerServerProvider(port)
47+
{
48+
OnSseConnectionAsync = OnSseConnectionAsync,
49+
OnMessageAsync = OnMessageAsync,
50+
};
4851
}
4952

5053
/// <inheritdoc/>
5154
public Task StartListeningAsync(CancellationToken cancellationToken = default)
5255
{
53-
_shutdownCts = new CancellationTokenSource();
54-
55-
_httpServerProvider.InitializeMessageHandler(HttpMessageHandler);
56-
_httpServerProvider.StartAsync(cancellationToken);
57-
58-
SetConnected(true);
59-
60-
return Task.CompletedTask;
56+
return _httpServerProvider.StartAsync(cancellationToken);
6157
}
6258

63-
6459
/// <inheritdoc/>
6560
public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
6661
{
67-
if (!IsConnected)
62+
if (!IsConnected || _sseResponseStreamTransport is null)
6863
{
6964
_logger.TransportNotConnected(EndpointName);
7065
throw new McpTransportException("Transport is not connected");
@@ -78,10 +73,13 @@ public override async Task SendMessageAsync(IJsonRpcMessage message, Cancellatio
7873

7974
try
8075
{
81-
var json = JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo<IJsonRpcMessage>());
82-
_logger.TransportSendingMessage(EndpointName, id, json);
76+
if (_logger.IsEnabled(LogLevel.Debug))
77+
{
78+
var json = JsonSerializer.Serialize(message, McpJsonUtilities.DefaultOptions.GetTypeInfo<IJsonRpcMessage>());
79+
_logger.TransportSendingMessage(EndpointName, id, json);
80+
}
8381

84-
await _httpServerProvider.SendEvent(json, "message").ConfigureAwait(false);
82+
await _sseResponseStreamTransport.SendMessageAsync(message, cancellationToken).ConfigureAwait(false);
8583

8684
_logger.TransportSentMessage(EndpointName, id);
8785
}
@@ -99,49 +97,61 @@ public override async ValueTask DisposeAsync()
9997
GC.SuppressFinalize(this);
10098
}
10199

102-
private async Task CleanupAsync(CancellationToken cancellationToken)
100+
private Task CleanupAsync(CancellationToken cancellationToken)
103101
{
104102
_logger.TransportCleaningUp(EndpointName);
105103

106-
if (_shutdownCts != null)
107-
{
108-
await _shutdownCts.CancelAsync().ConfigureAwait(false);
109-
_shutdownCts.Dispose();
110-
_shutdownCts = null;
111-
}
112-
113104
_httpServerProvider.Dispose();
114-
115105
SetConnected(false);
106+
116107
_logger.TransportCleanedUp(EndpointName);
108+
return Task.CompletedTask;
109+
}
110+
111+
private async Task OnSseConnectionAsync(Stream responseStream, CancellationToken cancellationToken)
112+
{
113+
await using var sseResponseStreamTransport = new SseResponseStreamTransport(responseStream);
114+
_sseResponseStreamTransport = sseResponseStreamTransport;
115+
SetConnected(true);
116+
await sseResponseStreamTransport.RunAsync(cancellationToken);
117117
}
118118

119119
/// <summary>
120120
/// Handles HTTP messages received by the HTTP server provider.
121121
/// </summary>
122122
/// <returns>true if the message was accepted (return 202), false otherwise (return 400)</returns>
123-
private bool HttpMessageHandler(string request, CancellationToken cancellationToken)
123+
private async Task<bool> OnMessageAsync(Stream requestStream, CancellationToken cancellationToken)
124124
{
125-
_logger.TransportReceivedMessage(EndpointName, request);
125+
string request;
126+
IJsonRpcMessage? message = null;
127+
128+
if (_logger.IsEnabled(LogLevel.Information))
129+
{
130+
using var reader = new StreamReader(requestStream);
131+
request = await reader.ReadToEndAsync(cancellationToken).ConfigureAwait(false);
132+
message = JsonSerializer.Deserialize(request, McpJsonUtilities.DefaultOptions.GetTypeInfo<IJsonRpcMessage>());
133+
134+
_logger.TransportReceivedMessage(EndpointName, request);
135+
}
136+
else
137+
{
138+
request = "(Enable information-level logs to see the request)";
139+
}
126140

127141
try
128142
{
129-
var message = JsonSerializer.Deserialize(request, _jsonOptions.GetTypeInfo<IJsonRpcMessage>());
143+
message ??= await JsonSerializer.DeserializeAsync(requestStream, McpJsonUtilities.DefaultOptions.GetTypeInfo<IJsonRpcMessage>());
130144
if (message != null)
131145
{
132-
// Fire-and-forget the message to the message channel
133-
Task.Run(async () =>
146+
string messageId = "(no id)";
147+
if (message is IJsonRpcMessageWithId messageWithId)
134148
{
135-
string messageId = "(no id)";
136-
if (message is IJsonRpcMessageWithId messageWithId)
137-
{
138-
messageId = messageWithId.Id.ToString();
139-
}
140-
141-
_logger.TransportReceivedMessageParsed(EndpointName, messageId);
142-
await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false);
143-
_logger.TransportMessageWritten(EndpointName, messageId);
144-
}, cancellationToken);
149+
messageId = messageWithId.Id.ToString();
150+
}
151+
152+
_logger.TransportReceivedMessageParsed(EndpointName, messageId);
153+
await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false);
154+
_logger.TransportMessageWritten(EndpointName, messageId);
145155

146156
return true;
147157
}

0 commit comments

Comments
 (0)