Skip to content

Commit add9d62

Browse files
committed
Add the missing tests and fix formatting errors
1 parent bef6468 commit add9d62

File tree

5 files changed

+51
-19
lines changed

5 files changed

+51
-19
lines changed

source/opt/type_manager.cpp

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,17 @@ uint32_t TypeManager::GetTypeInstruction(const Type* type) {
337337
std::initializer_list<Operand>{{SPV_OPERAND_TYPE_ID, {subtype}}});
338338
break;
339339
}
340+
case Type::kNodePayloadArrayAMDX: {
341+
uint32_t subtype =
342+
GetTypeInstruction(type->AsNodePayloadArrayAMDX()->element_type());
343+
if (subtype == 0) {
344+
return 0;
345+
}
346+
typeInst = MakeUnique<Instruction>(
347+
context(), spv::Op::OpTypeNodePayloadArrayAMDX, 0, id,
348+
std::initializer_list<Operand>{{SPV_OPERAND_TYPE_ID, {subtype}}});
349+
break;
350+
}
340351
case Type::kStruct: {
341352
std::vector<Operand> ops;
342353
const Struct* structTy = type->AsStruct();
@@ -603,6 +614,13 @@ Type* TypeManager::RebuildType(uint32_t type_id, const Type& type) {
603614
MakeUnique<RuntimeArray>(RebuildType(GetId(ele_ty), *ele_ty));
604615
break;
605616
}
617+
case Type::kNodePayloadArrayAMDX: {
618+
const NodePayloadArrayAMDX* array_ty = type.AsNodePayloadArrayAMDX();
619+
const Type* ele_ty = array_ty->element_type();
620+
rebuilt_ty =
621+
MakeUnique<NodePayloadArrayAMDX>(RebuildType(GetId(ele_ty), *ele_ty));
622+
break;
623+
}
606624
case Type::kStruct: {
607625
const Struct* struct_ty = type.AsStruct();
608626
std::vector<const Type*> subtypes;
@@ -806,7 +824,7 @@ Type* TypeManager::RecordIfTypeDefinition(const Instruction& inst) {
806824
}
807825
break;
808826
case spv::Op::OpTypeNodePayloadArrayAMDX:
809-
type = new NodePayloadArray(GetType(inst.GetSingleWordInOperand(0)));
827+
type = new NodePayloadArrayAMDX(GetType(inst.GetSingleWordInOperand(0)));
810828
if (id_to_incomplete_type_.count(inst.GetSingleWordInOperand(0))) {
811829
incomplete_types_.emplace_back(inst.result_id(), type);
812830
id_to_incomplete_type_[inst.result_id()] = type;

source/opt/types.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ bool Type::IsUniqueType() const {
9292
case kStruct:
9393
case kArray:
9494
case kRuntimeArray:
95-
case kNodePayloadArray:
95+
case kNodePayloadArrayAMDX:
9696
return false;
9797
default:
9898
return true;
@@ -165,6 +165,7 @@ bool Type::operator==(const Type& other) const {
165165
DeclareKindCase(SampledImage);
166166
DeclareKindCase(Array);
167167
DeclareKindCase(RuntimeArray);
168+
DeclareKindCase(NodePayloadArrayAMDX);
168169
DeclareKindCase(Struct);
169170
DeclareKindCase(Opaque);
170171
DeclareKindCase(Pointer);
@@ -221,7 +222,7 @@ size_t Type::ComputeHashValue(size_t hash, SeenTypes* seen) const {
221222
DeclareKindCase(SampledImage);
222223
DeclareKindCase(Array);
223224
DeclareKindCase(RuntimeArray);
224-
DeclareKindCase(NodePayloadArray);
225+
DeclareKindCase(NodePayloadArrayAMDX);
225226
DeclareKindCase(Struct);
226227
DeclareKindCase(Opaque);
227228
DeclareKindCase(Pointer);
@@ -489,29 +490,31 @@ void RuntimeArray::ReplaceElementType(const Type* type) {
489490
element_type_ = type;
490491
}
491492

492-
NodePayloadArray::NodePayloadArray(const Type* type)
493-
: Type(kNodePayloadArray), element_type_(type) {
493+
NodePayloadArrayAMDX::NodePayloadArrayAMDX(const Type* type)
494+
: Type(kNodePayloadArrayAMDX), element_type_(type) {
494495
assert(!type->AsVoid());
495496
}
496497

497-
bool NodePayloadArray::IsSameImpl(const Type* that, IsSameCache* seen) const {
498-
const NodePayloadArray* rat = that->AsNodePayloadArray();
498+
bool NodePayloadArrayAMDX::IsSameImpl(const Type* that,
499+
IsSameCache* seen) const {
500+
const NodePayloadArrayAMDX* rat = that->AsNodePayloadArrayAMDX();
499501
if (!rat) return false;
500502
return element_type_->IsSameImpl(rat->element_type_, seen) &&
501503
HasSameDecorations(that);
502504
}
503505

504-
std::string NodePayloadArray::str() const {
506+
std::string NodePayloadArrayAMDX::str() const {
505507
std::ostringstream oss;
506508
oss << "[" << element_type_->str() << "]";
507509
return oss.str();
508510
}
509511

510-
size_t NodePayloadArray::ComputeExtraStateHash(size_t hash, SeenTypes* seen) const {
512+
size_t NodePayloadArrayAMDX::ComputeExtraStateHash(size_t hash,
513+
SeenTypes* seen) const {
511514
return element_type_->ComputeHashValue(hash, seen);
512515
}
513516

514-
void NodePayloadArray::ReplaceElementType(const Type* type) {
517+
void NodePayloadArrayAMDX::ReplaceElementType(const Type* type) {
515518
element_type_ = type;
516519
}
517520

source/opt/types.h

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class Sampler;
4848
class SampledImage;
4949
class Array;
5050
class RuntimeArray;
51-
class NodePayloadArray;
51+
class NodePayloadArrayAMDX;
5252
class Struct;
5353
class Opaque;
5454
class Pointer;
@@ -90,7 +90,7 @@ class Type {
9090
kSampledImage,
9191
kArray,
9292
kRuntimeArray,
93-
kNodePayloadArray,
93+
kNodePayloadArrayAMDX,
9494
kStruct,
9595
kOpaque,
9696
kPointer,
@@ -193,7 +193,7 @@ class Type {
193193
DeclareCastMethod(SampledImage)
194194
DeclareCastMethod(Array)
195195
DeclareCastMethod(RuntimeArray)
196-
DeclareCastMethod(NodePayloadArray)
196+
DeclareCastMethod(NodePayloadArrayAMDX)
197197
DeclareCastMethod(Struct)
198198
DeclareCastMethod(Opaque)
199199
DeclareCastMethod(Pointer)
@@ -439,16 +439,18 @@ class RuntimeArray : public Type {
439439
const Type* element_type_;
440440
};
441441

442-
class NodePayloadArray : public Type {
442+
class NodePayloadArrayAMDX : public Type {
443443
public:
444-
NodePayloadArray(const Type* element_type);
445-
NodePayloadArray(const NodePayloadArray&) = default;
444+
NodePayloadArrayAMDX(const Type* element_type);
445+
NodePayloadArrayAMDX(const NodePayloadArrayAMDX&) = default;
446446

447447
std::string str() const override;
448448
const Type* element_type() const { return element_type_; }
449449

450-
NodePayloadArray* AsNodePayloadArray() override { return this; }
451-
const NodePayloadArray* AsNodePayloadArray() const override { return this; }
450+
NodePayloadArrayAMDX* AsNodePayloadArrayAMDX() override { return this; }
451+
const NodePayloadArrayAMDX* AsNodePayloadArrayAMDX() const override {
452+
return this;
453+
}
452454

453455
size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
454456

test/opt/type_manager_test.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
// Copyright (c) 2016 Google Inc.
2+
// Modifications Copyright (C) 2024 Advanced Micro Devices, Inc. All rights
3+
// reserved.
24
//
35
// Licensed under the Apache License, Version 2.0 (the "License");
46
// you may not use this file except in compliance with the License.
@@ -175,6 +177,10 @@ std::vector<std::unique_ptr<Type>> GenerateAllTypes() {
175177
types.emplace_back(new RayQueryKHR());
176178
types.emplace_back(new HitObjectNV());
177179

180+
// NodePayloadArrayAMDX (SPV_AMDX_shader_enqueue)
181+
types.emplace_back(
182+
new NodePayloadArrayAMDX(new Struct(std::vector<const Type*>{s32})));
183+
178184
return types;
179185
}
180186

test/val/val_id_test.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
// Copyright (c) 2015-2016 The Khronos Group Inc.
2+
// Modifications Copyright (C) 2024 Advanced Micro Devices, Inc. All rights
3+
// reserved.
24
//
35
// Licensed under the Apache License, Version 2.0 (the "License");
46
// you may not use this file except in compliance with the License.
@@ -2469,7 +2471,8 @@ OpFunctionEnd
24692471
"be used with non-externally visible shader Storage Classes: "
24702472
"Workgroup, CrossWorkgroup, Private, Function, Input, Output, "
24712473
"RayPayloadKHR, IncomingRayPayloadKHR, HitAttributeKHR, "
2472-
"CallableDataKHR, IncomingCallableDataKHR, or UniformConstant")));
2474+
"CallableDataKHR, IncomingCallableDataKHR, NodePayloadAMDX, or "
2475+
"UniformConstant")));
24732476
}
24742477

24752478
TEST_P(ValidateIdWithMessage, OpVariableContainsBoolPrivateGood) {

0 commit comments

Comments
 (0)