Skip to content

[HLSL][RootSignature] Implement diagnostic for missed comma #147350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

inbelic
Copy link
Contributor

@inbelic inbelic commented Jul 7, 2025

This pr fixes a bug that allows parameters to be specified without an intermediate comma.

After this pr, we will correctly produce a diagnostic for (eg):

RootFlags(0) CBV(b0)

This pr updates the problematic code pattern containing a chain of 'if' statements to a chain of 'else if' statements, to prevent parsing of an element before checking for a comma.

This pr also does 2 small updates, while in the region:

  1. Simplify the do loop that these if statements are contained in. This helps code readability and makes it easier to improve the diagnostics further
  2. Moves the consumeExpectedToken function calls to be right after the parse.*Params invocation. This will ensure that the comma or invalid token error is presented before a "missed mandatory param" diagnostic.
  • Updates all occurrences of the if chains with an else-if chain
  • Simplifies the surrounding do loop to be an easier to understand while loop
  • Moves the consumeExpectedToken diagnostic right after the loop so that the missing comma diagnostic is produce before checking for any missed mandatory arguments
  • Adds unit tests for this scenario
  • Small fix to the diagnostic of RootDescriptors to use their respective Token instead of RootConstants

Resolves: #147337

@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" HLSL HLSL Language Support labels Jul 7, 2025
@llvmbot
Copy link
Member

llvmbot commented Jul 7, 2025

@llvm/pr-subscribers-clang

Author: Finn Plummer (inbelic)

Changes

TODO: will fill in.

Resolves: #147337


Patch is 47.39 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147350.diff

5 Files Affected:

  • (modified) clang/include/clang/Basic/DiagnosticParseKinds.td (+1-2)
  • (modified) clang/include/clang/Parse/ParseHLSLRootSignature.h (+10-4)
  • (modified) clang/lib/Parse/ParseHLSLRootSignature.cpp (+198-289)
  • (modified) clang/test/SemaHLSL/RootSignature-err.hlsl (+13-5)
  • (modified) clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp (+164-2)
diff --git a/clang/include/clang/Basic/DiagnosticParseKinds.td b/clang/include/clang/Basic/DiagnosticParseKinds.td
index 6c30da376dafb..7c636d9049762 100644
--- a/clang/include/clang/Basic/DiagnosticParseKinds.td
+++ b/clang/include/clang/Basic/DiagnosticParseKinds.td
@@ -1857,8 +1857,7 @@ def err_hlsl_virtual_inheritance
     : Error<"virtual inheritance is unsupported in HLSL">;
 
 // HLSL Root Signature Parser Diagnostics
-def err_hlsl_unexpected_end_of_params
-    : Error<"expected %0 to denote end of parameters, or, another valid parameter of %1">;
+def err_hlsl_rootsig_invalid_param : Error<"invalid parameter of %0">;
 def err_hlsl_rootsig_repeat_param : Error<"specified the same parameter '%0' multiple times">;
 def err_hlsl_rootsig_missing_param : Error<"did not specify mandatory parameter '%0'">;
 def err_hlsl_number_literal_overflow : Error<
diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index 9ef5b64d7b4a5..2846b83ee959b 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -100,7 +100,8 @@ class RootSignatureParser {
     std::optional<llvm::dxbc::RootDescriptorFlags> Flags;
   };
   std::optional<ParsedRootDescriptorParams>
