Skip to content

Move the SseResponseStreamTransport out of sample #47

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
<PackageVersion Include="Microsoft.Extensions.AI.Abstractions" Version="$(MicrosoftExtensionsAIVersion)" />
<PackageVersion Include="Microsoft.Extensions.Hosting.Abstractions" Version="$(MicrosoftExtensionsVersion)" />
<PackageVersion Include="Microsoft.Extensions.Logging.Abstractions" Version="$(MicrosoftExtensionsVersion)" />
<PackageVersion Include="System.Net.ServerSentEvents" Version="$(SystemVersion)" />
<PackageVersion Include="System.Net.ServerSentEvents" Version="$(System10Version)" />
<PackageVersion Include="System.Text.Json" Version="$(SystemVersion)" />
<PackageVersion Include="System.Threading.Channels" Version="$(SystemVersion)" />

Expand Down
4 changes: 0 additions & 4 deletions samples/AspNetCoreSseServer/AspNetCoreSseServer.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,4 @@
<ProjectReference Include="..\..\src\ModelContextProtocol\ModelContextProtocol.csproj" />
</ItemGroup>

<ItemGroup>
<PackageReference Include="System.Net.ServerSentEvents" VersionOverride="10.0.0-preview.1.25080.5" />
</ItemGroup>

</Project>
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using ModelContextProtocol.Server;
using ModelContextProtocol.Utils.Json;
using Microsoft.Extensions.Options;
using ModelContextProtocol.Protocol.Transport;

namespace AspNetCoreSseServer;

