Skip to content

Commit 21ba24a

Browse files
authored
Fix handling of parsable types in validations generator (#61728)
1 parent bce4dd5 commit 21ba24a

10 files changed

+391
-73
lines changed

src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Extensions/ITypeSymbolExtensions.cs

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
using System.Collections.Immutable;
55
using System.Linq;
6+
using Microsoft.AspNetCore.App.Analyzers.Infrastructure;
67
using Microsoft.CodeAnalysis;
78

89
namespace Microsoft.AspNetCore.Http.ValidationsGenerator;
@@ -90,17 +91,17 @@ internal static bool ImplementsInterface(this ITypeSymbol type, ITypeSymbol inte
9091

9192
// Types exempted here have special binding rules in RDF and RDG and are not validatable
9293
// types themselves so we short-circuit on them.
93-
internal static bool IsExemptType(this ITypeSymbol type, RequiredSymbols requiredSymbols)
94+
internal static bool IsExemptType(this ITypeSymbol type, WellKnownTypes wellKnownTypes)
9495
{
95-
return SymbolEqualityComparer.Default.Equals(type, requiredSymbols.HttpContext)
96-
|| SymbolEqualityComparer.Default.Equals(type, requiredSymbols.HttpRequest)
97-
|| SymbolEqualityComparer.Default.Equals(type, requiredSymbols.HttpResponse)
98-
|| SymbolEqualityComparer.Default.Equals(type, requiredSymbols.CancellationToken)
99-
|| SymbolEqualityComparer.Default.Equals(type, requiredSymbols.IFormCollection)
100-
|| SymbolEqualityComparer.Default.Equals(type, requiredSymbols.IFormFileCollection)
101-
|| SymbolEqualityComparer.Default.Equals(type, requiredSymbols.IFormFile)
102-
|| SymbolEqualityComparer.Default.Equals(type, requiredSymbols.Stream)
103-
|| SymbolEqualityComparer.Default.Equals(type, requiredSymbols.PipeReader);
96+
return SymbolEqualityComparer.Default.Equals(type, wellKnownTypes.Get(WellKnownTypeData.WellKnownType.Microsoft_AspNetCore_Http_HttpContext))
97+
|| SymbolEqualityComparer.Default.Equals(type, wellKnownTypes.Get(WellKnownTypeData.WellKnownType.Microsoft_AspNetCore_Http_HttpRequest))
98+
|| SymbolEqualityComparer.Default.Equals(type, wellKnownTypes.Get(WellKnownTypeData.WellKnownType.Microsoft_AspNetCore_Http_HttpResponse))
99+
|| SymbolEqualityComparer.Default.Equals(type, wellKnownTypes.Get(WellKnownTypeData.WellKnownType.System_Threading_CancellationToken))
100+
|| SymbolEqualityComparer.Default.Equals(type, wellKnownTypes.Get(WellKnownTypeData.WellKnownType.Microsoft_AspNetCore_Http_IFormCollection))
101+
|| SymbolEqualityComparer.Default.Equals(type, wellKnownTypes.Get(WellKnownTypeData.WellKnownType.Microsoft_AspNetCore_Http_IFormFileCollection))
102+
|| SymbolEqualityComparer.Default.Equals(type, wellKnownTypes.Get(WellKnownTypeData.WellKnownType.Microsoft_AspNetCore_Http_IFormFile))
103+
|| SymbolEqualityComparer.Default.Equals(type, wellKnownTypes.Get(WellKnownTypeData.WellKnownType.System_IO_Stream))
104+
|| SymbolEqualityComparer.Default.Equals(type, wellKnownTypes.Get(WellKnownTypeData.WellKnownType.System_IO_Pipelines_PipeReader));
104105
}
105106

106107
internal static IPropertySymbol? FindPropertyIncludingBaseTypes(this INamedTypeSymbol typeSymbol, string propertyName)

src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Microsoft.AspNetCore.Http.ValidationsGenerator.csproj

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
<Compile Include="$(SharedSourceRoot)RoslynUtils\CodeWriter.cs" LinkBase="Shared" />
2929
<Compile Include="$(RepoRoot)\src\Http\Http.Extensions\gen\Microsoft.AspNetCore.Http.RequestDelegateGenerator\StaticRouteHandlerModel\InvocationOperationExtensions.cs" LinkBase="Shared" />
3030
<Compile Include="$(SharedSourceRoot)Diagnostics\AnalyzerDebug.cs" LinkBase="Shared" />
31+
<Compile Include="$(SharedSourceRoot)RoslynUtils\ParsabilityHelper.cs" LinkBase="Shared" />
32+
<Compile Include="$(SharedSourceRoot)RoslynUtils\SymbolExtensions.cs" LinkBase="Shared" />
3133
</ItemGroup>
3234

3335
</Project>

src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Parsers/ValidationsGenerator.AttributeParser.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System.Collections.Generic;
55
using System.Collections.Immutable;
66
using System.Threading;
7+
using Microsoft.AspNetCore.App.Analyzers.Infrastructure;
78
using Microsoft.CodeAnalysis;
89
using Microsoft.CodeAnalysis.CSharp.Syntax;
910

@@ -20,8 +21,8 @@ internal ImmutableArray<ValidatableType> TransformValidatableTypeWithAttribute(G
2021
{
2122
var validatableTypes = new HashSet<ValidatableType>(ValidatableTypeComparer.Instance);
2223
List<ITypeSymbol> visitedTypes = [];
23-
var requiredSymbols = ExtractRequiredSymbols(context.SemanticModel.Compilation, cancellationToken);
24-
if (TryExtractValidatableType((ITypeSymbol)context.TargetSymbol, requiredSymbols, ref validatableTypes, ref visitedTypes))
24+
var wellKnownTypes = WellKnownTypes.GetOrCreate(context.SemanticModel.Compilation);
25+
if (TryExtractValidatableType((ITypeSymbol)context.TargetSymbol, wellKnownTypes, ref validatableTypes, ref visitedTypes))
2526
{
2627
return [..validatableTypes];
2728
}

src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Parsers/ValidationsGenerator.EndpointsParser.cs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using System.Linq;
77
using System.Threading;
88
using Microsoft.AspNetCore.Analyzers.Infrastructure;
9+
using Microsoft.AspNetCore.App.Analyzers.Infrastructure;
910
using Microsoft.AspNetCore.Http.RequestDelegateGenerator.StaticRouteHandlerModel;
1011
using Microsoft.CodeAnalysis;
1112
using Microsoft.CodeAnalysis.CSharp.Syntax;
@@ -38,10 +39,12 @@ internal bool FindEndpoints(SyntaxNode syntaxNode, CancellationToken cancellatio
3839
: null;
3940
}
4041

41-
internal ImmutableArray<ValidatableType> ExtractValidatableEndpoint((IInvocationOperation? Operation, RequiredSymbols RequiredSymbols) input, CancellationToken cancellationToken)
42+
internal ImmutableArray<ValidatableType> ExtractValidatableEndpoint(IInvocationOperation? operation, CancellationToken cancellationToken)
4243
{
43-
AnalyzerDebug.Assert(input.Operation != null, "Operation should not be null.");
44-
var validatableTypes = ExtractValidatableTypes(input.Operation, input.RequiredSymbols);
44+
AnalyzerDebug.Assert(operation != null, "Operation should not be null.");
45+
AnalyzerDebug.Assert(operation.SemanticModel != null, "Operation should have a semantic model.");
46+
var wellKnownTypes = WellKnownTypes.GetOrCreate(operation.SemanticModel.Compilation);
47+
var validatableTypes = ExtractValidatableTypes(operation, wellKnownTypes);
4548
return validatableTypes;
4649
}
4750
}

src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Parsers/ValidationsGenerator.RequiredSymbolsParser.cs

Lines changed: 0 additions & 32 deletions
This file was deleted.

src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Parsers/ValidationsGenerator.TypesParser.cs

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System.Collections.Immutable;
66
using System.Linq;
77
using Microsoft.AspNetCore.Analyzers.Infrastructure;
8+
using Microsoft.AspNetCore.App.Analyzers.Infrastructure;
89
using Microsoft.AspNetCore.Http.RequestDelegateGenerator.StaticRouteHandlerModel;
910
using Microsoft.CodeAnalysis;
1011
using Microsoft.CodeAnalysis.CSharp;
@@ -18,7 +19,7 @@ public sealed partial class ValidationsGenerator : IIncrementalGenerator
1819
globalNamespaceStyle: SymbolDisplayGlobalNamespaceStyle.Included,
1920
typeQualificationStyle: SymbolDisplayTypeQualificationStyle.NameAndContainingTypesAndNamespaces);
2021

21-
internal ImmutableArray<ValidatableType> ExtractValidatableTypes(IInvocationOperation operation, RequiredSymbols requiredSymbols)
22+
internal ImmutableArray<ValidatableType> ExtractValidatableTypes(IInvocationOperation operation, WellKnownTypes wellKnownTypes)
2223
{
2324
AnalyzerDebug.Assert(operation.SemanticModel != null, "SemanticModel should not be null.");
2425
var parameters = operation.TryGetRouteHandlerMethod(operation.SemanticModel, out var method)
@@ -28,12 +29,12 @@ internal ImmutableArray<ValidatableType> ExtractValidatableTypes(IInvocationOper
2829
List<ITypeSymbol> visitedTypes = [];
2930
foreach (var parameter in parameters)
3031
{
31-
_ = TryExtractValidatableType(parameter.Type.UnwrapType(requiredSymbols.IEnumerable), requiredSymbols, ref validatableTypes, ref visitedTypes);
32+
_ = TryExtractValidatableType(parameter.Type.UnwrapType(wellKnownTypes.Get(WellKnownTypeData.WellKnownType.System_Collections_IEnumerable)), wellKnownTypes, ref validatableTypes, ref visitedTypes);
3233
}
3334
return [.. validatableTypes];
3435
}
3536

36-
internal bool TryExtractValidatableType(ITypeSymbol typeSymbol, RequiredSymbols requiredSymbols, ref HashSet<ValidatableType> validatableTypes, ref List<ITypeSymbol> visitedTypes)
37+
internal bool TryExtractValidatableType(ITypeSymbol typeSymbol, WellKnownTypes wellKnownTypes, ref HashSet<ValidatableType> validatableTypes, ref List<ITypeSymbol> visitedTypes)
3738
{
3839
if (typeSymbol.SpecialType != SpecialType.None)
3940
{
@@ -45,7 +46,7 @@ internal bool TryExtractValidatableType(ITypeSymbol typeSymbol, RequiredSymbols
4546
return true;
4647
}
4748

48-
if (typeSymbol.IsExemptType(requiredSymbols))
49+
if (typeSymbol.IsExemptType(wellKnownTypes))
4950
{
5051
return false;
5152
}
@@ -57,19 +58,23 @@ internal bool TryExtractValidatableType(ITypeSymbol typeSymbol, RequiredSymbols
5758
var hasValidatableBaseType = false;
5859
while (current != null && current.SpecialType != SpecialType.System_Object)
5960
{
60-
hasValidatableBaseType |= TryExtractValidatableType(current, requiredSymbols, ref validatableTypes, ref visitedTypes);
61+
hasValidatableBaseType |= TryExtractValidatableType(current, wellKnownTypes, ref validatableTypes, ref visitedTypes);
6162
current = current.BaseType;
6263
}
6364

6465
// Extract validatable types discovered in members of this type and add them to the top-level list.
65-
var members = ExtractValidatableMembers(typeSymbol, requiredSymbols, ref validatableTypes, ref visitedTypes);
66+
ImmutableArray<ValidatableProperty> members = [];
67+
if (ParsabilityHelper.GetParsability(typeSymbol, wellKnownTypes) is Parsability.NotParsable)
68+
{
69+
members = ExtractValidatableMembers(typeSymbol, wellKnownTypes, ref validatableTypes, ref visitedTypes);
70+
}
6671

6772
// Extract the validatable types discovered in the JsonDerivedTypeAttributes of this type and add them to the top-level list.
68-
var derivedTypes = typeSymbol.GetJsonDerivedTypes(requiredSymbols.JsonDerivedTypeAttribute);
73+
var derivedTypes = typeSymbol.GetJsonDerivedTypes(wellKnownTypes.Get(WellKnownTypeData.WellKnownType.System_Text_Json_Serialization_JsonDerivedTypeAttribute));
6974
var hasValidatableDerivedTypes = false;
7075
foreach (var derivedType in derivedTypes ?? [])
7176
{
72-
hasValidatableDerivedTypes |= TryExtractValidatableType(derivedType, requiredSymbols, ref validatableTypes, ref visitedTypes);
77+
hasValidatableDerivedTypes |= TryExtractValidatableType(derivedType, wellKnownTypes, ref validatableTypes, ref visitedTypes);
7378
}
7479

7580
// No validatable members or derived types found, so we don't need to add this type.
@@ -86,7 +91,7 @@ internal bool TryExtractValidatableType(ITypeSymbol typeSymbol, RequiredSymbols
8691
return true;
8792
}
8893

89-
internal ImmutableArray<ValidatableProperty> ExtractValidatableMembers(ITypeSymbol typeSymbol, RequiredSymbols requiredSymbols, ref HashSet<ValidatableType> validatableTypes, ref List<ITypeSymbol> visitedTypes)
94+
internal ImmutableArray<ValidatableProperty> ExtractValidatableMembers(ITypeSymbol typeSymbol, WellKnownTypes wellKnownTypes, ref HashSet<ValidatableType> validatableTypes, ref List<ITypeSymbol> visitedTypes)
9095
{
9196
var members = new List<ValidatableProperty>();
9297
var resolvedRecordProperty = new List<IPropertySymbol>();
@@ -121,17 +126,17 @@ internal ImmutableArray<ValidatableProperty> ExtractValidatableMembers(ITypeSymb
121126
// Check if the property's type is validatable, this resolves
122127
// validatable types in the inheritance hierarchy
123128
var hasValidatableType = TryExtractValidatableType(
124-
correspondingProperty.Type.UnwrapType(requiredSymbols.IEnumerable),
125-
requiredSymbols,
129+
correspondingProperty.Type.UnwrapType(wellKnownTypes.Get(WellKnownTypeData.WellKnownType.System_Collections_IEnumerable)),
130+
wellKnownTypes,
126131
ref validatableTypes,
127132
ref visitedTypes);
128133

129134
members.Add(new ValidatableProperty(
130135
ContainingType: correspondingProperty.ContainingType,
131136
Type: correspondingProperty.Type,
132137
Name: correspondingProperty.Name,
133-
DisplayName: parameter.GetDisplayName(requiredSymbols.DisplayAttribute) ??
134-
correspondingProperty.GetDisplayName(requiredSymbols.DisplayAttribute),
138+
DisplayName: parameter.GetDisplayName(wellKnownTypes.Get(WellKnownTypeData.WellKnownType.System_ComponentModel_DataAnnotations_DisplayAttribute)) ??
139+
correspondingProperty.GetDisplayName(wellKnownTypes.Get(WellKnownTypeData.WellKnownType.System_ComponentModel_DataAnnotations_DisplayAttribute)),
135140
Attributes: []));
136141
}
137142
}
@@ -148,8 +153,8 @@ internal ImmutableArray<ValidatableProperty> ExtractValidatableMembers(ITypeSymb
148153
continue;
149154
}
150155

151-
var hasValidatableType = TryExtractValidatableType(member.Type.UnwrapType(requiredSymbols.IEnumerable), requiredSymbols, ref validatableTypes, ref visitedTypes);
152-
var attributes = ExtractValidationAttributes(member, requiredSymbols, out var isRequired);
156+
var hasValidatableType = TryExtractValidatableType(member.Type.UnwrapType(wellKnownTypes.Get(WellKnownTypeData.WellKnownType.System_Collections_IEnumerable)), wellKnownTypes, ref validatableTypes, ref visitedTypes);
157+
var attributes = ExtractValidationAttributes(member, wellKnownTypes, out var isRequired);
153158

154159
// If the member has no validation attributes or validatable types and is not required, skip it.
155160
if (attributes.IsDefaultOrEmpty && !hasValidatableType && !isRequired)
@@ -161,14 +166,14 @@ internal ImmutableArray<ValidatableProperty> ExtractValidatableMembers(ITypeSymb
161166
ContainingType: member.ContainingType,
162167
Type: member.Type,
163168
Name: member.Name,
164-
DisplayName: member.GetDisplayName(requiredSymbols.DisplayAttribute),
169+
DisplayName: member.GetDisplayName(wellKnownTypes.Get(WellKnownTypeData.WellKnownType.System_ComponentModel_DataAnnotations_DisplayAttribute)),
165170
Attributes: attributes));
166171
}
167172

168173
return [.. members];
169174
}
170175

171-
internal static ImmutableArray<ValidationAttribute> ExtractValidationAttributes(ISymbol symbol, RequiredSymbols requiredSymbols, out bool isRequired)
176+
internal static ImmutableArray<ValidationAttribute> ExtractValidationAttributes(ISymbol symbol, WellKnownTypes wellKnownTypes, out bool isRequired)
172177
{
173178
var attributes = symbol.GetAttributes();
174179
if (attributes.Length == 0)
@@ -179,15 +184,15 @@ internal static ImmutableArray<ValidationAttribute> ExtractValidationAttributes(
179184

180185
var validationAttributes = attributes
181186
.Where(attribute => attribute.AttributeClass != null)
182-
.Where(attribute => attribute.AttributeClass!.ImplementsValidationAttribute(requiredSymbols.ValidationAttribute));
183-
isRequired = validationAttributes.Any(attr => SymbolEqualityComparer.Default.Equals(attr.AttributeClass, requiredSymbols.RequiredAttribute));
187+
.Where(attribute => attribute.AttributeClass!.ImplementsValidationAttribute(wellKnownTypes.Get(WellKnownTypeData.WellKnownType.System_ComponentModel_DataAnnotations_ValidationAttribute)));
188+
isRequired = validationAttributes.Any(attr => SymbolEqualityComparer.Default.Equals(attr.AttributeClass, wellKnownTypes.Get(WellKnownTypeData.WellKnownType.System_ComponentModel_DataAnnotations_RequiredAttribute)));
184189
return [.. validationAttributes
185-
.Where(attr => !SymbolEqualityComparer.Default.Equals(attr.AttributeClass, requiredSymbols.ValidationAttribute))
190+
.Where(attr => !SymbolEqualityComparer.Default.Equals(attr.AttributeClass, wellKnownTypes.Get(WellKnownTypeData.WellKnownType.System_ComponentModel_DataAnnotations_ValidationAttribute)))
186191
.Select(attribute => new ValidationAttribute(
187192
Name: symbol.Name + attribute.AttributeClass!.Name,
188193
ClassName: attribute.AttributeClass!.ToDisplayString(_symbolDisplayFormat),
189194
Arguments: [.. attribute.ConstructorArguments.Select(a => a.ToCSharpString())],
190195
NamedArguments: attribute.NamedArguments.ToDictionary(namedArgument => namedArgument.Key, namedArgument => namedArgument.Value.ToCSharpString()),
191-
IsCustomValidationAttribute: SymbolEqualityComparer.Default.Equals(attribute.AttributeClass, requiredSymbols.CustomValidationAttribute)))];
196+
IsCustomValidationAttribute: SymbolEqualityComparer.Default.Equals(attribute.AttributeClass, wellKnownTypes.Get(WellKnownTypeData.WellKnownType.System_ComponentModel_DataAnnotations_CustomValidationAttribute))))];
192197
}
193198
}

src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/ValidationsGenerator.cs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,6 @@ public sealed partial class ValidationsGenerator : IIncrementalGenerator
1111
{
1212
public void Initialize(IncrementalGeneratorInitializationContext context)
1313
{
14-
// Resolve the symbols that will be required when making comparisons
15-
// in future steps.
16-
var requiredSymbols = context.CompilationProvider.Select(ExtractRequiredSymbols);
17-
1814
// Find the builder.Services.AddValidation() call in the application.
1915
var addValidation = context.SyntaxProvider.CreateSyntaxProvider(
2016
predicate: FindAddValidation,
@@ -34,7 +30,6 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
3430
.Where(endpoint => endpoint is not null);
3531
// Extract validatable types from all endpoints.
3632
var validatableTypesFromEndpoints = endpoints
37-
.Combine(requiredSymbols)
3833
.Select(ExtractValidatableEndpoint);
3934
// Join all validatable types encountered in the type graph.
4035
var validatableTypes = validatableTypesWithAttribute

0 commit comments

Comments
 (0)