-  parseRootDescriptorParams(RootSignatureToken::Kind RegType);
+  parseRootDescriptorParams(RootSignatureToken::Kind DescType,
+                            RootSignatureToken::Kind RegType);
 
   struct ParsedClauseParams {
     std::optional<llvm::hlsl::rootsig::Register> Reg;
@@ -110,7 +111,8 @@ class RootSignatureParser {
     std::optional<llvm::dxbc::DescriptorRangeFlags> Flags;
   };
   std::optional<ParsedClauseParams>
-  parseDescriptorTableClauseParams(RootSignatureToken::Kind RegType);
+  parseDescriptorTableClauseParams(RootSignatureToken::Kind DescType,
+                                   RootSignatureToken::Kind RegType);
 
   struct ParsedStaticSamplerParams {
     std::optional<llvm::hlsl::rootsig::Register> Reg;
@@ -178,8 +180,8 @@ class RootSignatureParser {
   ///
   /// Returns true if there was an error reported.
   bool consumeExpectedToken(
-      RootSignatureToken::Kind Expected, unsigned DiagID = diag::err_expected,
-      RootSignatureToken::Kind Context = RootSignatureToken::Kind::invalid);
+      RootSignatureToken::Kind Expected,
+      std::optional<RootSignatureToken::Kind> Context = std::nullopt);
 
   /// Peek if the next token is of the expected kind and if it is then consume
   /// it.
@@ -195,6 +197,10 @@ class RootSignatureParser {
   /// StringLiterals
   SourceLocation getTokenLocation(RootSignatureToken Tok);
 
+  DiagnosticBuilder reportDiag(unsigned DiagID) {
+    return getDiags().Report(getTokenLocation(CurToken), DiagID);
+  }
+
 private:
   llvm::dxbc::RootSignatureVersion Version;
   SmallVector<RootSignatureElement> &Elements;
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index 86dd01c1a2841..582445b876c3f 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -25,58 +25,62 @@ RootSignatureParser::RootSignatureParser(
       Lexer(Signature->getString()), PP(PP), CurToken(0) {}
 
 bool RootSignatureParser::parse() {
-  // Iterate as many RootSignatureElements as possible
-  do {
+  // Iterate as many RootSignatureElements as possible, until we hit the
+  // end of the stream
+  while (!peekExpectedToken(TokenKind::end_of_stream)) {
     std::optional<RootSignatureElement> Element = std::nullopt;
     if (tryConsumeExpectedToken(TokenKind::kw_RootFlags)) {
+      // RootFlags
       SourceLocation ElementLoc = getTokenLocation(CurToken);
       auto Flags = parseRootFlags();
       if (!Flags.has_value())
         return true;
       Element = RootSignatureElement(ElementLoc, *Flags);
-    }
-
-    if (tryConsumeExpectedToken(TokenKind::kw_RootConstants)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_RootConstants)) {
+      // RootConstants
       SourceLocation ElementLoc = getTokenLocation(CurToken);
       auto Constants = parseRootConstants();
       if (!Constants.has_value())
         return true;
       Element = RootSignatureElement(ElementLoc, *Constants);
-    }
-
-    if (tryConsumeExpectedToken(TokenKind::kw_DescriptorTable)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_DescriptorTable)) {
+      // DescriptorTable
       SourceLocation ElementLoc = getTokenLocation(CurToken);
       auto Table = parseDescriptorTable();
       if (!Table.has_value())
         return true;
       Element = RootSignatureElement(ElementLoc, *Table);
-    }
-
-    if (tryConsumeExpectedToken(
-            {TokenKind::kw_CBV, TokenKind::kw_SRV, TokenKind::kw_UAV})) {
+    } else if (tryConsumeExpectedToken(
+                   {TokenKind::kw_CBV, TokenKind::kw_SRV, TokenKind::kw_UAV})) {
+      // RootDescriptor - CBV, SRV, UAV
       SourceLocation ElementLoc = getTokenLocation(CurToken);
       auto Descriptor = parseRootDescriptor();
       if (!Descriptor.has_value())
         return true;
       Element = RootSignatureElement(ElementLoc, *Descriptor);
-    }
-
-    if (tryConsumeExpectedToken(TokenKind::kw_StaticSampler)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_StaticSampler)) {
+      // StaticSampler
       SourceLocation ElementLoc = getTokenLocation(CurToken);
       auto Sampler = parseStaticSampler();
       if (!Sampler.has_value())
         return true;
       Element = RootSignatureElement(ElementLoc, *Sampler);
+    } else {
+      consumeNextToken(); // position to start of invalid token
+      reportDiag(diag::err_hlsl_rootsig_invalid_param)
+          << /*param of=*/TokenKind::kw_RootSignature;
+      return true;
     }
 
     if (Element.has_value())
       Elements.push_back(*Element);
 
-  } while (tryConsumeExpectedToken(TokenKind::pu_comma));
+    // ',' denotes another element, otherwise, expected to be at end of stream
+    if (!tryConsumeExpectedToken(TokenKind::pu_comma))
+      break;
+  }
 
-  return consumeExpectedToken(TokenKind::end_of_stream,
-                              diag::err_hlsl_unexpected_end_of_params,
-                              /*param of=*/TokenKind::kw_RootSignature);
+  return consumeExpectedToken(TokenKind::end_of_stream);
 }
 
 template <typename FlagType>
@@ -92,8 +96,7 @@ std::optional<llvm::dxbc::RootFlags> RootSignatureParser::parseRootFlags() {
   assert(CurToken.TokKind == TokenKind::kw_RootFlags &&
          "Expects to only be invoked starting at given keyword");
 
-  if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
-                           CurToken.TokKind))
+  if (consumeExpectedToken(TokenKind::pu_l_paren, CurToken.TokKind))
     return std::nullopt;
 
   std::optional<llvm::dxbc::RootFlags> Flags = llvm::dxbc::RootFlags::None;
@@ -101,8 +104,7 @@ std::optional<llvm::dxbc::RootFlags> RootSignatureParser::parseRootFlags() {
   // Handle the edge-case of '0' to specify no flags set
   if (tryConsumeExpectedToken(TokenKind::int_literal)) {
     if (!verifyZeroFlag()) {
-      getDiags().Report(getTokenLocation(CurToken),
-                        diag::err_hlsl_rootsig_non_zero_flag);
+      reportDiag(diag::err_hlsl_rootsig_non_zero_flag);
       return std::nullopt;
     }
   } else {
@@ -128,9 +130,7 @@ std::optional<llvm::dxbc::RootFlags> RootSignatureParser::parseRootFlags() {
     } while (tryConsumeExpectedToken(TokenKind::pu_or));
   }
 
-  if (consumeExpectedToken(TokenKind::pu_r_paren,
-                           diag::err_hlsl_unexpected_end_of_params,
-                           /*param of=*/TokenKind::kw_RootFlags))
+  if (consumeExpectedToken(TokenKind::pu_r_paren))
     return std::nullopt;
 
   return Flags;
@@ -140,8 +140,7 @@ std::optional<RootConstants> RootSignatureParser::parseRootConstants() {
   assert(CurToken.TokKind == TokenKind::kw_RootConstants &&
          "Expects to only be invoked starting at given keyword");
 
-  if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
-                           CurToken.TokKind))
+  if (consumeExpectedToken(TokenKind::pu_l_paren, CurToken.TokKind))
     return std::nullopt;
 
   RootConstants Constants;
