Skip to content

Commit 8dae33e

Browse files
committed
fix for descriptor table clauses
1 parent 1584490 commit 8dae33e

File tree

3 files changed

+53
-27
lines changed

3 files changed

+53
-27
lines changed

clang/include/clang/Parse/ParseHLSLRootSignature.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,8 @@ class RootSignatureParser {
110110
std::optional<llvm::dxbc::DescriptorRangeFlags> Flags;
111111
};
112112
std::optional<ParsedClauseParams>
113-
parseDescriptorTableClauseParams(RootSignatureToken::Kind RegType);
113+
parseDescriptorTableClauseParams(RootSignatureToken::Kind DescType,
114+
RootSignatureToken::Kind RegType);
114115

115116
struct ParsedStaticSamplerParams {
116117
std::optional<llvm::hlsl::rootsig::Register> Reg;

clang/lib/Parse/ParseHLSLRootSignature.cpp

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -320,10 +320,15 @@ RootSignatureParser::parseDescriptorTableClause() {
320320
}
321321
Clause.setDefaultFlags(Version);
322322

323-
auto Params = parseDescriptorTableClauseParams(ExpectedReg);
323+
auto Params = parseDescriptorTableClauseParams(ParamKind, ExpectedReg);
324324
if (!Params.has_value())
325325
return std::nullopt;
326326

327+
if (consumeExpectedToken(TokenKind::pu_r_paren,
328+
diag::err_hlsl_unexpected_end_of_params,
329+
/*param of=*/ParamKind))
330+
return std::nullopt;
331+
327332
// Check mandatory parameters were provided
328333
if (!Params->Reg.has_value()) {
329334
reportDiag(diag::err_hlsl_rootsig_missing_param) << ExpectedReg;
@@ -345,11 +350,6 @@ RootSignatureParser::parseDescriptorTableClause() {
345350
if (Params->Flags.has_value())
346351
Clause.Flags = Params->Flags.value();
347352

348-
if (consumeExpectedToken(TokenKind::pu_r_paren,
349-
diag::err_hlsl_unexpected_end_of_params,
350-
/*param of=*/ParamKind))
351-
return std::nullopt;
352-
353353
return Clause;
354354
}
355355

@@ -563,14 +563,15 @@ RootSignatureParser::parseRootDescriptorParams(TokenKind DescType,
563563
}
564564

565565
std::optional<RootSignatureParser::ParsedClauseParams>
566-
RootSignatureParser::parseDescriptorTableClauseParams(TokenKind RegType) {
566+
RootSignatureParser::parseDescriptorTableClauseParams(TokenKind DescType,
567+
TokenKind RegType) {
567568
assert(CurToken.TokKind == TokenKind::pu_l_paren &&
568569
"Expects to only be invoked starting at given token");
569570

570571
ParsedClauseParams Params;
571-
do {
572-
// ( `b` | `t` | `u` | `s`) POS_INT
572+
while (!peekExpectedToken(TokenKind::pu_r_paren)) {
573573
if (tryConsumeExpectedToken(RegType)) {
574+
// ( `b` | `t` | `u` | `s`) POS_INT
574575
if (Params.Reg.has_value()) {
575576
reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
576577
return std::nullopt;
@@ -579,10 +580,8 @@ RootSignatureParser::parseDescriptorTableClauseParams(TokenKind RegType) {
579580
if (!Reg.has_value())
580581
return std::nullopt;
581582
Params.Reg = Reg;
582-
}
583-
584-
// `numDescriptors` `=` POS_INT | unbounded
585-
if (tryConsumeExpectedToken(TokenKind::kw_numDescriptors)) {
583+
} else if (tryConsumeExpectedToken(TokenKind::kw_numDescriptors)) {
584+
// `numDescriptors` `=` POS_INT | unbounded
586585
if (Params.NumDescriptors.has_value()) {
587586
reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
588587
return std::nullopt;
@@ -601,10 +600,8 @@ RootSignatureParser::parseDescriptorTableClauseParams(TokenKind RegType) {
601600
}
602601

603602
Params.NumDescriptors = NumDescriptors;
604-
}
605-
606-
// `space` `=` POS_INT
607-
if (tryConsumeExpectedToken(TokenKind::kw_space)) {
603+
} else if (tryConsumeExpectedToken(TokenKind::kw_space)) {
604+
// `space` `=` POS_INT
608605
if (Params.Space.has_value()) {
609606
reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
610607
return std::nullopt;
@@ -617,10 +614,8 @@ RootSignatureParser::parseDescriptorTableClauseParams(TokenKind RegType) {
617614
if (!Space.has_value())
618615
return std::nullopt;
619616
Params.Space = Space;
620-
}
621-
622-
// `offset` `=` POS_INT | DESCRIPTOR_RANGE_OFFSET_APPEND
623-
if (tryConsumeExpectedToken(TokenKind::kw_offset)) {
617+
} else if (tryConsumeExpectedToken(TokenKind::kw_offset)) {
618+
// `offset` `=` POS_INT | DESCRIPTOR_RANGE_OFFSET_APPEND
624619
if (Params.Offset.has_value()) {
625620
reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
626621
return std::nullopt;
@@ -639,10 +634,8 @@ RootSignatureParser::parseDescriptorTableClauseParams(TokenKind RegType) {
639634
}
640635

641636
Params.Offset = Offset;
642-
}
643-
644-
// `flags` `=` DESCRIPTOR_RANGE_FLAGS
645-
if (tryConsumeExpectedToken(TokenKind::kw_flags)) {
637+
} else if (tryConsumeExpectedToken(TokenKind::kw_flags)) {
638+
// `flags` `=` DESCRIPTOR_RANGE_FLAGS
646639
if (Params.Flags.has_value()) {
647640
reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
648641
return std::nullopt;
@@ -657,7 +650,10 @@ RootSignatureParser::parseDescriptorTableClauseParams(TokenKind RegType) {
657650
Params.Flags = Flags;
658651
}
659652

660-
} while (tryConsumeExpectedToken(TokenKind::pu_comma));
653+
// ',' denotes another element, otherwise, expected to be at ')'
654+
if (!tryConsumeExpectedToken(TokenKind::pu_comma))
655+
break;
656+
}
661657

662658
return Params;
663659
}

clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1344,4 +1344,33 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidRootDescriptorParamsCommaTest) {
13441344
ASSERT_TRUE(Consumer->isSatisfied());
13451345
}
13461346

1347+
TEST_F(ParseHLSLRootSignatureTest, InvalidDescriptorClauseParamsCommaTest) {
1348+
// This test will check that an error is produced when there is a missing
1349+
// comma between parameters
1350+
const llvm::StringLiteral Source = R"cc(
1351+
DescriptorTable(
1352+
UAV(
1353+
u0
1354+
flags = 0
1355+
)
1356+
)
1357+
)cc";
1358+
1359+
auto Ctx = createMinimalASTContext();
1360+
StringLiteral *Signature = wrapSource(Ctx, Source);
1361+
1362+
TrivialModuleLoader ModLoader;
1363+
auto PP = createPP(Source, ModLoader);
1364+
1365+
SmallVector<RootElement> Elements;
1366+
hlsl::RootSignatureParser Parser(RootSignatureVersion::V1_1, Elements,
1367+
Signature, *PP);
1368+
1369+
// Test correct diagnostic produced
1370+
Consumer->setExpected(diag::err_hlsl_unexpected_end_of_params);
1371+
ASSERT_TRUE(Parser.parse());
1372+
1373+
ASSERT_TRUE(Consumer->isSatisfied());
1374+
}
1375+
13471376
} // anonymous namespace

0 commit comments

Comments
 (0)