Skip to content

Commit bf98ede

Browse files
committed
Don't propagate symbols in RelayCommandGenerator
1 parent 5876242 commit bf98ede

File tree

2 files changed

+43
-23
lines changed

2 files changed

+43
-23
lines changed

CommunityToolkit.Mvvm.SourceGenerators/Extensions/ISymbolExtensions.cs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System.Collections.Immutable;
6+
using System.Diagnostics.CodeAnalysis;
67
using Microsoft.CodeAnalysis;
78

89
namespace CommunityToolkit.Mvvm.SourceGenerators.Extensions;
@@ -64,6 +65,32 @@ public static bool HasAttributeWithFullyQualifiedName(this ISymbol symbol, strin
6465
return false;
6566
}
6667

68+
/// <summary>
69+
/// Tries to get an attribute with the specified full name.
70+
/// </summary>
71+
/// <param name="symbol">The input <see cref="ISymbol"/> instance to check.</param>
72+
/// <param name="name">The attribute name to look for.</param>
73+
/// <param name="attributeData">The resulting attribute, if it was found.</param>
74+
/// <returns>Whether or not <paramref name="symbol"/> has an attribute with the specified name.</returns>
75+
public static bool TryGetAttributeWithFullyQualifiedName(this ISymbol symbol, string name, [NotNullWhen(true)] out AttributeData? attributeData)
76+
{
77+
ImmutableArray<AttributeData> attributes = symbol.GetAttributes();
78+
79+
foreach (AttributeData attribute in attributes)
80+
{
81+
if (attribute.AttributeClass?.HasFullyQualifiedName(name) == true)
82+
{
83+
attributeData = attribute;
84+
85+
return true;
86+
}
87+
}
88+
89+
attributeData = null;
90+
91+
return false;
92+
}
93+
6794
/// <summary>
6895
/// Calculates the effective accessibility for a given symbol.
6996
/// </summary>

CommunityToolkit.Mvvm.SourceGenerators/Input/RelayCommandGenerator.cs

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -23,40 +23,33 @@ public sealed partial class RelayCommandGenerator : IIncrementalGenerator
2323
/// <inheritdoc/>
2424
public void Initialize(IncrementalGeneratorInitializationContext context)
2525
{
26-
// Get all method declarations with at least one attribute
27-
IncrementalValuesProvider<IMethodSymbol> methodSymbols =
26+
// Gather info for all annotated command methods (starting from method declarations with at least one attribute)
27+
IncrementalValuesProvider<(HierarchyInfo Hierarchy, Result<CommandInfo?> Info)> commandInfoWithErrors =
2828
context.SyntaxProvider
2929
.CreateSyntaxProvider(
3030
static (node, _) => node is MethodDeclarationSyntax { Parent: ClassDeclarationSyntax, AttributeLists.Count: > 0 },
31-
static (context, _) =>
31+
static (context, token) =>
3232
{
3333
if (!context.SemanticModel.Compilation.HasLanguageVersionAtLeastEqualTo(LanguageVersion.CSharp8))
3434
{
3535
return default;
3636
}
3737

38-
return (IMethodSymbol)context.SemanticModel.GetDeclaredSymbol(context.Node)!;
39-
})
40-
.Where(static item => item is not null)!;
38+
IMethodSymbol methodSymbol = (IMethodSymbol)context.SemanticModel.GetDeclaredSymbol(context.Node, token)!;
4139

42-
// Filter the methods using [RelayCommand]
43-
IncrementalValuesProvider<(IMethodSymbol Symbol, AttributeData Attribute)> methodSymbolsWithAttributeData =
44-
methodSymbols
45-
.Select(static (item, _) => (
46-
item,
47-
Attribute: item.GetAttributes().FirstOrDefault(a => a.AttributeClass?.HasFullyQualifiedName("global::CommunityToolkit.Mvvm.Input.RelayCommandAttribute") == true)))
48-
.Where(static item => item.Attribute is not null)!;
40+
// Filter the methods using [RelayCommand]
41+
if (!methodSymbol.TryGetAttributeWithFullyQualifiedName("global::CommunityToolkit.Mvvm.Input.RelayCommandAttribute", out AttributeData? attribute))
42+
{
43+
return default;
44+
}
4945

50-
// Gather info for all annotated command methods
51-
IncrementalValuesProvider<(HierarchyInfo Hierarchy, Result<CommandInfo?> Info)> commandInfoWithErrors =
52-
methodSymbolsWithAttributeData
53-
.Select(static (item, _) =>
54-
{
55-
HierarchyInfo hierarchy = HierarchyInfo.From(item.Symbol.ContainingType);
56-
CommandInfo? commandInfo = Execute.GetInfo(item.Symbol, item.Attribute, out ImmutableArray<Diagnostic> diagnostics);
46+
// Produce the incremental models
47+
HierarchyInfo hierarchy = HierarchyInfo.From(methodSymbol.ContainingType);
48+
CommandInfo? commandInfo = Execute.GetInfo(methodSymbol, attribute, out ImmutableArray<Diagnostic> diagnostics);
5749

58-
return (hierarchy, new Result<CommandInfo?>(commandInfo, diagnostics));
59-
});
50+
return (Hierarchy: hierarchy, new Result<CommandInfo?>(commandInfo, diagnostics));
51+
})
52+
.Where(static item => item.Hierarchy is not null);
6053

6154
// Output the diagnostics
6255
context.ReportDiagnostics(commandInfoWithErrors.Select(static (item, _) => item.Info.Errors));
@@ -66,7 +59,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
6659
commandInfoWithErrors
6760
.Where(static item => item.Info.Value is not null)
6861
.Select(static (item, _) => (item.Hierarchy, item.Info.Value!))
69-
.WithComparers(HierarchyInfo.Comparer.Default, CommandInfo.Comparer.Default);
62+
.WithComparers(HierarchyInfo.Comparer.Default, CommandInfo.Comparer.Default);
7063

7164
// Generate the commands
7265
context.RegisterSourceOutput(commandInfo, static (context, item) =>

0 commit comments

Comments
 (0)