@@ -150,10 +149,12 @@ std::optional<RootConstants> RootSignatureParser::parseRootConstants() {
   if (!Params.has_value())
     return std::nullopt;
 
+  if (consumeExpectedToken(TokenKind::pu_r_paren))
+    return std::nullopt;
+
   // Check mandatory parameters where provided
   if (!Params->Num32BitConstants.has_value()) {
-    getDiags().Report(getTokenLocation(CurToken),
-                      diag::err_hlsl_rootsig_missing_param)
+    reportDiag(diag::err_hlsl_rootsig_missing_param)
         << TokenKind::kw_num32BitConstants;
     return std::nullopt;
   }
@@ -161,9 +162,7 @@ std::optional<RootConstants> RootSignatureParser::parseRootConstants() {
   Constants.Num32BitConstants = Params->Num32BitConstants.value();
 
   if (!Params->Reg.has_value()) {
-    getDiags().Report(getTokenLocation(CurToken),
-                      diag::err_hlsl_rootsig_missing_param)
-        << TokenKind::bReg;
+    reportDiag(diag::err_hlsl_rootsig_missing_param) << TokenKind::bReg;
     return std::nullopt;
   }
 
@@ -176,11 +175,6 @@ std::optional<RootConstants> RootSignatureParser::parseRootConstants() {
   if (Params->Space.has_value())
     Constants.Space = Params->Space.value();
 
-  if (consumeExpectedToken(TokenKind::pu_r_paren,
-                           diag::err_hlsl_unexpected_end_of_params,
-                           /*param of=*/TokenKind::kw_RootConstants))
-    return std::nullopt;
-
   return Constants;
 }
 
@@ -192,8 +186,7 @@ std::optional<RootDescriptor> RootSignatureParser::parseRootDescriptor() {
 
   TokenKind DescriptorKind = CurToken.TokKind;
 
-  if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
-                           CurToken.TokKind))
+  if (consumeExpectedToken(TokenKind::pu_l_paren, CurToken.TokKind))
     return std::nullopt;
 
   RootDescriptor Descriptor;
@@ -216,15 +209,16 @@ std::optional<RootDescriptor> RootSignatureParser::parseRootDescriptor() {
   }
   Descriptor.setDefaultFlags(Version);
 
-  auto Params = parseRootDescriptorParams(ExpectedReg);
+  auto Params = parseRootDescriptorParams(DescriptorKind, ExpectedReg);
   if (!Params.has_value())
     return std::nullopt;
 
+  if (consumeExpectedToken(TokenKind::pu_r_paren))
+    return std::nullopt;
+
   // Check mandatory parameters were provided
   if (!Params->Reg.has_value()) {
-    getDiags().Report(getTokenLocation(CurToken),
-                      diag::err_hlsl_rootsig_missing_param)
-        << ExpectedReg;
+    reportDiag(diag::err_hlsl_rootsig_missing_param) << ExpectedReg;
     return std::nullopt;
   }
 
@@ -240,11 +234,6 @@ std::optional<RootDescriptor> RootSignatureParser::parseRootDescriptor() {
   if (Params->Flags.has_value())
     Descriptor.Flags = Params->Flags.value();
 
-  if (consumeExpectedToken(TokenKind::pu_r_paren,
-                           diag::err_hlsl_unexpected_end_of_params,
-                           /*param of=*/TokenKind::kw_RootConstants))
-    return std::nullopt;
-
   return Descriptor;
 }
 
@@ -252,30 +241,27 @@ std::optional<DescriptorTable> RootSignatureParser::parseDescriptorTable() {
   assert(CurToken.TokKind == TokenKind::kw_DescriptorTable &&
          "Expects to only be invoked starting at given keyword");
 
-  if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
-                           CurToken.TokKind))
+  if (consumeExpectedToken(TokenKind::pu_l_paren, CurToken.TokKind))
     return std::nullopt;
 
   DescriptorTable Table;
   std::optional<llvm::dxbc::ShaderVisibility> Visibility;
 
-  // Iterate as many Clauses as possible
-  do {
+  // Iterate as many Clauses as possible, until we hit ')'
+  while (!peekExpectedToken(TokenKind::pu_r_paren)) {
     if (tryConsumeExpectedToken({TokenKind::kw_CBV, TokenKind::kw_SRV,
                                  TokenKind::kw_UAV, TokenKind::kw_Sampler})) {
+      // DescriptorTableClause - CBV, SRV, UAV, or Sampler
       SourceLocation ElementLoc = getTokenLocation(CurToken);
       auto Clause = parseDescriptorTableClause();
       if (!Clause.has_value())
         return std::nullopt;
       Elements.push_back(RootSignatureElement(ElementLoc, *Clause));
       Table.NumClauses++;
-    }
-
-    if (tryConsumeExpectedToken(TokenKind::kw_visibility)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_visibility)) {
+      // visibility = SHADER_VISIBILITY
       if (Visibility.has_value()) {
-        getDiags().Report(getTokenLocation(CurToken),
-                          diag::err_hlsl_rootsig_repeat_param)
-            << CurToken.TokKind;
+        reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
       }
 
@@ -285,18 +271,25 @@ std::optional<DescriptorTable> RootSignatureParser::parseDescriptorTable() {
       Visibility = parseShaderVisibility();
       if (!Visibility.has_value())
         return std::nullopt;
+    } else {
+      consumeNextToken(); // position to start of invalid token
+      reportDiag(diag::err_hlsl_rootsig_invalid_param)
+          << /*param of=*/TokenKind::kw_DescriptorTable;
+      return std::nullopt;
     }
-  } while (tryConsumeExpectedToken(TokenKind::pu_comma));
+
+    // ',' denotes another element, otherwise, expected to be at ')'
+    if (!tryConsumeExpectedToken(TokenKind::pu_comma))
+      break;
+  }
+
+  if (consumeExpectedToken(TokenKind::pu_r_paren))
+    return std::nullopt;
 
   // Fill in optional visibility
   if (Visibility.has_value())
     Table.Visibility = Visibility.value();
 
-  if (consumeExpectedToken(TokenKind::pu_r_paren,
-                           diag::err_hlsl_unexpected_end_of_params,
-                           /*param of=*/TokenKind::kw_DescriptorTable))
-    return std::nullopt;
-
   return Table;
 }
 
@@ -310,8 +303,7 @@ RootSignatureParser::parseDescriptorTableClause() {
 
   TokenKind ParamKind = CurToken.TokKind;
 
-  if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
-                           CurToken.TokKind))
+  if (consumeExpectedToken(TokenKind::pu_l_paren, CurToken.TokKind))
     return std::nullopt;
 
   DescriptorTableClause Clause;
@@ -338,15 +330,16 @@ RootSignatureParser::parseDescriptorTableClause() {
   }
   Clause.setDefaultFlags(Version);
 
-  auto Params = parseDescriptorTableClauseParams(ExpectedReg);
+  auto Params = parseDescriptorTableClauseParams(ParamKind, ExpectedReg);
   if (!Params.has_value())
     return std::nullopt;
 
+  if (consumeExpectedToken(TokenKind::pu_r_paren))
+    return std::nullopt;
+
   // Check mandatory parameters were provided
   if (!Params->Reg.has_value()) {
-    getDiags().Report(getTokenLocation(CurToken),
-                      diag::err_hlsl_rootsig_missing_param)
-        << ExpectedReg;
+    reportDiag(diag::err_hlsl_rootsig_missing_param) << ExpectedReg;
     return std::nullopt;
   }
 
@@ -365,11 +358,6 @@ RootSignatureParser::parseDescriptorTableClause() {
   if (Params->Flags.has_value())
     Clause.Flags = Params->Flags.value();
 
-  if (consumeExpectedToken(TokenKind::pu_r_paren,
-                           diag::err_hlsl_unexpected_end_of_params,
-                           /*param of=*/ParamKind))
-    return std::nullopt;
-
   return Clause;
 }
 
