Skip to content

Commit 0936b58

Browse files
authored
Merge pull request #106 from CommunityToolkit/dev/icommand-attribute-generated-can-execute
Allow using [ICommand(CanExecute)] on generated observable properties
2 parents ea10105 + b861efc commit 0936b58

File tree

2 files changed

+141
-1
lines changed

2 files changed

+141
-1
lines changed

CommunityToolkit.Mvvm.SourceGenerators/Input/ICommandGenerator.Execute.cs

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System;
66
using System.Collections.Immutable;
77
using System.Diagnostics.CodeAnalysis;
8+
using System.Linq;
89
using CommunityToolkit.Mvvm.SourceGenerators.Diagnostics;
910
using CommunityToolkit.Mvvm.SourceGenerators.Extensions;
1011
using CommunityToolkit.Mvvm.SourceGenerators.Input.Models;
@@ -435,6 +436,14 @@ private static bool TryGetCanExecuteExpressionType(
435436

436437
if (canExecuteSymbols.IsEmpty)
437438
{
439+
// Special case for when the target member is a generated property from [ObservableProperty]
440+
if (TryGetCanExecuteMemberFromGeneratedProperty(memberName, methodSymbol.ContainingType, commandTypeArguments, out canExecuteExpressionType))
441+
{
442+
canExecuteMemberName = memberName;
443+
444+
return true;
445+
}
446+
438447
diagnostics.Add(InvalidCanExecuteMemberName, methodSymbol, memberName, methodSymbol.ContainingType);
439448
}
440449
else if (canExecuteSymbols.Length > 1)
@@ -531,5 +540,63 @@ private static bool TryGetCanExecuteExpressionFromSymbol(
531540

532541
return false;
533542
}
543+
544+
/// <summary>
545+
/// Gets the expression type for the can execute logic, if possible.
546+
/// </summary>
547+
/// <param name="memberName">The member name passed to <c>[ICommand(CanExecute = ...)]</c>.</param>
548+
/// <param name="containingType">The containing type for the method annotated with <c>[ICommand]</c>.</param>
549+
/// <param name="commandTypeArguments">The type arguments for the command interface, if any.</param>
550+
/// <param name="canExecuteExpressionType">The resulting can execute expression type, if available.</param>
551+
/// <returns>Whether or not <paramref name="canExecuteExpressionType"/> was set and the input symbol was valid.</returns>
552+
private static bool TryGetCanExecuteMemberFromGeneratedProperty(
553+
string memberName,
554+
INamedTypeSymbol containingType,
555+
ImmutableArray<string> commandTypeArguments,
556+
[NotNullWhen(true)] out CanExecuteExpressionType? canExecuteExpressionType)
557+
{
558+
foreach (ISymbol memberSymbol in containingType.GetMembers())
559+
{
560+
// Only look for instance fields of bool type
561+
if (memberSymbol is not IFieldSymbol fieldSymbol ||
562+
fieldSymbol is { IsStatic: true } ||
563+
!fieldSymbol.Type.HasFullyQualifiedName("bool"))
564+
{
565+
continue;
566+
}
567+
568+
ImmutableArray<AttributeData> attributes = memberSymbol.GetAttributes();
569+
570+
// Only filter fields with the [ObservableProperty] attribute
571+
if (memberSymbol is IFieldSymbol &&
572+
!attributes.Any(static a => a.AttributeClass?.HasFullyQualifiedName(
573+
"global::CommunityToolkit.Mvvm.ComponentModel.ObservablePropertyAttribute") == true))
574+
{
575+
continue;
576+
}
577+
578+
// Get the target property name either directly or matching the generated one
579+
string propertyName = ObservablePropertyGenerator.Execute.GetGeneratedPropertyName(fieldSymbol);
580+
581+
// If the generated property name matches, get the right expression type
582+
if (memberName == propertyName)
583+
{
584+
if (commandTypeArguments.Length > 0)
585+
{
586+
canExecuteExpressionType = CanExecuteExpressionType.PropertyAccessLambdaWithDiscard;
587+
}
588+
else
589+
{
590+
canExecuteExpressionType = CanExecuteExpressionType.PropertyAccessLambda;
591+
}
592+
593+
return true;
594+
}
595+
}
596+
597+
canExecuteExpressionType = null;
598+
599+
return false;
600+
}
534601
}
535602
}

