Skip to content

Annotate SignalR server for native AOT #56460

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions eng/TrimmableProjects.props
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@
<TrimmableProject Include="Microsoft.AspNetCore.Http.Connections" />
<TrimmableProject Include="Microsoft.AspNetCore.SignalR.Protocols.Json" />
<TrimmableProject Include="Microsoft.AspNetCore.SignalR.Common" />
<TrimmableProject Include="Microsoft.AspNetCore.SignalR.Core" />
<TrimmableProject Include="Microsoft.AspNetCore.SignalR" />
<TrimmableProject Include="Microsoft.AspNetCore.StaticAssets" />
<TrimmableProject Include="Microsoft.AspNetCore.Components.Authorization" />
<TrimmableProject Include="Microsoft.AspNetCore.Components" />
Expand Down
194 changes: 191 additions & 3 deletions src/Shared/ObjectMethodExecutor/ObjectMethodExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
using System.Diagnostics.CodeAnalysis;
using System.Linq.Expressions;
using System.Reflection;
using System.Runtime.CompilerServices;

namespace Microsoft.Extensions.Internal;

[RequiresUnreferencedCode("ObjectMethodExecutor performs reflection on arbitrary types.")]
[RequiresDynamicCode("ObjectMethodExecutor performs reflection on arbitrary types.")]
internal sealed class ObjectMethodExecutor
{
private readonly object?[]? _parameterDefaultValues;
Expand All @@ -28,15 +27,21 @@ internal sealed class ObjectMethodExecutor
typeof(Action<object, Action>) // unsafeOnCompletedMethod
})!;

private ObjectMethodExecutor(MethodInfo methodInfo, TypeInfo targetTypeInfo, object?[]? parameterDefaultValues)
private ObjectMethodExecutor(MethodInfo methodInfo, TypeInfo targetTypeInfo)
{
ArgumentNullException.ThrowIfNull(methodInfo);

MethodInfo = methodInfo;
MethodParameters = methodInfo.GetParameters();
TargetTypeInfo = targetTypeInfo;
MethodReturnType = methodInfo.ReturnType;
}