@@ -377,8 +365,7 @@ std::optional<StaticSampler> RootSignatureParser::parseStaticSampler() {
   assert(CurToken.TokKind == TokenKind::kw_StaticSampler &&
          "Expects to only be invoked starting at given keyword");
 
-  if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
-                           CurToken.TokKind))
+  if (consumeExpectedToken(TokenKind::pu_l_paren, CurToken.TokKind))
     return std::nullopt;
 
   StaticSampler Sampler;
@@ -387,11 +374,12 @@ std::optional<StaticSampler> RootSignatureParser::parseStaticSampler() {
   if (!Params.has_value())
     return std::nullopt;
 
+  if (consumeExpectedToken(TokenKind::pu_r_paren))
+    return std::nullopt;
+
   // Check mandatory parameters were provided
   if (!Params->Reg.has_value()) {
-    getDiags().Report(getTokenLocation(CurToken),
-                      diag::err_hlsl_rootsig_missing_param)
-        << TokenKind::sReg;
+    reportDiag(diag::err_hlsl_rootsig_missing_param) << TokenKind::sReg;
     return std::nullopt;
   }
 
@@ -434,11 +422,6 @@ std::optional<StaticSampler> RootSignatureParser::parseStaticSampler() {
   if (Params->Visibility.has_value())
     Sampler.Visibility = Params->Visibility.value();
 
-  if (consumeExpectedToken(TokenKind::pu_r_paren,
-                           diag::err_hlsl_unexpected_end_of_params,
-                           /*param of=*/TokenKind::kw_StaticSampler))
-    return std::nullopt;
-
   return Sampler;
 }
 
@@ -451,13 +434,11 @@ RootSignatureParser::parseRootConstantParams() {
          "Expects to only be invoked starting at given token");
 
   ParsedConstantParams Params;
-  do {
-    // `num32BitConstants` `=` POS_INT
+  while (!peekExpectedToken(TokenKind::pu_r_paren)) {
     if (tryConsumeExpectedToken(TokenKind::kw_num32BitConstants)) {
+      // `num32BitConstants` `=` POS_INT
       if (Params.Num32BitConstants.has_value()) {
-        getDiags().Report(getTokenLocation(CurToken),
-                          diag::err_hlsl_rootsig_repeat_param)
-            << CurToken.TokKind;
+        reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
       }
 
@@ -468,28 +449,20 @@ RootSignatureParser::parseRootConstantParams() {
       if (!Num32BitConstants.has_value())
         return std::nullopt;
       Params.Num32BitConstants = Num32BitConstants;
-    }
-
-    // `b` POS_INT
-    if (tryConsumeExpectedToken(TokenKind::bReg)) {
+    } else if (tryConsumeExpectedToken(TokenKind::bReg)) {
+      // `b` POS_INT
       if (Params.Reg.has_value()) {
-        getDiags().Report(getTokenLocation(CurToken),
-                          diag::err_hlsl_rootsig_repeat_param)
-            << CurToken.TokKind;
+        reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
       }
       auto Reg = parseRegister();
       if (!Reg.has_value())
         return std::nullopt;
       Params.Reg = Reg;
-    }
-
-    // `space` `=` POS_INT
-    if (tryConsumeExpectedToken(TokenKind::kw_space)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_space)) {
+      // `space` `=` POS_INT
       if (Params.Space.has_value()) {
-        getDiags().Report(getTokenLocation(CurToken),
-                          diag::err_hlsl_rootsig_repeat_param)
-            << CurToken.TokKind;
+        reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
       }
 
@@ -500,14 +473,10 @@ RootSignatureParser::parseRootConstantParams() {
       if (!Space.has_value())
         return std::nullopt;
       Params.Space = Space;
-    }
-
-    // `visibility` `=` SHADER_VISIBILITY
-    if (tryConsumeExpectedToken(TokenKind::kw_visibility)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_visibility)) {
+      // `visibility` `=` SHADER_VISIBILITY
       if (Params.Visibility.has_value()) {
-        getDiags().Report(getTokenLocation(CurToken),
-                          diag::err_hlsl_rootsig_repeat_param)
-            << CurToken.TokKind;
+        reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
       }
 
@@ -518,39 +487,43 @@ RootSignatureParser::parseRootConstantParams() {
       if (!Visibility.has_value())
         return std::nullopt;
       Params.Visibility = Visibility;
+    } else {
+      consumeNextToken(); // position to start of invalid token
+      reportDiag(diag::err_hlsl_rootsig_invalid_param)
+          << /*param of=*/TokenKind::kw_RootConstants;
+      return std::nullopt;
     }
-  } while (tryConsumeExpectedToken(TokenKind::pu_comma));
+
+    // ',' denotes another element, otherwise, expected to be at ')'
+    if (!tryConsumeExpectedToken(TokenKind::pu_comma))
+      break;
+  }
 
   return Params;
 }
 
 std::optional<RootSignatureParser::ParsedRootDescriptorParams>
-RootSignatureParser::parseRootDescriptorParams(TokenKind RegType) {
+RootSignatureParser::parseRootDescriptorParams(To...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jul 7, 2025

@llvm/pr-subscribers-hlsl

Author: Finn Plummer (inbelic)

Changes

TODO: will fill in.

Resolves: #147337


Patch is 47.39 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147350.diff

5 Files Affected:

  • (modified) clang/include/clang/Basic/DiagnosticParseKinds.td (+1-2)
  • (modified) clang/include/clang/Parse/ParseHLSLRootSignature.h (+10-4)
  • (modified) clang/lib/Parse/ParseHLSLRootSignature.cpp (+198-289)
  • (modified) clang/test/SemaHLSL/RootSignature-err.hlsl (+13-5)
  • (modified) clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp (+164-2)
diff --git a/clang/include/clang/Basic/DiagnosticParseKinds.td b/clang/include/clang/Basic/DiagnosticParseKinds.td
index 6c30da376dafb..7c636d9049762 100644
--- a/clang/include/clang/Basic/DiagnosticParseKinds.td
+++ b/clang/include/clang/Basic/DiagnosticParseKinds.td
@@ -1857,8 +1857,7 @@ def err_hlsl_virtual_inheritance
     : Error<"virtual inheritance is unsupported in HLSL">;
 
 // HLSL Root Signature Parser Diagnostics
-def err_hlsl_unexpected_end_of_params
-    : Error<"expected %0 to denote end of parameters, or, another valid parameter of %1">;
+def err_hlsl_rootsig_invalid_param : Error<"invalid parameter of %0">;
 def err_hlsl_rootsig_repeat_param : Error<"specified the same parameter '%0' multiple times">;
 def err_hlsl_rootsig_missing_param : Error<"did not specify mandatory parameter '%0'">;
 def err_hlsl_number_literal_overflow : Error<
diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index 9ef5b64d7b4a5..2846b83ee959b 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -100,7 +100,8 @@ class RootSignatureParser {
     std::optional<llvm::dxbc::RootDescriptorFlags> Flags;
   };
   std::optional<ParsedRootDescriptorParams>
-  parseRootDescriptorParams(RootSignatureToken::Kind RegType);
+  parseRootDescriptorParams(RootSignatureToken::Kind DescType,
+                            RootSignatureToken::Kind RegType);
 
   struct ParsedClauseParams {
     std::optional<llvm::hlsl::rootsig::Register> Reg;
@@ -110,7 +111,8 @@ class RootSignatureParser {
     std::optional<llvm::dxbc::DescriptorRangeFlags> Flags;
   };
   std::optional<ParsedClauseParams>
-  parseDescriptorTableClauseParams(RootSignatureToken::Kind RegType);
+  parseDescriptorTableClauseParams(RootSignatureToken::Kind DescType,
+                                   RootSignatureToken::Kind RegType);
 
   struct ParsedStaticSamplerParams {
     std::optional<llvm::hlsl::rootsig::Register> Reg;
@@ -178,8 +180,8 @@ class RootSignatureParser {
   ///
   /// Returns true if there was an error reported.
   bool consumeExpectedToken(
-      RootSignatureToken::Kind Expected, unsigned DiagID = diag::err_expected,
-      RootSignatureToken::Kind Context = RootSignatureToken::Kind::invalid);
+      RootSignatureToken::Kind Expected,
+      std::optional<RootSignatureToken::Kind> Context = std::nullopt);
 
   /// Peek if the next token is of the expected kind and if it is then consume
   /// it.
@@ -195,6 +197,10 @@ class RootSignatureParser {
   /// StringLiterals
   SourceLocation getTokenLocation(RootSignatureToken Tok);
 
+  DiagnosticBuilder reportDiag(unsigned DiagID) {
+    return getDiags().Report(getTokenLocation(CurToken), DiagID);
+  }
+
 private:
   llvm::dxbc::RootSignatureVersion Version;
   SmallVector<RootSignatureElement> &Elements;
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index 86dd01c1a2841..582445b876c3f 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -25,58 +25,62 @@ RootSignatureParser::RootSignatureParser(
       Lexer(Signature->getString()), PP(PP), CurToken(0) {}
 
 bool RootSignatureParser::parse() {
-  // Iterate as many RootSignatureElements as possible
-  do {
+  // Iterate as many RootSignatureElements as possible, until we hit the
+  // end of the stream
+  while (!peekExpectedToken(TokenKind::end_of_stream)) {
     std::optional<RootSignatureElement> Element = std::nullopt;
     if (tryConsumeExpectedToken(TokenKind::kw_RootFlags)) {
+      // RootFlags
       SourceLocation ElementLoc = getTokenLocation(CurToken);
       auto Flags = parseRootFlags();
       if (!Flags.has_value())
         return true;
       Element = RootSignatureElement(ElementLoc, *Flags);
-    }
-
-    if (tryConsumeExpectedToken(TokenKind::kw_RootConstants)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_RootConstants)) {
+      // RootConstants
       SourceLocation ElementLoc = getTokenLocation(CurToken);
       auto Constants = parseRootConstants();
       if (!Constants.has_value())
         return true;
       Element = RootSignatureElement(ElementLoc, *Constants);
-    }
-
-    if (tryConsumeExpectedToken(TokenKind::kw_DescriptorTable)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_DescriptorTable)) {
+      // DescriptorTable
       SourceLocation ElementLoc = getTokenLocation(CurToken);
       auto Table = parseDescriptorTable();
       if (!Table.has_value())
         return true;
       Element = RootSignatureElement(ElementLoc, *Table);
-    }
-
-    if (tryConsumeExpectedToken(
-            {TokenKind::kw_CBV, TokenKind::kw_SRV, TokenKind::kw_UAV})) {
+    } else if (tryConsumeExpectedToken(
+                   {TokenKind::kw_CBV, TokenKind::kw_SRV, TokenKind::kw_UAV})) {
+      // RootDescriptor - CBV, SRV, UAV
       SourceLocation ElementLoc = getTokenLocation(CurToken);
       auto Descriptor = parseRootDescriptor();
       if (!Descriptor.has_value())
         return true;
       Element = RootSignatureElement(ElementLoc, *Descriptor);
-    }
-
-    if (tryConsumeExpectedToken(TokenKind::kw_StaticSampler)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_StaticSampler)) {
+      // StaticSampler
       SourceLocation ElementLoc = getTokenLocation(CurToken);
       auto Sampler = parseStaticSampler();
       if (!Sampler.has_value())
         return true;
       Element = RootSignatureElement(ElementLoc, *Sampler);
+    } else {
+      consumeNextToken(); // position to start of invalid token
+      reportDiag(diag::err_hlsl_rootsig_invalid_param)
+          << /*param of=*/TokenKind::kw_RootSignature;
+      return true;
     }
 
     if (Element.has_value())
       Elements.push_back(*Element);
 
-  } while (tryConsumeExpectedToken(TokenKind::pu_comma));
+    // ',' denotes another element, otherwise, expected to be at end of stream
+    if (!tryConsumeExpectedToken(TokenKind::pu_comma))
+      break;
+  }
 
-  return consumeExpectedToken(TokenKind::end_of_stream,
-                              diag::err_hlsl_unexpected_end_of_params,
-                              /*param of=*/TokenKind::kw_RootSignature);
+  return consumeExpectedToken(TokenKind::end_of_stream);
 }
 
 template <typename FlagType>
@@ -92,8 +96,7 @@ std::optional<llvm::dxbc::RootFlags> RootSignatureParser::parseRootFlags() {
   assert(CurToken.TokKind == TokenKind::kw_RootFlags &&
          "Expects to only be invoked starting at given keyword");
 
-  if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
-                           CurToken.TokKind))
+  if (consumeExpectedToken(TokenKind::pu_l_paren, CurToken.TokKind))
     return std::nullopt;
 
   std::optional<llvm::dxbc::RootFlags> Flags = llvm::dxbc::RootFlags::None;
@@ -101,8 +104,7 @@ std::optional<llvm::dxbc::RootFlags> RootSignatureParser::parseRootFlags() {
   // Handle the edge-case of '0' to specify no flags set
   if (tryConsumeExpectedToken(TokenKind::int_literal)) {
     if (!verifyZeroFlag()) {
-      getDiags().Report(getTokenLocation(CurToken),
-                        diag::err_hlsl_rootsig_non_zero_flag);
+      reportDiag(diag::err_hlsl_rootsig_non_zero_flag);
       return std::nullopt;
     }
   } else {
@@ -128,9 +130,7 @@ std::optional<llvm::dxbc::RootFlags> RootSignatureParser::parseRootFlags() {
     } while (tryConsumeExpectedToken(TokenKind::pu_or));
   }
 
-  if (consumeExpectedToken(TokenKind::pu_r_paren,
-                           diag::err_hlsl_unexpected_end_of_params,
-                           /*param of=*/TokenKind::kw_RootFlags))
+  if (consumeExpectedToken(TokenKind::pu_r_paren))
     return std::nullopt;
 
   return Flags;
@@ -140,8 +140,7 @@ std::optional<RootConstants> RootSignatureParser::parseRootConstants() {
   assert(CurToken.TokKind == TokenKind::kw_RootConstants &&
          "Expects to only be invoked starting at given keyword");
 
-  if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
-                           CurToken.TokKind))
+  if (consumeExpectedToken(TokenKind::pu_l_paren, CurToken.TokKind))
     return std::nullopt;
 
   RootConstants Constants;
@@ -150,10 +149,12 @@ std::optional<RootConstants> RootSignatureParser::parseRootConstants() {
   if (!Params.has_value())
     return std::nullopt;
 
+  if (consumeExpectedToken(TokenKind::pu_r_paren))
+    return std::nullopt;
+
   // Check mandatory parameters where provided
   if (!Params->Num32BitConstants.has_value()) {
-    getDiags().Report(getTokenLocation(CurToken),
-                      diag::err_hlsl_rootsig_missing_param)
+    reportDiag(diag::err_hlsl_rootsig_missing_param)
         << TokenKind::kw_num32BitConstants;
     return std::nullopt;
   }
@@ -161,9 +162,7 @@ std::optional<RootConstants> RootSignatureParser::parseRootConstants() {
   Constants.Num32BitConstants = Params->Num32BitConstants.value();
 
   if (!Params->Reg.has_value()) {
-    getDiags().Report(getTokenLocation(CurToken),
-                      diag::err_hlsl_rootsig_missing_param)
-        << TokenKind::bReg;
+    reportDiag(diag::err_hlsl_rootsig_missing_param) << TokenKind::bReg;
     return std::nullopt;
   }
 
@@ -176,11 +175,6 @@ std::optional<RootConstants> RootSignatureParser::parseRootConstants() {
   if (Params->Space.has_value())
     Constants.Space = Params->Space.value();
 
-  if (consumeExpectedToken(TokenKind::pu_r_paren,
-                           diag::err_hlsl_unexpected_end_of_params,
-                           /*param of=*/TokenKind::kw_RootConstants))
-    return std::nullopt;
-
   return Constants;
 }
 
@@ -192,8 +186,7 @@ std::optional<RootDescriptor> RootSignatureParser::parseRootDescriptor() {
 
   TokenKind DescriptorKind = CurToken.TokKind;
 
-  if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
-                           CurToken.TokKind))
+  if (consumeExpectedToken(TokenKind::pu_l_paren, CurToken.TokKind))
     return std::nullopt;
 
   RootDescriptor Descriptor;
@@ -216,15 +209,16 @@ std::optional<RootDescriptor> RootSignatureParser::parseRootDescriptor() {
   }
   Descriptor.setDefaultFlags(Version);
 
-  auto Params = parseRootDescriptorParams(ExpectedReg);
+  auto Params = parseRootDescriptorParams(DescriptorKind, ExpectedReg);
   if (!Params.has_value())
     return std::nullopt;
 
+  if (consumeExpectedToken(TokenKind::pu_r_paren))
+    return std::nullopt;
+
   // Check mandatory parameters were provided
   if (!Params->Reg.has_value()) {
-    getDiags().Report(getTokenLocation(CurToken),
-                      diag::err_hlsl_rootsig_missing_param)
-        << ExpectedReg;
+    reportDiag(diag::err_hlsl_rootsig_missing_param) << ExpectedReg;
     return std::nullopt;
   }
 
@@ -240,11 +234,6 @@ std::optional<RootDescriptor> RootSignatureParser::parseRootDescriptor() {
   if (Params->Flags.has_value())
     Descriptor.Flags = Params->Flags.value();
 
-  if (consumeExpectedToken(TokenKind::pu_r_paren,
-                           diag::err_hlsl_unexpected_end_of_params,
-                           /*param of=*/TokenKind::kw_RootConstants))
-    return std::nullopt;
-
   return Descriptor;
 }
 
@@ -252,30 +241,27 @@ std::optional<DescriptorTable> RootSignatureParser::parseDescriptorTable() {
   assert(CurToken.TokKind == TokenKind::kw_DescriptorTable &&
          "Expects to only be invoked starting at given keyword");
 
-  if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
-                           CurToken.TokKind))
+  if (consumeExpectedToken(TokenKind::pu_l_paren, CurToken.TokKind))
     return std::nullopt;
 
   DescriptorTable Table;
   std::optional<llvm::dxbc::ShaderVisibility> Visibility;
 
-  // Iterate as many Clauses as possible
-  do {
+  // Iterate as many Clauses as possible, until we hit ')'
+  while (!peekExpectedToken(TokenKind::pu_r_paren)) {
     if (tryConsumeExpectedToken({TokenKind::kw_CBV, TokenKind::kw_SRV,
                                  TokenKind::kw_UAV, TokenKind::kw_Sampler})) {
+      // DescriptorTableClause - CBV, SRV, UAV, or Sampler
       SourceLocation ElementLoc = getTokenLocation(CurToken);
       auto Clause = parseDescriptorTableClause();
       if (!Clause.has_value())
         return std::nullopt;
       Elements.push_back(RootSignatureElement(ElementLoc, *Clause));
       Table.NumClauses++;
-    }
-
-    if (tryConsumeExpectedToken(TokenKind::kw_visibility)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_visibility)) {
+      // visibility = SHADER_VISIBILITY
       if (Visibility.has_value()) {
-        getDiags().Report(getTokenLocation(CurToken),
-                          diag::err_hlsl_rootsig_repeat_param)
-            << CurToken.TokKind;
+        reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
       }
 
@@ -285,18 +271,25 @@ std::optional<DescriptorTable> RootSignatureParser::parseDescriptorTable() {
       Visibility = parseShaderVisibility();
       if (!Visibility.has_value())
         return std::nullopt;
+    } else {
+      consumeNextToken(); // position to start of invalid token
+      reportDiag(diag::err_hlsl_rootsig_invalid_param)
+          << /*param of=*/TokenKind::kw_DescriptorTable;
+      return std::nullopt;
     }
-  } while (tryConsumeExpectedToken(TokenKind::pu_comma));
+
+    // ',' denotes another element, otherwise, expected to be at ')'
+    if (!tryConsumeExpectedToken(TokenKind::pu_comma))
+      break;
+  }
+
+  if (consumeExpectedToken(TokenKind::pu_r_paren))
+    return std::nullopt;
 
   // Fill in optional visibility
   if (Visibility.has_value())
     Table.Visibility = Visibility.value();
 
-  if (consumeExpectedToken(TokenKind::pu_r_paren,
-                           diag::err_hlsl_unexpected_end_of_params,
-                           /*param of=*/TokenKind::kw_DescriptorTable))
-    return std::nullopt;
-
   return Table;
 }
 
@@ -310,8 +303,7 @@ RootSignatureParser::parseDescriptorTableClause() {
 
   TokenKind ParamKind = CurToken.TokKind;
 
-  if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
-                           CurToken.TokKind))
+  if (consumeExpectedToken(TokenKind::pu_l_paren, CurToken.TokKind))
     return std::nullopt;
 
   DescriptorTableClause Clause;
@@ -338,15 +330,16 @@ RootSignatureParser::parseDescriptorTableClause() {
   }
   Clause.setDefaultFlags(Version);
 
-  auto Params = parseDescriptorTableClauseParams(ExpectedReg);
+  auto Params = parseDescriptorTableClauseParams(ParamKind, ExpectedReg);
   if (!Params.has_value())
     return std::nullopt;
 
+  if (consumeExpectedToken(TokenKind::pu_r_paren))
+    return std::nullopt;
+
   // Check mandatory parameters were provided
   if (!Params->Reg.has_value()) {
-    getDiags().Report(getTokenLocation(CurToken),
-                      diag::err_hlsl_rootsig_missing_param)
-        << ExpectedReg;
+    reportDiag(diag::err_hlsl_rootsig_missing_param) << ExpectedReg;
     return std::nullopt;
   }
 
@@ -365,11 +358,6 @@ RootSignatureParser::parseDescriptorTableClause() {
   if (Params->Flags.has_value())
     Clause.Flags = Params->Flags.value();
 
-  if (consumeExpectedToken(TokenKind::pu_r_paren,
-                           diag::err_hlsl_unexpected_end_of_params,
-                           /*param of=*/ParamKind))
-    return std::nullopt;
-
   return Clause;
 }
 
@@ -377,8 +365,7 @@ std::optional<StaticSampler> RootSignatureParser::parseStaticSampler() {
   assert(CurToken.TokKind == TokenKind::kw_StaticSampler &&
          "Expects to only be invoked starting at given keyword");
 
-  if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
-                           CurToken.TokKind))
+  if (consumeExpectedToken(TokenKind::pu_l_paren, CurToken.TokKind))
     return std::nullopt;
 
   StaticSampler Sampler;
