Skip to content

Commit 57d3de4

Browse files
authored
Merge pull request #633 from CommunityToolkit/dev/partial-commands
Support partial methods and forwarded attributes with [RelayCommand]
2 parents 7ec9495 + 66f582b commit 57d3de4

File tree

8 files changed

+323
-5
lines changed

8 files changed

+323
-5
lines changed

src/CommunityToolkit.Mvvm.SourceGenerators/CommunityToolkit.Mvvm.SourceGenerators.projitems

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
<Compile Include="$(MSBuildThisFileDirectory)Extensions\INamedTypeSymbolExtensions.cs" />
5656
<Compile Include="$(MSBuildThisFileDirectory)Extensions\IncrementalGeneratorInitializationContextExtensions.cs" />
5757
<Compile Include="$(MSBuildThisFileDirectory)Extensions\IncrementalValuesProviderExtensions.cs" />
58+
<Compile Include="$(MSBuildThisFileDirectory)Extensions\MethodDeclarationSyntaxExtensions.cs" />
5859
<Compile Include="$(MSBuildThisFileDirectory)Extensions\SymbolInfoExtensions.cs" />
5960
<Compile Include="$(MSBuildThisFileDirectory)Extensions\ISymbolExtensions.cs" />
6061
<Compile Include="$(MSBuildThisFileDirectory)Extensions\SourceProductionContextExtensions.cs" />
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.CodeAnalysis;
6+
using Microsoft.CodeAnalysis.CSharp;
7+
using Microsoft.CodeAnalysis.CSharp.Syntax;
8+
9+
namespace CommunityToolkit.Mvvm.SourceGenerators.Extensions;
10+
11+
/// <summary>
12+
/// Extension methods for the <see cref="MethodDeclarationSyntax"/> type.
13+
/// </summary>
14+
internal static class MethodDeclarationSyntaxExtensions
15+
{
16+
/// <summary>
17+
/// Checks whether a given <see cref="MethodDeclarationSyntax"/> has or could potentially have any attribute lists.
18+
/// </summary>
19+
/// <param name="methodDeclaration">The input <see cref="MethodDeclarationSyntax"/> to check.</param>
20+
/// <returns>Whether <paramref name="methodDeclaration"/> has or potentially has any attribute lists.</returns>
21+
public static bool HasOrPotentiallyHasAttributeLists(this MethodDeclarationSyntax methodDeclaration)
22+
{
23+
// If the declaration has any attribute lists, there's nothing left to do
24+
if (methodDeclaration.AttributeLists.Count > 0)
25+
{
26+
return true;
27+
}
28+
29+
// If there are no attributes, check whether the method declaration has the partial keyword. If it
30+
// does, there could potentially be attribute lists on the other partial definition/implementation.
31+
return methodDeclaration.Modifiers.Any(SyntaxKind.PartialKeyword);
32+
}
33+
}

src/CommunityToolkit.Mvvm.SourceGenerators/Extensions/SyntaxNodeExtensions.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ internal static class SyntaxNodeExtensions
2626
public static bool IsFirstSyntaxDeclarationForSymbol(this SyntaxNode syntaxNode, ISymbol symbol)
2727
{
2828
return
29-
symbol.DeclaringSyntaxReferences.Length > 0 &&
30-
symbol.DeclaringSyntaxReferences[0] is SyntaxReference syntaxReference &&
29+
symbol.DeclaringSyntaxReferences is [SyntaxReference syntaxReference, ..] &&
3130
syntaxReference.SyntaxTree == syntaxNode.SyntaxTree &&
3231
syntaxReference.Span == syntaxNode.Span;
3332
}

src/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.Execute.cs

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,15 @@ private static bool IsCommandDefinitionUnique(IMethodSymbol methodSymbol, in Imm
430430
return true;
431431
}
432432

