-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[HLSL][RootSignature] Implement diagnostic for missed comma #147350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-clang Author: Finn Plummer (inbelic) ChangesTODO: will fill in. Resolves: #147337 Patch is 47.39 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147350.diff 5 Files Affected:
diff --git a/clang/include/clang/Basic/DiagnosticParseKinds.td b/clang/include/clang/Basic/DiagnosticParseKinds.td
index 6c30da376dafb..7c636d9049762 100644
--- a/clang/include/clang/Basic/DiagnosticParseKinds.td
+++ b/clang/include/clang/Basic/DiagnosticParseKinds.td
@@ -1857,8 +1857,7 @@ def err_hlsl_virtual_inheritance
: Error<"virtual inheritance is unsupported in HLSL">;
// HLSL Root Signature Parser Diagnostics
-def err_hlsl_unexpected_end_of_params
- : Error<"expected %0 to denote end of parameters, or, another valid parameter of %1">;
+def err_hlsl_rootsig_invalid_param : Error<"invalid parameter of %0">;
def err_hlsl_rootsig_repeat_param : Error<"specified the same parameter '%0' multiple times">;
def err_hlsl_rootsig_missing_param : Error<"did not specify mandatory parameter '%0'">;
def err_hlsl_number_literal_overflow : Error<
diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index 9ef5b64d7b4a5..2846b83ee959b 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -100,7 +100,8 @@ class RootSignatureParser {
std::optional<llvm::dxbc::RootDescriptorFlags> Flags;
};
std::optional<ParsedRootDescriptorParams>
- parseRootDescriptorParams(RootSignatureToken::Kind RegType);
+ parseRootDescriptorParams(RootSignatureToken::Kind DescType,
+ RootSignatureToken::Kind RegType);
struct ParsedClauseParams {
std::optional<llvm::hlsl::rootsig::Register> Reg;
@@ -110,7 +111,8 @@ class RootSignatureParser {
std::optional<llvm::dxbc::DescriptorRangeFlags> Flags;
};
std::optional<ParsedClauseParams>
- parseDescriptorTableClauseParams(RootSignatureToken::Kind RegType);
+ parseDescriptorTableClauseParams(RootSignatureToken::Kind DescType,
+ RootSignatureToken::Kind RegType);
struct ParsedStaticSamplerParams {
std::optional<llvm::hlsl::rootsig::Register> Reg;
@@ -178,8 +180,8 @@ class RootSignatureParser {
///
/// Returns true if there was an error reported.
bool consumeExpectedToken(
- RootSignatureToken::Kind Expected, unsigned DiagID = diag::err_expected,
- RootSignatureToken::Kind Context = RootSignatureToken::Kind::invalid);
+ RootSignatureToken::Kind Expected,
+ std::optional<RootSignatureToken::Kind> Context = std::nullopt);
/// Peek if the next token is of the expected kind and if it is then consume
/// it.
@@ -195,6 +197,10 @@ class RootSignatureParser {
/// StringLiterals
SourceLocation getTokenLocation(RootSignatureToken Tok);
+ DiagnosticBuilder reportDiag(unsigned DiagID) {
+ return getDiags().Report(getTokenLocation(CurToken), DiagID);
+ }
+
private:
llvm::dxbc::RootSignatureVersion Version;
SmallVector<RootSignatureElement> &Elements;
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index 86dd01c1a2841..582445b876c3f 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -25,58 +25,62 @@ RootSignatureParser::RootSignatureParser(
Lexer(Signature->getString()), PP(PP), CurToken(0) {}
bool RootSignatureParser::parse() {
- // Iterate as many RootSignatureElements as possible
- do {
+ // Iterate as many RootSignatureElements as possible, until we hit the
+ // end of the stream
+ while (!peekExpectedToken(TokenKind::end_of_stream)) {
std::optional<RootSignatureElement> Element = std::nullopt;
if (tryConsumeExpectedToken(TokenKind::kw_RootFlags)) {
+ // RootFlags
SourceLocation ElementLoc = getTokenLocation(CurToken);
auto Flags = parseRootFlags();
if (!Flags.has_value())
return true;
Element = RootSignatureElement(ElementLoc, *Flags);
- }
-
- if (tryConsumeExpectedToken(TokenKind::kw_RootConstants)) {
+ } else if (tryConsumeExpectedToken(TokenKind::kw_RootConstants)) {
+ // RootConstants
SourceLocation ElementLoc = getTokenLocation(CurToken);
auto Constants = parseRootConstants();
if (!Constants.has_value())
return true;
Element = RootSignatureElement(ElementLoc, *Constants);
- }
-
- if (tryConsumeExpectedToken(TokenKind::kw_DescriptorTable)) {
+ } else if (tryConsumeExpectedToken(TokenKind::kw_DescriptorTable)) {
+ // DescriptorTable
SourceLocation ElementLoc = getTokenLocation(CurToken);
auto Table = parseDescriptorTable();
if (!Table.has_value())
return true;
Element = RootSignatureElement(ElementLoc, *Table);
- }
-
- if (tryConsumeExpectedToken(
- {TokenKind::kw_CBV, TokenKind::kw_SRV, TokenKind::kw_UAV})) {
+ } else if (tryConsumeExpectedToken(
+ {TokenKind::kw_CBV, TokenKind::kw_SRV, TokenKind::kw_UAV})) {
+ // RootDescriptor - CBV, SRV, UAV
SourceLocation ElementLoc = getTokenLocation(CurToken);
auto Descriptor = parseRootDescriptor();
if (!Descriptor.has_value())
return true;
Element = RootSignatureElement(ElementLoc, *Descriptor);
- }
-
- if (tryConsumeExpectedToken(TokenKind::kw_StaticSampler)) {
+ } else if (tryConsumeExpectedToken(TokenKind::kw_StaticSampler)) {
+ // StaticSampler
SourceLocation ElementLoc = getTokenLocation(CurToken);
auto Sampler = parseStaticSampler();
if (!Sampler.has_value())
return true;
Element = RootSignatureElement(ElementLoc, *Sampler);
+ } else {
+ consumeNextToken(); // position to start of invalid token
+ reportDiag(diag::err_hlsl_rootsig_invalid_param)
+ << /*param of=*/TokenKind::kw_RootSignature;
+ return true;
}
if (Element.has_value())
Elements.push_back(*Element);
- } while (tryConsumeExpectedToken(TokenKind::pu_comma));
+ // ',' denotes another element, otherwise, expected to be at end of stream
+ if (!tryConsumeExpectedToken(TokenKind::pu_comma))
+ break;
+ }
- return consumeExpectedToken(TokenKind::end_of_stream,
- diag::err_hlsl_unexpected_end_of_params,
- /*param of=*/TokenKind::kw_RootSignature);
+ return consumeExpectedToken(TokenKind::end_of_stream);
}
template <typename FlagType>
@@ -92,8 +96,7 @@ std::optional<llvm::dxbc::RootFlags> RootSignatureParser::parseRootFlags() {
assert(CurToken.TokKind == TokenKind::kw_RootFlags &&
"Expects to only be invoked starting at given keyword");
- if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
- CurToken.TokKind))
+ if (consumeExpectedToken(TokenKind::pu_l_paren, CurToken.TokKind))
return std::nullopt;
std::optional<llvm::dxbc::RootFlags> Flags = llvm::dxbc::RootFlags::None;
@@ -101,8 +104,7 @@ std::optional<llvm::dxbc::RootFlags> RootSignatureParser::parseRootFlags() {
// Handle the edge-case of '0' to specify no flags set
if (tryConsumeExpectedToken(TokenKind::int_literal)) {
if (!verifyZeroFlag()) {
- getDiags().Report(getTokenLocation(CurToken),
- diag::err_hlsl_rootsig_non_zero_flag);
+ reportDiag(diag::err_hlsl_rootsig_non_zero_flag);
return std::nullopt;
}
} else {
@@ -128,9 +130,7 @@ std::optional<llvm::dxbc::RootFlags> RootSignatureParser::parseRootFlags() {
} while (tryConsumeExpectedToken(TokenKind::pu_or));
}
- if (consumeExpectedToken(TokenKind::pu_r_paren,
- diag::err_hlsl_unexpected_end_of_params,
- /*param of=*/TokenKind::kw_RootFlags))
+ if (consumeExpectedToken(TokenKind::pu_r_paren))
return std::nullopt;
return Flags;
@@ -140,8 +140,7 @@ std::optional<RootConstants> RootSignatureParser::parseRootConstants() {
assert(CurToken.TokKind == TokenKind::kw_RootConstants &&
"Expects to only be invoked starting at given keyword");
- if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
- CurToken.TokKind))
+ if (consumeExpectedToken(TokenKind::pu_l_paren, CurToken.TokKind))
return std::nullopt;
RootConstants Constants;
@@ -150,10 +149,12 @@ std::optional<RootConstants> RootSignatureParser::parseRootConstants() {
if (!Params.has_value())
return std::nullopt;
+ if (consumeExpectedToken(TokenKind::pu_r_paren))
+ return std::nullopt;
+
// Check mandatory parameters where provided
if (!Params->Num32BitConstants.has_value()) {
- getDiags().Report(getTokenLocation(CurToken),
- diag::err_hlsl_rootsig_missing_param)
+ reportDiag(diag::err_hlsl_rootsig_missing_param)
<< TokenKind::kw_num32BitConstants;
return std::nullopt;
}
@@ -161,9 +162,7 @@ std::optional<RootConstants> RootSignatureParser::parseRootConstants() {
Constants.Num32BitConstants = Params->Num32BitConstants.value();
if (!Params->Reg.has_value()) {
- getDiags().Report(getTokenLocation(CurToken),
- diag::err_hlsl_rootsig_missing_param)
- << TokenKind::bReg;
+ reportDiag(diag::err_hlsl_rootsig_missing_param) << TokenKind::bReg;
return std::nullopt;
}
@@ -176,11 +175,6 @@ std::optional<RootConstants> RootSignatureParser::parseRootConstants() {
if (Params->Space.has_value())
Constants.Space = Params->Space.value();
- if (consumeExpectedToken(TokenKind::pu_r_paren,
- diag::err_hlsl_unexpected_end_of_params,
- /*param of=*/TokenKind::kw_RootConstants))
- return std::nullopt;
-
return Constants;
}
@@ -192,8 +186,7 @@ std::optional<RootDescriptor> RootSignatureParser::parseRootDescriptor() {
TokenKind DescriptorKind = CurToken.TokKind;
- if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
- CurToken.TokKind))
+ if (consumeExpectedToken(TokenKind::pu_l_paren, CurToken.TokKind))
return std::nullopt;
RootDescriptor Descriptor;
@@ -216,15 +209,16 @@ std::optional<RootDescriptor> RootSignatureParser::parseRootDescriptor() {
}
Descriptor.setDefaultFlags(Version);
- auto Params = parseRootDescriptorParams(ExpectedReg);
+ auto Params = parseRootDescriptorParams(DescriptorKind, ExpectedReg);
if (!Params.has_value())
return std::nullopt;
+ if (consumeExpectedToken(TokenKind::pu_r_paren))
+ return std::nullopt;
+
// Check mandatory parameters were provided
if (!Params->Reg.has_value()) {
- getDiags().Report(getTokenLocation(CurToken),
- diag::err_hlsl_rootsig_missing_param)
- << ExpectedReg;
+ reportDiag(diag::err_hlsl_rootsig_missing_param) << ExpectedReg;
return std::nullopt;
}
@@ -240,11 +234,6 @@ std::optional<RootDescriptor> RootSignatureParser::parseRootDescriptor() {
if (Params->Flags.has_value())
Descriptor.Flags = Params->Flags.value();
- if (consumeExpectedToken(TokenKind::pu_r_paren,
- diag::err_hlsl_unexpected_end_of_params,
- /*param of=*/TokenKind::kw_RootConstants))
- return std::nullopt;
-
return Descriptor;
}
@@ -252,30 +241,27 @@ std::optional<DescriptorTable> RootSignatureParser::parseDescriptorTable() {
assert(CurToken.TokKind == TokenKind::kw_DescriptorTable &&
"Expects to only be invoked starting at given keyword");
- if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
- CurToken.TokKind))
+ if (consumeExpectedToken(TokenKind::pu_l_paren, CurToken.TokKind))
return std::nullopt;
DescriptorTable Table;
std::optional<llvm::dxbc::ShaderVisibility> Visibility;
- // Iterate as many Clauses as possible
- do {
+ // Iterate as many Clauses as possible, until we hit ')'
+ while (!peekExpectedToken(TokenKind::pu_r_paren)) {
if (tryConsumeExpectedToken({TokenKind::kw_CBV, TokenKind::kw_SRV,
TokenKind::kw_UAV, TokenKind::kw_Sampler})) {
+ // DescriptorTableClause - CBV, SRV, UAV, or Sampler
SourceLocation ElementLoc = getTokenLocation(CurToken);
auto Clause = parseDescriptorTableClause();
if (!Clause.has_value())
return std::nullopt;
Elements.push_back(RootSignatureElement(ElementLoc, *Clause));
Table.NumClauses++;
- }
-
- if (tryConsumeExpectedToken(TokenKind::kw_visibility)) {
+ } else if (tryConsumeExpectedToken(TokenKind::kw_visibility)) {
+ // visibility = SHADER_VISIBILITY
if (Visibility.has_value()) {
- getDiags().Report(getTokenLocation(CurToken),
- diag::err_hlsl_rootsig_repeat_param)
- << CurToken.TokKind;
+ reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
return std::nullopt;
}
@@ -285,18 +271,25 @@ std::optional<DescriptorTable> RootSignatureParser::parseDescriptorTable() {
Visibility = parseShaderVisibility();
if (!Visibility.has_value())
return std::nullopt;
+ } else {
+ consumeNextToken(); // position to start of invalid token
+ reportDiag(diag::err_hlsl_rootsig_invalid_param)
+ << /*param of=*/TokenKind::kw_DescriptorTable;
+ return std::nullopt;
}
- } while (tryConsumeExpectedToken(TokenKind::pu_comma));
+
+ // ',' denotes another element, otherwise, expected to be at ')'
+ if (!tryConsumeExpectedToken(TokenKind::pu_comma))
+ break;
+ }
+
+ if (consumeExpectedToken(TokenKind::pu_r_paren))
+ return std::nullopt;
// Fill in optional visibility
if (Visibility.has_value())
Table.Visibility = Visibility.value();
- if (consumeExpectedToken(TokenKind::pu_r_paren,
- diag::err_hlsl_unexpected_end_of_params,
- /*param of=*/TokenKind::kw_DescriptorTable))
- return std::nullopt;
-
return Table;
}
@@ -310,8 +303,7 @@ RootSignatureParser::parseDescriptorTableClause() {
TokenKind ParamKind = CurToken.TokKind;
- if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
- CurToken.TokKind))
+ if (consumeExpectedToken(TokenKind::pu_l_paren, CurToken.TokKind))
return std::nullopt;
DescriptorTableClause Clause;
@@ -338,15 +330,16 @@ RootSignatureParser::parseDescriptorTableClause() {
}
Clause.setDefaultFlags(Version);
- auto Params = parseDescriptorTableClauseParams(ExpectedReg);
+ auto Params = parseDescriptorTableClauseParams(ParamKind, ExpectedReg);
if (!Params.has_value())
return std::nullopt;
+ if (consumeExpectedToken(TokenKind::pu_r_paren))
+ return std::nullopt;
+
// Check mandatory parameters were provided
if (!Params->Reg.has_value()) {
- getDiags().Report(getTokenLocation(CurToken),
- diag::err_hlsl_rootsig_missing_param)
- << ExpectedReg;
+ reportDiag(diag::err_hlsl_rootsig_missing_param) << ExpectedReg;
return std::nullopt;
}
@@ -365,11 +358,6 @@ RootSignatureParser::parseDescriptorTableClause() {
if (Params->Flags.has_value())
Clause.Flags = Params->Flags.value();
- if (consumeExpectedToken(TokenKind::pu_r_paren,
- diag::err_hlsl_unexpected_end_of_params,
- /*param of=*/ParamKind))
- return std::nullopt;
-
return Clause;
}
@@ -377,8 +365,7 @@ std::optional<StaticSampler> RootSignatureParser::parseStaticSampler() {
assert(CurToken.TokKind == TokenKind::kw_StaticSampler &&
"Expects to only be invoked starting at given keyword");
- if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
- CurToken.TokKind))
+ if (consumeExpectedToken(TokenKind::pu_l_paren, CurToken.TokKind))
return std::nullopt;
StaticSampler Sampler;
@@ -387,11 +374,12 @@ std::optional<StaticSampler> RootSignatureParser::parseStaticSampler() {
if (!Params.has_value())
return std::nullopt;
+ if (consumeExpectedToken(TokenKind::pu_r_paren))
+ return std::nullopt;
+
// Check mandatory parameters were provided
if (!Params->Reg.has_value()) {
- getDiags().Report(getTokenLocation(CurToken),
- diag::err_hlsl_rootsig_missing_param)
- << TokenKind::sReg;
+ reportDiag(diag::err_hlsl_rootsig_missing_param) << TokenKind::sReg;
return std::nullopt;
}
@@ -434,11 +422,6 @@ std::optional<StaticSampler> RootSignatureParser::parseStaticSampler() {
if (Params->Visibility.has_value())
Sampler.Visibility = Params->Visibility.value();
- if (consumeExpectedToken(TokenKind::pu_r_paren,
- diag::err_hlsl_unexpected_end_of_params,
- /*param of=*/TokenKind::kw_StaticSampler))
- return std::nullopt;
-
return Sampler;
}
@@ -451,13 +434,11 @@ RootSignatureParser::parseRootConstantParams() {
"Expects to only be invoked starting at given token");
ParsedConstantParams Params;
- do {
- // `num32BitConstants` `=` POS_INT
+ while (!peekExpectedToken(TokenKind::pu_r_paren)) {
if (tryConsumeExpectedToken(TokenKind::kw_num32BitConstants)) {
+ // `num32BitConstants` `=` POS_INT
if (Params.Num32BitConstants.has_value()) {
- getDiags().Report(getTokenLocation(CurToken),
- diag::err_hlsl_rootsig_repeat_param)
- << CurToken.TokKind;
+ reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
return std::nullopt;
}
@@ -468,28 +449,20 @@ RootSignatureParser::parseRootConstantParams() {
if (!Num32BitConstants.has_value())
return std::nullopt;
Params.Num32BitConstants = Num32BitConstants;
- }
-
- // `b` POS_INT
- if (tryConsumeExpectedToken(TokenKind::bReg)) {
+ } else if (tryConsumeExpectedToken(TokenKind::bReg)) {
+ // `b` POS_INT
if (Params.Reg.has_value()) {
- getDiags().Report(getTokenLocation(CurToken),
- diag::err_hlsl_rootsig_repeat_param)
- << CurToken.TokKind;
+ reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
return std::nullopt;
}
auto Reg = parseRegister();
if (!Reg.has_value())
return std::nullopt;
Params.Reg = Reg;
- }
-
- // `space` `=` POS_INT
- if (tryConsumeExpectedToken(TokenKind::kw_space)) {
+ } else if (tryConsumeExpectedToken(TokenKind::kw_space)) {
+ // `space` `=` POS_INT
if (Params.Space.has_value()) {
- getDiags().Report(getTokenLocation(CurToken),
- diag::err_hlsl_rootsig_repeat_param)
- << CurToken.TokKind;
+ reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
return std::nullopt;
}
@@ -500,14 +473,10 @@ RootSignatureParser::parseRootConstantParams() {
if (!Space.has_value())
return std::nullopt;
Params.Space = Space;
- }
-
- // `visibility` `=` SHADER_VISIBILITY
- if (tryConsumeExpectedToken(TokenKind::kw_visibility)) {
+ } else if (tryConsumeExpectedToken(TokenKind::kw_visibility)) {
+ // `visibility` `=` SHADER_VISIBILITY
if (Params.Visibility.has_value()) {
- getDiags().Report(getTokenLocation(CurToken),
- diag::err_hlsl_rootsig_repeat_param)
- << CurToken.TokKind;
+ reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
return std::nullopt;
}
@@ -518,39 +487,43 @@ RootSignatureParser::parseRootConstantParams() {
if (!Visibility.has_value())
return std::nullopt;
Params.Visibility = Visibility;
+ } else {
+ consumeNextToken(); // position to start of invalid token
+ reportDiag(diag::err_hlsl_rootsig_invalid_param)
+ << /*param of=*/TokenKind::kw_RootConstants;
+ return std::nullopt;
}
- } while (tryConsumeExpectedToken(TokenKind::pu_comma));
+
+ // ',' denotes another element, otherwise, expected to be at ')'
+ if (!tryConsumeExpectedToken(TokenKind::pu_comma))
+ break;
+ }
return Params;
}
std::optional<RootSignatureParser::ParsedRootDescriptorParams>
-RootSignatureParser::parseRootDescriptorParams(TokenKind RegType) {
+RootSignatureParser::parseRootDescriptorParams(To...
[truncated]
|
@llvm/pr-subscribers-hlsl Author: Finn Plummer (inbelic) ChangesTODO: will fill in. Resolves: #147337 Patch is 47.39 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147350.diff 5 Files Affected:
diff --git a/clang/include/clang/Basic/DiagnosticParseKinds.td b/clang/include/clang/Basic/DiagnosticParseKinds.td
index 6c30da376dafb..7c636d9049762 100644
--- a/clang/include/clang/Basic/DiagnosticParseKinds.td
+++ b/clang/include/clang/Basic/DiagnosticParseKinds.td
@@ -1857,8 +1857,7 @@ def err_hlsl_virtual_inheritance
: Error<"virtual inheritance is unsupported in HLSL">;
// HLSL Root Signature Parser Diagnostics
-def err_hlsl_unexpected_end_of_params
- : Error<"expected %0 to denote end of parameters, or, another valid parameter of %1">;
+def err_hlsl_rootsig_invalid_param : Error<"invalid parameter of %0">;
def err_hlsl_rootsig_repeat_param : Error<"specified the same parameter '%0' multiple times">;
def err_hlsl_rootsig_missing_param : Error<"did not specify mandatory parameter '%0'">;
def err_hlsl_number_literal_overflow : Error<
diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index 9ef5b64d7b4a5..2846b83ee959b 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -100,7 +100,8 @@ class RootSignatureParser {
std::optional<llvm::dxbc::RootDescriptorFlags> Flags;
};
std::optional<ParsedRootDescriptorParams>
- parseRootDescriptorParams(RootSignatureToken::Kind RegType);
+ parseRootDescriptorParams(RootSignatureToken::Kind DescType,
+ RootSignatureToken::Kind RegType);
struct ParsedClauseParams {
std::optional<llvm::hlsl::rootsig::Register> Reg;
@@ -110,7 +111,8 @@ class RootSignatureParser {
std::optional<llvm::dxbc::DescriptorRangeFlags> Flags;
};
std::optional<ParsedClauseParams>
- parseDescriptorTableClauseParams(RootSignatureToken::Kind RegType);
+ parseDescriptorTableClauseParams(RootSignatureToken::Kind DescType,
+ RootSignatureToken::Kind RegType);
struct ParsedStaticSamplerParams {
std::optional<llvm::hlsl::rootsig::Register> Reg;
@@ -178,8 +180,8 @@ class RootSignatureParser {
///
/// Returns true if there was an error reported.
bool consumeExpectedToken(
- RootSignatureToken::Kind Expected, unsigned DiagID = diag::err_expected,
- RootSignatureToken::Kind Context = RootSignatureToken::Kind::invalid);
+ RootSignatureToken::Kind Expected,
+ std::optional<RootSignatureToken::Kind> Context = std::nullopt);
/// Peek if the next token is of the expected kind and if it is then consume
/// it.
@@ -195,6 +197,10 @@ class RootSignatureParser {
/// StringLiterals
SourceLocation getTokenLocation(RootSignatureToken Tok);
+ DiagnosticBuilder reportDiag(unsigned DiagID) {
+ return getDiags().Report(getTokenLocation(CurToken), DiagID);
+ }
+
private:
llvm::dxbc::RootSignatureVersion Version;
SmallVector<RootSignatureElement> &Elements;
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index 86dd01c1a2841..582445b876c3f 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -25,58 +25,62 @@ RootSignatureParser::RootSignatureParser(
Lexer(Signature->getString()), PP(PP), CurToken(0) {}
bool RootSignatureParser::parse() {
- // Iterate as many RootSignatureElements as possible
- do {
+ // Iterate as many RootSignatureElements as possible, until we hit the
+ // end of the stream
+ while (!peekExpectedToken(TokenKind::end_of_stream)) {
std::optional<RootSignatureElement> Element = std::nullopt;
if (tryConsumeExpectedToken(TokenKind::kw_RootFlags)) {
+ // RootFlags
SourceLocation ElementLoc = getTokenLocation(CurToken);
auto Flags = parseRootFlags();
if (!Flags.has_value())
return true;
Element = RootSignatureElement(ElementLoc, *Flags);
- }
-
- if (tryConsumeExpectedToken(TokenKind::kw_RootConstants)) {
+ } else if (tryConsumeExpectedToken(TokenKind::kw_RootConstants)) {
+ // RootConstants
SourceLocation ElementLoc = getTokenLocation(CurToken);
auto Constants = parseRootConstants();
if (!Constants.has_value())
return true;
Element = RootSignatureElement(ElementLoc, *Constants);
- }
-
- if (tryConsumeExpectedToken(TokenKind::kw_DescriptorTable)) {
+ } else if (tryConsumeExpectedToken(TokenKind::kw_DescriptorTable)) {
+ // DescriptorTable
SourceLocation ElementLoc = getTokenLocation(CurToken);
auto Table = parseDescriptorTable();
if (!Table.has_value())
return true;
Element = RootSignatureElement(ElementLoc, *Table);
- }
-
- if (tryConsumeExpectedToken(
- {TokenKind::kw_CBV, TokenKind::kw_SRV, TokenKind::kw_UAV})) {
+ } else if (tryConsumeExpectedToken(
+ {TokenKind::kw_CBV, TokenKind::kw_SRV, TokenKind::kw_UAV})) {
+ // RootDescriptor - CBV, SRV, UAV
SourceLocation ElementLoc = getTokenLocation(CurToken);
auto Descriptor = parseRootDescriptor();
if (!Descriptor.has_value())
return true;
Element = RootSignatureElement(ElementLoc, *Descriptor);
- }
-
- if (tryConsumeExpectedToken(TokenKind::kw_StaticSampler)) {
+ } else if (tryConsumeExpectedToken(TokenKind::kw_StaticSampler)) {
+ // StaticSampler
SourceLocation ElementLoc = getTokenLocation(CurToken);
auto Sampler = parseStaticSampler();
if (!Sampler.has_value())
return true;
Element = RootSignatureElement(ElementLoc, *Sampler);
+ } else {
+ consumeNextToken(); // position to start of invalid token
+ reportDiag(diag::err_hlsl_rootsig_invalid_param)
+ << /*param of=*/TokenKind::kw_RootSignature;
+ return true;
}
if (Element.has_value())
Elements.push_back(*Element);
- } while (tryConsumeExpectedToken(TokenKind::pu_comma));
+ // ',' denotes another element, otherwise, expected to be at end of stream
+ if (!tryConsumeExpectedToken(TokenKind::pu_comma))
+ break;
+ }
- return consumeExpectedToken(TokenKind::end_of_stream,
- diag::err_hlsl_unexpected_end_of_params,
- /*param of=*/TokenKind::kw_RootSignature);
+ return consumeExpectedToken(TokenKind::end_of_stream);
}
template <typename FlagType>
@@ -92,8 +96,7 @@ std::optional<llvm::dxbc::RootFlags> RootSignatureParser::parseRootFlags() {
assert(CurToken.TokKind == TokenKind::kw_RootFlags &&
"Expects to only be invoked starting at given keyword");
- if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
- CurToken.TokKind))
+ if (consumeExpectedToken(TokenKind::pu_l_paren, CurToken.TokKind))
return std::nullopt;
std::optional<llvm::dxbc::RootFlags> Flags = llvm::dxbc::RootFlags::None;
@@ -101,8 +104,7 @@ std::optional<llvm::dxbc::RootFlags> RootSignatureParser::parseRootFlags() {
// Handle the edge-case of '0' to specify no flags set
if (tryConsumeExpectedToken(TokenKind::int_literal)) {
if (!verifyZeroFlag()) {
- getDiags().Report(getTokenLocation(CurToken),
- diag::err_hlsl_rootsig_non_zero_flag);
+ reportDiag(diag::err_hlsl_rootsig_non_zero_flag);
return std::nullopt;
}
} else {
@@ -128,9 +130,7 @@ std::optional<llvm::dxbc::RootFlags> RootSignatureParser::parseRootFlags() {
} while (tryConsumeExpectedToken(TokenKind::pu_or));
}
- if (consumeExpectedToken(TokenKind::pu_r_paren,
- diag::err_hlsl_unexpected_end_of_params,
- /*param of=*/TokenKind::kw_RootFlags))
+ if (consumeExpectedToken(TokenKind::pu_r_paren))
return std::nullopt;
return Flags;
@@ -140,8 +140,7 @@ std::optional<RootConstants> RootSignatureParser::parseRootConstants() {
assert(CurToken.TokKind == TokenKind::kw_RootConstants &&
"Expects to only be invoked starting at given keyword");
- if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
- CurToken.TokKind))
+ if (consumeExpectedToken(TokenKind::pu_l_paren, CurToken.TokKind))
return std::nullopt;
RootConstants Constants;
@@ -150,10 +149,12 @@ std::optional<RootConstants> RootSignatureParser::parseRootConstants() {
if (!Params.has_value())
return std::nullopt;
+ if (consumeExpectedToken(TokenKind::pu_r_paren))
+ return std::nullopt;
+
// Check mandatory parameters where provided
if (!Params->Num32BitConstants.has_value()) {
- getDiags().Report(getTokenLocation(CurToken),
- diag::err_hlsl_rootsig_missing_param)
+ reportDiag(diag::err_hlsl_rootsig_missing_param)
<< TokenKind::kw_num32BitConstants;
return std::nullopt;
}
@@ -161,9 +162,7 @@ std::optional<RootConstants> RootSignatureParser::parseRootConstants() {
Constants.Num32BitConstants = Params->Num32BitConstants.value();
if (!Params->Reg.has_value()) {
- getDiags().Report(getTokenLocation(CurToken),
- diag::err_hlsl_rootsig_missing_param)
- << TokenKind::bReg;
+ reportDiag(diag::err_hlsl_rootsig_missing_param) << TokenKind::bReg;
return std::nullopt;
}
@@ -176,11 +175,6 @@ std::optional<RootConstants> RootSignatureParser::parseRootConstants() {
if (Params->Space.has_value())
Constants.Space = Params->Space.value();
- if (consumeExpectedToken(TokenKind::pu_r_paren,
- diag::err_hlsl_unexpected_end_of_params,
- /*param of=*/TokenKind::kw_RootConstants))
- return std::nullopt;
-
return Constants;
}
@@ -192,8 +186,7 @@ std::optional<RootDescriptor> RootSignatureParser::parseRootDescriptor() {
TokenKind DescriptorKind = CurToken.TokKind;
- if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
- CurToken.TokKind))
+ if (consumeExpectedToken(TokenKind::pu_l_paren, CurToken.TokKind))
return std::nullopt;
RootDescriptor Descriptor;
@@ -216,15 +209,16 @@ std::optional<RootDescriptor> RootSignatureParser::parseRootDescriptor() {
}
Descriptor.setDefaultFlags(Version);
- auto Params = parseRootDescriptorParams(ExpectedReg);
+ auto Params = parseRootDescriptorParams(DescriptorKind, ExpectedReg);
if (!Params.has_value())
return std::nullopt;
+ if (consumeExpectedToken(TokenKind::pu_r_paren))
+ return std::nullopt;
+
// Check mandatory parameters were provided
if (!Params->Reg.has_value()) {
- getDiags().Report(getTokenLocation(CurToken),
- diag::err_hlsl_rootsig_missing_param)
- << ExpectedReg;
+ reportDiag(diag::err_hlsl_rootsig_missing_param) << ExpectedReg;
return std::nullopt;
}
@@ -240,11 +234,6 @@ std::optional<RootDescriptor> RootSignatureParser::parseRootDescriptor() {
if (Params->Flags.has_value())
Descriptor.Flags = Params->Flags.value();
- if (consumeExpectedToken(TokenKind::pu_r_paren,
- diag::err_hlsl_unexpected_end_of_params,
- /*param of=*/TokenKind::kw_RootConstants))
- return std::nullopt;
-
return Descriptor;
}
@@ -252,30 +241,27 @@ std::optional<DescriptorTable> RootSignatureParser::parseDescriptorTable() {
assert(CurToken.TokKind == TokenKind::kw_DescriptorTable &&
"Expects to only be invoked starting at given keyword");
- if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
- CurToken.TokKind))
+ if (consumeExpectedToken(TokenKind::pu_l_paren, CurToken.TokKind))
return std::nullopt;
DescriptorTable Table;
std::optional<llvm::dxbc::ShaderVisibility> Visibility;
- // Iterate as many Clauses as possible
- do {
+ // Iterate as many Clauses as possible, until we hit ')'
+ while (!peekExpectedToken(TokenKind::pu_r_paren)) {
if (tryConsumeExpectedToken({TokenKind::kw_CBV, TokenKind::kw_SRV,
TokenKind::kw_UAV, TokenKind::kw_Sampler})) {
+ // DescriptorTableClause - CBV, SRV, UAV, or Sampler
SourceLocation ElementLoc = getTokenLocation(CurToken);
auto Clause = parseDescriptorTableClause();
if (!Clause.has_value())
return std::nullopt;
Elements.push_back(RootSignatureElement(ElementLoc, *Clause));
Table.NumClauses++;
- }
-
- if (tryConsumeExpectedToken(TokenKind::kw_visibility)) {
+ } else if (tryConsumeExpectedToken(TokenKind::kw_visibility)) {
+ // visibility = SHADER_VISIBILITY
if (Visibility.has_value()) {
- getDiags().Report(getTokenLocation(CurToken),
- diag::err_hlsl_rootsig_repeat_param)
- << CurToken.TokKind;
+ reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
return std::nullopt;
}
@@ -285,18 +271,25 @@ std::optional<DescriptorTable> RootSignatureParser::parseDescriptorTable() {
Visibility = parseShaderVisibility();
if (!Visibility.has_value())
return std::nullopt;
+ } else {
+ consumeNextToken(); // position to start of invalid token
+ reportDiag(diag::err_hlsl_rootsig_invalid_param)
+ << /*param of=*/TokenKind::kw_DescriptorTable;
+ return std::nullopt;
}
- } while (tryConsumeExpectedToken(TokenKind::pu_comma));
+
+ // ',' denotes another element, otherwise, expected to be at ')'
+ if (!tryConsumeExpectedToken(TokenKind::pu_comma))
+ break;
+ }
+
+ if (consumeExpectedToken(TokenKind::pu_r_paren))
+ return std::nullopt;
// Fill in optional visibility
if (Visibility.has_value())
Table.Visibility = Visibility.value();
- if (consumeExpectedToken(TokenKind::pu_r_paren,
- diag::err_hlsl_unexpected_end_of_params,
- /*param of=*/TokenKind::kw_DescriptorTable))
- return std::nullopt;
-
return Table;
}
@@ -310,8 +303,7 @@ RootSignatureParser::parseDescriptorTableClause() {
TokenKind ParamKind = CurToken.TokKind;
- if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
- CurToken.TokKind))
+ if (consumeExpectedToken(TokenKind::pu_l_paren, CurToken.TokKind))
return std::nullopt;
DescriptorTableClause Clause;
@@ -338,15 +330,16 @@ RootSignatureParser::parseDescriptorTableClause() {
}
Clause.setDefaultFlags(Version);
- auto Params = parseDescriptorTableClauseParams(ExpectedReg);
+ auto Params = parseDescriptorTableClauseParams(ParamKind, ExpectedReg);
if (!Params.has_value())
return std::nullopt;
+ if (consumeExpectedToken(TokenKind::pu_r_paren))
+ return std::nullopt;
+
// Check mandatory parameters were provided
if (!Params->Reg.has_value()) {
- getDiags().Report(getTokenLocation(CurToken),
- diag::err_hlsl_rootsig_missing_param)
- << ExpectedReg;
+ reportDiag(diag::err_hlsl_rootsig_missing_param) << ExpectedReg;
return std::nullopt;
}
@@ -365,11 +358,6 @@ RootSignatureParser::parseDescriptorTableClause() {
if (Params->Flags.has_value())
Clause.Flags = Params->Flags.value();
- if (consumeExpectedToken(TokenKind::pu_r_paren,
- diag::err_hlsl_unexpected_end_of_params,
- /*param of=*/ParamKind))
- return std::nullopt;
-
return Clause;
}
@@ -377,8 +365,7 @@ std::optional<StaticSampler> RootSignatureParser::parseStaticSampler() {
assert(CurToken.TokKind == TokenKind::kw_StaticSampler &&
"Expects to only be invoked starting at given keyword");
- if (consumeExpectedToken(TokenKind::pu_l_paren, diag::err_expected_after,
- CurToken.TokKind))
+ if (consumeExpectedToken(TokenKind::pu_l_paren, CurToken.TokKind))
return std::nullopt;
StaticSampler Sampler;
@@ -387,11 +374,12 @@ std::optional<StaticSampler> RootSignatureParser::parseStaticSampler() {
if (!Params.has_value())
return std::nullopt;
+ if (consumeExpectedToken(TokenKind::pu_r_paren))
+ return std::nullopt;
+
// Check mandatory parameters were provided
if (!Params->Reg.has_value()) {
- getDiags().Report(getTokenLocation(CurToken),
- diag::err_hlsl_rootsig_missing_param)
- << TokenKind::sReg;
+ reportDiag(diag::err_hlsl_rootsig_missing_param) << TokenKind::sReg;
return std::nullopt;
}
@@ -434,11 +422,6 @@ std::optional<StaticSampler> RootSignatureParser::parseStaticSampler() {
if (Params->Visibility.has_value())
Sampler.Visibility = Params->Visibility.value();
- if (consumeExpectedToken(TokenKind::pu_r_paren,
- diag::err_hlsl_unexpected_end_of_params,
- /*param of=*/TokenKind::kw_StaticSampler))
- return std::nullopt;
-
return Sampler;
}
@@ -451,13 +434,11 @@ RootSignatureParser::parseRootConstantParams() {
"Expects to only be invoked starting at given token");
ParsedConstantParams Params;
- do {
- // `num32BitConstants` `=` POS_INT
+ while (!peekExpectedToken(TokenKind::pu_r_paren)) {
if (tryConsumeExpectedToken(TokenKind::kw_num32BitConstants)) {
+ // `num32BitConstants` `=` POS_INT
if (Params.Num32BitConstants.has_value()) {
- getDiags().Report(getTokenLocation(CurToken),
- diag::err_hlsl_rootsig_repeat_param)
- << CurToken.TokKind;
+ reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
return std::nullopt;
}
@@ -468,28 +449,20 @@ RootSignatureParser::parseRootConstantParams() {
if (!Num32BitConstants.has_value())
return std::nullopt;
Params.Num32BitConstants = Num32BitConstants;
- }
-
- // `b` POS_INT
- if (tryConsumeExpectedToken(TokenKind::bReg)) {
+ } else if (tryConsumeExpectedToken(TokenKind::bReg)) {
+ // `b` POS_INT
if (Params.Reg.has_value()) {
- getDiags().Report(getTokenLocation(CurToken),
- diag::err_hlsl_rootsig_repeat_param)
- << CurToken.TokKind;
+ reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
return std::nullopt;
}
auto Reg = parseRegister();
if (!Reg.has_value())
return std::nullopt;
Params.Reg = Reg;
- }
-
- // `space` `=` POS_INT
- if (tryConsumeExpectedToken(TokenKind::kw_space)) {
+ } else if (tryConsumeExpectedToken(TokenKind::kw_space)) {
+ // `space` `=` POS_INT
if (Params.Space.has_value()) {
- getDiags().Report(getTokenLocation(CurToken),
- diag::err_hlsl_rootsig_repeat_param)
- << CurToken.TokKind;
+ reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
return std::nullopt;
}
@@ -500,14 +473,10 @@ RootSignatureParser::parseRootConstantParams() {
if (!Space.has_value())
return std::nullopt;
Params.Space = Space;
- }
-
- // `visibility` `=` SHADER_VISIBILITY
- if (tryConsumeExpectedToken(TokenKind::kw_visibility)) {
+ } else if (tryConsumeExpectedToken(TokenKind::kw_visibility)) {
+ // `visibility` `=` SHADER_VISIBILITY
if (Params.Visibility.has_value()) {
- getDiags().Report(getTokenLocation(CurToken),
- diag::err_hlsl_rootsig_repeat_param)
- << CurToken.TokKind;
+ reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
return std::nullopt;
}
@@ -518,39 +487,43 @@ RootSignatureParser::parseRootConstantParams() {
if (!Visibility.has_value())
return std::nullopt;
Params.Visibility = Visibility;
+ } else {
+ consumeNextToken(); // position to start of invalid token
+ reportDiag(diag::err_hlsl_rootsig_invalid_param)
+ << /*param of=*/TokenKind::kw_RootConstants;
+ return std::nullopt;
}
- } while (tryConsumeExpectedToken(TokenKind::pu_comma));
+
+ // ',' denotes another element, otherwise, expected to be at ')'
+ if (!tryConsumeExpectedToken(TokenKind::pu_comma))
+ break;
+ }
return Params;
}
std::optional<RootSignatureParser::ParsedRootDescriptorParams>
-RootSignatureParser::parseRootDescriptorParams(TokenKind RegType) {
+RootSignatureParser::parseRootDescriptorParams(To...
[truncated]
|
Contemplating if I should split this into two prs. Will see if there is a nice way to de-couple the improve and fix error portions of this. |
a561510
to
dfde6d4
Compare
Updated to rebase onto main so that it will merge before #147115. Removes the 'improve diag' portion. I will create a follow-up issue for that to track the improvement of diagnostic. |
@@ -34,3 +34,7 @@ void bad_root_signature_5() {} | |||
// expected-error@+1 {{expected ')' to denote end of parameters, or, another valid parameter of RootConstants}} | |||
[RootSignature(MultiLineRootSignature)] | |||
void bad_root_signature_6() {} | |||
|
|||
// expected-error@+1 {{expected end of stream to denote end of parameters, or, another valid parameter of RootSignature}} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this diagnostic just be expected ','
? It seems like all the tests flag cases where a comma is expected but not found.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A similar parsing error in C++ would result in expected ')'
:
https://godbolt.org/z/z4Gf1Tar6
I think simplifying to expected ','
and/or expected ')'
where appropriate will be more understandable to users.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. I think we can simplify the diagnostic here quite a bit.
A similar concern was also noted here: #145827 (comment)
I will create a follow-up issue tomorrow to track this work and do so in a follow-up pr, but will leave this pr to just focus on the bug fix as it has a dependency here: #147115 (comment).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- these are lingering from the improve diag related changes
29f7bad
to
b9cf614
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this do when there are commas at the ends of lists of elements?
Interestingly, DXC seems inconsistent on its behaviour for those:
// Unexpected token ')'
[RootSignature("CBV(b0), CBV(b1,)")]
// valid
[RootSignature("CBV(b0), CBV(b1),")]
I don't know that we need to match this exactly - we should probably be consistent about it. In any case, please do add some tests that make sure we do something sensible.
this worked before because we returned on the first error found
Added a test to show that it is consistent in allowing a trailing comma after parameter/values |
// - a single trailing comma is allowed after any parameter | ||
// - a trailing comma is not required | ||
|
||
[RootSignature("CBV(b0, flags = DATA_VOLATILE,), DescriptorTable(Sampler(s0,),),")] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we reject multiple trailing commas? Something like:
[RootSignature("CBV(b0, flags = DATA_VOLATILE,), DescriptorTable(Sampler(s0)),,")]
This pr fixes a bug that allows parameters to be specified without an intermediate comma.
After this pr, we will correctly produce a diagnostic for (eg):
This pr updates the problematic code pattern containing a chain of 'if' statements to a chain of 'else if' statements, to prevent parsing of an element before checking for a comma.
This pr also does 2 small updates, while in the region:
do
loop that theseif
statements are contained in. This helps code readability and makes it easier to improve the diagnostics furtherconsumeExpectedToken
function calls to be right after theparse.*Params
invocation. This will ensure that the comma or invalid token error is presented before a "missed mandatory param" diagnostic.do
loop to be an easier to understandwhile
loopconsumeExpectedToken
diagnostic right after the loop so that the missing comma diagnostic is produce before checking for any missed mandatory argumentsRootDescriptors
to use their respectiveToken
instead ofRootConstants
Resolves: #147337