@@ -55,6 +55,14 @@ static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
55
55
return std::nullopt;
56
56
}
57
57
58
+ static std::optional<StringRef> extractMdStringValue (MDNode *Node,
59
+ unsigned int OpId) {
60
+ MDString *NodeText = dyn_cast<MDString>(Node->getOperand (OpId));
61
+ if (NodeText == nullptr )
62
+ return std::nullopt;
63
+ return NodeText->getString ();
64
+ }
65
+
58
66
static bool parseRootFlags (LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
59
67
MDNode *RootFlagNode) {
60
68
@@ -107,17 +115,79 @@ static bool parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
107
115
return false ;
108
116
}
109
117
118
+ static bool parseRootDescriptors (LLVMContext *Ctx,
119
+ mcdxbc::RootSignatureDesc &RSD,
120
+ MDNode *RootDescriptorNode,
121
+ RootSignatureElementKind ElementKind) {
122
+ assert (ElementKind == RootSignatureElementKind::SRV ||
123
+ ElementKind == RootSignatureElementKind::UAV ||
124
+ ElementKind == RootSignatureElementKind::CBV &&
125
+ " parseRootDescriptors should only be called with RootDescriptor "
126
+ " element kind." );
127
+ if (RootDescriptorNode->getNumOperands () != 5 )
128
+ return reportError (Ctx, " Invalid format for Root Descriptor Element" );
129
+
130
+ dxbc::RTS0::v1::RootParameterHeader Header;
131
+ switch (ElementKind) {
132
+ case RootSignatureElementKind::SRV:
133
+ Header.ParameterType = llvm::to_underlying (dxbc::RootParameterType::SRV);
134
+ break ;
135
+ case RootSignatureElementKind::UAV:
136
+ Header.ParameterType = llvm::to_underlying (dxbc::RootParameterType::UAV);
137
+ break ;
138
+ case RootSignatureElementKind::CBV:
139
+ Header.ParameterType = llvm::to_underlying (dxbc::RootParameterType::CBV);
140
+ break ;
141
+ default :
142
+ llvm_unreachable (" invalid Root Descriptor kind" );
143
+ break ;
144
+ }
145
+
146
+ if (std::optional<uint32_t > Val = extractMdIntValue (RootDescriptorNode, 1 ))
147
+ Header.ShaderVisibility = *Val;
148
+ else
149
+ return reportError (Ctx, " Invalid value for ShaderVisibility" );
150
+
151
+ dxbc::RTS0::v2::RootDescriptor Descriptor;
152
+ if (std::optional<uint32_t > Val = extractMdIntValue (RootDescriptorNode, 2 ))
153
+ Descriptor.ShaderRegister = *Val;
154
+ else
155
+ return reportError (Ctx, " Invalid value for ShaderRegister" );
156
+
157
+ if (std::optional<uint32_t > Val = extractMdIntValue (RootDescriptorNode, 3 ))
158
+ Descriptor.RegisterSpace = *Val;
159
+ else
160
+ return reportError (Ctx, " Invalid value for RegisterSpace" );
161
+
162
+ if (RSD.Version == 1 ) {
163
+ RSD.ParametersContainer .addParameter (Header, Descriptor);
164
+ return false ;
165
+ }
166
+ assert (RSD.Version > 1 );
167
+
168
+ if (std::optional<uint32_t > Val = extractMdIntValue (RootDescriptorNode, 4 ))
169
+ Descriptor.Flags = *Val;
170
+ else
171
+ return reportError (Ctx, " Invalid value for Root Descriptor Flags" );
172
+
173
+ RSD.ParametersContainer .addParameter (Header, Descriptor);
174
+ return false ;
175
+ }
176
+
110
177
static bool parseRootSignatureElement (LLVMContext *Ctx,
111
178
mcdxbc::RootSignatureDesc &RSD,
112
179
MDNode *Element) {
113
- MDString * ElementText = cast<MDString> (Element-> getOperand ( 0 ) );
114
- if (ElementText == nullptr )
180
+ std::optional<StringRef> ElementText = extractMdStringValue (Element, 0 );
181
+ if (! ElementText. has_value () )
115
182
return reportError (Ctx, " Invalid format for Root Element" );
116
183
117
184
RootSignatureElementKind ElementKind =
118
- StringSwitch<RootSignatureElementKind>(ElementText-> getString () )
185
+ StringSwitch<RootSignatureElementKind>(* ElementText)
119
186
.Case (" RootFlags" , RootSignatureElementKind::RootFlags)
120
187
.Case (" RootConstants" , RootSignatureElementKind::RootConstants)
188
+ .Case (" RootCBV" , RootSignatureElementKind::CBV)
189
+ .Case (" RootSRV" , RootSignatureElementKind::SRV)
190
+ .Case (" RootUAV" , RootSignatureElementKind::UAV)
121
191
.Default (RootSignatureElementKind::Error);
122
192
123
193
switch (ElementKind) {
@@ -126,10 +196,12 @@ static bool parseRootSignatureElement(LLVMContext *Ctx,
126
196
return parseRootFlags (Ctx, RSD, Element);
127
197
case RootSignatureElementKind::RootConstants:
128
198
return parseRootConstants (Ctx, RSD, Element);
129
- break ;
199
+ case RootSignatureElementKind::CBV:
200
+ case RootSignatureElementKind::SRV:
201
+ case RootSignatureElementKind::UAV:
202
+ return parseRootDescriptors (Ctx, RSD, Element, ElementKind);
130
203
case RootSignatureElementKind::Error:
131
- return reportError (Ctx, " Invalid Root Signature Element: " +
132
- ElementText->getString ());
204
+ return reportError (Ctx, " Invalid Root Signature Element: " + *ElementText);
133
205
}
134
206
135
207
llvm_unreachable (" Unhandled RootSignatureElementKind enum." );
@@ -157,6 +229,18 @@ static bool verifyVersion(uint32_t Version) {
157
229
return (Version == 1 || Version == 2 );
158
230
}
159
231
232
+ static bool verifyRegisterValue (uint32_t RegisterValue) {
233
+ return RegisterValue != ~0U ;
234
+ }
235
+
236
+ // This Range is reserverved, therefore invalid, according to the spec
237
+ // https://github.com/llvm/wg-hlsl/blob/main/proposals/0002-root-signature-in-clang.md#all-the-values-should-be-legal
238
+ static bool verifyRegisterSpace (uint32_t RegisterSpace) {
239
+ return !(RegisterSpace >= 0xFFFFFFF0 && RegisterSpace <= 0xFFFFFFFF );
240
+ }
241
+
242
+ static bool verifyDescriptorFlag (uint32_t Flags) { return (Flags & ~0xE ) == 0 ; }
243
+
160
244
static bool validate (LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) {
161
245
162
246
if (!verifyVersion (RSD.Version )) {
@@ -174,6 +258,28 @@ static bool validate(LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) {
174
258
175
259
assert (dxbc::isValidParameterType (Info.Header .ParameterType ) &&
176
260
" Invalid value for ParameterType" );
261
+
262
+ switch (Info.Header .ParameterType ) {
263
+
264
+ case llvm::to_underlying (dxbc::RootParameterType::CBV):
265
+ case llvm::to_underlying (dxbc::RootParameterType::UAV):
266
+ case llvm::to_underlying (dxbc::RootParameterType::SRV): {
267
+ const dxbc::RTS0::v2::RootDescriptor &Descriptor =
268
+ RSD.ParametersContainer .getRootDescriptor (Info.Location );
269
+ if (!verifyRegisterValue (Descriptor.ShaderRegister ))
270
+ return reportValueError (Ctx, " ShaderRegister" ,
271
+ Descriptor.ShaderRegister );
272
+
273
+ if (!verifyRegisterSpace (Descriptor.RegisterSpace ))
274
+ return reportValueError (Ctx, " RegisterSpace" , Descriptor.RegisterSpace );
275
+
276
+ if (RSD.Version > 1 ) {
277
+ if (!verifyDescriptorFlag (Descriptor.Flags ))
278
+ return reportValueError (Ctx, " DescriptorFlag" , Descriptor.Flags );
279
+ }
280
+ break ;
281
+ }
282
+ }
177
283
}
178
284
179
285
return false ;
@@ -313,6 +419,20 @@ PreservedAnalyses RootSignatureAnalysisPrinter::run(Module &M,
313
419
<< " Shader Register: " << Constants.ShaderRegister << " \n " ;
314
420
OS << indent (Space + 2 )
315
421
<< " Num 32 Bit Values: " << Constants.Num32BitValues << " \n " ;
422
+ break ;
423
+ }
424
+ case llvm::to_underlying (dxbc::RootParameterType::CBV):
425
+ case llvm::to_underlying (dxbc::RootParameterType::UAV):
426
+ case llvm::to_underlying (dxbc::RootParameterType::SRV): {
427
+ const dxbc::RTS0::v2::RootDescriptor &Descriptor =
428
+ RS.ParametersContainer .getRootDescriptor (Loc);
429
+ OS << indent (Space + 2 )
430
+ << " Register Space: " << Descriptor.RegisterSpace << " \n " ;
431
+ OS << indent (Space + 2 )
432
+ << " Shader Register: " << Descriptor.ShaderRegister << " \n " ;
433
+ if (RS.Version > 1 )
434
+ OS << indent (Space + 2 ) << " Flags: " << Descriptor.Flags << " \n " ;
435
+ break ;
316
436
}
317
437
}
318
438
Space--;
0 commit comments