26
26
#include " llvm/Pass.h"
27
27
#include " llvm/Support/Error.h"
28
28
#include " llvm/Support/ErrorHandling.h"
29
+ #include < cstdint>
29
30
#include < optional>
31
+ #include < utility>
30
32
31
33
using namespace llvm ;
32
34
using namespace llvm ::dxil;
@@ -37,20 +39,20 @@ static bool reportError(LLVMContext *Ctx, Twine Message,
37
39
return true ;
38
40
}
39
41
40
- static bool parseRootFlags (LLVMContext *Ctx, ModuleRootSignature * MRS,
42
+ static bool parseRootFlags (LLVMContext *Ctx, ModuleRootSignature & MRS,
41
43
MDNode *RootFlagNode) {
42
44
43
45
if (RootFlagNode->getNumOperands () != 2 )
44
46
return reportError (Ctx, " Invalid format for RootFlag Element" );
45
47
46
48
auto *Flag = mdconst::extract<ConstantInt>(RootFlagNode->getOperand (1 ));
47
- MRS-> Flags = Flag->getZExtValue ();
49
+ MRS. Flags = Flag->getZExtValue ();
48
50
49
51
return false ;
50
52
}
51
53
52
54
static bool parseRootSignatureElement (LLVMContext *Ctx,
53
- ModuleRootSignature * MRS,
55
+ ModuleRootSignature & MRS,
54
56
MDNode *Element) {
55
57
MDString *ElementText = cast<MDString>(Element->getOperand (0 ));
56
58
if (ElementText == nullptr )
@@ -73,8 +75,7 @@ static bool parseRootSignatureElement(LLVMContext *Ctx,
73
75
llvm_unreachable (" Root signature element kind not expected." );
74
76
}
75
77
76
- static bool parse (LLVMContext *Ctx, ModuleRootSignature *MRS, NamedMDNode *Root,
77
- const Function *EntryFunction) {
78
+ static bool parse (LLVMContext *Ctx, ModuleRootSignature &MRS, MDNode *Node) {
78
79
bool HasError = false ;
79
80
80
81
/* * Root Signature are specified as following in the metadata:
@@ -89,15 +90,46 @@ static bool parse(LLVMContext *Ctx, ModuleRootSignature *MRS, NamedMDNode *Root,
89
90
signature pair.
90
91
*/
91
92
92
- for (const MDNode *Node : Root->operands ()) {
93
- if (Node->getNumOperands () != 2 ) {
94
- HasError = reportError (
95
- Ctx, " Invalid format for Root Signature Definition. Pairs "
96
- " of function, root signature expected." );
93
+ // Get the Root Signature Description from the function signature pair.
94
+
95
+ // Loop through the Root Elements of the root signature.
96
+ for (const auto &Operand : Node->operands ()) {
97
+ MDNode *Element = dyn_cast<MDNode>(Operand);
98
+ if (Element == nullptr )
99
+ return reportError (Ctx, " Missing Root Element Metadata Node." );
100
+
101
+ HasError = HasError || parseRootSignatureElement (Ctx, MRS, Element);
102
+ }
103
+
104
+ return HasError;
105
+ }
106
+
107
+ static bool validate (LLVMContext *Ctx, const ModuleRootSignature &MRS) {
108
+ if (!dxbc::RootSignatureValidations::isValidRootFlag (MRS.Flags )) {
109
+ return reportError (Ctx, " Invalid Root Signature flag value" );
110
+ }
111
+ return false ;
112
+ }
113
+
114
+ static SmallDenseMap<const Function *, ModuleRootSignature>
115
+ analyzeModule (Module &M) {
116
+
117
+ LLVMContext *Ctx = &M.getContext ();
118
+
119
+ SmallDenseMap<const Function *, ModuleRootSignature> MRSMap;
120
+
121
+ NamedMDNode *RootSignatureNode = M.getNamedMetadata (" dx.rootsignatures" );
122
+ if (RootSignatureNode == nullptr )
123
+ return MRSMap;
124
+
125
+ for (const auto &RSDefNode : RootSignatureNode->operands ()) {
126
+ if (RSDefNode->getNumOperands () != 2 ) {
127
+ reportError (Ctx, " Invalid format for Root Signature Definition. Pairs "
128
+ " of function, root signature expected." );
97
129
continue ;
98
130
}
99
131
100
- const MDOperand &FunctionPointerMdNode = Node ->getOperand (0 );
132
+ const MDOperand &FunctionPointerMdNode = RSDefNode ->getOperand (0 );
101
133
if (FunctionPointerMdNode == nullptr ) {
102
134
// Function was pruned during compilation.
103
135
continue ;
@@ -106,97 +138,76 @@ static bool parse(LLVMContext *Ctx, ModuleRootSignature *MRS, NamedMDNode *Root,
106
138
ValueAsMetadata *VAM =
107
139
llvm::dyn_cast<ValueAsMetadata>(FunctionPointerMdNode.get ());
108
140
if (VAM == nullptr ) {
109
- HasError =
110
- reportError (Ctx, " First element of root signature is not a value" );
141
+ reportError (Ctx, " First element of root signature is not a value" );
111
142
continue ;
112
143
}
113
144
114
145
Function *F = dyn_cast<Function>(VAM->getValue ());
115
146
if (F == nullptr ) {
116
- HasError =
117
- reportError (Ctx, " First element of root signature is not a function" );
147
+ reportError (Ctx, " First element of root signature is not a function" );
118
148
continue ;
119
149
}
120
150
121
- if (F != EntryFunction)
122
- continue ;
151
+ MDNode *RootElementListNode =
152
+ dyn_cast<MDNode>(RSDefNode-> getOperand ( 1 ). get ()) ;
123
153
124
- // Get the Root Signature Description from the function signature pair.
125
- MDNode *RS = dyn_cast<MDNode>(Node->getOperand (1 ).get ());
126
-
127
- if (RS == nullptr ) {
154
+ if (RootElementListNode == nullptr ) {
128
155
reportError (Ctx, " Missing Root Element List Metadata node." );
129
- continue ;
130
156
}
131
157
132
- // Loop through the Root Elements of the root signature.
133
- for (const auto &Operand : RS->operands ()) {
134
- MDNode *Element = dyn_cast<MDNode>(Operand);
135
- if (Element == nullptr )
136
- return reportError (Ctx, " Missing Root Element Metadata Node." );
158
+ ModuleRootSignature MRS;
137
159
138
- HasError = HasError || parseRootSignatureElement (Ctx, MRS, Element);
160
+ if (parse (Ctx, MRS, RootElementListNode) || validate (Ctx, MRS)) {
161
+ return MRSMap;
139
162
}
140
- }
141
- return HasError;
142
- }
143
163
144
- static bool validate (LLVMContext *Ctx, ModuleRootSignature *MRS) {
145
- if (!dxbc::RootSignatureValidations::isValidRootFlag (MRS->Flags )) {
146
- return reportError (Ctx, " Invalid Root Signature flag value" );
164
+ MRSMap.insert (std::make_pair (F, MRS));
147
165
}
148
- return false ;
149
- }
150
166
151
- static const Function *getEntryFunction (Module &M, ModuleMetadataInfo MMI) {
152
-
153
- LLVMContext *Ctx = &M.getContext ();
154
- if (MMI.EntryPropertyVec .size () != 1 ) {
155
- reportError (Ctx, " More than one entry function defined." );
156
- // needed to stop compilation
157
- report_fatal_error (" Invalid Root Signature Definition" , false );
158
- return nullptr ;
159
- }
160
- return MMI.EntryPropertyVec [0 ].Entry ;
167
+ return MRSMap;
161
168
}
162
169
163
- std::optional<ModuleRootSignature>
164
- ModuleRootSignature::analyzeModule (Module &M, const Function *F) {
165
-
166
- LLVMContext *Ctx = &M.getContext ();
170
+ AnalysisKey RootSignatureAnalysis::Key;
167
171
168
- ModuleRootSignature MRS;
172
+ SmallDenseMap<const Function *, ModuleRootSignature>
173
+ RootSignatureAnalysis::run (Module &M, ModuleAnalysisManager &AM) {
174
+ return analyzeModule (M);
175
+ }
169
176
170
- NamedMDNode *RootSignatureNode = M.getNamedMetadata (" dx.rootsignatures" );
171
- if (RootSignatureNode == nullptr )
172
- return std::nullopt;
177
+ // ===----------------------------------------------------------------------===//
173
178
174
- if (parse (Ctx, &MRS, RootSignatureNode, F) || validate (Ctx, &MRS)) {
175
- // needed to stop compilation
176
- report_fatal_error (" Invalid Root Signature Definition" , false );
177
- return std::nullopt;
179
+ static void printSpaces (raw_ostream &Stream, unsigned int Count) {
180
+ for (unsigned int I = 0 ; I < Count; ++I) {
181
+ Stream << ' ' ;
178
182
}
179
-
180
- return MRS;
181
183
}
182
184
183
- AnalysisKey RootSignatureAnalysis::Key;
185
+ PreservedAnalyses RootSignatureAnalysisPrinter::run (Module &M,
186
+ ModuleAnalysisManager &AM) {
187
+
188
+ SmallDenseMap<const Function *, ModuleRootSignature> &MRSMap =
189
+ AM.getResult <RootSignatureAnalysis>(M);
190
+ OS << " Root Signature Definitions"
191
+ << " \n " ;
192
+ uint8_t Space = 0 ;
193
+ for (const auto &P : MRSMap) {
194
+ const auto &[Function, MRS] = P;
195
+ OS << " Definition for '" << Function->getName () << " ':\n " ;
196
+
197
+ // start root signature header
198
+ Space++;
199
+ printSpaces (OS, Space);
200
+ OS << " Flags: " << format_hex (MRS.Flags , 8 ) << " :\n " ;
201
+ Space--;
202
+ // end root signature header
203
+ }
184
204
185
- std::optional<ModuleRootSignature>
186
- RootSignatureAnalysis::run (Module &M, ModuleAnalysisManager &AM) {
187
- ModuleMetadataInfo MMI = AM.getResult <DXILMetadataAnalysis>(M);
188
- if (MMI.ShaderProfile == Triple::Library)
189
- return std::nullopt;
190
- return ModuleRootSignature::analyzeModule (M, getEntryFunction (M, MMI));
205
+ return PreservedAnalyses::all ();
191
206
}
192
207
193
208
// ===----------------------------------------------------------------------===//
194
209
bool RootSignatureAnalysisWrapper::runOnModule (Module &M) {
195
- dxil::ModuleMetadataInfo &MMI =
196
- getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata ();
197
- if (MMI.ShaderProfile == Triple::Library)
198
- return false ;
199
- MRS = ModuleRootSignature::analyzeModule (M, getEntryFunction (M, MMI));
210
+ MRS = analyzeModule (M);
200
211
return false ;
201
212
}
202
213
@@ -208,8 +219,8 @@ void RootSignatureAnalysisWrapper::getAnalysisUsage(AnalysisUsage &AU) const {
208
219
char RootSignatureAnalysisWrapper::ID = 0 ;
209
220
210
221
INITIALIZE_PASS_BEGIN (RootSignatureAnalysisWrapper,
211
- " dx -root-signature-analysis" ,
222
+ " dxil -root-signature-analysis" ,
212
223
" DXIL Root Signature Analysis" , true , true )
213
- INITIALIZE_PASS_DEPENDENCY(DXILMetadataAnalysisWrapperPass)
214
- INITIALIZE_PASS_END(RootSignatureAnalysisWrapper, " dx -root-signature-analysis" ,
224
+ INITIALIZE_PASS_END(RootSignatureAnalysisWrapper,
225
+ " dxil -root-signature-analysis" ,
215
226
" DXIL Root Signature Analysis" , true , true )
0 commit comments