Skip to content

[HLSL][RootSignature] Allow for multiple parsing errors in RootSignatureParser #147832

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: users/inbelic/pr-147800
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions clang/include/clang/Parse/ParseHLSLRootSignature.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,14 @@ class RootSignatureParser {
bool tryConsumeExpectedToken(RootSignatureToken::Kind Expected);
bool tryConsumeExpectedToken(ArrayRef<RootSignatureToken::Kind> Expected);

/// Consume tokens until the expected token has been peeked to be next
/// or we have reached the end of the stream. Note that this means the
/// expected token will be the next token not CurToken.
///
/// Returns true if it found a token of the given type.
bool skipUntilExpectedToken(RootSignatureToken::Kind Expected);
bool skipUntilExpectedToken(ArrayRef<RootSignatureToken::Kind> Expected);

/// Convert the token's offset in the signature string to its SourceLocation
///
/// This allows to currently retrieve the location for multi-token
Expand Down
87 changes: 67 additions & 20 deletions clang/lib/Parse/ParseHLSLRootSignature.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@ namespace hlsl {

using TokenKind = RootSignatureToken::Kind;

static const TokenKind RootElementKeywords[] = {
TokenKind::kw_RootFlags,
TokenKind::kw_CBV,
TokenKind::kw_UAV,
TokenKind::kw_SRV,
TokenKind::kw_DescriptorTable,
TokenKind::kw_StaticSampler,
};

RootSignatureParser::RootSignatureParser(
llvm::dxbc::RootSignatureVersion Version,
SmallVector<RootSignatureElement> &Elements, StringLiteral *Signature,
Expand All @@ -27,51 +36,68 @@ RootSignatureParser::RootSignatureParser(
bool RootSignatureParser::parse() {
// Iterate as many RootSignatureElements as possible, until we hit the
// end of the stream
bool HadError = false;
while (!peekExpectedToken(TokenKind::end_of_stream)) {
bool HadLocalError = false;
if (tryConsumeExpectedToken(TokenKind::kw_RootFlags)) {
SourceLocation ElementLoc = getTokenLocation(CurToken);
auto Flags = parseRootFlags();
if (!Flags.has_value())
return true;
Elements.emplace_back(RootSignatureElement(ElementLoc, *Flags));
if (Flags.has_value())
Elements.emplace_back(RootSignatureElement(ElementLoc, *Flags));
else
HadLocalError = true;
} else if (tryConsumeExpectedToken(TokenKind::kw_RootConstants)) {
SourceLocation ElementLoc = getTokenLocation(CurToken);
auto Constants = parseRootConstants();
if (!Constants.has_value())
return true;
Elements.emplace_back(RootSignatureElement(ElementLoc, *Constants));
if (Constants.has_value())
Elements.emplace_back(RootSignatureElement(ElementLoc, *Constants));
else
HadLocalError = true;
} else if (tryConsumeExpectedToken(TokenKind::kw_DescriptorTable)) {
SourceLocation ElementLoc = getTokenLocation(CurToken);
auto Table = parseDescriptorTable();
if (!Table.has_value())
return true;
Elements.emplace_back(RootSignatureElement(ElementLoc, *Table));
if (Table.has_value())
Elements.emplace_back(RootSignatureElement(ElementLoc, *Table));
else {
HadLocalError = true;
// We are within a DescriptorTable, we will do our best to recover
// by skipping until we encounter the expected closing ')'.
skipUntilExpectedToken(TokenKind::pu_r_paren);
consumeNextToken();
}
} else if (tryConsumeExpectedToken(
{TokenKind::kw_CBV, TokenKind::kw_SRV, TokenKind::kw_UAV})) {
SourceLocation ElementLoc = getTokenLocation(CurToken);
auto Descriptor = parseRootDescriptor();
if (!Descriptor.has_value())
return true;
Elements.emplace_back(RootSignatureElement(ElementLoc, *Descriptor));
if (Descriptor.has_value())
Elements.emplace_back(RootSignatureElement(ElementLoc, *Descriptor));
else
HadLocalError = true;
} else if (tryConsumeExpectedToken(TokenKind::kw_StaticSampler)) {
SourceLocation ElementLoc = getTokenLocation(CurToken);
auto Sampler = parseStaticSampler();
if (!Sampler.has_value())
return true;
Elements.emplace_back(RootSignatureElement(ElementLoc, *Sampler));
if (Sampler.has_value())
Elements.emplace_back(RootSignatureElement(ElementLoc, *Sampler));
else
HadLocalError = true;
} else {
HadLocalError = true;
consumeNextToken(); // let diagnostic be at the start of invalid token
reportDiag(diag::err_hlsl_invalid_token)
<< /*parameter=*/0 << /*param of*/ TokenKind::kw_RootSignature;
return true;
}

// ',' denotes another element, otherwise, expected to be at end of stream
if (!tryConsumeExpectedToken(TokenKind::pu_comma))
if (HadLocalError) {
HadError = true;
skipUntilExpectedToken(RootElementKeywords);
} else if (!tryConsumeExpectedToken(TokenKind::pu_comma)) {
// ',' denotes another element, otherwise, expected to be at end of stream
break;
}
}

return consumeExpectedToken(TokenKind::end_of_stream,
return HadError ||
consumeExpectedToken(TokenKind::end_of_stream,
diag::err_expected_either, TokenKind::pu_comma);
}

Expand Down Expand Up @@ -262,8 +288,13 @@ std::optional<DescriptorTable> RootSignatureParser::parseDescriptorTable() {
// DescriptorTableClause - CBV, SRV, UAV, or Sampler
SourceLocation ElementLoc = getTokenLocation(CurToken);
auto Clause = parseDescriptorTableClause();
if (!Clause.has_value())
if (!Clause.has_value()) {
// We are within a DescriptorTableClause, we will do our best to recover
// by skipping until we encounter the expected closing ')'
skipUntilExpectedToken(TokenKind::pu_r_paren);
consumeNextToken();
return std::nullopt;
}
Elements.emplace_back(RootSignatureElement(ElementLoc, *Clause));
Table.NumClauses++;
} else if (tryConsumeExpectedToken(TokenKind::kw_visibility)) {
Expand Down Expand Up @@ -1371,6 +1402,22 @@ bool RootSignatureParser::tryConsumeExpectedToken(
return true;
}

bool RootSignatureParser::skipUntilExpectedToken(TokenKind Expected) {
return skipUntilExpectedToken(ArrayRef{Expected});
}

bool RootSignatureParser::skipUntilExpectedToken(
ArrayRef<TokenKind> AnyExpected) {

while (!peekExpectedToken(AnyExpected)) {
if (peekExpectedToken(TokenKind::end_of_stream))
return false;
consumeNextToken();
}

return true;
}

SourceLocation RootSignatureParser::getTokenLocation(RootSignatureToken Tok) {
return Signature->getLocationOfByte(Tok.LocOffset, PP.getSourceManager(),
PP.getLangOpts(), PP.getTargetInfo());
Expand Down
38 changes: 38 additions & 0 deletions clang/test/SemaHLSL/RootSignature-err.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,41 @@ void bad_root_signature_22() {}
// expected-error@+1 {{invalid value of RootFlags}}
[RootSignature("RootFlags(local_root_signature | root_flag_typo)")]
void bad_root_signature_23() {}

#define DemoMultipleErrorsRootSignature \
"CBV(b0, space = invalid)," \
"StaticSampler()" \
"DescriptorTable(" \
" visibility = SHADER_VISIBILITY_ALL," \
" visibility = SHADER_VISIBILITY_DOMAIN," \
")," \
"SRV(t0, space = 28947298374912374098172)" \
"UAV(u0, flags = 3)" \
"DescriptorTable(Sampler(s0 flags = DATA_VOLATILE))," \
"CBV(b0),,"

// expected-error@+7 {{expected integer literal after '='}}
// expected-error@+6 {{did not specify mandatory parameter 's register'}}
// expected-error@+5 {{specified the same parameter 'visibility' multiple times}}
// expected-error@+4 {{integer literal is too large to be represented as a 32-bit signed integer type}}
// expected-error@+3 {{flag value is neither a literal 0 nor a named value}}
// expected-error@+2 {{expected ')' or ','}}
// expected-error@+1 {{invalid parameter of RootSignature}}
[RootSignature(DemoMultipleErrorsRootSignature)]
void multiple_errors() {}

#define DemoGranularityRootSignature \
"CBV(b0, reported_diag, flags = skipped_diag)," \
"DescriptorTable( " \
" UAV(u0, reported_diag), " \
" SRV(t0, skipped_diag), " \
")," \
"StaticSampler(s0, reported_diag, SRV(t0, reported_diag)" \
""

// expected-error@+4 {{invalid parameter of CBV}}
// expected-error@+3 {{invalid parameter of UAV}}
// expected-error@+2 {{invalid parameter of StaticSampler}}
// expected-error@+1 {{invalid parameter of SRV}}
[RootSignature(DemoGranularityRootSignature)]
void granularity_errors() {}
Loading