Skip to content

Address some static analysis and TODOs #12

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using ModelContextProtocol.Utils;

namespace System.Collections.Generic;

internal static class CollectionExtensions
Expand All @@ -9,10 +11,7 @@ internal static class CollectionExtensions

public static TValue GetValueOrDefault<TKey, TValue>(this IReadOnlyDictionary<TKey, TValue> dictionary, TKey key, TValue defaultValue)
{
if (dictionary is null)
{
throw new ArgumentNullException(nameof(dictionary));
}
Throw.IfNull(dictionary);

return dictionary.TryGetValue(key, out TValue? value) ? value : defaultValue;
}
Expand Down
36 changes: 36 additions & 0 deletions src/Common/Polyfills/System/IO/StreamExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
using ModelContextProtocol.Utils;
using System.Buffers;
using System.Runtime.InteropServices;

namespace System.IO;

internal static class StreamExtensions
{
public static ValueTask WriteAsync(this Stream stream, ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken)
{
Throw.IfNull(stream);

if (MemoryMarshal.TryGetArray(buffer, out ArraySegment<byte> segment))
{
return new ValueTask(stream.WriteAsync(segment.Array, segment.Offset, segment.Count, cancellationToken));
}
else
{
return WriteAsyncCore(stream, buffer, cancellationToken);

static async ValueTask WriteAsyncCore(Stream stream, ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken)
{
byte[] array = ArrayPool<byte>.Shared.Rent(buffer.Length);
try
{
buffer.Span.CopyTo(array);
await stream.WriteAsync(array, 0, buffer.Length, cancellationToken).ConfigureAwait(false);
}
finally
{
ArrayPool<byte>.Shared.Return(array);
}
}
}
}
}
16 changes: 15 additions & 1 deletion src/Common/Polyfills/System/IO/TextReaderExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,21 @@ internal static class TextReaderExtensions
{
public static Task<string> ReadLineAsync(this TextReader reader, CancellationToken cancellationToken)
{
_ = cancellationToken;
if (cancellationToken.IsCancellationRequested)
{
return Task.FromCanceled<string>(cancellationToken);
}

return reader.ReadLineAsync();
}

public static Task<string> ReadToEndAsync(this TextReader reader, CancellationToken cancellationToken)
{
if (cancellationToken.IsCancellationRequested)
{
return Task.FromCanceled<string>(cancellationToken);
}

return reader.ReadToEndAsync();
}
}
6 changes: 2 additions & 4 deletions src/Common/Polyfills/System/IO/TextWriterExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using ModelContextProtocol.Utils;
using System.Runtime.InteropServices;

