Skip to content

Commit f077308

Browse files
authored
Merge pull request #856 from Sergio0694/dev/shader-explicit-entry-points
Support explicit interface implementations for shader entry points
2 parents 57f0a9b + f7b03ea commit f077308

File tree

11 files changed

+249
-32
lines changed

11 files changed

+249
-32
lines changed

src/ComputeSharp.D2D1.SourceGenerators/D2DPixelShaderDescriptorGenerator.HlslSource.cs

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,17 @@ private static partial class HlslSource
2929
/// <param name="diagnostics">The collection of produced <see cref="DiagnosticInfo"/> instances.</param>
3030
/// <param name="compilation">The input <see cref="Compilation"/> object currently in use.</param>
3131
/// <param name="structDeclarationSymbol">The <see cref="INamedTypeSymbol"/> for the shader type.</param>
32+
/// <param name="shaderInterfaceType">The shader interface type implemented by the shader type.</param>
3233
/// <param name="inputCount">The number of inputs for the shader.</param>
33-
/// <param name="inputSimpleIndices">The indicess of the simple shader inputs.</param>
34-
/// <param name="inputComplexIndices">The indicess of the complex shader inputs.</param>
34+
/// <param name="inputSimpleIndices">The indices of the simple shader inputs.</param>
35+
/// <param name="inputComplexIndices">The indices of the complex shader inputs.</param>
3536
/// <param name="token">The <see cref="CancellationToken"/> used to cancel the operation, if needed.</param>
3637
/// <returns>The HLSL source for the shader.</returns>
3738
public static string GetHlslSource(
3839
ImmutableArrayBuilder<DiagnosticInfo> diagnostics,
3940
Compilation compilation,
4041
INamedTypeSymbol structDeclarationSymbol,
42+
INamedTypeSymbol shaderInterfaceType,
4143
int inputCount,
4244
ImmutableArray<int> inputSimpleIndices,
4345
ImmutableArray<int> inputComplexIndices,
@@ -46,6 +48,8 @@ public static string GetHlslSource(
4648
// Detect any invalid properties
4749
HlslDefinitionsSyntaxProcessor.DetectAndReportInvalidPropertyDeclarations(diagnostics, structDeclarationSymbol);
4850

51+
token.ThrowIfCancellationRequested();
52+
4953
// We need to sets to track all discovered custom types and static methods
5054
HashSet<INamedTypeSymbol> discoveredTypes = new(SymbolEqualityComparer.Default);
5155
Dictionary<IMethodSymbol, MethodDeclarationSyntax> staticMethods = new(SymbolEqualityComparer.Default);
@@ -72,6 +76,7 @@ public static string GetHlslSource(
7276
(string entryPoint, ImmutableArray<HlslMethod> processedMethods) = GetProcessedMethods(
7377
diagnostics,
7478
structDeclarationSymbol,
79+
shaderInterfaceType,
7580
semanticModelProvider,
7681
discoveredTypes,
7782
staticMethods,
@@ -302,6 +307,7 @@ private static ImmutableArray<HlslStaticField> GetStaticFields(
302307
/// </summary>
303308
/// <param name="diagnostics">The collection of produced <see cref="DiagnosticInfo"/> instances.</param>
304309
/// <param name="structDeclarationSymbol">The type symbol for the shader type.</param>
310+
/// <param name="shaderInterfaceType">The shader interface type implemented by the shader type.</param>
305311
/// <param name="semanticModel">The <see cref="SemanticModelProvider"/> instance for the type to process.</param>
306312
/// <param name="discoveredTypes">The collection of currently discovered types.</param>
307313
/// <param name="staticMethods">The set of discovered and processed static methods.</param>
@@ -315,6 +321,7 @@ private static ImmutableArray<HlslStaticField> GetStaticFields(
315321
private static (string EntryPoint, ImmutableArray<HlslMethod> Methods) GetProcessedMethods(
316322
ImmutableArrayBuilder<DiagnosticInfo> diagnostics,
317323
INamedTypeSymbol structDeclarationSymbol,
324+
INamedTypeSymbol shaderInterfaceType,
318325
SemanticModelProvider semanticModel,
319326
ICollection<INamedTypeSymbol> discoveredTypes,
320327
IDictionary<IMethodSymbol, MethodDeclarationSyntax> staticMethods,
@@ -327,6 +334,7 @@ private static (string EntryPoint, ImmutableArray<HlslMethod> Methods) GetProces
327334
{
328335
using ImmutableArrayBuilder<HlslMethod> methods = new();
329336

337+
IMethodSymbol entryPointInterfaceMethod = shaderInterfaceType.GetMethod("Execute")!;
330338
string? entryPoint = null;
331339

332340
// By default, the scene position is not required. We will set this while
@@ -341,16 +349,17 @@ private static (string EntryPoint, ImmutableArray<HlslMethod> Methods) GetProces
341349
continue;
342350
}
343351

352+
// Ensure that we have accessible source information
344353
if (!methodSymbol.TryGetSyntaxNode(token, out MethodDeclarationSyntax? methodDeclaration))
345354
{
346355
continue;
347356
}
348357

349-
bool isShaderEntryPoint =
350-
methodSymbol.Name == "Execute" &&
351-
methodSymbol.ReturnType.HasFullyQualifiedMetadataName("ComputeSharp.Float4") &&
352-
methodSymbol.TypeParameters.Length == 0 &&
353-
methodSymbol.Parameters.Length == 0;
358+
// Check whether the current method is the entry point (ie. it's implementing 'Execute').
359+
// This is the same logic as in the DX12 generator for compute shaders and pixel shaders.
360+
bool isShaderEntryPoint = SymbolEqualityComparer.Default.Equals(
361+
structDeclarationSymbol.FindImplementationForInterfaceMember(entryPointInterfaceMethod),
362+
methodSymbol);
354363

355364
// Except for the entry point, ignore explicit interface implementations
356365
if (!isShaderEntryPoint && !methodSymbol.ExplicitInterfaceImplementations.IsDefaultOrEmpty)
@@ -460,8 +469,8 @@ private static bool GetD2DRequiresScenePositionInfo(INamedTypeSymbol structDecla
460469
/// <param name="typeMethodDeclarations"><inheritdoc cref="HlslSourceSyntaxProcessor.WriteMethodDeclarations" path="/param[@name='typeMethodDeclarations']/node()"/></param>
461470
/// <param name="executeMethod">The body of the entry point of the shader.</param>
462471
/// <param name="inputCount">The number of shader inputs to declare.</param>
463-
/// <param name="inputSimpleIndices">The indicess of the simple shader inputs.</param>
464-
/// <param name="inputComplexIndices">The indicess of the complex shader inputs.</param>
472+
/// <param name="inputSimpleIndices">The indices of the simple shader inputs.</param>
473+
/// <param name="inputComplexIndices">The indices of the complex shader inputs.</param>
465474
/// <param name="requiresScenePosition">Whether the shader requires the scene position.</param>
466475
/// <returns>The series of statements to build the HLSL source to compile to execute the current shader.</returns>
467476
private static string GetHlslSource(

src/ComputeSharp.D2D1.SourceGenerators/D2DPixelShaderDescriptorGenerator.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,10 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
5353
return default;
5454
}
5555

56+
INamedTypeSymbol shaderInterfaceType = context.SemanticModel.Compilation.GetTypeByMetadataName("ComputeSharp.D2D1.ID2D1PixelShader")!;
57+
5658
// Check that the shader implements the ID2D1PixelShader interface
57-
if (!typeSymbol.HasInterfaceWithType(context.SemanticModel.Compilation.GetTypeByMetadataName("ComputeSharp.D2D1.ID2D1PixelShader")!))
59+
if (!typeSymbol.HasInterfaceWithType(shaderInterfaceType))
5860
{
5961
return default;
6062
}
@@ -142,6 +144,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
142144
diagnostics,
143145
context.SemanticModel.Compilation,
144146
typeSymbol,
147+
shaderInterfaceType,
145148
inputCount,
146149
inputSimpleIndices,
147150
inputComplexIndices,

src/ComputeSharp.SourceGeneration.Hlsl/SyntaxRewriters/ShaderSourceRewriter.cs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,14 +154,23 @@ internal sealed partial class ShaderSourceRewriter(
154154
Diagnostics.Add(UnsafeModifierOnMethodOrFunction, node);
155155
}
156156

157-
if (updatedNode is not null)
157+
// Add the tracked implicit declarations (at the start of the body).
158+
// To optimize, we only do this if we do have any implicit variables.
159+
if (this.implicitVariables.Count > 0)
158160
{
159161
BlockSyntax implicitBlock = Block(this.implicitVariables.Select(static v => LocalDeclarationStatement(v)).ToArray());
160162

161-
// Add the tracked implicit declarations (at the start of the body)
162163
updatedNode = updatedNode.WithBody(implicitBlock).AddBodyStatements([.. updatedNode.Body!.Statements]);
163164
}
164165

166+
// The entry point might be an explicit interface method implementation. In that case,
167+
// the transpiled method will have the rewritten interface name as a prefix for the
168+
// method name, which we don't want (it's invalid HLSL). So in that case, remove it.
169+
if (this.isEntryPoint && updatedNode.ExplicitInterfaceSpecifier is not null)
170+
{
171+
updatedNode = updatedNode.WithExplicitInterfaceSpecifier(null);
172+
}
173+
165174
return updatedNode;
166175
}
167176

src/ComputeSharp.SourceGeneration/Extensions/ITypeSymbolExtensions.cs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,26 @@ namespace ComputeSharp.SourceGeneration.Extensions;
99
/// </summary>
1010
internal static class ITypeSymbolExtensions
1111
{
12+
/// <summary>
13+
/// Gets the method of this symbol that have a particular name.
14+
/// </summary>
15+
/// <param name="symbol">The input <see cref="ITypeSymbol"/> instance to check.</param>
16+
/// <param name="name">The name of the method to find.</param>
17+
/// <returns>The target method, if present.</returns>
18+
public static IMethodSymbol? GetMethod(this ITypeSymbol symbol, string name)
19+
{
20+
foreach (ISymbol memberSymbol in symbol.GetMembers(name))
21+
{
22+
if (memberSymbol is IMethodSymbol methodSymbol &&
23+
memberSymbol.Name == name)
24+
{
25+
return methodSymbol;
26+
}
27+
}
28+
29+
return null;
30+
}
31+
1232
/// <summary>
1333
/// Checks whether or not a given type symbol has a specified fully qualified metadata name.
1434
/// </summary>
@@ -28,7 +48,7 @@ public static bool HasFullyQualifiedMetadataName(this ITypeSymbol symbol, string
2848
/// Checks whether or not a given <see cref="ITypeSymbol"/> implements an interface of a specified type.
2949
/// </summary>
3050
/// <param name="typeSymbol">The target <see cref="ITypeSymbol"/> instance to check.</param>
31-
/// <param name="interfaceSymbol">The <see cref="ITypeSymbol"/> instane to check for inheritance from.</param>
51+
/// <param name="interfaceSymbol">The <see cref="ITypeSymbol"/> instance to check for inheritance from.</param>
3252
/// <returns>Whether or not <paramref name="typeSymbol"/> has an interface of type <paramref name="interfaceSymbol"/>.</returns>
3353
public static bool HasInterfaceWithType(this ITypeSymbol typeSymbol, ITypeSymbol interfaceSymbol)
3454
{

src/ComputeSharp.SourceGenerators/ComputeShaderDescriptorGenerator.Helpers.cs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using System.Diagnostics.CodeAnalysis;
12
using Microsoft.CodeAnalysis;
23

34
namespace ComputeSharp.SourceGenerators;
@@ -10,9 +11,14 @@ partial class ComputeShaderDescriptorGenerator
1011
/// </summary>
1112
/// <param name="typeSymbol">The input <see cref="INamedTypeSymbol"/> instance to check.</param>
1213
/// <param name="compilation">The <see cref="Compilation"/> instance currently in use.</param>
14+
/// <param name="shaderInterfaceType">The (constructed) shader interface type implemented by the shader type.</param>
1315
/// <param name="isPixelShaderLike">Whether <paramref name="typeSymbol"/> is a "pixel shader like" type.</param>
1416
/// <returns>Whether <paramref name="typeSymbol"/> is a compute shader type at all.</returns>
15-
private static bool TryGetIsPixelShaderLike(INamedTypeSymbol typeSymbol, Compilation compilation, out bool isPixelShaderLike)
17+
private static bool TryGetIsPixelShaderLike(
18+
INamedTypeSymbol typeSymbol,
19+
Compilation compilation,
20+
[NotNullWhen(true)] out INamedTypeSymbol? shaderInterfaceType,
21+
out bool isPixelShaderLike)
1622
{
1723
INamedTypeSymbol computeShaderSymbol = compilation.GetTypeByMetadataName("ComputeSharp.IComputeShader")!;
1824
INamedTypeSymbol pixelShaderSymbol = compilation.GetTypeByMetadataName("ComputeSharp.IComputeShader`1")!;
@@ -21,18 +27,21 @@ private static bool TryGetIsPixelShaderLike(INamedTypeSymbol typeSymbol, Compila
2127
{
2228
if (SymbolEqualityComparer.Default.Equals(interfaceSymbol, computeShaderSymbol))
2329
{
30+
shaderInterfaceType = interfaceSymbol;
2431
isPixelShaderLike = false;
2532

2633
return true;
2734
}
2835
else if (SymbolEqualityComparer.Default.Equals(interfaceSymbol.ConstructedFrom, pixelShaderSymbol))
2936
{
37+
shaderInterfaceType = interfaceSymbol;
3038
isPixelShaderLike = true;
3139

3240
return true;
3341
}
3442
}
3543

44+
shaderInterfaceType = null;
3645
isPixelShaderLike = false;
3746

3847
return false;

src/ComputeSharp.SourceGenerators/ComputeShaderDescriptorGenerator.HlslSource.cs

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ internal static partial class HlslSource
3131
/// <param name="diagnostics">The collection of produced <see cref="DiagnosticInfo"/> instances.</param>
3232
/// <param name="compilation">The input <see cref="Compilation"/> object currently in use.</param>
3333
/// <param name="structDeclarationSymbol">The <see cref="INamedTypeSymbol"/> for the shader type.</param>
34+
/// <param name="shaderInterfaceType">The shader interface type implemented by the shader type.</param>
35+
/// <param name="isPixelShaderLike">Whether <paramref name="structDeclarationSymbol"/> is a "pixel shader like" type.</param>
3436
/// <param name="threadsX">The thread ids value for the X axis.</param>
3537
/// <param name="threadsY">The thread ids value for the Y axis.</param>
3638
/// <param name="threadsZ">The thread ids value for the Z axis.</param>
@@ -42,6 +44,8 @@ public static void GetInfo(
4244
ImmutableArrayBuilder<DiagnosticInfo> diagnostics,
4345
Compilation compilation,
4446
INamedTypeSymbol structDeclarationSymbol,
47+
INamedTypeSymbol shaderInterfaceType,
48+
bool isPixelShaderLike,
4549
int threadsX,
4650
int threadsY,
4751
int threadsZ,
@@ -53,6 +57,8 @@ public static void GetInfo(
5357
// Detect any invalid properties
5458
HlslDefinitionsSyntaxProcessor.DetectAndReportInvalidPropertyDeclarations(diagnostics, structDeclarationSymbol);
5559

60+
token.ThrowIfCancellationRequested();
61+
5662
// We need to sets to track all discovered custom types and static methods
5763
HashSet<INamedTypeSymbol> discoveredTypes = new(SymbolEqualityComparer.Default);
5864
Dictionary<IMethodSymbol, MethodDeclarationSyntax> staticMethods = new(SymbolEqualityComparer.Default);
@@ -62,9 +68,8 @@ public static void GetInfo(
6268
Dictionary<IFieldSymbol, HlslStaticField> staticFieldDefinitions = new(SymbolEqualityComparer.Default);
6369

6470
// Setup the semantic model and basic properties
65-
INamedTypeSymbol? pixelShaderSymbol = structDeclarationSymbol.AllInterfaces.FirstOrDefault(static interfaceSymbol => interfaceSymbol is { IsGenericType: true, Name: "IComputeShader" });
66-
bool isComputeShader = pixelShaderSymbol is null;
67-
string? implicitTextureType = isComputeShader ? null : HlslKnownTypes.GetMappedNameForPixelShaderType(pixelShaderSymbol!);
71+
bool isComputeShader = !isPixelShaderLike;
72+
string? implicitTextureType = HlslKnownTypes.GetMappedNameForPixelShaderType(shaderInterfaceType);
6873

6974
token.ThrowIfCancellationRequested();
7075

@@ -90,6 +95,7 @@ public static void GetInfo(
9095
(string entryPoint, ImmutableArray<HlslMethod> processedMethods, isSamplerUsed) = GetProcessedMethods(
9196
diagnostics,
9297
structDeclarationSymbol,
98+
shaderInterfaceType,
9399
semanticModelProvider,
94100
discoveredTypes,
95101
staticMethods,
@@ -360,6 +366,7 @@ private static ImmutableArray<HlslSharedBuffer> GetSharedBuffers(
360366
/// </summary>
361367
/// <param name="diagnostics">The collection of produced <see cref="DiagnosticInfo"/> instances.</param>
362368
/// <param name="structDeclarationSymbol">The type symbol for the shader type.</param>
369+
/// <param name="shaderInterfaceType">The shader interface type implemented by the shader type.</param>
363370
/// <param name="semanticModel">The <see cref="SemanticModelProvider"/> instance for the type to process.</param>
364371
/// <param name="discoveredTypes">The collection of currently discovered types.</param>
365372
/// <param name="staticMethods">The set of discovered and processed static methods.</param>
@@ -373,6 +380,7 @@ private static ImmutableArray<HlslSharedBuffer> GetSharedBuffers(
373380
private static (string EntryPoint, ImmutableArray<HlslMethod> Methods, bool IsSamplerUser) GetProcessedMethods(
374381
ImmutableArrayBuilder<DiagnosticInfo> diagnostics,
375382
INamedTypeSymbol structDeclarationSymbol,
383+
INamedTypeSymbol shaderInterfaceType,
376384
SemanticModelProvider semanticModel,
377385
ICollection<INamedTypeSymbol> discoveredTypes,
378386
IDictionary<IMethodSymbol, MethodDeclarationSyntax> staticMethods,
@@ -385,6 +393,7 @@ private static (string EntryPoint, ImmutableArray<HlslMethod> Methods, bool IsSa
385393
{
386394
using ImmutableArrayBuilder<HlslMethod> methods = new();
387395

396+
IMethodSymbol entryPointInterfaceMethod = shaderInterfaceType.GetMethod("Execute")!;
388397
string? entryPoint = null;
389398
bool isSamplerUsed = false;
390399

@@ -396,22 +405,17 @@ private static (string EntryPoint, ImmutableArray<HlslMethod> Methods, bool IsSa
396405
continue;
397406
}
398407

408+
// Ensure that we have accessible source information
399409
if (!methodSymbol.TryGetSyntaxNode(token, out MethodDeclarationSyntax? methodDeclaration))
400410
{
401411
continue;
402412
}
403413

404-
bool isShaderEntryPoint =
405-
(isComputeShader &&
406-
methodSymbol.Name == "Execute" &&
407-
methodSymbol.ReturnsVoid &&
408-
methodSymbol.TypeParameters.Length == 0 &&
409-
methodSymbol.Parameters.Length == 0) ||
410-
(!isComputeShader &&
411-
methodSymbol.Name == "Execute" &&
412-
methodSymbol.ReturnType is not null && // TODO: match for pixel type
413-
methodSymbol.TypeParameters.Length == 0 &&
414-
methodSymbol.Parameters.Length == 0);
414+
// Check whether the current method is the entry point (ie. it's implementing 'Execute'). We use
415+
// 'FindImplementationForInterfaceMember' to handle explicit interface implementations as well.
416+
bool isShaderEntryPoint = SymbolEqualityComparer.Default.Equals(
417+
structDeclarationSymbol.FindImplementationForInterfaceMember(entryPointInterfaceMethod),
418+
methodSymbol);
415419

416420
// Except for the entry point, ignore explicit interface implementations
417421
if (!isShaderEntryPoint && !methodSymbol.ExplicitInterfaceImplementations.IsDefaultOrEmpty)

0 commit comments

Comments
 (0)