tests/CommunityToolkit.Mvvm.UnitTests/Test_ICommandAttribute.cs

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using System.Linq;
77
using System.Threading;
88
using System.Threading.Tasks;
9+
using CommunityToolkit.Mvvm.ComponentModel;
910
using CommunityToolkit.Mvvm.Input;
1011
using Microsoft.VisualStudio.TestTools.UnitTesting;
1112

@@ -76,6 +77,24 @@ public void Test_ICommandAttribute_CanExecute_NoParameters_Property()
7677
Assert.AreEqual(model.Counter, 1);
7778
}
7879

80+
[TestMethod]
81+
public void Test_ICommandAttribute_CanExecute_NoParameters_GeneratedProperty()
82+
{
83+
CanExecuteViewModel model = new();
84+
85+
model.SetGeneratedFlag(true);
86+
87+
model.IncrementCounter_NoParameters_GeneratedPropertyCommand.Execute(null);
88+
89+
Assert.AreEqual(model.Counter, 1);
90+
91+
model.SetGeneratedFlag(false);
92+
93+
model.IncrementCounter_NoParameters_GeneratedPropertyCommand.Execute(null);
94+
95+
Assert.AreEqual(model.Counter, 1);
96+
}
97+
7998
[TestMethod]
8099
public void Test_ICommandAttribute_CanExecute_WithParameter_Property()
81100
{
@@ -94,6 +113,24 @@ public void Test_ICommandAttribute_CanExecute_WithParameter_Property()
94113
Assert.AreEqual(model.Counter, 1);
95114
}
96115

116+
[TestMethod]
117+
public void Test_ICommandAttribute_CanExecute_WithParameter_GeneratedProperty()
118+
{
119+
CanExecuteViewModel model = new();
120+
121+
model.SetGeneratedFlag(true);
122+
123+
model.IncrementCounter_WithParameter_GeneratedPropertyCommand.Execute(null);
124+
125+
Assert.AreEqual(model.Counter, 1);
126+
127+
model.SetGeneratedFlag(false);
128+
129+
model.IncrementCounter_WithParameter_GeneratedPropertyCommand.Execute(null);
130+
131+
Assert.AreEqual(model.Counter, 1);
132+
}
133+
97134
[TestMethod]
98135
public void Test_ICommandAttribute_CanExecute_NoParameters_MethodWithNoParameters()
99136
{
@@ -384,12 +421,20 @@ private async Task AwaitForInputTaskAsync(Task task)
384421
}
385422
}
386423

387-
public sealed partial class CanExecuteViewModel
424+
public sealed partial class CanExecuteViewModel : ObservableObject
388425
{
389426
public int Counter { get; private set; }
390427

391428
public bool Flag { get; set; }
392429

430+
public void SetGeneratedFlag(bool flag)
431+
{
432+
GeneratedFlag = flag;
433+
}
434+
435+
[ObservableProperty]
436+
private bool generatedFlag;
437+
393438
private bool GetFlag1() => Flag;
394439

395440
private bool GetFlag2(User user) => user.Name == nameof(CanExecuteViewModel);
@@ -406,6 +451,18 @@ private void IncrementCounter_WithParameter_Property(User user)
406451
Counter++;
407452
}
408453

454+
[ICommand(CanExecute = nameof(GeneratedFlag))]
455+
private void IncrementCounter_NoParameters_GeneratedProperty()
456+
{
457+
Counter++;
458+
}
459+
460+
[ICommand(CanExecute = nameof(GeneratedFlag))]
461+
private void IncrementCounter_WithParameter_GeneratedProperty(User user)
462+
{
463+
Counter++;
464+
}
465+
409466
[ICommand(CanExecute = nameof(GetFlag1))]
410467
private void IncrementCounter_NoParameters_MethodWithNoParameters()
411468
{
@@ -440,6 +497,22 @@ private async Task IncrementCounter_Async_WithParameter_Property(User user)
440497
await Task.Delay(100);
441498
}
442499

500+
[ICommand(CanExecute = nameof(GeneratedFlag))]
501+
private async Task IncrementCounter_Async_NoParameters_GeneratedProperty()
502+
{
503+
Counter++;
504+
505+
await Task.Delay(100);
506+
}
507+
508+
[ICommand(CanExecute = nameof(GeneratedFlag))]
509+
private async Task IncrementCounter_Async_WithParameter_GeneratedProperty(User user)
510+
{
511+
Counter++;
512+
513+
await Task.Delay(100);
514+
}
515+
443516
[ICommand(CanExecute = nameof(GetFlag1))]
444517
private async Task IncrementCounter_Async_NoParameters_MethodWithNoParameters()
445518
{

0 commit comments

Comments
 (0)