@@ -29,15 +29,17 @@ private static partial class HlslSource
2929 /// <param name="diagnostics">The collection of produced <see cref="DiagnosticInfo"/> instances.</param>
3030 /// <param name="compilation">The input <see cref="Compilation"/> object currently in use.</param>
3131 /// <param name="structDeclarationSymbol">The <see cref="INamedTypeSymbol"/> for the shader type.</param>
32+ /// <param name="shaderInterfaceType">The shader interface type implemented by the shader type.</param>
3233 /// <param name="inputCount">The number of inputs for the shader.</param>
33- /// <param name="inputSimpleIndices">The indicess of the simple shader inputs.</param>
34- /// <param name="inputComplexIndices">The indicess of the complex shader inputs.</param>
34+ /// <param name="inputSimpleIndices">The indices of the simple shader inputs.</param>
35+ /// <param name="inputComplexIndices">The indices of the complex shader inputs.</param>
3536 /// <param name="token">The <see cref="CancellationToken"/> used to cancel the operation, if needed.</param>
3637 /// <returns>The HLSL source for the shader.</returns>
3738 public static string GetHlslSource (
3839 ImmutableArrayBuilder < DiagnosticInfo > diagnostics ,
3940 Compilation compilation ,
4041 INamedTypeSymbol structDeclarationSymbol ,
42+ INamedTypeSymbol shaderInterfaceType ,
4143 int inputCount ,
4244 ImmutableArray < int > inputSimpleIndices ,
4345 ImmutableArray < int > inputComplexIndices ,
@@ -46,6 +48,8 @@ public static string GetHlslSource(
4648 // Detect any invalid properties
4749 HlslDefinitionsSyntaxProcessor . DetectAndReportInvalidPropertyDeclarations ( diagnostics , structDeclarationSymbol ) ;
4850
51+ token . ThrowIfCancellationRequested ( ) ;
52+
4953 // We need to sets to track all discovered custom types and static methods
5054 HashSet < INamedTypeSymbol > discoveredTypes = new ( SymbolEqualityComparer . Default ) ;
5155 Dictionary < IMethodSymbol , MethodDeclarationSyntax > staticMethods = new ( SymbolEqualityComparer . Default ) ;
@@ -72,6 +76,7 @@ public static string GetHlslSource(
7276 ( string entryPoint , ImmutableArray < HlslMethod > processedMethods ) = GetProcessedMethods (
7377 diagnostics ,
7478 structDeclarationSymbol ,
79+ shaderInterfaceType ,
7580 semanticModelProvider ,
7681 discoveredTypes ,
7782 staticMethods ,
@@ -302,6 +307,7 @@ private static ImmutableArray<HlslStaticField> GetStaticFields(
302307 /// </summary>
303308 /// <param name="diagnostics">The collection of produced <see cref="DiagnosticInfo"/> instances.</param>
304309 /// <param name="structDeclarationSymbol">The type symbol for the shader type.</param>
310+ /// <param name="shaderInterfaceType">The shader interface type implemented by the shader type.</param>
305311 /// <param name="semanticModel">The <see cref="SemanticModelProvider"/> instance for the type to process.</param>
306312 /// <param name="discoveredTypes">The collection of currently discovered types.</param>
307313 /// <param name="staticMethods">The set of discovered and processed static methods.</param>
@@ -315,6 +321,7 @@ private static ImmutableArray<HlslStaticField> GetStaticFields(
315321 private static ( string EntryPoint , ImmutableArray < HlslMethod > Methods ) GetProcessedMethods (
316322 ImmutableArrayBuilder < DiagnosticInfo > diagnostics ,
317323 INamedTypeSymbol structDeclarationSymbol ,
324+ INamedTypeSymbol shaderInterfaceType ,
318325 SemanticModelProvider semanticModel ,
319326 ICollection < INamedTypeSymbol > discoveredTypes ,
320327 IDictionary < IMethodSymbol , MethodDeclarationSyntax > staticMethods ,
@@ -327,6 +334,7 @@ private static (string EntryPoint, ImmutableArray<HlslMethod> Methods) GetProces
327334 {
328335 using ImmutableArrayBuilder < HlslMethod > methods = new ( ) ;
329336
337+ IMethodSymbol entryPointInterfaceMethod = shaderInterfaceType . GetMethod ( "Execute" ) ! ;
330338 string ? entryPoint = null ;
331339
332340 // By default, the scene position is not required. We will set this while
@@ -341,16 +349,17 @@ private static (string EntryPoint, ImmutableArray<HlslMethod> Methods) GetProces
341349 continue ;
342350 }
343351
352+ // Ensure that we have accessible source information
344353 if ( ! methodSymbol . TryGetSyntaxNode ( token , out MethodDeclarationSyntax ? methodDeclaration ) )
345354 {
346355 continue ;
347356 }
348357
349- bool isShaderEntryPoint =
350- methodSymbol . Name == "Execute" &&
351- methodSymbol . ReturnType . HasFullyQualifiedMetadataName ( "ComputeSharp.Float4" ) &&
352- methodSymbol . TypeParameters . Length == 0 &&
353- methodSymbol . Parameters . Length == 0 ;
358+ // Check whether the current method is the entry point (ie. it's implementing 'Execute').
359+ // This is the same logic as in the DX12 generator for compute shaders and pixel shaders.
360+ bool isShaderEntryPoint = SymbolEqualityComparer . Default . Equals (
361+ structDeclarationSymbol . FindImplementationForInterfaceMember ( entryPointInterfaceMethod ) ,
362+ methodSymbol ) ;
354363
355364 // Except for the entry point, ignore explicit interface implementations
356365 if ( ! isShaderEntryPoint && ! methodSymbol . ExplicitInterfaceImplementations . IsDefaultOrEmpty )
@@ -460,8 +469,8 @@ private static bool GetD2DRequiresScenePositionInfo(INamedTypeSymbol structDecla
460469 /// <param name="typeMethodDeclarations"><inheritdoc cref="HlslSourceSyntaxProcessor.WriteMethodDeclarations" path="/param[@name='typeMethodDeclarations']/node()"/></param>
461470 /// <param name="executeMethod">The body of the entry point of the shader.</param>
462471 /// <param name="inputCount">The number of shader inputs to declare.</param>
463- /// <param name="inputSimpleIndices">The indicess of the simple shader inputs.</param>
464- /// <param name="inputComplexIndices">The indicess of the complex shader inputs.</param>
472+ /// <param name="inputSimpleIndices">The indices of the simple shader inputs.</param>
473+ /// <param name="inputComplexIndices">The indices of the complex shader inputs.</param>
465474 /// <param name="requiresScenePosition">Whether the shader requires the scene position.</param>
466475 /// <returns>The series of statements to build the HLSL source to compile to execute the current shader.</returns>
467476 private static string GetHlslSource (
0 commit comments