@@ -387,11 +374,12 @@ std::optional<StaticSampler> RootSignatureParser::parseStaticSampler() {
   if (!Params.has_value())
     return std::nullopt;
 
+  if (consumeExpectedToken(TokenKind::pu_r_paren))
+    return std::nullopt;
+
   // Check mandatory parameters were provided
   if (!Params->Reg.has_value()) {
-    getDiags().Report(getTokenLocation(CurToken),
-                      diag::err_hlsl_rootsig_missing_param)
-        << TokenKind::sReg;
+    reportDiag(diag::err_hlsl_rootsig_missing_param) << TokenKind::sReg;
     return std::nullopt;
   }
 
@@ -434,11 +422,6 @@ std::optional<StaticSampler> RootSignatureParser::parseStaticSampler() {
   if (Params->Visibility.has_value())
     Sampler.Visibility = Params->Visibility.value();
 
-  if (consumeExpectedToken(TokenKind::pu_r_paren,
-                           diag::err_hlsl_unexpected_end_of_params,
-                           /*param of=*/TokenKind::kw_StaticSampler))
-    return std::nullopt;
-
   return Sampler;
 }
 
@@ -451,13 +434,11 @@ RootSignatureParser::parseRootConstantParams() {
          "Expects to only be invoked starting at given token");
 
   ParsedConstantParams Params;
-  do {
-    // `num32BitConstants` `=` POS_INT
+  while (!peekExpectedToken(TokenKind::pu_r_paren)) {
     if (tryConsumeExpectedToken(TokenKind::kw_num32BitConstants)) {
+      // `num32BitConstants` `=` POS_INT
       if (Params.Num32BitConstants.has_value()) {
-        getDiags().Report(getTokenLocation(CurToken),
-                          diag::err_hlsl_rootsig_repeat_param)
-            << CurToken.TokKind;
+        reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
       }
 
@@ -468,28 +449,20 @@ RootSignatureParser::parseRootConstantParams() {
       if (!Num32BitConstants.has_value())
         return std::nullopt;
       Params.Num32BitConstants = Num32BitConstants;
-    }
-
-    // `b` POS_INT
-    if (tryConsumeExpectedToken(TokenKind::bReg)) {
+    } else if (tryConsumeExpectedToken(TokenKind::bReg)) {
+      // `b` POS_INT
       if (Params.Reg.has_value()) {
-        getDiags().Report(getTokenLocation(CurToken),
-                          diag::err_hlsl_rootsig_repeat_param)
-            << CurToken.TokKind;
+        reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
       }
       auto Reg = parseRegister();
       if (!Reg.has_value())
         return std::nullopt;
       Params.Reg = Reg;
-    }
-
-    // `space` `=` POS_INT
-    if (tryConsumeExpectedToken(TokenKind::kw_space)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_space)) {
+      // `space` `=` POS_INT
       if (Params.Space.has_value()) {
-        getDiags().Report(getTokenLocation(CurToken),
-                          diag::err_hlsl_rootsig_repeat_param)
-            << CurToken.TokKind;
+        reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
       }
 
@@ -500,14 +473,10 @@ RootSignatureParser::parseRootConstantParams() {
       if (!Space.has_value())
         return std::nullopt;
       Params.Space = Space;
-    }
-
-    // `visibility` `=` SHADER_VISIBILITY
-    if (tryConsumeExpectedToken(TokenKind::kw_visibility)) {
+    } else if (tryConsumeExpectedToken(TokenKind::kw_visibility)) {
+      // `visibility` `=` SHADER_VISIBILITY
       if (Params.Visibility.has_value()) {
-        getDiags().Report(getTokenLocation(CurToken),
-                          diag::err_hlsl_rootsig_repeat_param)
-            << CurToken.TokKind;
+        reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
         return std::nullopt;
       }
 
@@ -518,39 +487,43 @@ RootSignatureParser::parseRootConstantParams() {
       if (!Visibility.has_value())
         return std::nullopt;
       Params.Visibility = Visibility;
+    } else {
+      consumeNextToken(); // position to start of invalid token
+      reportDiag(diag::err_hlsl_rootsig_invalid_param)
+          << /*param of=*/TokenKind::kw_RootConstants;
+      return std::nullopt;
     }
-  } while (tryConsumeExpectedToken(TokenKind::pu_comma));
+
+    // ',' denotes another element, otherwise, expected to be at ')'
+    if (!tryConsumeExpectedToken(TokenKind::pu_comma))
+      break;
+  }
 
   return Params;
 }
 
 std::optional<RootSignatureParser::ParsedRootDescriptorParams>
-RootSignatureParser::parseRootDescriptorParams(TokenKind RegType) {
+RootSignatureParser::parseRootDescriptorParams(To...
[truncated]

@inbelic inbelic linked an issue Jul 7, 2025 that may be closed by this pull request
2 tasks
@inbelic inbelic marked this pull request as draft July 7, 2025 17:55
@inbelic
Copy link
Contributor Author

inbelic commented Jul 7, 2025

Contemplating if I should split this into two prs. Will see if there is a nice way to de-couple the improve and fix error portions of this.

@inbelic inbelic changed the base branch from users/inbelic/pr-147115 to main July 8, 2025 17:59
@inbelic inbelic force-pushed the inbelic/rs-improve-diags branch from a561510 to dfde6d4 Compare July 8, 2025 17:59
@inbelic inbelic marked this pull request as ready for review July 8, 2025 18:00
@inbelic
Copy link
Contributor Author

inbelic commented Jul 8, 2025

Updated to rebase onto main so that it will merge before #147115. Removes the 'improve diag' portion. I will create a follow-up issue for that to track the improvement of diagnostic.

@@ -34,3 +34,7 @@ void bad_root_signature_5() {}
// expected-error@+1 {{expected ')' to denote end of parameters, or, another valid parameter of RootConstants}}
[RootSignature(MultiLineRootSignature)]
void bad_root_signature_6() {}

// expected-error@+1 {{expected end of stream to denote end of parameters, or, another valid parameter of RootSignature}}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this diagnostic just be expected ','? It seems like all the tests flag cases where a comma is expected but not found.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A similar parsing error in C++ would result in expected ')':

https://godbolt.org/z/z4Gf1Tar6

I think simplifying to expected ',' and/or expected ')' where appropriate will be more understandable to users.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I think we can simplify the diagnostic here quite a bit.

A similar concern was also noted here: #145827 (comment)

I will create a follow-up issue tomorrow to track this work and do so in a follow-up pr, but will leave this pr to just focus on the bug fix as it has a dependency here: #147115 (comment).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the pr and issue to track this.

@inbelic inbelic force-pushed the inbelic/rs-improve-diags branch from 29f7bad to b9cf614 Compare July 9, 2025 15:31
Copy link
Contributor

@bogner bogner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this do when there are commas at the ends of lists of elements?

Interestingly, DXC seems inconsistent on its behaviour for those:

// Unexpected token ')'
[RootSignature("CBV(b0), CBV(b1,)")]
// valid
[RootSignature("CBV(b0), CBV(b1),")]

I don't know that we need to match this exactly - we should probably be consistent about it. In any case, please do add some tests that make sure we do something sensible.

inbelic added 2 commits July 9, 2025 17:23
this worked before because we returned on the first error found
@inbelic
Copy link
Contributor Author

inbelic commented Jul 9, 2025

Added a test to show that it is consistent in allowing a trailing comma after parameter/values

// - a single trailing comma is allowed after any parameter
// - a trailing comma is not required

[RootSignature("CBV(b0, flags = DATA_VOLATILE,), DescriptorTable(Sampler(s0,),),")]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we reject multiple trailing commas? Something like:

[RootSignature("CBV(b0, flags = DATA_VOLATILE,), DescriptorTable(Sampler(s0)),,")]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category HLSL HLSL Language Support
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[HLSL][RootSignature] Incorrectly allows specifying parameters without a comma
4 participants