Expand All @@ -10,15 +11,15 @@ public static class McpEndpointRouteBuilderExtensions
public static IEndpointConventionBuilder MapMcpSse(this IEndpointRouteBuilder endpoints)
{
IMcpServer? server = null;
SseServerStreamTransport? transport = null;
SseResponseStreamTransport? transport = null;
var loggerFactory = endpoints.ServiceProvider.GetRequiredService<ILoggerFactory>();
var mcpServerOptions = endpoints.ServiceProvider.GetRequiredService<IOptions<McpServerOptions>>();

var routeGroup = endpoints.MapGroup("");

routeGroup.MapGet("/sse", async (HttpResponse response, CancellationToken requestAborted) =>
{
await using var localTransport = transport = new SseServerStreamTransport(response.Body);
await using var localTransport = transport = new SseResponseStreamTransport(response.Body);
await using var localServer = server = McpServerFactory.Create(transport, mcpServerOptions.Value, loggerFactory, endpoints.ServiceProvider);

await localServer.StartAsync(requestAborted);
Expand All @@ -37,7 +38,7 @@ public static IEndpointConventionBuilder MapMcpSse(this IEndpointRouteBuilder en
}
});

routeGroup.MapPost("/message", async (HttpContext context) =>
routeGroup.MapPost("/message", async context =>
{
if (transport is null)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
using ModelContextProtocol.Server;
using System.Diagnostics.CodeAnalysis;
using System.Net;
using System.Text;
using System.Net;
using ModelContextProtocol.Server;

namespace ModelContextProtocol.Protocol.Transport;

/// <summary>
/// HTTP server provider using HttpListener.
/// </summary>
[ExcludeFromCodeCoverage]
internal class HttpListenerServerProvider : IDisposable
{
private static readonly byte[] s_accepted = "Accepted"u8.ToArray();
Expand All @@ -19,8 +16,6 @@ internal class HttpListenerServerProvider : IDisposable
private readonly int _port;
private HttpListener? _listener;
private CancellationTokenSource? _cts;
private Func<string, CancellationToken, bool>? _messageHandler;
private StreamWriter? _streamWriter;
private bool _isRunning;

/// <summary>
Expand All @@ -32,39 +27,16 @@ public HttpListenerServerProvider(int port)
_port = port;
}

public Task<string> GetSseEndpointUri()
{
return Task.FromResult($"http://localhost:{_port}{SseEndpoint}");
}

public Task InitializeMessageHandler(Func<string, CancellationToken, bool> messageHandler)
{
_messageHandler = messageHandler;
return Task.CompletedTask;
}

public async Task SendEvent(string data, string eventId)
{
if (_streamWriter == null)
{
throw new McpServerException("Stream writer not initialized");
}
if (eventId != null)
{
await _streamWriter.WriteLineAsync($"id: {eventId}").ConfigureAwait(false);
}
await _streamWriter.WriteLineAsync($"data: {data}").ConfigureAwait(false);
await _streamWriter.WriteLineAsync().ConfigureAwait(false); // Empty line to finish the event
await _streamWriter.FlushAsync().ConfigureAwait(false);
}
public required Func<Stream, CancellationToken, Task> OnSseConnectionAsync { get; set; }
public required Func<Stream, CancellationToken, Task<bool>> OnMessageAsync { get; set; }

/// <inheritdoc/>
public Task StartAsync(CancellationToken cancellationToken = default)
{
if (_isRunning)
return Task.CompletedTask;

_cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
_cts = new CancellationTokenSource();
_listener = new HttpListener();
_listener.Prefixes.Add($"http://localhost:{_port}/");
_listener.Start();
Expand All @@ -84,8 +56,6 @@ public Task StopAsync(CancellationToken cancellationToken = default)
_cts?.Cancel();
_listener?.Stop();

_streamWriter?.Close();

_isRunning = false;
return Task.CompletedTask;
}
Expand Down Expand Up @@ -170,28 +140,10 @@ private async Task HandleSseConnectionAsync(HttpListenerContext context, Cancell
response.Headers.Add("Cache-Control", "no-cache");
response.Headers.Add("Connection", "keep-alive");

// Get the output stream and create a StreamWriter
var outputStream = response.OutputStream;
_streamWriter = new StreamWriter(outputStream, Encoding.UTF8) { AutoFlush = true };

// Keep the connection open until cancelled
try
{
// Immediately send the "endpoint" event with the POST URL
await _streamWriter.WriteLineAsync("event: endpoint").ConfigureAwait(false);
await _streamWriter.WriteLineAsync($"data: {MessageEndpoint}").ConfigureAwait(false);
await _streamWriter.WriteLineAsync().ConfigureAwait(false); // blank line to end an SSE message
await _streamWriter.FlushAsync(cancellationToken).ConfigureAwait(false);

// Keep the connection open by "pinging" or just waiting
// until the client disconnects or the server is canceled.
while (!cancellationToken.IsCancellationRequested && response.OutputStream.CanWrite)
{
// Do a periodic no-op to keep connection alive:
await _streamWriter.WriteLineAsync(": keep-alive").ConfigureAwait(false);
await _streamWriter.FlushAsync(cancellationToken).ConfigureAwait(false);
await Task.Delay(TimeSpan.FromSeconds(5), cancellationToken).ConfigureAwait(false);
}
await OnSseConnectionAsync(response.OutputStream, cancellationToken).ConfigureAwait(false);
}
catch (TaskCanceledException)
{
Expand All @@ -206,7 +158,6 @@ private async Task HandleSseConnectionAsync(HttpListenerContext context, Cancell
// Remove client on disconnect
try
{
_streamWriter.Close();
response.Close();
}
catch { /* Ignore errors during cleanup */ }
Expand All @@ -218,15 +169,8 @@ private async Task HandleMessageAsync(HttpListenerContext context, CancellationT
var request = context.Request;
var response = context.Response;

// Read the request body
string requestBody;
using (var reader = new StreamReader(request.InputStream, request.ContentEncoding))
{
requestBody = await reader.ReadToEndAsync(cancellationToken).ConfigureAwait(false);
}

// Process the message asynchronously
if (_messageHandler != null && _messageHandler(requestBody, cancellationToken))
if (await OnMessageAsync(request.InputStream, cancellationToken))
{
// Return 202 Accepted
response.StatusCode = 202;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Text.Json;
using System.Net;
using System.Text.Json;
using Microsoft.Extensions.Logging;
using ModelContextProtocol.Protocol.Messages;
using ModelContextProtocol.Utils.Json;
Expand All @@ -9,16 +10,15 @@
namespace ModelContextProtocol.Protocol.Transport;

/// <summary>
/// Implements the MCP transport protocol over standard input/output streams.
/// Implements the MCP transport protocol using <see cref="HttpListener"/>.
/// </summary>
public sealed class HttpListenerSseServerTransport : TransportBase, IServerTransport
{
private readonly string _serverName;
private readonly HttpListenerServerProvider _httpServerProvider;
private readonly ILogger<HttpListenerSseServerTransport> _logger;
private readonly JsonSerializerOptions _jsonOptions;
private CancellationTokenSource? _shutdownCts;

private SseResponseStreamTransport? _sseResponseStreamTransport;

private string EndpointName => $"Server (SSE) ({_serverName})";

/// <summary>
Expand All @@ -43,28 +43,23 @@ public HttpListenerSseServerTransport(string serverName, int port, ILoggerFactor
{
_serverName = serverName;
_logger = loggerFactory.CreateLogger<HttpListenerSseServerTransport>();
_jsonOptions = McpJsonUtilities.DefaultOptions;
_httpServerProvider = new HttpListenerServerProvider(port);
_httpServerProvider = new HttpListenerServerProvider(port)
{
OnSseConnectionAsync = OnSseConnectionAsync,
OnMessageAsync = OnMessageAsync,
};
}

/// <inheritdoc/>
public Task StartListeningAsync(CancellationToken cancellationToken = default)
{
_shutdownCts = new CancellationTokenSource();

_httpServerProvider.InitializeMessageHandler(HttpMessageHandler);
_httpServerProvider.StartAsync(cancellationToken);

SetConnected(true);

return Task.CompletedTask;
return _httpServerProvider.StartAsync(cancellationToken);
}


/// <inheritdoc/>
public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
{
if (!IsConnected)
if (!IsConnected || _sseResponseStreamTransport is null)
{
_logger.TransportNotConnected(EndpointName);
throw new McpTransportException("Transport is not connected");
Expand All @@ -78,10 +73,13 @@ public override async Task SendMessageAsync(IJsonRpcMessage message, Cancellatio

try
{
var json = JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo<IJsonRpcMessage>());
_logger.TransportSendingMessage(EndpointName, id, json);
if (_logger.IsEnabled(LogLevel.Debug))
{
var json = JsonSerializer.Serialize(message, McpJsonUtilities.DefaultOptions.GetTypeInfo<IJsonRpcMessage>());
_logger.TransportSendingMessage(EndpointName, id, json);
}

await _httpServerProvider.SendEvent(json, "message").ConfigureAwait(false);
await _sseResponseStreamTransport.SendMessageAsync(message, cancellationToken).ConfigureAwait(false);

_logger.TransportSentMessage(EndpointName, id);
}
Expand All @@ -99,49 +97,61 @@ public override async ValueTask DisposeAsync()
GC.SuppressFinalize(this);
}

private async Task CleanupAsync(CancellationToken cancellationToken)
private Task CleanupAsync(CancellationToken cancellationToken)
{
_logger.TransportCleaningUp(EndpointName);

if (_shutdownCts != null)
{
await _shutdownCts.CancelAsync().ConfigureAwait(false);
_shutdownCts.Dispose();
_shutdownCts = null;
}

_httpServerProvider.Dispose();

SetConnected(false);

_logger.TransportCleanedUp(EndpointName);
return Task.CompletedTask;
}

private async Task OnSseConnectionAsync(Stream responseStream, CancellationToken cancellationToken)
{
await using var sseResponseStreamTransport = new SseResponseStreamTransport(responseStream);
_sseResponseStreamTransport = sseResponseStreamTransport;
SetConnected(true);
await sseResponseStreamTransport.RunAsync(cancellationToken);
}

/// <summary>
/// Handles HTTP messages received by the HTTP server provider.
/// </summary>
/// <returns>true if the message was accepted (return 202), false otherwise (return 400)</returns>
private bool HttpMessageHandler(string request, CancellationToken cancellationToken)
private async Task<bool> OnMessageAsync(Stream requestStream, CancellationToken cancellationToken)
{
_logger.TransportReceivedMessage(EndpointName, request);
string request;
IJsonRpcMessage? message = null;

if (_logger.IsEnabled(LogLevel.Information))
{
using var reader = new StreamReader(requestStream);
request = await reader.ReadToEndAsync(cancellationToken).ConfigureAwait(false);
message = JsonSerializer.Deserialize(request, McpJsonUtilities.DefaultOptions.GetTypeInfo<IJsonRpcMessage>());

_logger.TransportReceivedMessage(EndpointName, request);
}
else
{
request = "(Enable information-level logs to see the request)";
}

try
{
var message = JsonSerializer.Deserialize(request, _jsonOptions.GetTypeInfo<IJsonRpcMessage>());
message ??= await JsonSerializer.DeserializeAsync(requestStream, McpJsonUtilities.DefaultOptions.GetTypeInfo<IJsonRpcMessage>());
if (message != null)
{
// Fire-and-forget the message to the message channel
Task.Run(async () =>
string messageId = "(no id)";
if (message is IJsonRpcMessageWithId messageWithId)
{
string messageId = "(no id)";
if (message is IJsonRpcMessageWithId messageWithId)
{
messageId = messageWithId.Id.ToString();
}

_logger.TransportReceivedMessageParsed(EndpointName, messageId);
await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false);
_logger.TransportMessageWritten(EndpointName, messageId);
}, cancellationToken);
messageId = messageWithId.Id.ToString();
}

_logger.TransportReceivedMessageParsed(EndpointName, messageId);
await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false);
_logger.TransportMessageWritten(EndpointName, messageId);

return true;
}
Expand Down
Loading