433+
// If the two method symbols are partial and either is the implementation of the other one, this is allowed
434+
if ((methodSymbol is { IsPartialDefinition: true, PartialImplementationPart: { } partialImplementation } &&
435+
SymbolEqualityComparer.Default.Equals(otherSymbol, partialImplementation)) ||
436+
(otherSymbol is { IsPartialDefinition: true, PartialImplementationPart: { } otherPartialImplementation } &&
437+
SymbolEqualityComparer.Default.Equals(methodSymbol, otherPartialImplementation)))
438+
{
439+
continue;
440+
}
441+
433442
diagnostics.Add(
434443
MultipleRelayCommandMethodOverloadsError,
435444
methodSymbol,
@@ -952,12 +961,24 @@ private static void GatherForwardedAttributes(
952961
using ImmutableArrayBuilder<AttributeInfo> fieldAttributesInfo = ImmutableArrayBuilder<AttributeInfo>.Rent();
953962
using ImmutableArrayBuilder<AttributeInfo> propertyAttributesInfo = ImmutableArrayBuilder<AttributeInfo>.Rent();
954963

955-
foreach (SyntaxReference syntaxReference in methodSymbol.DeclaringSyntaxReferences)
964+
static void GatherForwardedAttributes(
965+
IMethodSymbol methodSymbol,
966+
SemanticModel semanticModel,
967+
CancellationToken token,
968+
in ImmutableArrayBuilder<DiagnosticInfo> diagnostics,
969+
in ImmutableArrayBuilder<AttributeInfo> fieldAttributesInfo,
970+
in ImmutableArrayBuilder<AttributeInfo> propertyAttributesInfo)
956971
{
972+
// Get the single syntax reference for the input method symbol (there should be only one)
973+
if (methodSymbol.DeclaringSyntaxReferences is not [SyntaxReference syntaxReference])
974+
{
975+
return;
976+
}
977+
957978
// Try to get the target method declaration syntax node
958979
if (syntaxReference.GetSyntax(token) is not MethodDeclarationSyntax methodDeclaration)
959980
{
960-
continue;
981+
return;
961982
}
962983

963984
// Gather explicit forwarded attributes info
@@ -998,6 +1019,22 @@ private static void GatherForwardedAttributes(
9981019
}
9991020
}
10001021

1022+
// If the method is a partial definition, also gather attributes from the implementation part
1023+
if (methodSymbol is { IsPartialDefinition: true } or { PartialDefinitionPart: not null })
1024+
{
1025+
IMethodSymbol partialDefinition = methodSymbol.PartialDefinitionPart ?? methodSymbol;
1026+
IMethodSymbol partialImplementation = methodSymbol.PartialImplementationPart ?? methodSymbol;
1027+
1028+
// We always give priority to the partial definition, to ensure a predictable and testable ordering
1029+
GatherForwardedAttributes(partialDefinition, semanticModel, token, in diagnostics, in fieldAttributesInfo, in propertyAttributesInfo);
1030+
GatherForwardedAttributes(partialImplementation, semanticModel, token, in diagnostics, in fieldAttributesInfo, in propertyAttributesInfo);
1031+
}
1032+
else
1033+
{
1034+
// If the method is not a partial definition/implementation, just gather attributes from the method with no modifications
1035+
GatherForwardedAttributes(methodSymbol, semanticModel, token, in diagnostics, in fieldAttributesInfo, in propertyAttributesInfo);
1036+
}
1037+
10011038
fieldAttributes = fieldAttributesInfo.ToImmutable();
10021039
propertyAttributes = propertyAttributesInfo.ToImmutable();
10031040
}

src/CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
2727
context.SyntaxProvider
2828
.ForAttributeWithMetadataName(
2929
"CommunityToolkit.Mvvm.Input.RelayCommandAttribute",
30-
static (node, _) => node is MethodDeclarationSyntax { Parent: ClassDeclarationSyntax, AttributeLists.Count: > 0 },
30+
static (node, _) => node is MethodDeclarationSyntax { Parent: ClassDeclarationSyntax } methodDeclaration && methodDeclaration.HasOrPotentiallyHasAttributeLists(),
3131
static (context, token) =>
3232
{
3333
if (!context.SemanticModel.Compilation.HasLanguageVersionAtLeastEqualTo(LanguageVersion.CSharp8))

src/CommunityToolkit.Mvvm.SourceGenerators/Polyfills/SyntaxValueProviderExtensions.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,15 @@ public static IncrementalValuesProvider<T> ForAttributeWithMetadataName<T>(
5959
return null;
6060
}
6161

62+
// Edge case: if the symbol is a partial method, skip the implementation part and only process the partial method
63+
// definition. This is needed because attributes will be reported as available on both the definition and the
64+
// implementation part. To avoid generating duplicate files, we only give priority to the definition part.
65+
// On Roslyn 4.3+, ForAttributeWithMetadataName will already only return the symbol the attribute was located on.
66+
if (symbol is IMethodSymbol { IsPartialDefinition: false, PartialDefinitionPart: not null })
67+
{
68+
return null;
69+
}
70+
6271
// Create the GeneratorAttributeSyntaxContext value to pass to the input transform. The attributes array
6372
// will only ever have a single value, but that's fine with the attributes the various generators look for.
6473
GeneratorAttributeSyntaxContext syntaxContext = new(

tests/CommunityToolkit.Mvvm.SourceGenerators.UnitTests/Test_SourceGeneratorsCodegen.cs

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,169 @@ partial class MyViewModel
268268
VerifyGenerateSources(source, new[] { new RelayCommandGenerator() }, ("MyApp.MyViewModel.Test.g.cs", result));
269269
}
270270

271+
// See https://github.com/CommunityToolkit/dotnet/issues/632
272+
[TestMethod]
273+
public void RelayCommandMethodWithPartialDeclarations_TriggersCorrectly()
274+
{
275+
string source = """
276+
using CommunityToolkit.Mvvm.Input;
277+
278+
#nullable enable
279+
280+
namespace MyApp;
281+
282+
partial class MyViewModel
283+
{
284+
[RelayCommand]
285+
private partial void Test1()
286+
{
287+
}
288+
289+
private partial void Test1();
290+
291+
private partial void Test2()
292+
{
293+
}
294+
295+
[RelayCommand]
296+
private partial void Test2();
297+
}
298+
""";
299+
300+
string result1 = """
301+
// <auto-generated/>
302+
#pragma warning disable
303+
#nullable enable
304+
namespace MyApp
305+
{
306+
partial class MyViewModel
307+
{
308+
/// <summary>The backing field for <see cref="Test1Command"/>.</summary>
309+
[global::System.CodeDom.Compiler.GeneratedCode("CommunityToolkit.Mvvm.SourceGenerators.RelayCommandGenerator", "8.1.0.0")]
310+
private global::CommunityToolkit.Mvvm.Input.RelayCommand? test1Command;
311+
/// <summary>Gets an <see cref="global::CommunityToolkit.Mvvm.Input.IRelayCommand"/> instance wrapping <see cref="Test1"/>.</summary>
312+
[global::System.CodeDom.Compiler.GeneratedCode("CommunityToolkit.Mvvm.SourceGenerators.RelayCommandGenerator", "8.1.0.0")]
313+
[global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
314+
public global::CommunityToolkit.Mvvm.Input.IRelayCommand Test1Command => test1Command ??= new global::CommunityToolkit.Mvvm.Input.RelayCommand(new global::System.Action(Test1));
315+
}
316+
}
317+
""";
318+
319+
string result2 = """
320+
// <auto-generated/>
321+
#pragma warning disable
322+
#nullable enable
323+
namespace MyApp
324+
{
325+
partial class MyViewModel
326+
{
327+
/// <summary>The backing field for <see cref="Test2Command"/>.</summary>
328+
[global::System.CodeDom.Compiler.GeneratedCode("CommunityToolkit.Mvvm.SourceGenerators.RelayCommandGenerator", "8.1.0.0")]
329+
private global::CommunityToolkit.Mvvm.Input.RelayCommand? test2Command;
330+
/// <summary>Gets an <see cref="global::CommunityToolkit.Mvvm.Input.IRelayCommand"/> instance wrapping <see cref="Test2"/>.</summary>
331+
[global::System.CodeDom.Compiler.GeneratedCode("CommunityToolkit.Mvvm.SourceGenerators.RelayCommandGenerator", "8.1.0.0")]
332+
[global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
333+
public global::CommunityToolkit.Mvvm.Input.IRelayCommand Test2Command => test2Command ??= new global::CommunityToolkit.Mvvm.Input.RelayCommand(new global::System.Action(Test2));
334+
}
335+
}
336+
""";
337+
338+
VerifyGenerateSources(source, new[] { new RelayCommandGenerator() }, ("MyApp.MyViewModel.Test1.g.cs", result1), ("MyApp.MyViewModel.Test2.g.cs", result2));
339+
}
340+
341+
// See https://github.com/CommunityToolkit/dotnet/issues/632
342+
[TestMethod]
343+
public void RelayCommandMethodWithForwardedAttributesOverPartialDeclarations_MergesAttributes()
344+
{
345+
string source = """
346+
using CommunityToolkit.Mvvm.Input;
347+
348+
#nullable enable
349+
350+
namespace MyApp;
351+
352+
partial class MyViewModel
353+
{
354+
[RelayCommand]
355+
[field: Value(0)]
356+
[property: Value(1)]
357+
private partial void Test1()
358+
{
359+
}
360+
361+
[field: Value(2)]
362+
[property: Value(3)]
363+
private partial void Test1();
364+
365+
[field: Value(0)]
366+
[property: Value(1)]
367+
private partial void Test2()
368+
{
369+
}
370+
371+
[RelayCommand]
372+
[field: Value(2)]
373+
[property: Value(3)]
374+
private partial void Test2();
375+
}
376+
377+
public class ValueAttribute : Attribute
378+
{
379+
public ValueAttribute(object value)
380+
{
381+
}
382+
}
383+
""";
384+
385+
string result1 = """
386+
// <auto-generated/>
387+
#pragma warning disable
388+
#nullable enable
389+
namespace MyApp
390+
{
391+
partial class MyViewModel
392+
{
393+
/// <summary>The backing field for <see cref="Test1Command"/>.</summary>
394+
[global::System.CodeDom.Compiler.GeneratedCode("CommunityToolkit.Mvvm.SourceGenerators.RelayCommandGenerator", "8.1.0.0")]
395+
[global::MyApp.ValueAttribute(2)]
396+
[global::MyApp.ValueAttribute(0)]
397+
private global::CommunityToolkit.Mvvm.Input.RelayCommand? test1Command;
398+
/// <summary>Gets an <see cref="global::CommunityToolkit.Mvvm.Input.IRelayCommand"/> instance wrapping <see cref="Test1"/>.</summary>
399+
[global::System.CodeDom.Compiler.GeneratedCode("CommunityToolkit.Mvvm.SourceGenerators.RelayCommandGenerator", "8.1.0.0")]
400+
[global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
401+
[global::MyApp.ValueAttribute(3)]
402+
[global::MyApp.ValueAttribute(1)]
403+
public global::CommunityToolkit.Mvvm.Input.IRelayCommand Test1Command => test1Command ??= new global::CommunityToolkit.Mvvm.Input.RelayCommand(new global::System.Action(Test1));
404+
}
405+
}
406+
""";
407+
408+
string result2 = """
409+
// <auto-generated/>
410+
#pragma warning disable
411+
#nullable enable
412+
namespace MyApp
413+
{
414+
partial class MyViewModel
415+
{
416+
/// <summary>The backing field for <see cref="Test2Command"/>.</summary>
417+
[global::System.CodeDom.Compiler.GeneratedCode("CommunityToolkit.Mvvm.SourceGenerators.RelayCommandGenerator", "8.1.0.0")]
418+
[global::MyApp.ValueAttribute(2)]
419+
[global::MyApp.ValueAttribute(0)]
420+
private global::CommunityToolkit.Mvvm.Input.RelayCommand? test2Command;
421+
/// <summary>Gets an <see cref="global::CommunityToolkit.Mvvm.Input.IRelayCommand"/> instance wrapping <see cref="Test2"/>.</summary>
422+
[global::System.CodeDom.Compiler.GeneratedCode("CommunityToolkit.Mvvm.SourceGenerators.RelayCommandGenerator", "8.1.0.0")]
423+
[global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
424+
[global::MyApp.ValueAttribute(3)]
425+
[global::MyApp.ValueAttribute(1)]
426+
public global::CommunityToolkit.Mvvm.Input.IRelayCommand Test2Command => test2Command ??= new global::CommunityToolkit.Mvvm.Input.RelayCommand(new global::System.Action(Test2));
427+
}
428+
}
429+
""";
430+
431+
VerifyGenerateSources(source, new[] { new RelayCommandGenerator() }, ("MyApp.MyViewModel.Test1.g.cs", result1), ("MyApp.MyViewModel.Test2.g.cs", result2));
432+
}
433+
271434
[TestMethod]
272435
public void ObservablePropertyWithinGenericAndNestedTypes()
273436
{

0 commit comments

Comments
 (0)