Skip to content

Commit 3579e24

Browse files
authored
Merge branch 'master' into feature/mvvm-toolkit-part2
2 parents b5c0272 + f8b78a3 commit 3579e24

File tree

2 files changed

+41
-19
lines changed

2 files changed

+41
-19
lines changed

Microsoft.Toolkit/Extensions/TaskExtensions.cs

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5-
using System;
65
using System.Diagnostics.CodeAnalysis;
76
using System.Diagnostics.Contracts;
87
using System.Reflection;
@@ -41,28 +40,21 @@ public static class TaskExtensions
4140
#endif
4241
)
4342
{
44-
Type taskType = task.GetType();
45-
46-
// Check if the task is actually some Task<T>
47-
if (
48-
#if NETSTANDARD1_4
49-
taskType.GetTypeInfo().IsGenericType &&
50-
#else
51-
taskType.IsGenericType &&
52-
#endif
53-
taskType.GetGenericTypeDefinition() == typeof(Task<>))
54-
{
55-
// Get the Task<T>.Result property
56-
PropertyInfo propertyInfo =
43+
// Try to get the Task<T>.Result property. This method would've
44+
// been called anyway after the type checks, but using that to
45+
// validate the input type saves some additional reflection calls.
46+
// Furthermore, doing this also makes the method flexible enough to
47+
// cases whether the input Task<T> is actually an instance of some
48+
// runtime-specific type that inherits from Task<T>.
49+
PropertyInfo? propertyInfo =
5750
#if NETSTANDARD1_4
58-
taskType.GetRuntimeProperty(nameof(Task<object>.Result));
51+
task.GetType().GetRuntimeProperty(nameof(Task<object>.Result));
5952
#else
60-
taskType.GetProperty(nameof(Task<object>.Result));
53+
task.GetType().GetProperty(nameof(Task<object>.Result));
6154
#endif
6255

63-
// Finally retrieve the result
64-
return propertyInfo!.GetValue(task);
65-
}
56+
// Return the result, if possible
57+
return propertyInfo?.GetValue(task);
6658
}
6759

6860
return null;

UnitTests/UnitTests.Shared/Extensions/Test_TaskExtensions.cs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,25 @@ public void Test_TaskExtensions_ResultOrDefault()
3737
Assert.AreEqual(42, ((Task)tcs.Task).GetResultOrDefault());
3838
}
3939

40+
[TestCategory("TaskExtensions")]
41+
[TestMethod]
42+
public async Task Test_TaskExtensions_ResultOrDefault_FromAsyncTaskMethodBuilder()
43+
{
44+
var tcs = new TaskCompletionSource<object>();
45+
46+
Task<string> taskFromBuilder = GetTaskFromAsyncMethodBuilder("Test", tcs);
47+
48+
Assert.IsNull(((Task)taskFromBuilder).GetResultOrDefault());
49+
Assert.IsNull(taskFromBuilder.GetResultOrDefault());
50+
51+
tcs.SetResult(null);
52+
53+
await taskFromBuilder;
54+
55+
Assert.AreEqual(((Task)taskFromBuilder).GetResultOrDefault(), "Test");
56+
Assert.AreEqual(taskFromBuilder.GetResultOrDefault(), "Test");
57+
}
58+
4059
[TestCategory("TaskExtensions")]
4160
[TestMethod]
4261
public void Test_TaskExtensions_ResultOrDefault_OfT_Int32()
@@ -86,5 +105,16 @@ public void Test_TaskExtensions_ResultOrDefault_OfT_String()
86105

87106
Assert.AreEqual("Hello world", tcs.Task.GetResultOrDefault());
88107
}
108+
109+
// Creates a Task<T> of a given type which is actually an instance of
110+
// System.Runtime.CompilerServices.AsyncTaskMethodBuilder<TResult>.AsyncStateMachineBox<TStateMachine>.
111+
// See https://source.dot.net/#System.Private.CoreLib/AsyncTaskMethodBuilderT.cs,f8f35fd356112b30.
112+
// This is needed to verify that the extension also works when the input Task<T> is of a derived type.
113+
private static async Task<T> GetTaskFromAsyncMethodBuilder<T>(T result, TaskCompletionSource<object> tcs)
114+
{
115+
await tcs.Task;
116+
117+
return result;
118+
}
89119
}
90120
}

0 commit comments

Comments
 (0)