Skip to content

Commit 83b0979

Browse files
author
joaosaffran
committed
addressing pr comments
1 parent b7f2716 commit 83b0979

File tree

5 files changed

+32
-54
lines changed

5 files changed

+32
-54
lines changed

llvm/include/llvm/BinaryFormat/DXContainer.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -548,9 +548,7 @@ static_assert(sizeof(ProgramSignatureElement) == 32,
548548

549549
struct RootSignatureValidations {
550550

551-
static bool validateRootFlag(uint32_t Flags) {
552-
return (Flags & ~0x80000fff) != 0;
553-
}
551+
static bool validateRootFlag(uint32_t Flags) { return (Flags & ~0xfff) != 0; }
554552

555553
static bool validateVersion(uint32_t Version) {
556554
return !(Version == 1 || Version == 2);

llvm/lib/Target/DirectX/DXContainerGlobals.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "llvm/Pass.h"
2929
#include "llvm/Support/MD5.h"
3030
#include "llvm/Transforms/Utils/ModuleUtils.h"
31+
#include <optional>
3132

3233
using namespace llvm;
3334
using namespace llvm::dxil;
@@ -153,11 +154,12 @@ void DXContainerGlobals::addRootSignature(Module &M,
153154
SmallVector<GlobalValue *> &Globals) {
154155

155156
auto &RSA = getAnalysis<RootSignatureAnalysisWrapper>();
157+
std::optional<ModuleRootSignature> MaybeRootSignature = RSA.getResult();
156158

157-
if (!RSA.hasRootSignature())
159+
if (!MaybeRootSignature.has_value())
158160
return;
159161

160-
ModuleRootSignature MRS = RSA.getRootSignature();
162+
ModuleRootSignature MRS = MaybeRootSignature.value();
161163

162164
SmallString<256> Data;
163165
raw_svector_ostream OS(Data);

llvm/lib/Target/DirectX/DXILRootSignature.cpp

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -29,24 +29,26 @@
2929
using namespace llvm;
3030
using namespace llvm::dxil;
3131

32-
bool ModuleRootSignature::reportError(Twine Message,
33-
DiagnosticSeverity Severity) {
32+
LLVMContext *Ctx;
33+
34+
static bool reportError(Twine Message, DiagnosticSeverity Severity = DS_Error) {
3435
Ctx->diagnose(DiagnosticInfoGeneric(Message, Severity));
3536
return true;
3637
}
3738

38-
bool ModuleRootSignature::parseRootFlags(MDNode *RootFlagNode) {
39+
static bool parseRootFlags(ModuleRootSignature *MRS, MDNode *RootFlagNode) {
3940

4041
if (RootFlagNode->getNumOperands() != 2)
4142
return reportError("Invalid format for RootFlag Element");
4243

4344
auto *Flag = mdconst::extract<ConstantInt>(RootFlagNode->getOperand(1));
44-
this->Flags = Flag->getZExtValue();
45+
MRS->Flags = Flag->getZExtValue();
4546

4647
return false;
4748
}
4849

49-
bool ModuleRootSignature::parseRootSignatureElement(MDNode *Element) {
50+
static bool parseRootSignatureElement(ModuleRootSignature *MRS,
51+
MDNode *Element) {
5052
MDString *ElementText = cast<MDString>(Element->getOperand(0));
5153
if (ElementText == nullptr)
5254
return reportError("Invalid format for Root Element");
@@ -65,24 +67,21 @@ bool ModuleRootSignature::parseRootSignatureElement(MDNode *Element) {
6567

6668
switch (ElementKind) {
6769

68-
case RootSignatureElementKind::RootFlags: {
69-
return parseRootFlags(Element);
70-
break;
71-
}
72-
70+
case RootSignatureElementKind::RootFlags:
71+
return parseRootFlags(MRS, Element);
7372
case RootSignatureElementKind::RootConstants:
7473
case RootSignatureElementKind::RootDescriptor:
7574
case RootSignatureElementKind::DescriptorTable:
7675
case RootSignatureElementKind::StaticSampler:
7776
case RootSignatureElementKind::None:
7877
return reportError("Invalid Root Element: " + ElementText->getString());
79-
break;
8078
}
8179

8280
return true;
8381
}
8482

85-
bool ModuleRootSignature::parse(NamedMDNode *Root, const Function *EF) {
83+
static bool parse(ModuleRootSignature *MRS, NamedMDNode *Root,
84+
const Function *EF) {
8685
bool HasError = false;
8786

8887
/** Root Signature are specified as following in the metadata:
@@ -93,7 +92,7 @@ bool ModuleRootSignature::parse(NamedMDNode *Root, const Function *EF) {
9392
9493
So for each MDNode inside dx.rootsignatures NamedMDNode
9594
(the Root parameter of this function), the parsing process needs
96-
to loop through each of it's operand and process the pairs function
95+
to loop through each of its operands and process the function,
9796
signature pair.
9897
*/
9998

@@ -126,35 +125,36 @@ bool ModuleRootSignature::parse(NamedMDNode *Root, const Function *EF) {
126125
if (Element == nullptr)
127126
return reportError("Missing Root Element Metadata Node.");
128127

129-
HasError = HasError || parseRootSignatureElement(Element);
128+
HasError = HasError || parseRootSignatureElement(MRS, Element);
130129
}
131130
}
132131
return HasError;
133132
}
134133

135-
bool ModuleRootSignature::validate() {
136-
if (dxbc::RootSignatureValidations::validateRootFlag(Flags)) {
134+
static bool validate(ModuleRootSignature *MRS) {
135+
if (dxbc::RootSignatureValidations::validateRootFlag(MRS->Flags)) {
137136
return reportError("Invalid Root Signature flag value");
138137
}
139138
return false;
140139
}
141140

142-
OptionalRootSignature ModuleRootSignature::analyzeModule(Module &M,
143-
const Function *F) {
144-
ModuleRootSignature MRS(&M.getContext());
141+
std::optional<ModuleRootSignature>
142+
ModuleRootSignature::analyzeModule(Module &M, const Function *F) {
143+
ModuleRootSignature MRS;
144+
Ctx = &M.getContext();
145145

146146
NamedMDNode *RootSignatureNode = M.getNamedMetadata("dx.rootsignatures");
147-
if (RootSignatureNode == nullptr || MRS.parse(RootSignatureNode, F) ||
148-
MRS.validate())
147+
if (RootSignatureNode == nullptr || parse(&MRS, RootSignatureNode, F) ||
148+
validate(&MRS))
149149
return std::nullopt;
150150

151151
return MRS;
152152
}
153153

154154
AnalysisKey RootSignatureAnalysis::Key;
155155

156-
OptionalRootSignature RootSignatureAnalysis::run(Module &M,
157-
ModuleAnalysisManager &AM) {
156+
std::optional<ModuleRootSignature>
157+
RootSignatureAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
158158
auto MMI = AM.getResult<DXILMetadataAnalysis>(M);
159159

160160
if (MMI.ShaderProfile == Triple::Library)

llvm/lib/Target/DirectX/DXILRootSignature.h

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,36 +33,20 @@ enum class RootSignatureElementKind {
3333

3434
struct ModuleRootSignature {
3535
uint32_t Flags = 0;
36-
ModuleRootSignature() { Ctx = nullptr; };
3736
static std::optional<ModuleRootSignature> analyzeModule(Module &M,
3837
const Function *F);
39-
40-
private:
41-
LLVMContext *Ctx;
42-
43-
ModuleRootSignature(LLVMContext *Ctx) : Ctx(Ctx) {}
44-
45-
bool parse(NamedMDNode *Root, const Function *F);
46-
bool parseRootSignatureElement(MDNode *Element);
47-
bool parseRootFlags(MDNode *RootFlagNode);
48-
49-
bool validate();
50-
51-
bool reportError(Twine Message, DiagnosticSeverity Severity = DS_Error);
5238
};
5339

54-
using OptionalRootSignature = std::optional<ModuleRootSignature>;
55-
5640
class RootSignatureAnalysis : public AnalysisInfoMixin<RootSignatureAnalysis> {
5741
friend AnalysisInfoMixin<RootSignatureAnalysis>;
5842
static AnalysisKey Key;
5943

6044
public:
6145
RootSignatureAnalysis() = default;
6246

63-
using Result = OptionalRootSignature;
47+
using Result = std::optional<ModuleRootSignature>;
6448

65-
OptionalRootSignature run(Module &M, ModuleAnalysisManager &AM);
49+
std::optional<ModuleRootSignature> run(Module &M, ModuleAnalysisManager &AM);
6650
};
6751

6852
/// Wrapper pass for the legacy pass manager.
@@ -71,16 +55,14 @@ class RootSignatureAnalysis : public AnalysisInfoMixin<RootSignatureAnalysis> {
7155
/// passes which run through the legacy pass manager.
7256
class RootSignatureAnalysisWrapper : public ModulePass {
7357
private:
74-
OptionalRootSignature MRS;
58+
std::optional<ModuleRootSignature> MRS;
7559

7660
public:
7761
static char ID;
7862

7963
RootSignatureAnalysisWrapper() : ModulePass(ID) {}
8064

81-
const ModuleRootSignature &getRootSignature() { return MRS.value(); }
82-
83-
bool hasRootSignature() { return MRS.has_value(); }
65+
std::optional<ModuleRootSignature> getResult() const { return MRS; }
8466

8567
bool runOnModule(Module &M) override;
8668

llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Flags.ll

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,6 @@ define void @main() #0 {
99
entry:
1010
ret void
1111
}
12-
13-
14-
15-
1612
attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
1713

1814

0 commit comments

Comments
 (0)