Skip to content

Commit 7ec7e32

Browse files
committed
[HLSL][RootSignature] Implement multiple diagnostics in RootSignatureParser
1 parent 70448a3 commit 7ec7e32

File tree

3 files changed

+113
-20
lines changed

3 files changed

+113
-20
lines changed

clang/include/clang/Parse/ParseHLSLRootSignature.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,14 @@ class RootSignatureParser {
198198
bool tryConsumeExpectedToken(RootSignatureToken::Kind Expected);
199199
bool tryConsumeExpectedToken(ArrayRef<RootSignatureToken::Kind> Expected);
200200

201+
/// Consume tokens until the expected token has been peeked to be next
202+
/// or we have reached the end of the stream. Note that this means the
203+
/// expected token will be the next token not CurToken.
204+
///
205+
/// Returns true if it found a token of the given type.
206+
bool skipUntilExpectedToken(RootSignatureToken::Kind Expected);
207+
bool skipUntilExpectedToken(ArrayRef<RootSignatureToken::Kind> Expected);
208+
201209
/// Convert the token's offset in the signature string to its SourceLocation
202210
///
203211
/// This allows to currently retrieve the location for multi-token

clang/lib/Parse/ParseHLSLRootSignature.cpp

Lines changed: 67 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,15 @@ namespace hlsl {
1717

1818
using TokenKind = RootSignatureToken::Kind;
1919

20+
static const TokenKind RootElementKeywords[] = {
21+
TokenKind::kw_RootFlags,
22+
TokenKind::kw_CBV,
23+
TokenKind::kw_UAV,
24+
TokenKind::kw_SRV,
25+
TokenKind::kw_DescriptorTable,
26+
TokenKind::kw_StaticSampler,
27+
};
28+
2029
RootSignatureParser::RootSignatureParser(
2130
llvm::dxbc::RootSignatureVersion Version,
2231
SmallVector<RootSignatureElement> &Elements, StringLiteral *Signature,
@@ -27,51 +36,68 @@ RootSignatureParser::RootSignatureParser(
2736
bool RootSignatureParser::parse() {
2837
// Iterate as many RootSignatureElements as possible, until we hit the
2938
// end of the stream
39+
bool HadError = false;
3040
while (!peekExpectedToken(TokenKind::end_of_stream)) {
41+
bool HadLocalError = false;
3142
if (tryConsumeExpectedToken(TokenKind::kw_RootFlags)) {
3243
SourceLocation ElementLoc = getTokenLocation(CurToken);
3344
auto Flags = parseRootFlags();
34-
if (!Flags.has_value())
35-
return true;
36-
Elements.emplace_back(RootSignatureElement(ElementLoc, *Flags));
45+
if (Flags.has_value())
46+
Elements.emplace_back(RootSignatureElement(ElementLoc, *Flags));
47+
else
48+
HadLocalError = true;
3749
} else if (tryConsumeExpectedToken(TokenKind::kw_RootConstants)) {
3850
SourceLocation ElementLoc = getTokenLocation(CurToken);
3951
auto Constants = parseRootConstants();
40-
if (!Constants.has_value())
41-
return true;
42-
Elements.emplace_back(RootSignatureElement(ElementLoc, *Constants));
52+
if (Constants.has_value())
53+
Elements.emplace_back(RootSignatureElement(ElementLoc, *Constants));
54+
else
55+
HadLocalError = true;
4356
} else if (tryConsumeExpectedToken(TokenKind::kw_DescriptorTable)) {
4457
SourceLocation ElementLoc = getTokenLocation(CurToken);
4558
auto Table = parseDescriptorTable();
46-
if (!Table.has_value())
47-
return true;
48-
Elements.emplace_back(RootSignatureElement(ElementLoc, *Table));
59+
if (Table.has_value())
60+
Elements.emplace_back(RootSignatureElement(ElementLoc, *Table));
61+
else {
62+
HadLocalError = true;
63+
// We are within a DescriptorTable, we will do our best to recover
64+
// by skipping until we encounter the expected closing ')'.
65+
skipUntilExpectedToken(TokenKind::pu_r_paren);
66+
consumeNextToken();
67+
}
4968
} else if (tryConsumeExpectedToken(
5069
{TokenKind::kw_CBV, TokenKind::kw_SRV, TokenKind::kw_UAV})) {
5170
SourceLocation ElementLoc = getTokenLocation(CurToken);
5271
auto Descriptor = parseRootDescriptor();
53-
if (!Descriptor.has_value())
54-
return true;
55-
Elements.emplace_back(RootSignatureElement(ElementLoc, *Descriptor));
72+
if (Descriptor.has_value())
73+
Elements.emplace_back(RootSignatureElement(ElementLoc, *Descriptor));
74+
else
75+
HadLocalError = true;
5676
} else if (tryConsumeExpectedToken(TokenKind::kw_StaticSampler)) {
5777
SourceLocation ElementLoc = getTokenLocation(CurToken);
5878
auto Sampler = parseStaticSampler();
59-
if (!Sampler.has_value())
60-
return true;
61-
Elements.emplace_back(RootSignatureElement(ElementLoc, *Sampler));
79+
if (Sampler.has_value())
80+
Elements.emplace_back(RootSignatureElement(ElementLoc, *Sampler));
81+
else
82+
HadLocalError = true;
6283
} else {
84+
HadLocalError = true;
6385
consumeNextToken(); // let diagnostic be at the start of invalid token
6486
reportDiag(diag::err_hlsl_invalid_token)
6587
<< /*parameter=*/0 << /*param of*/ TokenKind::kw_RootSignature;
66-
return true;
6788
}
6889

69-
// ',' denotes another element, otherwise, expected to be at end of stream
70-
if (!tryConsumeExpectedToken(TokenKind::pu_comma))
90+
if (HadLocalError) {
91+
HadError = true;
92+
skipUntilExpectedToken(RootElementKeywords);
93+
} else if (!tryConsumeExpectedToken(TokenKind::pu_comma)) {
94+
// ',' denotes another element, otherwise, expected to be at end of stream
7195
break;
96+
}
7297
}
7398

74-
return consumeExpectedToken(TokenKind::end_of_stream,
99+
return HadError ||
100+
consumeExpectedToken(TokenKind::end_of_stream,
75101
diag::err_expected_either, TokenKind::pu_comma);
76102
}
77103

@@ -262,8 +288,13 @@ std::optional<DescriptorTable> RootSignatureParser::parseDescriptorTable() {
262288
// DescriptorTableClause - CBV, SRV, UAV, or Sampler
263289
SourceLocation ElementLoc = getTokenLocation(CurToken);
264290
auto Clause = parseDescriptorTableClause();
265-
if (!Clause.has_value())
291+
if (!Clause.has_value()) {
292+
// We are within a DescriptorTableClause, we will do our best to recover
293+
// by skipping until we encounter the expected closing ')'
294+
skipUntilExpectedToken(TokenKind::pu_r_paren);
295+
consumeNextToken();
266296
return std::nullopt;
297+
}
267298
Elements.emplace_back(RootSignatureElement(ElementLoc, *Clause));
268299
Table.NumClauses++;
269300
} else if (tryConsumeExpectedToken(TokenKind::kw_visibility)) {
@@ -1371,6 +1402,22 @@ bool RootSignatureParser::tryConsumeExpectedToken(
13711402
return true;
13721403
}
13731404

1405+
bool RootSignatureParser::skipUntilExpectedToken(TokenKind Expected) {
1406+
return skipUntilExpectedToken(ArrayRef{Expected});
1407+
}
1408+
1409+
bool RootSignatureParser::skipUntilExpectedToken(
1410+
ArrayRef<TokenKind> AnyExpected) {
1411+
1412+
while (!peekExpectedToken(AnyExpected)) {
1413+
if (peekExpectedToken(TokenKind::end_of_stream))
1414+
return false;
1415+
consumeNextToken();
1416+
}
1417+
1418+
return true;
1419+
}
1420+
13741421
SourceLocation RootSignatureParser::getTokenLocation(RootSignatureToken Tok) {
13751422
return Signature->getLocationOfByte(Tok.LocOffset, PP.getSourceManager(),
13761423
PP.getLangOpts(), PP.getTargetInfo());

clang/test/SemaHLSL/RootSignature-err.hlsl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,41 @@ void bad_root_signature_22() {}
103103
// expected-error@+1 {{invalid value of RootFlags}}
104104
[RootSignature("RootFlags(local_root_signature | root_flag_typo)")]
105105
void bad_root_signature_23() {}
106+
107+
#define DemoMultipleErrorsRootSignature \
108+
"CBV(b0, space = invalid)," \
109+
"StaticSampler()" \
110+
"DescriptorTable(" \
111+
" visibility = SHADER_VISIBILITY_ALL," \
112+
" visibility = SHADER_VISIBILITY_DOMAIN," \
113+
")," \
114+
"SRV(t0, space = 28947298374912374098172)" \
115+
"UAV(u0, flags = 3)" \
116+
"DescriptorTable(Sampler(s0 flags = DATA_VOLATILE))," \
117+
"CBV(b0),,"
118+
119+
// expected-error@+7 {{expected integer literal after '='}}
120+
// expected-error@+6 {{did not specify mandatory parameter 's register'}}
121+
// expected-error@+5 {{specified the same parameter 'visibility' multiple times}}
122+
// expected-error@+4 {{integer literal is too large to be represented as a 32-bit signed integer type}}
123+
// expected-error@+3 {{flag value is neither a literal 0 nor a named value}}
124+
// expected-error@+2 {{expected ')' or ','}}
125+
// expected-error@+1 {{invalid parameter of RootSignature}}
126+
[RootSignature(DemoMultipleErrorsRootSignature)]
127+
void multiple_errors() {}
128+
129+
#define DemoGranularityRootSignature \
130+
"CBV(b0, reported_diag, flags = skipped_diag)," \
131+
"DescriptorTable( " \
132+
" UAV(u0, reported_diag), " \
133+
" SRV(t0, skipped_diag), " \
134+
")," \
135+
"StaticSampler(s0, reported_diag, SRV(t0, reported_diag)" \
136+
""
137+
138+
// expected-error@+4 {{invalid parameter of CBV}}
139+
// expected-error@+3 {{invalid parameter of UAV}}
140+
// expected-error@+2 {{invalid parameter of StaticSampler}}
141+
// expected-error@+1 {{invalid parameter of SRV}}
142+
[RootSignature(DemoGranularityRootSignature)]
143+
void granularity_errors() {}

0 commit comments

Comments
 (0)