namespace System.IO;
Expand All @@ -6,10 +7,7 @@ internal static class TextWriterExtensions
{
public static async Task WriteLineAsync(this TextWriter writer, ReadOnlyMemory<char> value, CancellationToken cancellationToken)
{
if (writer is null)
{
throw new ArgumentNullException(nameof(writer));
}
Throw.IfNull(writer);

if (value.IsEmpty)
{
Expand Down
12 changes: 4 additions & 8 deletions src/Common/Polyfills/System/Net/Http/HttpClientExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,24 +1,20 @@
using ModelContextProtocol.Utils;

namespace System.Net.Http;

internal static class HttpClientExtensions
{
public static async Task<Stream> ReadAsStreamAsync(this HttpContent content, CancellationToken cancellationToken)
{
if (content is null)
{
throw new ArgumentNullException(nameof(content));
}
Throw.IfNull(content);

cancellationToken.ThrowIfCancellationRequested();
return await content.ReadAsStreamAsync();
}

public static async Task<string> ReadAsStringAsync(this HttpContent content, CancellationToken cancellationToken)
{
if (content is null)
{
throw new ArgumentNullException(nameof(content));
}
Throw.IfNull(content);

cancellationToken.ThrowIfCancellationRequested();
return await content.ReadAsStringAsync();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
using ModelContextProtocol.Utils;

namespace System.Threading.Tasks;

internal static class CancellationTokenSourceExtensions
{
public static Task CancelAsync(this CancellationTokenSource cancellationTokenSource)
{
if (cancellationTokenSource is null)
{
throw new ArgumentNullException(nameof(cancellationTokenSource));
}
Throw.IfNull(cancellationTokenSource);

cancellationTokenSource.Cancel();
return Task.CompletedTask;
Expand Down
17 changes: 17 additions & 0 deletions src/Common/Polyfills/System/Threading/ForceYielding.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using System.Runtime.CompilerServices;

namespace System.Threading;

/// <summary>
/// await default(ForceYielding) to provide the same behavior as
/// await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding).
/// </summary>
internal readonly struct ForceYielding : INotifyCompletion, ICriticalNotifyCompletion
{
public ForceYielding GetAwaiter() => this;

public bool IsCompleted => false;
public void OnCompleted(Action continuation) => ThreadPool.QueueUserWorkItem(a => ((Action)a!)(), continuation);
public void UnsafeOnCompleted(Action continuation) => ThreadPool.UnsafeQueueUserWorkItem(a => ((Action)a!)(), continuation);
public void GetResult() { }
}
7 changes: 3 additions & 4 deletions src/Common/Polyfills/System/Threading/Tasks/TaskExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using ModelContextProtocol.Utils;

namespace System.Threading.Tasks;

internal static class TaskExtensions
Expand All @@ -15,10 +17,7 @@ public static async Task<T> WaitAsync<T>(this Task<T> task, CancellationToken ca

public static async Task WaitAsync(this Task task, TimeSpan timeout, CancellationToken cancellationToken)
{
if (task is null)
{
throw new ArgumentNullException(nameof(task));
}
Throw.IfNull(task);

if (timeout < TimeSpan.Zero && timeout != Timeout.InfiniteTimeSpan)
{
Expand Down
4 changes: 1 addition & 3 deletions src/ModelContextProtocol/Client/McpClientExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
using System.Runtime.CompilerServices;
using System.Text.Json;

#pragma warning disable CA1508 // Avoid dead conditional code

namespace ModelContextProtocol.Client;

/// <summary>
Expand Down Expand Up @@ -439,7 +437,7 @@ internal static CreateMessageResult ToCreateMessageResult(this ChatResponse chat
}

private static JsonRpcRequest CreateRequest(string method, Dictionary<string, object?>? parameters) =>
new JsonRpcRequest
new()
{
Method = method,
Params = parameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,7 @@ public static IMcpServerBuilder WithStdioServerTransport(this IMcpServerBuilder
/// <param name="builder">The builder instance.</param>
public static IMcpServerBuilder WithHttpListenerSseServerTransport(this IMcpServerBuilder builder)
{
if (builder is null)
{
throw new ArgumentNullException(nameof(builder));
}
Throw.IfNull(builder);

builder.Services.AddSingleton<IServerTransport, HttpListenerSseServerTransport>();
builder.Services.AddHostedService<McpServerHostedService>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,15 @@
using ModelContextProtocol.Protocol.Types;
using ModelContextProtocol.Server;
using Microsoft.Extensions.Options;
using ModelContextProtocol.Utils;

namespace ModelContextProtocol.Configuration;

internal sealed class McpServerOptionsSetup(IOptions<McpServerHandlers> serverHandlers) : IConfigureOptions<McpServerOptions>
{
public void Configure(McpServerOptions options)
{
if (options is null)
{
throw new ArgumentNullException(nameof(options));
}
Throw.IfNull(options);

var assemblyName = Assembly.GetEntryAssembly()?.GetName();
options.ServerInfo = new Implementation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@ namespace ModelContextProtocol.Protocol.Transport;
[ExcludeFromCodeCoverage]
internal class HttpListenerServerProvider : IDisposable
{
private static readonly byte[] s_accepted = "Accepted"u8.ToArray();

private const string SseEndpoint = "/sse";
private const string MessageEndpoint = "/message";

private readonly int _port;
private readonly string _sseEndpoint = "/sse";
private readonly string _messageEndpoint = "/message";
private HttpListener? _listener;
private CancellationTokenSource? _cts;
private Func<string, CancellationToken, bool>? _messageHandler;
Expand All @@ -31,7 +34,7 @@ public HttpListenerServerProvider(int port)

public Task<string> GetSseEndpointUri()
{
return Task.FromResult($"http://localhost:{_port}{_sseEndpoint}");
return Task.FromResult($"http://localhost:{_port}{SseEndpoint}");
}

public Task InitializeMessageHandler(Func<string, CancellationToken, bool> messageHandler)
Expand Down Expand Up @@ -131,12 +134,12 @@ private async Task ProcessRequestAsync(HttpListenerContext context, Cancellation
throw new McpServerException("Request is null");

// Handle SSE connection
if (request.HttpMethod == "GET" && request.Url?.LocalPath == _sseEndpoint)
if (request.HttpMethod == "GET" && request.Url?.LocalPath == SseEndpoint)
{
await HandleSseConnectionAsync(context, cancellationToken).ConfigureAwait(false);
}
// Handle message POST
else if (request.HttpMethod == "POST" && request.Url?.LocalPath == _messageEndpoint)
else if (request.HttpMethod == "POST" && request.Url?.LocalPath == MessageEndpoint)
{
await HandleMessageAsync(context, cancellationToken).ConfigureAwait(false);
}
Expand Down Expand Up @@ -167,9 +170,6 @@ private async Task HandleSseConnectionAsync(HttpListenerContext context, Cancell
response.Headers.Add("Cache-Control", "no-cache");
response.Headers.Add("Connection", "keep-alive");

// Create a unique ID for this client
var clientId = Guid.NewGuid().ToString();

// Get the output stream and create a StreamWriter
var outputStream = response.OutputStream;
_streamWriter = new StreamWriter(outputStream, Encoding.UTF8) { AutoFlush = true };
Expand All @@ -179,7 +179,7 @@ private async Task HandleSseConnectionAsync(HttpListenerContext context, Cancell
{
// Immediately send the "endpoint" event with the POST URL
await _streamWriter.WriteLineAsync("event: endpoint").ConfigureAwait(false);
await _streamWriter.WriteLineAsync($"data: {_messageEndpoint}").ConfigureAwait(false);
await _streamWriter.WriteLineAsync($"data: {MessageEndpoint}").ConfigureAwait(false);
await _streamWriter.WriteLineAsync().ConfigureAwait(false); // blank line to end an SSE message
await _streamWriter.FlushAsync(cancellationToken).ConfigureAwait(false);

Expand Down Expand Up @@ -222,23 +222,17 @@ private async Task HandleMessageAsync(HttpListenerContext context, CancellationT
string requestBody;
using (var reader = new StreamReader(request.InputStream, request.ContentEncoding))
{
// TODO: Add cancellation token and netstandard2.0 support (polyfill?)
#pragma warning disable CA2016 // Forward the 'CancellationToken' parameter to methods
requestBody = await reader.ReadToEndAsync().ConfigureAwait(false);
#pragma warning restore CA2016 // Forward the 'CancellationToken' parameter to methods
requestBody = await reader.ReadToEndAsync(cancellationToken).ConfigureAwait(false);
}

// Process the message asynchronously
if (_messageHandler != null && _messageHandler(requestBody, cancellationToken))
{
// Return 202 Accepted
response.StatusCode = 202;

// Write "accepted" response
// TODO: Use WriteAsync, add cancellation token and netstandard2.0 support (polyfill?)
byte[] buffer = Encoding.UTF8.GetBytes("Accepted");
#pragma warning disable CA1835 // Prefer the 'Memory'-based overloads for 'ReadAsync' and 'WriteAsync'
await response.OutputStream.WriteAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false);
#pragma warning restore CA1835 // Prefer the 'Memory'-based overloads for 'ReadAsync' and 'WriteAsync'
await response.OutputStream.WriteAsync(s_accepted, cancellationToken).ConfigureAwait(false);
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using ModelContextProtocol.Utils.Json;
using ModelContextProtocol.Logging;
using ModelContextProtocol.Server;
using ModelContextProtocol.Utils;

namespace ModelContextProtocol.Protocol.Transport;

Expand Down Expand Up @@ -160,15 +161,8 @@ private bool HttpMessageHandler(string request, CancellationToken cancellationTo
/// <summary>Validates the <paramref name="serverOptions"/> and extracts from it the server name to use.</summary>
private static string GetServerName(McpServerOptions serverOptions)
{
if (serverOptions is null)
{
throw new ArgumentNullException(nameof(serverOptions));
}

if (serverOptions.ServerInfo is null)
{
throw new ArgumentNullException($"{nameof(serverOptions)}.{nameof(serverOptions.ServerInfo)}");
}
Throw.IfNull(serverOptions);
Throw.IfNull(serverOptions.ServerInfo);

return serverOptions.ServerInfo.Name;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;

#pragma warning disable CA2213 // Disposable fields should be disposed

namespace ModelContextProtocol.Protocol.Transport;

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Options;

#pragma warning disable CA2208 // Instantiate argument exceptions correctly
#pragma warning disable CA2213 // Disposable fields should be disposed

namespace ModelContextProtocol.Protocol.Transport;

/// <summary>
Expand Down
2 changes: 1 addition & 1 deletion src/ModelContextProtocol/Shared/McpJsonRpcEndpoint.cs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ async Task ProcessMessageAsync()
#if NET
await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding);
#else
await Task.Yield(); // TODO: Fix this
await default(ForceYielding);
#endif
try
{
Expand Down
2 changes: 1 addition & 1 deletion src/ModelContextProtocol/Utils/Throw.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public static void IfNullOrWhiteSpace([NotNull] string? arg, [CallerArgumentExpr
{
if (arg is null || arg.AsSpan().IsWhiteSpace())
{
ThrowArgumentNullOrWhiteSpaceException(arg);
ThrowArgumentNullOrWhiteSpaceException(parameterName);
}
}

Expand Down
4 changes: 2 additions & 2 deletions tests/ModelContextProtocol.TestServer/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ private static async Task Main(string[] args)
GetCompletionHandler = ConfigureCompletion(),
};

var loggerFactory = CreateLoggerFactory();
using var loggerFactory = CreateLoggerFactory();
await using IMcpServer server = McpServerFactory.Create(new StdioServerTransport("TestServer", loggerFactory), options, loggerFactory);

Log.Logger.Information("Server initialized.");
Expand Down Expand Up @@ -266,7 +266,7 @@ private static PromptsCapability ConfigurePrompts()
};
}

private static HashSet<string> _subscribedResources = new();
private static readonly HashSet<string> _subscribedResources = new();
private static readonly object _subscribedResourcesLock = new();

private static ResourcesCapability ConfigureResources()
Expand Down
Loading
Loading