Skip to content

Commit b175b65

Browse files
author
joaosaffran
committed
addressing PR comments
1 parent 83b0979 commit b175b65

File tree

2 files changed

+52
-29
lines changed

2 files changed

+52
-29
lines changed

llvm/lib/Target/DirectX/DXILRootSignature.cpp

Lines changed: 51 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -29,29 +29,30 @@
2929
using namespace llvm;
3030
using namespace llvm::dxil;
3131

32-
LLVMContext *Ctx;
33-
34-
static bool reportError(Twine Message, DiagnosticSeverity Severity = DS_Error) {
32+
static bool reportError(LLVMContext *Ctx, Twine Message,
33+
DiagnosticSeverity Severity = DS_Error) {
3534
Ctx->diagnose(DiagnosticInfoGeneric(Message, Severity));
3635
return true;
3736
}
3837

39-
static bool parseRootFlags(ModuleRootSignature *MRS, MDNode *RootFlagNode) {
38+
static bool parseRootFlags(LLVMContext *Ctx, ModuleRootSignature *MRS,
39+
MDNode *RootFlagNode) {
4040

4141
if (RootFlagNode->getNumOperands() != 2)
42-
return reportError("Invalid format for RootFlag Element");
42+
return reportError(Ctx, "Invalid format for RootFlag Element");
4343

4444
auto *Flag = mdconst::extract<ConstantInt>(RootFlagNode->getOperand(1));
4545
MRS->Flags = Flag->getZExtValue();
4646

4747
return false;
4848
}
4949

50-
static bool parseRootSignatureElement(ModuleRootSignature *MRS,
50+
static bool parseRootSignatureElement(LLVMContext *Ctx,
51+
ModuleRootSignature *MRS,
5152
MDNode *Element) {
5253
MDString *ElementText = cast<MDString>(Element->getOperand(0));
5354
if (ElementText == nullptr)
54-
return reportError("Invalid format for Root Element");
55+
return reportError(Ctx, "Invalid format for Root Element");
5556

5657
RootSignatureElementKind ElementKind =
5758
StringSwitch<RootSignatureElementKind>(ElementText->getString())
@@ -68,19 +69,20 @@ static bool parseRootSignatureElement(ModuleRootSignature *MRS,
6869
switch (ElementKind) {
6970

7071
case RootSignatureElementKind::RootFlags:
71-
return parseRootFlags(MRS, Element);
72+
return parseRootFlags(Ctx, MRS, Element);
7273
case RootSignatureElementKind::RootConstants:
7374
case RootSignatureElementKind::RootDescriptor:
7475
case RootSignatureElementKind::DescriptorTable:
7576
case RootSignatureElementKind::StaticSampler:
7677
case RootSignatureElementKind::None:
77-
return reportError("Invalid Root Element: " + ElementText->getString());
78+
return reportError(Ctx,
79+
"Invalid Root Element: " + ElementText->getString());
7880
}
7981

8082
return true;
8183
}
8284

83-
static bool parse(ModuleRootSignature *MRS, NamedMDNode *Root,
85+
static bool parse(LLVMContext *Ctx, ModuleRootSignature *MRS, NamedMDNode *Root,
8486
const Function *EF) {
8587
bool HasError = false;
8688

@@ -97,55 +99,66 @@ static bool parse(ModuleRootSignature *MRS, NamedMDNode *Root,
9799
*/
98100

99101
for (const MDNode *Node : Root->operands()) {
100-
if (Node->getNumOperands() != 2)
101-
return reportError("Invalid format for Root Signature Definition. Pairs "
102-
"of function, root signature expected.");
102+
if (Node->getNumOperands() != 2) {
103+
HasError = reportError(
104+
Ctx, "Invalid format for Root Signature Definition. Pairs "
105+
"of function, root signature expected.");
106+
continue;
107+
}
103108

104109
ValueAsMetadata *VAM =
105110
llvm::dyn_cast<ValueAsMetadata>(Node->getOperand(0).get());
106-
if (VAM == nullptr)
107-
return reportError("First element of root signature is not a value");
111+
if (VAM == nullptr) {
112+
HasError =
113+
reportError(Ctx, "First element of root signature is not a value");
114+
continue;
115+
}
108116

109117
Function *F = dyn_cast<Function>(VAM->getValue());
110-
if (F == nullptr)
111-
return reportError("First element of root signature is not a function");
118+
if (F == nullptr) {
119+
HasError =
120+
reportError(Ctx, "First element of root signature is not a function");
121+
continue;
122+
}
112123

113124
if (F != EF)
114125
continue;
115126

116127
// Get the Root Signature Description from the function signature pair.
117128
MDNode *RS = dyn_cast<MDNode>(Node->getOperand(1).get());
118129

119-
if (RS == nullptr)
120-
return reportError("Missing Root Element List Metadata node.");
130+
if (RS == nullptr) {
131+
reportError(Ctx, "Missing Root Element List Metadata node.");
132+
continue;
133+
}
121134

122135
// Loop through the Root Elements of the root signature.
123136
for (unsigned int Eid = 0; Eid < RS->getNumOperands(); Eid++) {
124137
MDNode *Element = dyn_cast<MDNode>(RS->getOperand(Eid));
125138
if (Element == nullptr)
126-
return reportError("Missing Root Element Metadata Node.");
139+
return reportError(Ctx, "Missing Root Element Metadata Node.");
127140

128-
HasError = HasError || parseRootSignatureElement(MRS, Element);
141+
HasError = HasError || parseRootSignatureElement(Ctx, MRS, Element);
129142
}
130143
}
131144
return HasError;
132145
}
133146

134-
static bool validate(ModuleRootSignature *MRS) {
147+
static bool validate(LLVMContext *Ctx, ModuleRootSignature *MRS) {
135148
if (dxbc::RootSignatureValidations::validateRootFlag(MRS->Flags)) {
136-
return reportError("Invalid Root Signature flag value");
149+
return reportError(Ctx, "Invalid Root Signature flag value");
137150
}
138151
return false;
139152
}
140153

141154
std::optional<ModuleRootSignature>
142155
ModuleRootSignature::analyzeModule(Module &M, const Function *F) {
143156
ModuleRootSignature MRS;
144-
Ctx = &M.getContext();
157+
LLVMContext *Ctx = &M.getContext();
145158

146159
NamedMDNode *RootSignatureNode = M.getNamedMetadata("dx.rootsignatures");
147-
if (RootSignatureNode == nullptr || parse(&MRS, RootSignatureNode, F) ||
148-
validate(&MRS))
160+
if (RootSignatureNode == nullptr || parse(Ctx, &MRS, RootSignatureNode, F) ||
161+
validate(Ctx, &MRS))
149162
return std::nullopt;
150163

151164
return MRS;
@@ -160,7 +173,12 @@ RootSignatureAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
160173
if (MMI.ShaderProfile == Triple::Library)
161174
return std::nullopt;
162175

163-
assert(MMI.EntryPropertyVec.size() == 1);
176+
LLVMContext *Ctx = &M.getContext();
177+
178+
if (MMI.EntryPropertyVec.size() != 1) {
179+
reportError(Ctx, "More than one entry function defined.");
180+
return std::nullopt;
181+
}
164182

165183
const Function *EntryFunction = MMI.EntryPropertyVec[0].Entry;
166184
return ModuleRootSignature::analyzeModule(M, EntryFunction);
@@ -174,7 +192,12 @@ bool RootSignatureAnalysisWrapper::runOnModule(Module &M) {
174192

175193
if (MMI.ShaderProfile == Triple::Library)
176194
return false;
177-
assert(MMI.EntryPropertyVec.size() == 1);
195+
196+
LLVMContext *Ctx = &M.getContext();
197+
if (MMI.EntryPropertyVec.size() != 1) {
198+
reportError(Ctx, "More than one entry function defined.");
199+
return false;
200+
}
178201

179202
const Function *EntryFunction = MMI.EntryPropertyVec[0].Entry;
180203
MRS = ModuleRootSignature::analyzeModule(M, EntryFunction);

llvm/lib/Target/DirectX/DXILRootSignature.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class RootSignatureAnalysisWrapper : public ModulePass {
6262

6363
RootSignatureAnalysisWrapper() : ModulePass(ID) {}
6464

65-
std::optional<ModuleRootSignature> getResult() const { return MRS; }
65+
const std::optional<ModuleRootSignature> &getResult() const { return MRS; }
6666

6767
bool runOnModule(Module &M) override;
6868

0 commit comments

Comments
 (0)