@@ -17,6 +17,15 @@ namespace hlsl {
17
17
18
18
using TokenKind = RootSignatureToken::Kind;
19
19
20
+ static const TokenKind RootElementKeywords[] = {
21
+ TokenKind::kw_RootFlags,
22
+ TokenKind::kw_CBV,
23
+ TokenKind::kw_UAV,
24
+ TokenKind::kw_SRV,
25
+ TokenKind::kw_DescriptorTable,
26
+ TokenKind::kw_StaticSampler,
27
+ };
28
+
20
29
RootSignatureParser::RootSignatureParser (
21
30
llvm::dxbc::RootSignatureVersion Version,
22
31
SmallVector<RootSignatureElement> &Elements, StringLiteral *Signature,
@@ -27,51 +36,68 @@ RootSignatureParser::RootSignatureParser(
27
36
bool RootSignatureParser::parse () {
28
37
// Iterate as many RootSignatureElements as possible, until we hit the
29
38
// end of the stream
39
+ bool HadError = false ;
30
40
while (!peekExpectedToken (TokenKind::end_of_stream)) {
41
+ bool HadLocalError = false ;
31
42
if (tryConsumeExpectedToken (TokenKind::kw_RootFlags)) {
32
43
SourceLocation ElementLoc = getTokenLocation (CurToken);
33
44
auto Flags = parseRootFlags ();
34
- if (!Flags.has_value ())
35
- return true ;
36
- Elements.emplace_back (RootSignatureElement (ElementLoc, *Flags));
45
+ if (Flags.has_value ())
46
+ Elements.emplace_back (RootSignatureElement (ElementLoc, *Flags));
47
+ else
48
+ HadLocalError = true ;
37
49
} else if (tryConsumeExpectedToken (TokenKind::kw_RootConstants)) {
38
50
SourceLocation ElementLoc = getTokenLocation (CurToken);
39
51
auto Constants = parseRootConstants ();
40
- if (!Constants.has_value ())
41
- return true ;
42
- Elements.emplace_back (RootSignatureElement (ElementLoc, *Constants));
52
+ if (Constants.has_value ())
53
+ Elements.emplace_back (RootSignatureElement (ElementLoc, *Constants));
54
+ else
55
+ HadLocalError = true ;
43
56
} else if (tryConsumeExpectedToken (TokenKind::kw_DescriptorTable)) {
44
57
SourceLocation ElementLoc = getTokenLocation (CurToken);
45
58
auto Table = parseDescriptorTable ();
46
- if (!Table.has_value ())
47
- return true ;
48
- Elements.emplace_back (RootSignatureElement (ElementLoc, *Table));
59
+ if (Table.has_value ())
60
+ Elements.emplace_back (RootSignatureElement (ElementLoc, *Table));
61
+ else {
62
+ HadLocalError = true ;
63
+ // We are within a DescriptorTable, we will do our best to recover
64
+ // by skipping until we encounter the expected closing ')'.
65
+ skipUntilExpectedToken (TokenKind::pu_r_paren);
66
+ consumeNextToken ();
67
+ }
49
68
} else if (tryConsumeExpectedToken (
50
69
{TokenKind::kw_CBV, TokenKind::kw_SRV, TokenKind::kw_UAV})) {
51
70
SourceLocation ElementLoc = getTokenLocation (CurToken);
52
71
auto Descriptor = parseRootDescriptor ();
53
- if (!Descriptor.has_value ())
54
- return true ;
55
- Elements.emplace_back (RootSignatureElement (ElementLoc, *Descriptor));
72
+ if (Descriptor.has_value ())
73
+ Elements.emplace_back (RootSignatureElement (ElementLoc, *Descriptor));
74
+ else
75
+ HadLocalError = true ;
56
76
} else if (tryConsumeExpectedToken (TokenKind::kw_StaticSampler)) {
57
77
SourceLocation ElementLoc = getTokenLocation (CurToken);
58
78
auto Sampler = parseStaticSampler ();
59
- if (!Sampler.has_value ())
60
- return true ;
61
- Elements.emplace_back (RootSignatureElement (ElementLoc, *Sampler));
79
+ if (Sampler.has_value ())
80
+ Elements.emplace_back (RootSignatureElement (ElementLoc, *Sampler));
81
+ else
82
+ HadLocalError = true ;
62
83
} else {
84
+ HadLocalError = true ;
63
85
consumeNextToken (); // let diagnostic be at the start of invalid token
64
86
reportDiag (diag::err_hlsl_invalid_token)
65
87
<< /* parameter=*/ 0 << /* param of*/ TokenKind::kw_RootSignature;
66
- return true ;
67
88
}
68
89
69
- // ',' denotes another element, otherwise, expected to be at end of stream
70
- if (!tryConsumeExpectedToken (TokenKind::pu_comma))
90
+ if (HadLocalError) {
91
+ HadError = true ;
92
+ skipUntilExpectedToken (RootElementKeywords);
93
+ } else if (!tryConsumeExpectedToken (TokenKind::pu_comma)) {
94
+ // ',' denotes another element, otherwise, expected to be at end of stream
71
95
break ;
96
+ }
72
97
}
73
98
74
- return consumeExpectedToken (TokenKind::end_of_stream,
99
+ return HadError ||
100
+ consumeExpectedToken (TokenKind::end_of_stream,
75
101
diag::err_expected_either, TokenKind::pu_comma);
76
102
}
77
103
@@ -262,8 +288,13 @@ std::optional<DescriptorTable> RootSignatureParser::parseDescriptorTable() {
262
288
// DescriptorTableClause - CBV, SRV, UAV, or Sampler
263
289
SourceLocation ElementLoc = getTokenLocation (CurToken);
264
290
auto Clause = parseDescriptorTableClause ();
265
- if (!Clause.has_value ())
291
+ if (!Clause.has_value ()) {
292
+ // We are within a DescriptorTableClause, we will do our best to recover
293
+ // by skipping until we encounter the expected closing ')'
294
+ skipUntilExpectedToken (TokenKind::pu_r_paren);
295
+ consumeNextToken ();
266
296
return std::nullopt;
297
+ }
267
298
Elements.emplace_back (RootSignatureElement (ElementLoc, *Clause));
268
299
Table.NumClauses ++;
269
300
} else if (tryConsumeExpectedToken (TokenKind::kw_visibility)) {
@@ -1371,6 +1402,22 @@ bool RootSignatureParser::tryConsumeExpectedToken(
1371
1402
return true ;
1372
1403
}
1373
1404
1405
+ bool RootSignatureParser::skipUntilExpectedToken (TokenKind Expected) {
1406
+ return skipUntilExpectedToken (ArrayRef{Expected});
1407
+ }
1408
+
1409
+ bool RootSignatureParser::skipUntilExpectedToken (
1410
+ ArrayRef<TokenKind> AnyExpected) {
1411
+
1412
+ while (!peekExpectedToken (AnyExpected)) {
1413
+ if (peekExpectedToken (TokenKind::end_of_stream))
1414
+ return false ;
1415
+ consumeNextToken ();
1416
+ }
1417
+
1418
+ return true ;
1419
+ }
1420
+
1374
1421
SourceLocation RootSignatureParser::getTokenLocation (RootSignatureToken Tok) {
1375
1422
return Signature->getLocationOfByte (Tok.LocOffset , PP.getSourceManager (),
1376
1423
PP.getLangOpts (), PP.getTargetInfo ());
0 commit comments