Skip to content

Removed unsafe indexing from StrongReferenceMessenger #3513

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 17 additions & 23 deletions Microsoft.Toolkit.Mvvm/Messaging/StrongReferenceMessenger.cs
Original file line number Diff line number Diff line change
Expand Up @@ -325,14 +325,12 @@ public void Unregister<TMessage, TToken>(object recipient, TToken token)
}

/// <inheritdoc/>
public unsafe TMessage Send<TMessage, TToken>(TMessage message, TToken token)
public TMessage Send<TMessage, TToken>(TMessage message, TToken token)
where TMessage : class
where TToken : IEquatable<TToken>
{
object[] handlers;
object[] recipients;
ref object handlersRef = ref Unsafe.AsRef<object>(null);
ref object recipientsRef = ref Unsafe.AsRef<object>(null);
object[] rentedArray;
Span<object> pairs;
int i = 0;

lock (this.recipientsMap)
Expand All @@ -358,10 +356,9 @@ public unsafe TMessage Send<TMessage, TToken>(TMessage message, TToken token)
return message;
}

handlers = ArrayPool<object>.Shared.Rent(totalHandlersCount);
recipients = ArrayPool<object>.Shared.Rent(totalHandlersCount);
handlersRef = ref handlers[0];
recipientsRef = ref recipients[0];
// Rent the array and also assign it to a span, which will be used to access values.
// We're doing this to avoid the array covariance checks slowdown in the loops below.
pairs = rentedArray = ArrayPool<object>.Shared.Rent(2 * totalHandlersCount);

// Copy the handlers to the local collection.
// The array is oversized at this point, since it also includes
Expand All @@ -379,10 +376,14 @@ public unsafe TMessage Send<TMessage, TToken>(TMessage message, TToken token)
// Pick the target handler, if the token is a match for the recipient
if (mappingEnumerator.Value.TryGetValue(token, out object? handler))
{
// We can manually offset here to skip the bounds checks in this inner loop when
// indexing the array (the size is already verified and guaranteed to be enough).
Unsafe.Add(ref handlersRef, (IntPtr)(void*)(uint)i) = handler!;
Unsafe.Add(ref recipientsRef, (IntPtr)(void*)(uint)i++) = recipient;
// This span access should always guaranteed to be valid due to the size of the
// array being set according to the current total number of registered handlers,
// which will always be greater or equal than the ones matching the previous test.
// We're still using a checked span accesses here though to make sure an out of
// bounds write can never happen even if an error was present in the logic above.
pairs[2 * i] = handler!;
pairs[(2 * i) + 1] = recipient;
i++;
}
}
}
Expand All @@ -392,27 +393,20 @@ public unsafe TMessage Send<TMessage, TToken>(TMessage message, TToken token)
// Invoke all the necessary handlers on the local copy of entries
for (int j = 0; j < i; j++)
{
// We're doing an unsafe cast to skip the type checks again.
// See the comments in the UnregisterAll method for more info.
object handler = Unsafe.Add(ref handlersRef, (IntPtr)(void*)(uint)j);
object recipient = Unsafe.Add(ref recipientsRef, (IntPtr)(void*)(uint)j);

// Here we perform an unsafe cast to enable covariance for delegate types.
// We know that the input recipient will always respect the type constraints
// of each original input delegate, and doing so allows us to still invoke
// them all from here without worrying about specific generic type arguments.
Unsafe.As<MessageHandler<object, TMessage>>(handler)(recipient, message);
Unsafe.As<MessageHandler<object, TMessage>>(pairs[2 * j])(pairs[(2 * j) + 1], message);
}
}
finally
{
// As before, we also need to clear it first to avoid having potentially long
// lasting memory leaks due to leftover references being stored in the pool.
handlers.AsSpan(0, i).Clear();
recipients.AsSpan(0, i).Clear();
Array.Clear(rentedArray, 0, 2 * i);

ArrayPool<object>.Shared.Return(handlers);
ArrayPool<object>.Shared.Return(recipients);
ArrayPool<object>.Shared.Return(rentedArray);
}

return message;
Expand Down
23 changes: 10 additions & 13 deletions Microsoft.Toolkit.Mvvm/Messaging/WeakReferenceMessenger.cs
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ public TMessage Send<TMessage, TToken>(TMessage message, TToken token)
where TMessage : class
where TToken : IEquatable<TToken>
{
ArrayPoolBufferWriter<object> recipients;
ArrayPoolBufferWriter<object> handlers;
ArrayPoolBufferWriter<object> bufferWriter;
int i = 0;

lock (this.recipientsMap)
{
Expand All @@ -182,8 +182,7 @@ public TMessage Send<TMessage, TToken>(TMessage message, TToken token)
return message;
}

recipients = ArrayPoolBufferWriter<object>.Create();
handlers = ArrayPoolBufferWriter<object>.Create();
bufferWriter = ArrayPoolBufferWriter<object>.Create();

// We need a local, temporary copy of all the pending recipients and handlers to
// invoke, to avoid issues with handlers unregistering from messages while we're
Expand All @@ -197,32 +196,30 @@ public TMessage Send<TMessage, TToken>(TMessage message, TToken token)

if (map.TryGetValue(token, out object? handler))
{
recipients.Add(pair.Key);
handlers.Add(handler!);
bufferWriter.Add(handler!);
bufferWriter.Add(pair.Key);
i++;
}
}
}

try
{
ReadOnlySpan<object>
recipientsSpan = recipients.Span,
handlersSpan = handlers.Span;
ReadOnlySpan<object> pairs = bufferWriter.Span;

for (int i = 0; i < recipientsSpan.Length; i++)
for (int j = 0; j < i; j++)
{
// Just like in the other messenger, here we need an unsafe cast to be able to
// invoke a generic delegate with a contravariant input argument, with a less
// derived reference, without reflection. This is guaranteed to work by how the
// messenger tracks registered recipients and their associated handlers, so the
// type conversion will always be valid (the recipients are the rigth instances).
Unsafe.As<MessageHandler<object, TMessage>>(handlersSpan[i])(recipientsSpan[i], message);
Unsafe.As<MessageHandler<object, TMessage>>(pairs[2 * j])(pairs[(2 * j) + 1], message);
}
}
finally
{
recipients.Dispose();
handlers.Dispose();
bufferWriter.Dispose();
}

return message;
Expand Down