[RequiresUnreferencedCode("ObjectMethodExecutor performs reflection on arbitrary types.")]
[RequiresDynamicCode("ObjectMethodExecutor performs reflection on arbitrary types.")]
private ObjectMethodExecutor(MethodInfo methodInfo, TypeInfo targetTypeInfo, object?[]? parameterDefaultValues)
: this(methodInfo, targetTypeInfo)
{
var isAwaitable = CoercedAwaitableInfo.IsTypeAwaitable(MethodReturnType, out var coercedAwaitableInfo);

IsMethodAsync = isAwaitable;
Expand All @@ -55,6 +60,27 @@ private ObjectMethodExecutor(MethodInfo methodInfo, TypeInfo targetTypeInfo, obj
_parameterDefaultValues = parameterDefaultValues;
}

private ObjectMethodExecutor(MethodInfo methodInfo, TypeInfo targetTypeInfo, bool isTrimAotCompatible)
: this(methodInfo, targetTypeInfo)
{
Debug.Assert(isTrimAotCompatible, "isTrimAotCompatible should always be true.");

var isAwaitable = IsTaskType(MethodReturnType, out var resultType);

IsMethodAsync = isAwaitable;
AsyncResultType = isAwaitable ? resultType : null;

// Upstream code may prefer to use the sync-executor even for async methods, because if it knows
// that the result is a specific Task<T> where T is known, then it can directly cast to that type
// and await it without the extra heap allocations involved in the _executorAsync code path.
_executor = methodInfo.Invoke;

if (IsMethodAsync)
{
_executorAsync = GetExecutorAsyncTrimAotCompatible(methodInfo, AsyncResultType!);
}
}

private delegate ObjectMethodExecutorAwaitable MethodExecutorAsync(object target, object?[]? parameters);

private delegate object? MethodExecutor(object target, object?[]? parameters);
Expand All @@ -74,18 +100,35 @@ private ObjectMethodExecutor(MethodInfo methodInfo, TypeInfo targetTypeInfo, obj

public bool IsMethodAsync { get; }

[RequiresUnreferencedCode("ObjectMethodExecutor performs reflection on arbitrary types.")]
[RequiresDynamicCode("ObjectMethodExecutor performs reflection on arbitrary types.")]
public static ObjectMethodExecutor Create(MethodInfo methodInfo, TypeInfo targetTypeInfo)
{
return new ObjectMethodExecutor(methodInfo, targetTypeInfo, null);
}

[RequiresUnreferencedCode("ObjectMethodExecutor performs reflection on arbitrary types.")]
[RequiresDynamicCode("ObjectMethodExecutor performs reflection on arbitrary types.")]
public static ObjectMethodExecutor Create(MethodInfo methodInfo, TypeInfo targetTypeInfo, object?[] parameterDefaultValues)
{
ArgumentNullException.ThrowIfNull(parameterDefaultValues);

return new ObjectMethodExecutor(methodInfo, targetTypeInfo, parameterDefaultValues);
}

/// <summary>
/// Creates an ObjectMethodExecutor that is compatible with trimming and Ahead-of-Time (AOT) compilation.
/// </summary>
/// <remarks>
/// The difference between this method and <see cref="Create(MethodInfo, TypeInfo)"/> is that
/// this method doesn't support custom awaitables and Task{unit} in F#. It only supports Task, Task{T}, ValueTask, and ValueTask{T}
/// as async methods.
/// </remarks>
public static ObjectMethodExecutor CreateTrimAotCompatible(MethodInfo methodInfo, TypeInfo targetTypeInfo)
{
return new ObjectMethodExecutor(methodInfo, targetTypeInfo, isTrimAotCompatible: true);
}

/// <summary>
/// Executes the configured method on <paramref name="target"/>. This can be used whether or not
/// the configured method is asynchronous.
Expand Down Expand Up @@ -123,6 +166,9 @@ public static ObjectMethodExecutor Create(MethodInfo methodInfo, TypeInfo target
/// of it, and if it is, it will have to be boxed so the calling code can reference it as an object).
/// 3. The async result value, if it's a value type (it has to be boxed as an object, since the calling
/// code doesn't know what type it's going to be).
///
/// Note if <see cref="CreateTrimAotCompatible"/> was used to create the ObjectMethodExecutor, only the
/// built-in Task types are supported and not custom awaitables.
/// </remarks>
/// <param name="target">The object whose method is to be executed.</param>
/// <param name="parameters">Parameters to pass to the method.</param>
Expand Down Expand Up @@ -336,4 +382,146 @@ private static MethodExecutorAsync GetExecutorAsync(
var lambda = Expression.Lambda<MethodExecutorAsync>(returnValueExpression, targetParameter, parametersParameter);
return lambda.Compile();
}

private static readonly MethodInfo _taskGetAwaiterMethodInfo = typeof(Task<>).GetMethod("GetAwaiter")!;
private static readonly MethodInfo _taskAwaiterGetIsCompletedMethodInfo = typeof(TaskAwaiter<>).GetMethod("get_IsCompleted")!;
private static readonly MethodInfo _taskAwaiterGetResultMethodInfo = typeof(TaskAwaiter<>).GetMethod("GetResult")!;
private static readonly MethodInfo _taskAwaiterOnCompletedMethodInfo = typeof(TaskAwaiter<>).GetMethod("OnCompleted")!;
private static readonly MethodInfo _taskAwaiterUnsafeOnCompletedMethodInfo = typeof(TaskAwaiter<>).GetMethod("UnsafeOnCompleted")!;

private static readonly MethodInfo _valueTaskGetAwaiterMethodInfo = typeof(ValueTask<>).GetMethod("GetAwaiter")!;
private static readonly MethodInfo _valueTaskAwaiterGetIsCompletedMethodInfo = typeof(ValueTaskAwaiter<>).GetMethod("get_IsCompleted")!;
private static readonly MethodInfo _valueTaskAwaiterGetResultMethodInfo = typeof(ValueTaskAwaiter<>).GetMethod("GetResult")!;
private static readonly MethodInfo _valueTaskAwaiterOnCompletedMethodInfo = typeof(ValueTaskAwaiter<>).GetMethod("OnCompleted")!;
private static readonly MethodInfo _valueTaskAwaiterUnsafeOnCompletedMethodInfo = typeof(ValueTaskAwaiter<>).GetMethod("UnsafeOnCompleted")!;

private static bool IsTaskType(Type methodReturnType, [NotNullWhen(true)] out Type? resultType)
{
if (methodReturnType == typeof(Task) || methodReturnType == typeof(ValueTask))
{
resultType = typeof(void);
return true;
}

if (methodReturnType.IsGenericType && methodReturnType.GetGenericTypeDefinition() == typeof(ValueTask<>))
{
resultType = methodReturnType.GetGenericArguments()[0];
return true;
}

var currentType = methodReturnType;
while (currentType is not null)
{
if (currentType == typeof(Task))
{
resultType = typeof(void);
return true;
}

if (currentType.IsGenericType && currentType.GetGenericTypeDefinition() == typeof(Task<>))
{
var taskGetAwaiterMethodInfo = (MethodInfo)methodReturnType.GetMemberWithSameMetadataDefinitionAs(_taskGetAwaiterMethodInfo);
var taskAwaiterGetResultMethodInfo = (MethodInfo)taskGetAwaiterMethodInfo.ReturnType.GetMemberWithSameMetadataDefinitionAs(_taskAwaiterGetResultMethodInfo);

resultType = taskAwaiterGetResultMethodInfo.ReturnType;
return true;
}

currentType = currentType.BaseType;
}

resultType = null;
return false;
}

private static MethodExecutorAsync? GetExecutorAsyncTrimAotCompatible(MethodInfo methodInfo, Type asyncResultType)
{
var methodReturnType = methodInfo.ReturnType;
if (asyncResultType == typeof(void))
{
if (methodReturnType == typeof(ValueTask))
{
return (target, args) =>
{
return new ObjectMethodExecutorAwaitable(
methodInfo.Invoke(target, args),
(awaitable) => ((ValueTask)awaitable).GetAwaiter(),
(awaiter) => ((ValueTaskAwaiter)awaiter).IsCompleted,
(awaiter) =>
{
((ValueTaskAwaiter)awaiter).GetResult();
return null;
},
(awaiter, continuation) =>
{
((ValueTaskAwaiter)awaiter).OnCompleted(continuation);
},
(awaiter, continuation) =>
{
((ValueTaskAwaiter)awaiter).UnsafeOnCompleted(continuation);
});
};
}

// The method must return Task, or a derived type that isn't Task<T>
return (target, args) =>
{
return new ObjectMethodExecutorAwaitable(
methodInfo.Invoke(target, args),
(awaitable) => ((Task)awaitable).GetAwaiter(),
(awaiter) => ((TaskAwaiter)awaiter).IsCompleted,
(awaiter) =>
{
((TaskAwaiter)awaiter).GetResult();
return null;
},
(awaiter, continuation) =>
{
((TaskAwaiter)awaiter).OnCompleted(continuation);
},
(awaiter, continuation) =>
{
((TaskAwaiter)awaiter).UnsafeOnCompleted(continuation);
});
};
}

if (methodReturnType.IsGenericType && methodReturnType.GetGenericTypeDefinition() == typeof(ValueTask<>))
{
return (target, args) =>
{
return new ObjectMethodExecutorAwaitable(
methodInfo.Invoke(target, args),
(awaitable) => ((MethodInfo)awaitable.GetType().GetMemberWithSameMetadataDefinitionAs(_valueTaskGetAwaiterMethodInfo)).Invoke(awaitable, Array.Empty<object>()),
(awaiter) => (bool)((MethodInfo)awaiter.GetType().GetMemberWithSameMetadataDefinitionAs(_valueTaskAwaiterGetIsCompletedMethodInfo)).Invoke(awaiter, Array.Empty<object>())!,
(awaiter) => ((MethodInfo)awaiter.GetType().GetMemberWithSameMetadataDefinitionAs(_valueTaskAwaiterGetResultMethodInfo)).Invoke(awaiter, Array.Empty<object>())!,
(awaiter, continuation) =>
{
((MethodInfo)awaiter.GetType().GetMemberWithSameMetadataDefinitionAs(_valueTaskAwaiterOnCompletedMethodInfo)).Invoke(awaiter, [continuation]);
},
(awaiter, continuation) =>
{
((MethodInfo)awaiter.GetType().GetMemberWithSameMetadataDefinitionAs(_valueTaskAwaiterUnsafeOnCompletedMethodInfo)).Invoke(awaiter, [continuation]);
});
};
}

// The method must return a Task<T> or a derived type
return (target, args) =>
{
return new ObjectMethodExecutorAwaitable(
methodInfo.Invoke(target, args),
(awaitable) => ((MethodInfo)awaitable.GetType().GetMemberWithSameMetadataDefinitionAs(_taskGetAwaiterMethodInfo)).Invoke(awaitable, Array.Empty<object>()),
(awaiter) => (bool)((MethodInfo)awaiter.GetType().GetMemberWithSameMetadataDefinitionAs(_taskAwaiterGetIsCompletedMethodInfo)).Invoke(awaiter, Array.Empty<object>())!,
(awaiter) => ((MethodInfo)awaiter.GetType().GetMemberWithSameMetadataDefinitionAs(_taskAwaiterGetResultMethodInfo)).Invoke(awaiter, Array.Empty<object>())!,
(awaiter, continuation) =>
{
((MethodInfo)awaiter.GetType().GetMemberWithSameMetadataDefinitionAs(_taskAwaiterOnCompletedMethodInfo)).Invoke(awaiter, [continuation]);
},
(awaiter, continuation) =>
{
((MethodInfo)awaiter.GetType().GetMemberWithSameMetadataDefinitionAs(_taskAwaiterUnsafeOnCompletedMethodInfo)).Invoke(awaiter, [continuation]);
});
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Diagnostics.CodeAnalysis;
using Microsoft.AspNetCore.SignalR;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.Extensions.DependencyInjection.Extensions;
Expand All @@ -21,6 +22,7 @@ public static class MessagePackProtocolDependencyInjectionExtensions
/// </remarks>
/// <param name="builder">The <see cref="ISignalRBuilder"/> representing the SignalR server to add MessagePack protocol support to.</param>
/// <returns>The value of <paramref name="builder"/></returns>
[RequiresUnreferencedCode("MessagePack does not currently support trimming or native AOT.", Url = "https://aka.ms/aspnet/trimming")]
public static TBuilder AddMessagePackProtocol<TBuilder>(this TBuilder builder) where TBuilder : ISignalRBuilder
=> AddMessagePackProtocol(builder, _ => { });

Expand All @@ -33,6 +35,7 @@ public static TBuilder AddMessagePackProtocol<TBuilder>(this TBuilder builder) w
/// <param name="builder">The <see cref="ISignalRBuilder"/> representing the SignalR server to add MessagePack protocol support to.</param>
/// <param name="configure">A delegate that can be used to configure the <see cref="MessagePackHubProtocolOptions"/></param>
/// <returns>The value of <paramref name="builder"/></returns>
[RequiresUnreferencedCode("MessagePack does not currently support trimming or native AOT.", Url = "https://aka.ms/aspnet/trimming")]
public static TBuilder AddMessagePackProtocol<TBuilder>(this TBuilder builder, Action<MessagePackHubProtocolOptions> configure) where TBuilder : ISignalRBuilder
{
builder.Services.TryAddEnumerable(ServiceDescriptor.Singleton<IHubProtocol, MessagePackHubProtocol>());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
<Compile Include="$(SignalRSharedSourceRoot)MemoryBufferWriter.cs" Link="Internal\MemoryBufferWriter.cs" />
<Compile Include="$(SignalRSharedSourceRoot)TryGetReturnType.cs" Link="TryGetReturnType.cs" />
<Compile Include="$(SharedSourceRoot)ThrowHelpers\ArgumentNullThrowHelper.cs" LinkBase="Shared" />
<Compile Include="$(SharedSourceRoot)TrimmingAttributes.cs" LinkBase="Shared" />
<Compile Include="$(SharedSourceRoot)CallerArgument\CallerArgumentExpressionAttribute.cs" LinkBase="Shared" />
</ItemGroup>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
</PropertyGroup>

<ItemGroup>
<Compile Include="$(SharedSourceRoot)TrimmingAttributes.cs" LinkBase="Shared" />

<Compile Include="$(SignalRSharedSourceRoot)JsonUtils.cs" Link="Internal\JsonUtils.cs" />
<Compile Include="$(SignalRSharedSourceRoot)TextMessageFormatter.cs" Link="TextMessageFormatter.cs" />
<Compile Include="$(SignalRSharedSourceRoot)TextMessageParser.cs" Link="TextMessageParser.cs" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Diagnostics.CodeAnalysis;
using Microsoft.AspNetCore.SignalR;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.Extensions.DependencyInjection.Extensions;
Expand All @@ -21,6 +22,7 @@ public static class NewtonsoftJsonProtocolDependencyInjectionExtensions
/// </remarks>
/// <param name="builder">The <see cref="ISignalRBuilder"/> representing the SignalR server to add JSON protocol support to.</param>
/// <returns>The value of <paramref name="builder"/></returns>
[RequiresUnreferencedCode("Newtonsoft.Json does not currently support trimming or native AOT.", Url = "https://aka.ms/aspnet/trimming")]
public static TBuilder AddNewtonsoftJsonProtocol<TBuilder>(this TBuilder builder) where TBuilder : ISignalRBuilder
=> AddNewtonsoftJsonProtocol(builder, _ => { });

Expand All @@ -33,6 +35,7 @@ public static TBuilder AddNewtonsoftJsonProtocol<TBuilder>(this TBuilder builder
/// <param name="builder">The <see cref="ISignalRBuilder"/> representing the SignalR server to add JSON protocol support to.</param>
/// <param name="configure">A delegate that can be used to configure the <see cref="NewtonsoftJsonHubProtocolOptions"/></param>
/// <returns>The value of <paramref name="builder"/></returns>
[RequiresUnreferencedCode("Newtonsoft.Json does not currently support trimming or native AOT.", Url = "https://aka.ms/aspnet/trimming")]
public static TBuilder AddNewtonsoftJsonProtocol<TBuilder>(this TBuilder builder, Action<NewtonsoftJsonHubProtocolOptions> configure) where TBuilder : ISignalRBuilder
{
builder.Services.TryAddEnumerable(ServiceDescriptor.Singleton<IHubProtocol, NewtonsoftJsonHubProtocol>());
Expand Down
4 changes: 1 addition & 3 deletions src/SignalR/common/Shared/ReflectionHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Threading.Channels;

namespace Microsoft.AspNetCore.SignalR;
Expand All @@ -17,7 +15,7 @@ internal static class ReflectionHelper
// and 'stream' types from the client are allowed to inherit from accepted 'stream' types
public static bool IsStreamingType(Type type, bool mustBeDirectType = false)
{
// TODO #2594 - add Streams here, to make sending files easy
// TODO https://github.com/dotnet/aspnetcore/issues/5316 - add Streams here, to make sending files easy

if (IsIAsyncEnumerable(type))
{
Expand Down
4 changes: 4 additions & 0 deletions src/SignalR/server/Core/src/Hub.cs
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics.CodeAnalysis;

namespace Microsoft.AspNetCore.SignalR;

/// <summary>
/// A base class for a SignalR hub.
/// </summary>
public abstract class Hub : IDisposable
{
internal const DynamicallyAccessedMemberTypes DynamicallyAccessedMembers = DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.PublicMethods;

private bool _disposed;
private IHubCallerClients _clients = default!;
private HubCallerContext _context = default!;
Expand Down
Loading
Loading