Skip to content

Commit f64c608

Browse files
author
joaosaffran
committed
refactoring root signature analysis to return a map instead
1 parent 5a3be7c commit f64c608

11 files changed

+140
-125
lines changed

llvm/lib/Target/DirectX/DXContainerGlobals.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "llvm/MC/DXContainerRootSignature.h"
2828
#include "llvm/Pass.h"
2929
#include "llvm/Support/MD5.h"
30+
#include "llvm/TargetParser/Triple.h"
3031
#include "llvm/Transforms/Utils/ModuleUtils.h"
3132
#include <optional>
3233

@@ -153,12 +154,23 @@ void DXContainerGlobals::addSignature(Module &M,
153154
void DXContainerGlobals::addRootSignature(Module &M,
154155
SmallVector<GlobalValue *> &Globals) {
155156

157+
dxil::ModuleMetadataInfo &MMI =
158+
getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
159+
160+
// Root Signature in Library shaders are different,
161+
// since they don't use DXContainer to share it.
162+
if (MMI.ShaderProfile == llvm::Triple::Library)
163+
return;
164+
165+
assert(MMI.EntryPropertyVec.size() == 1);
166+
156167
auto &RSA = getAnalysis<RootSignatureAnalysisWrapper>();
168+
const Function *&EntryFunction = MMI.EntryPropertyVec[0].Entry;
157169

158-
if (!RSA.getResult())
170+
if (!RSA.hasForFunction(EntryFunction))
159171
return;
160172

161-
const ModuleRootSignature &MRS = RSA.getResult().value();
173+
const ModuleRootSignature &MRS = RSA.getForFunction(EntryFunction);
162174
SmallString<256> Data;
163175
raw_svector_ostream OS(Data);
164176

llvm/lib/Target/DirectX/DXILRootSignature.cpp

Lines changed: 86 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626
#include "llvm/Pass.h"
2727
#include "llvm/Support/Error.h"
2828
#include "llvm/Support/ErrorHandling.h"
29+
#include <cstdint>
2930
#include <optional>
31+
#include <utility>
3032

3133
using namespace llvm;
3234
using namespace llvm::dxil;
@@ -37,20 +39,20 @@ static bool reportError(LLVMContext *Ctx, Twine Message,
3739
return true;
3840
}
3941

40-
static bool parseRootFlags(LLVMContext *Ctx, ModuleRootSignature *MRS,
42+
static bool parseRootFlags(LLVMContext *Ctx, ModuleRootSignature &MRS,
4143
MDNode *RootFlagNode) {
4244

4345
if (RootFlagNode->getNumOperands() != 2)
4446
return reportError(Ctx, "Invalid format for RootFlag Element");
4547

4648
auto *Flag = mdconst::extract<ConstantInt>(RootFlagNode->getOperand(1));
47-
MRS->Flags = Flag->getZExtValue();
49+
MRS.Flags = Flag->getZExtValue();
4850

4951
return false;
5052
}
5153

5254
static bool parseRootSignatureElement(LLVMContext *Ctx,
53-
ModuleRootSignature *MRS,
55+
ModuleRootSignature &MRS,
5456
MDNode *Element) {
5557
MDString *ElementText = cast<MDString>(Element->getOperand(0));
5658
if (ElementText == nullptr)
@@ -73,8 +75,7 @@ static bool parseRootSignatureElement(LLVMContext *Ctx,
7375
llvm_unreachable("Root signature element kind not expected.");
7476
}
7577

76-
static bool parse(LLVMContext *Ctx, ModuleRootSignature *MRS, NamedMDNode *Root,
77-
const Function *EntryFunction) {
78+
static bool parse(LLVMContext *Ctx, ModuleRootSignature &MRS, MDNode *Node) {
7879
bool HasError = false;
7980

8081
/** Root Signature are specified as following in the metadata:
@@ -89,15 +90,46 @@ static bool parse(LLVMContext *Ctx, ModuleRootSignature *MRS, NamedMDNode *Root,
8990
signature pair.
9091
*/
9192

92-
for (const MDNode *Node : Root->operands()) {
93-
if (Node->getNumOperands() != 2) {
94-
HasError = reportError(
95-
Ctx, "Invalid format for Root Signature Definition. Pairs "
96-
"of function, root signature expected.");
93+
// Get the Root Signature Description from the function signature pair.
94+
95+
// Loop through the Root Elements of the root signature.
96+
for (const auto &Operand : Node->operands()) {
97+
MDNode *Element = dyn_cast<MDNode>(Operand);
98+
if (Element == nullptr)
99+
return reportError(Ctx, "Missing Root Element Metadata Node.");
100+
101+
HasError = HasError || parseRootSignatureElement(Ctx, MRS, Element);
102+
}
103+
104+
return HasError;
105+
}
106+
107+
static bool validate(LLVMContext *Ctx, const ModuleRootSignature &MRS) {
108+
if (!dxbc::RootSignatureValidations::isValidRootFlag(MRS.Flags)) {
109+
return reportError(Ctx, "Invalid Root Signature flag value");
110+
}
111+
return false;
112+
}
113+
114+
static SmallDenseMap<const Function *, ModuleRootSignature>
115+
analyzeModule(Module &M) {
116+
117+
LLVMContext *Ctx = &M.getContext();
118+
119+
SmallDenseMap<const Function *, ModuleRootSignature> MRSMap;
120+
121+
NamedMDNode *RootSignatureNode = M.getNamedMetadata("dx.rootsignatures");
122+
if (RootSignatureNode == nullptr)
123+
return MRSMap;
124+
125+
for (const auto &RSDefNode : RootSignatureNode->operands()) {
126+
if (RSDefNode->getNumOperands() != 2) {
127+
reportError(Ctx, "Invalid format for Root Signature Definition. Pairs "
128+
"of function, root signature expected.");
97129
continue;
98130
}
99131

100-
const MDOperand &FunctionPointerMdNode = Node->getOperand(0);
132+
const MDOperand &FunctionPointerMdNode = RSDefNode->getOperand(0);
101133
if (FunctionPointerMdNode == nullptr) {
102134
// Function was pruned during compilation.
103135
continue;
@@ -106,97 +138,76 @@ static bool parse(LLVMContext *Ctx, ModuleRootSignature *MRS, NamedMDNode *Root,
106138
ValueAsMetadata *VAM =
107139
llvm::dyn_cast<ValueAsMetadata>(FunctionPointerMdNode.get());
108140
if (VAM == nullptr) {
109-
HasError =
110-
reportError(Ctx, "First element of root signature is not a value");
141+
reportError(Ctx, "First element of root signature is not a value");
111142
continue;
112143
}
113144

114145
Function *F = dyn_cast<Function>(VAM->getValue());
115146
if (F == nullptr) {
116-
HasError =
117-
reportError(Ctx, "First element of root signature is not a function");
147+
reportError(Ctx, "First element of root signature is not a function");
118148
continue;
119149
}
120150

121-
if (F != EntryFunction)
122-
continue;
151+
MDNode *RootElementListNode =
152+
dyn_cast<MDNode>(RSDefNode->getOperand(1).get());
123153

124-
// Get the Root Signature Description from the function signature pair.
125-
MDNode *RS = dyn_cast<MDNode>(Node->getOperand(1).get());
126-
127-
if (RS == nullptr) {
154+
if (RootElementListNode == nullptr) {
128155
reportError(Ctx, "Missing Root Element List Metadata node.");
129-
continue;
130156
}
131157

132-
// Loop through the Root Elements of the root signature.
133-
for (const auto &Operand : RS->operands()) {
134-
MDNode *Element = dyn_cast<MDNode>(Operand);
135-
if (Element == nullptr)
136-
return reportError(Ctx, "Missing Root Element Metadata Node.");
158+
ModuleRootSignature MRS;
137159

138-
HasError = HasError || parseRootSignatureElement(Ctx, MRS, Element);
160+
if (parse(Ctx, MRS, RootElementListNode) || validate(Ctx, MRS)) {
161+
return MRSMap;
139162
}
140-
}
141-
return HasError;
142-
}
143163

144-
static bool validate(LLVMContext *Ctx, ModuleRootSignature *MRS) {
145-
if (!dxbc::RootSignatureValidations::isValidRootFlag(MRS->Flags)) {
146-
return reportError(Ctx, "Invalid Root Signature flag value");
164+
MRSMap.insert(std::make_pair(F, MRS));
147165
}
148-
return false;
149-
}
150166

151-
static const Function *getEntryFunction(Module &M, ModuleMetadataInfo MMI) {
152-
153-
LLVMContext *Ctx = &M.getContext();
154-
if (MMI.EntryPropertyVec.size() != 1) {
155-
reportError(Ctx, "More than one entry function defined.");
156-
// needed to stop compilation
157-
report_fatal_error("Invalid Root Signature Definition", false);
158-
return nullptr;
159-
}
160-
return MMI.EntryPropertyVec[0].Entry;
167+
return MRSMap;
161168
}
162169

163-
std::optional<ModuleRootSignature>
164-
ModuleRootSignature::analyzeModule(Module &M, const Function *F) {
165-
166-
LLVMContext *Ctx = &M.getContext();
170+
AnalysisKey RootSignatureAnalysis::Key;
167171

168-
ModuleRootSignature MRS;
172+
SmallDenseMap<const Function *, ModuleRootSignature>
173+
RootSignatureAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
174+
return analyzeModule(M);
175+
}
169176

170-
NamedMDNode *RootSignatureNode = M.getNamedMetadata("dx.rootsignatures");
171-
if (RootSignatureNode == nullptr)
172-
return std::nullopt;
177+
//===----------------------------------------------------------------------===//
173178

174-
if (parse(Ctx, &MRS, RootSignatureNode, F) || validate(Ctx, &MRS)) {
175-
// needed to stop compilation
176-
report_fatal_error("Invalid Root Signature Definition", false);
177-
return std::nullopt;
179+
static void printSpaces(raw_ostream &Stream, unsigned int Count) {
180+
for (unsigned int I = 0; I < Count; ++I) {
181+
Stream << ' ';
178182
}
179-
180-
return MRS;
181183
}
182184

183-
AnalysisKey RootSignatureAnalysis::Key;
185+
PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M,
186+
ModuleAnalysisManager &AM) {
187+
188+
SmallDenseMap<const Function *, ModuleRootSignature> &MRSMap =
189+
AM.getResult<RootSignatureAnalysis>(M);
190+
OS << "Root Signature Definitions"
191+
<< "\n";
192+
uint8_t Space = 0;
193+
for (const auto &P : MRSMap) {
194+
const auto &[Function, MRS] = P;
195+
OS << "Definition for '" << Function->getName() << "':\n";
196+
197+
// start root signature header
198+
Space++;
199+
printSpaces(OS, Space);
200+
OS << "Flags: " << format_hex(MRS.Flags, 8) << ":\n";
201+
Space--;
202+
// end root signature header
203+
}
184204

185-
std::optional<ModuleRootSignature>
186-
RootSignatureAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
187-
ModuleMetadataInfo MMI = AM.getResult<DXILMetadataAnalysis>(M);
188-
if (MMI.ShaderProfile == Triple::Library)
189-
return std::nullopt;
190-
return ModuleRootSignature::analyzeModule(M, getEntryFunction(M, MMI));
205+
return PreservedAnalyses::all();
191206
}
192207

193208
//===----------------------------------------------------------------------===//
194209
bool RootSignatureAnalysisWrapper::runOnModule(Module &M) {
195-
dxil::ModuleMetadataInfo &MMI =
196-
getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
197-
if (MMI.ShaderProfile == Triple::Library)
198-
return false;
199-
MRS = ModuleRootSignature::analyzeModule(M, getEntryFunction(M, MMI));
210+
MRS = analyzeModule(M);
200211
return false;
201212
}
202213

@@ -208,8 +219,8 @@ void RootSignatureAnalysisWrapper::getAnalysisUsage(AnalysisUsage &AU) const {
208219
char RootSignatureAnalysisWrapper::ID = 0;
209220

210221
INITIALIZE_PASS_BEGIN(RootSignatureAnalysisWrapper,
211-
"dx-root-signature-analysis",
222+
"dxil-root-signature-analysis",
212223
"DXIL Root Signature Analysis", true, true)
213-
INITIALIZE_PASS_DEPENDENCY(DXILMetadataAnalysisWrapperPass)
214-
INITIALIZE_PASS_END(RootSignatureAnalysisWrapper, "dx-root-signature-analysis",
224+
INITIALIZE_PASS_END(RootSignatureAnalysisWrapper,
225+
"dxil-root-signature-analysis",
215226
"DXIL Root Signature Analysis", true, true)

llvm/lib/Target/DirectX/DXILRootSignature.h

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
///
1212
//===----------------------------------------------------------------------===//
1313

14+
#include "llvm/ADT/DenseMap.h"
1415
#include "llvm/Analysis/DXILMetadataAnalysis.h"
1516
#include "llvm/IR/DiagnosticInfo.h"
1617
#include "llvm/IR/Metadata.h"
@@ -25,9 +26,8 @@ namespace dxil {
2526
enum class RootSignatureElementKind { None = 0, RootFlags = 1 };
2627

2728
struct ModuleRootSignature {
29+
ModuleRootSignature() = default;
2830
uint32_t Flags = 0;
29-
static std::optional<ModuleRootSignature> analyzeModule(Module &M,
30-
const Function *F);
3131
};
3232

3333
class RootSignatureAnalysis : public AnalysisInfoMixin<RootSignatureAnalysis> {
@@ -37,9 +37,10 @@ class RootSignatureAnalysis : public AnalysisInfoMixin<RootSignatureAnalysis> {
3737
public:
3838
RootSignatureAnalysis() = default;
3939

40-
using Result = std::optional<ModuleRootSignature>;
40+
using Result = SmallDenseMap<const Function *, ModuleRootSignature>;
4141

42-
std::optional<ModuleRootSignature> run(Module &M, ModuleAnalysisManager &AM);
42+
SmallDenseMap<const Function *, ModuleRootSignature>
43+
run(Module &M, ModuleAnalysisManager &AM);
4344
};
4445

4546
/// Wrapper pass for the legacy pass manager.
@@ -48,19 +49,34 @@ class RootSignatureAnalysis : public AnalysisInfoMixin<RootSignatureAnalysis> {
4849
/// passes which run through the legacy pass manager.
4950
class RootSignatureAnalysisWrapper : public ModulePass {
5051
private:
51-
std::optional<ModuleRootSignature> MRS;
52+
SmallDenseMap<const Function *, ModuleRootSignature> MRS;
5253

5354
public:
5455
static char ID;
5556

5657
RootSignatureAnalysisWrapper() : ModulePass(ID) {}
5758

58-
const std::optional<ModuleRootSignature> &getResult() const { return MRS; }
59+
bool hasForFunction(const Function *F) { return MRS.find(F) != MRS.end(); }
60+
61+
ModuleRootSignature getForFunction(const Function *F) {
62+
assert(hasForFunction(F));
63+
return MRS[F];
64+
}
5965

6066
bool runOnModule(Module &M) override;
6167

6268
void getAnalysisUsage(AnalysisUsage &AU) const override;
6369
};
6470

71+
/// Printer pass for RootSignatureAnalysis results.
72+
class RootSignatureAnalysisPrinter
73+
: public PassInfoMixin<RootSignatureAnalysisPrinter> {
74+
raw_ostream &OS;
75+
76+
public:
77+
explicit RootSignatureAnalysisPrinter(raw_ostream &OS) : OS(OS) {}
78+
PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);
79+
};
80+
6581
} // namespace dxil
6682
} // namespace llvm

llvm/lib/Target/DirectX/DirectXPassRegistry.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#endif
1919
MODULE_ANALYSIS("dx-shader-flags", dxil::ShaderFlagsAnalysis())
2020
MODULE_ANALYSIS("dxil-resource-md", DXILResourceMDAnalysis())
21+
MODULE_ANALYSIS("dxil-root-signature-analysis", dxil::RootSignatureAnalysis())
2122
#undef MODULE_ANALYSIS
2223

2324
#ifndef MODULE_PASS
@@ -31,6 +32,7 @@ MODULE_PASS("dxil-pretty-printer", DXILPrettyPrinterPass(dbgs()))
3132
MODULE_PASS("dxil-translate-metadata", DXILTranslateMetadata())
3233
// TODO: rename to print<foo> after NPM switch
3334
MODULE_PASS("print-dx-shader-flags", dxil::ShaderFlagsAnalysisPrinter(dbgs()))
35+
MODULE_PASS("print<dxil-root-signature>", dxil::RootSignatureAnalysisPrinter(dbgs()))
3436
#undef MODULE_PASS
3537

3638
#ifndef FUNCTION_PASS

llvm/lib/Target/DirectX/DirectXTargetMachine.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "DXILPrettyPrinter.h"
2020
#include "DXILResourceAccess.h"
2121
#include "DXILResourceAnalysis.h"
22+
#include "DXILRootSignature.h"
2223
#include "DXILShaderFlags.h"
2324
#include "DXILTranslateMetadata.h"
2425
#include "DXILWriter/DXILWriterPass.h"

llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Error.ll

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
; RUN: not llc %s --filetype=obj -o - 2>&1 | FileCheck %s
1+
; RUN: not opt -passes='print<dxil-root-signature>' %s -S -o - 2>&1 | FileCheck %s
22

33
target triple = "dxil-unknown-shadermodel6.0-compute"
44

55
; CHECK: error: Invalid format for Root Signature Definition. Pairs of function, root signature expected.
6+
; CHECK-NO: Root Signature Definitions
67

78

89
define void @main() #0 {

llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Flags-Error.ll

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
; RUN: not llc %s --filetype=obj -o - 2>&1 | FileCheck %s
1+
; RUN: not opt -passes='print<dxil-root-signature>' %s -S -o - 2>&1 | FileCheck %s
22

33
target triple = "dxil-unknown-shadermodel6.0-compute"
44

55
; CHECK: error: Invalid Root Element: NOTRootFlags
6+
; CHECK-NO: Root Signature Definitions
67

78

89
define void @main() #0 {

0 commit comments

Comments
 (0)