Skip to content

Commit 633df1d

Browse files
committed
fix for root descriptors
1 parent 945aa9e commit 633df1d

File tree

3 files changed

+50
-23
lines changed

3 files changed

+50
-23
lines changed

clang/include/clang/Parse/ParseHLSLRootSignature.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ class RootSignatureParser {
9999
std::optional<llvm::dxbc::RootDescriptorFlags> Flags;
100100
};
101101
std::optional<ParsedRootDescriptorParams>
102-
parseRootDescriptorParams(RootSignatureToken::Kind RegType);
102+
parseRootDescriptorParams(RootSignatureToken::Kind DescType,
103+
RootSignatureToken::Kind RegType);
103104

104105
struct ParsedClauseParams {
105106
std::optional<llvm::hlsl::rootsig::Register> Reg;

clang/lib/Parse/ParseHLSLRootSignature.cpp

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -199,10 +199,15 @@ std::optional<RootDescriptor> RootSignatureParser::parseRootDescriptor() {
199199
}
200200
Descriptor.setDefaultFlags(Version);
201201

202-
auto Params = parseRootDescriptorParams(ExpectedReg);
202+
auto Params = parseRootDescriptorParams(DescriptorKind, ExpectedReg);
203203
if (!Params.has_value())
204204
return std::nullopt;
205205

206+
if (consumeExpectedToken(TokenKind::pu_r_paren,
207+
diag::err_hlsl_unexpected_end_of_params,
208+
/*param of=*/DescriptorKind))
209+
return std::nullopt;
210+
206211
// Check mandatory parameters were provided
207212
if (!Params->Reg.has_value()) {
208213
reportDiag(diag::err_hlsl_rootsig_missing_param) << ExpectedReg;
@@ -221,11 +226,6 @@ std::optional<RootDescriptor> RootSignatureParser::parseRootDescriptor() {
221226
if (Params->Flags.has_value())
222227
Descriptor.Flags = Params->Flags.value();
223228

224-
if (consumeExpectedToken(TokenKind::pu_r_paren,
225-
diag::err_hlsl_unexpected_end_of_params,
226-
/*param of=*/TokenKind::kw_RootConstants))
227-
return std::nullopt;
228-
229229
return Descriptor;
230230
}
231231

@@ -493,14 +493,15 @@ RootSignatureParser::parseRootConstantParams() {
493493
}
494494

495495
std::optional<RootSignatureParser::ParsedRootDescriptorParams>
496-
RootSignatureParser::parseRootDescriptorParams(TokenKind RegType) {
496+
RootSignatureParser::parseRootDescriptorParams(TokenKind DescType,
497+
TokenKind RegType) {
497498
assert(CurToken.TokKind == TokenKind::pu_l_paren &&
498499
"Expects to only be invoked starting at given token");
499500

500501
ParsedRootDescriptorParams Params;
501-
do {
502-
// ( `b` | `t` | `u`) POS_INT
502+
while (!peekExpectedToken(TokenKind::pu_r_paren)) {
503503
if (tryConsumeExpectedToken(RegType)) {
504+
// ( `b` | `t` | `u`) POS_INT
504505
if (Params.Reg.has_value()) {
505506
reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
506507
return std::nullopt;
@@ -509,10 +510,8 @@ RootSignatureParser::parseRootDescriptorParams(TokenKind RegType) {
509510
if (!Reg.has_value())
510511
return std::nullopt;
511512
Params.Reg = Reg;
512-
}
513-
514-
// `space` `=` POS_INT
515-
if (tryConsumeExpectedToken(TokenKind::kw_space)) {
513+
} else if (tryConsumeExpectedToken(TokenKind::kw_space)) {
514+
// `space` `=` POS_INT
516515
if (Params.Space.has_value()) {
517516
reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
518517
return std::nullopt;
@@ -525,10 +524,8 @@ RootSignatureParser::parseRootDescriptorParams(TokenKind RegType) {
525524
if (!Space.has_value())
526525
return std::nullopt;
527526
Params.Space = Space;
528-
}
529-
530-
// `visibility` `=` SHADER_VISIBILITY
531-
if (tryConsumeExpectedToken(TokenKind::kw_visibility)) {
527+
} else if (tryConsumeExpectedToken(TokenKind::kw_visibility)) {
528+
// `visibility` `=` SHADER_VISIBILITY
532529
if (Params.Visibility.has_value()) {
533530
reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
534531
return std::nullopt;
@@ -541,10 +538,8 @@ RootSignatureParser::parseRootDescriptorParams(TokenKind RegType) {
541538
if (!Visibility.has_value())
542539
return std::nullopt;
543540
Params.Visibility = Visibility;
544-
}
545-
546-
// `flags` `=` ROOT_DESCRIPTOR_FLAGS
547-
if (tryConsumeExpectedToken(TokenKind::kw_flags)) {
541+
} else if (tryConsumeExpectedToken(TokenKind::kw_flags)) {
542+
// `flags` `=` ROOT_DESCRIPTOR_FLAGS
548543
if (Params.Flags.has_value()) {
549544
reportDiag(diag::err_hlsl_rootsig_repeat_param) << CurToken.TokKind;
550545
return std::nullopt;
@@ -558,7 +553,11 @@ RootSignatureParser::parseRootDescriptorParams(TokenKind RegType) {
558553
return std::nullopt;
559554
Params.Flags = Flags;
560555
}
561-
} while (tryConsumeExpectedToken(TokenKind::pu_comma));
556+
557+
// ',' denotes another element, otherwise, expected to be at ')'
558+
if (!tryConsumeExpectedToken(TokenKind::pu_comma))
559+
break;
560+
}
562561

563562
return Params;
564563
}

clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1317,4 +1317,31 @@ TEST_F(ParseHLSLRootSignatureTest, InvalidRootConstantParamsCommaTest) {
13171317
ASSERT_TRUE(Consumer->isSatisfied());
13181318
}
13191319

1320+
TEST_F(ParseHLSLRootSignatureTest, InvalidRootDescriptorParamsCommaTest) {
1321+
// This test will check that an error is produced when there is a missing
1322+
// comma between parameters
1323+
const llvm::StringLiteral Source = R"cc(
1324+
CBV(
1325+
b0
1326+
flags = 0
1327+
)
1328+
)cc";
1329+
1330+
auto Ctx = createMinimalASTContext();
1331+
StringLiteral *Signature = wrapSource(Ctx, Source);
1332+
1333+
TrivialModuleLoader ModLoader;
1334+
auto PP = createPP(Source, ModLoader);
1335+
1336+
SmallVector<RootElement> Elements;
1337+
hlsl::RootSignatureParser Parser(RootSignatureVersion::V1_1, Elements,
1338+
Signature, *PP);
1339+
1340+
// Test correct diagnostic produced
1341+
Consumer->setExpected(diag::err_hlsl_unexpected_end_of_params);
1342+
ASSERT_TRUE(Parser.parse());
1343+
1344+
ASSERT_TRUE(Consumer->isSatisfied());
1345+
}
1346+
13201347
} // anonymous namespace

0 commit comments

Comments
 (0)