Skip to content

Commit d9caef5

Browse files
authored
[NFC][SPIR-V] Refactor SpirvGroupNonUniformOps (#6596)
A follow-up change will use the PartitionedExclusiveScanNV GroupOperation, which requires that an additional operand is added to all GroupNonUniformArithmetic instructions. This means that some of the SPIR-V opcodes which are currently categorized as unary will become either unary or binary depending on the GroupOp. Since the arity distinctions between the OpGroupNonUniform* instructions were already somewhat arbitrary, I'm prefacing that change by refactoring them into a single SpirvGroupNonUniformOp instruction type for better reusability. Follow up: #6608
1 parent ff623f8 commit d9caef5

13 files changed

+125
-262
lines changed

tools/clang/include/clang/SPIRV/SpirvBuilder.h

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -238,17 +238,10 @@ class SpirvBuilder {
238238

239239
/// \brief Creates an operation with the given OpGroupNonUniform* SPIR-V
240240
/// opcode.
241-
SpirvNonUniformElect *createGroupNonUniformElect(spv::Op op,
242-
QualType resultType,
243-
spv::Scope execScope,
244-
SourceLocation);
245-
SpirvNonUniformUnaryOp *createGroupNonUniformUnaryOp(
246-
SourceLocation, spv::Op op, QualType resultType, spv::Scope execScope,
247-
SpirvInstruction *operand,
248-
llvm::Optional<spv::GroupOperation> groupOp = llvm::None);
249-
SpirvNonUniformBinaryOp *createGroupNonUniformBinaryOp(
241+
SpirvGroupNonUniformOp *createGroupNonUniformOp(
250242
spv::Op op, QualType resultType, spv::Scope execScope,
251-
SpirvInstruction *operand1, SpirvInstruction *operand2, SourceLocation);
243+
llvm::ArrayRef<SpirvInstruction *> operands, SourceLocation,
244+
llvm::Optional<spv::GroupOperation> groupOp = llvm::None);
252245

253246
/// \brief Creates an atomic instruction with the given parameters and returns
254247
/// its pointer.

tools/clang/include/clang/SPIRV/SpirvInstruction.h

Lines changed: 17 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,7 @@ class SpirvInstruction {
111111

112112
IK_SetMeshOutputsEXT, // OpSetMeshOutputsEXT
113113

114-
// The following section is for group non-uniform instructions.
115-
// Used by LLVM-style RTTI; order matters.
116-
IK_GroupNonUniformBinaryOp, // Group non-uniform binary operations
117-
IK_GroupNonUniformElect, // OpGroupNonUniformElect
118-
IK_GroupNonUniformUnaryOp, // Group non-uniform unary operations
114+
IK_GroupNonUniformOp, // Group non-uniform operations
119115

120116
IK_ImageOp, // OpImage*
121117
IK_ImageQuery, // OpImageQuery*
@@ -1495,102 +1491,43 @@ class SpirvFunctionCall : public SpirvInstruction {
14951491
llvm::SmallVector<SpirvInstruction *, 4> args;
14961492
};
14971493

1498-
/// \brief Base for OpGroupNonUniform* instructions
1494+
/// \brief OpGroupNonUniform* instructions
14991495
class SpirvGroupNonUniformOp : public SpirvInstruction {
15001496
public:
1501-
// For LLVM-style RTTI
1502-
static bool classof(const SpirvInstruction *inst) {
1503-
return inst->getKind() >= IK_GroupNonUniformBinaryOp &&
1504-
inst->getKind() <= IK_GroupNonUniformUnaryOp;
1505-
}
1506-
1507-
spv::Scope getExecutionScope() const { return execScope; }
1508-
1509-
protected:
1510-
SpirvGroupNonUniformOp(Kind kind, spv::Op opcode, QualType resultType,
1511-
SourceLocation loc, spv::Scope scope);
1512-
1513-
private:
1514-
spv::Scope execScope;
1515-
};
1516-
1517-
/// \brief OpGroupNonUniform* binary instructions.
1518-
class SpirvNonUniformBinaryOp : public SpirvGroupNonUniformOp {
1519-
public:
1520-
SpirvNonUniformBinaryOp(spv::Op opcode, QualType resultType,
1521-
SourceLocation loc, spv::Scope scope,
1522-
SpirvInstruction *arg1, SpirvInstruction *arg2);
1523-
1524-
DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvNonUniformBinaryOp)
1525-
1526-
// For LLVM-style RTTI
1527-
static bool classof(const SpirvInstruction *inst) {
1528-
return inst->getKind() == IK_GroupNonUniformBinaryOp;
1529-
}
1530-
1531-
bool invokeVisitor(Visitor *v) override;
1532-
1533-
SpirvInstruction *getArg1() const { return arg1; }
1534-
SpirvInstruction *getArg2() const { return arg2; }
1535-
void replaceOperand(
1536-
llvm::function_ref<SpirvInstruction *(SpirvInstruction *)> remapOp,
1537-
bool inEntryFunctionWrapper) override {
1538-
arg1 = remapOp(arg1);
1539-
arg2 = remapOp(arg2);
1540-
}
1541-
1542-
private:
1543-
SpirvInstruction *arg1;
1544-
SpirvInstruction *arg2;
1545-
};
1546-
1547-
/// \brief OpGroupNonUniformElect instruction. This is currently the only
1548-
/// non-uniform instruction that takes no other arguments.
1549-
class SpirvNonUniformElect : public SpirvGroupNonUniformOp {
1550-
public:
1551-
SpirvNonUniformElect(QualType resultType, SourceLocation loc,
1552-
spv::Scope scope);
1497+
SpirvGroupNonUniformOp(spv::Op opcode, QualType resultType, spv::Scope scope,
1498+
llvm::ArrayRef<SpirvInstruction *> operands,
1499+
SourceLocation loc,
1500+
llvm::Optional<spv::GroupOperation> group);
15531501

1554-
DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvNonUniformElect)
1502+
DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvGroupNonUniformOp)
15551503

15561504
// For LLVM-style RTTI
15571505
static bool classof(const SpirvInstruction *inst) {
1558-
return inst->getKind() == IK_GroupNonUniformElect;
1506+
return inst->getKind() == IK_GroupNonUniformOp;
15591507
}
15601508

15611509
bool invokeVisitor(Visitor *v) override;
1562-
};
1563-
1564-
/// \brief OpGroupNonUniform* unary instructions.
1565-
class SpirvNonUniformUnaryOp : public SpirvGroupNonUniformOp {
1566-
public:
1567-
SpirvNonUniformUnaryOp(spv::Op opcode, QualType resultType,
1568-
SourceLocation loc, spv::Scope scope,
1569-
llvm::Optional<spv::GroupOperation> group,
1570-
SpirvInstruction *arg);
1571-
1572-
DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvNonUniformUnaryOp)
15731510

1574-
// For LLVM-style RTTI
1575-
static bool classof(const SpirvInstruction *inst) {
1576-
return inst->getKind() == IK_GroupNonUniformUnaryOp;
1577-
}
1511+
spv::Scope getExecutionScope() const { return execScope; }
15781512

1579-
bool invokeVisitor(Visitor *v) override;
1513+
llvm::ArrayRef<SpirvInstruction *> getOperands() const { return operands; }
15801514

1581-
SpirvInstruction *getArg() const { return arg; }
15821515
bool hasGroupOp() const { return groupOp.hasValue(); }
15831516
spv::GroupOperation getGroupOp() const { return groupOp.getValue(); }
1517+
15841518
void replaceOperand(
15851519
llvm::function_ref<SpirvInstruction *(SpirvInstruction *)> remapOp,
15861520
bool inEntryFunctionWrapper) override {
1587-
arg = remapOp(arg);
1521+
for (auto *operand : getOperands()) {
1522+
operand = remapOp(operand);
1523+
}
15881524
if (inEntryFunctionWrapper)
1589-
setAstResultType(arg->getAstResultType());
1525+
setAstResultType(getOperands()[0]->getAstResultType());
15901526
}
15911527

15921528
private:
1593-
SpirvInstruction *arg;
1529+
spv::Scope execScope;
1530+
llvm::SmallVector<SpirvInstruction *, 4> operands;
15941531
llvm::Optional<spv::GroupOperation> groupOp;
15951532
};
15961533

tools/clang/include/clang/SPIRV/SpirvVisitor.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,7 @@ class Visitor {
9696
DEFINE_VISIT_METHOD(SpirvEndPrimitive)
9797
DEFINE_VISIT_METHOD(SpirvExtInst)
9898
DEFINE_VISIT_METHOD(SpirvFunctionCall)
99-
DEFINE_VISIT_METHOD(SpirvNonUniformBinaryOp)
100-
DEFINE_VISIT_METHOD(SpirvNonUniformElect)
101-
DEFINE_VISIT_METHOD(SpirvNonUniformUnaryOp)
99+
DEFINE_VISIT_METHOD(SpirvGroupNonUniformOp)
102100
DEFINE_VISIT_METHOD(SpirvImageOp)
103101
DEFINE_VISIT_METHOD(SpirvImageQuery)
104102
DEFINE_VISIT_METHOD(SpirvImageSparseTexelsResident)

tools/clang/lib/SPIRV/EmitVisitor.cpp

Lines changed: 3 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1087,35 +1087,7 @@ bool EmitVisitor::visit(SpirvFunctionCall *inst) {
10871087
return true;
10881088
}
10891089

1090-
bool EmitVisitor::visit(SpirvNonUniformBinaryOp *inst) {
1091-
initInstruction(inst);
1092-
curInst.push_back(inst->getResultTypeId());
1093-
curInst.push_back(getOrAssignResultId<SpirvInstruction>(inst));
1094-
curInst.push_back(typeHandler.getOrCreateConstantInt(
1095-
llvm::APInt(32, static_cast<uint32_t>(inst->getExecutionScope())),
1096-
context.getUIntType(32), /* isSpecConst */ false));
1097-
curInst.push_back(getOrAssignResultId<SpirvInstruction>(inst->getArg1()));
1098-
curInst.push_back(getOrAssignResultId<SpirvInstruction>(inst->getArg2()));
1099-
finalizeInstruction(&mainBinary);
1100-
emitDebugNameForInstruction(getOrAssignResultId<SpirvInstruction>(inst),
1101-
inst->getDebugName());
1102-
return true;
1103-
}
1104-
1105-
bool EmitVisitor::visit(SpirvNonUniformElect *inst) {
1106-
initInstruction(inst);
1107-
curInst.push_back(inst->getResultTypeId());
1108-
curInst.push_back(getOrAssignResultId<SpirvInstruction>(inst));
1109-
curInst.push_back(typeHandler.getOrCreateConstantInt(
1110-
llvm::APInt(32, static_cast<uint32_t>(inst->getExecutionScope())),
1111-
context.getUIntType(32), /* isSpecConst */ false));
1112-
finalizeInstruction(&mainBinary);
1113-
emitDebugNameForInstruction(getOrAssignResultId<SpirvInstruction>(inst),
1114-
inst->getDebugName());
1115-
return true;
1116-
}
1117-
1118-
bool EmitVisitor::visit(SpirvNonUniformUnaryOp *inst) {
1090+
bool EmitVisitor::visit(SpirvGroupNonUniformOp *inst) {
11191091
initInstruction(inst);
11201092
curInst.push_back(inst->getResultTypeId());
11211093
curInst.push_back(getOrAssignResultId<SpirvInstruction>(inst));
@@ -1124,7 +1096,8 @@ bool EmitVisitor::visit(SpirvNonUniformUnaryOp *inst) {
11241096
context.getUIntType(32), /* isSpecConst */ false));
11251097
if (inst->hasGroupOp())
11261098
curInst.push_back(static_cast<uint32_t>(inst->getGroupOp()));
1127-
curInst.push_back(getOrAssignResultId<SpirvInstruction>(inst->getArg()));
1099+
for (auto *operand : inst->getOperands())
1100+
curInst.push_back(getOrAssignResultId<SpirvInstruction>(operand));
11281101
finalizeInstruction(&mainBinary);
11291102
emitDebugNameForInstruction(getOrAssignResultId<SpirvInstruction>(inst),
11301103
inst->getDebugName());

tools/clang/lib/SPIRV/EmitVisitor.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,9 +257,7 @@ class EmitVisitor : public Visitor {
257257
bool visit(SpirvCompositeInsert *) override;
258258
bool visit(SpirvExtInst *) override;
259259
bool visit(SpirvFunctionCall *) override;
260-
bool visit(SpirvNonUniformBinaryOp *) override;
261-
bool visit(SpirvNonUniformElect *) override;
262-
bool visit(SpirvNonUniformUnaryOp *) override;
260+
bool visit(SpirvGroupNonUniformOp *) override;
263261
bool visit(SpirvImageOp *) override;
264262
bool visit(SpirvImageQuery *) override;
265263
bool visit(SpirvImageSparseTexelsResident *) override;

tools/clang/lib/SPIRV/LiteralTypeVisitor.cpp

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -294,17 +294,9 @@ bool LiteralTypeVisitor::visit(SpirvVectorShuffle *inst) {
294294
return true;
295295
}
296296

297-
bool LiteralTypeVisitor::visit(SpirvNonUniformUnaryOp *inst) {
298-
// Went through each non-uniform binary operation and made sure the following
299-
// does not result in a wrong type deduction.
300-
tryToUpdateInstLitType(inst->getArg(), inst->getAstResultType());
301-
return true;
302-
}
303-
304-
bool LiteralTypeVisitor::visit(SpirvNonUniformBinaryOp *inst) {
305-
// Went through each non-uniform unary operation and made sure the following
306-
// does not result in a wrong type deduction.
307-
tryToUpdateInstLitType(inst->getArg1(), inst->getAstResultType());
297+
bool LiteralTypeVisitor::visit(SpirvGroupNonUniformOp *inst) {
298+
for (auto *operand : inst->getOperands())
299+
tryToUpdateInstLitType(operand, inst->getAstResultType());
308300
return true;
309301
}
310302

tools/clang/lib/SPIRV/LiteralTypeVisitor.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,7 @@ class LiteralTypeVisitor : public Visitor {
3232
bool visit(SpirvBitFieldExtract *) override;
3333
bool visit(SpirvSelect *) override;
3434
bool visit(SpirvVectorShuffle *) override;
35-
bool visit(SpirvNonUniformUnaryOp *) override;
36-
bool visit(SpirvNonUniformBinaryOp *) override;
35+
bool visit(SpirvGroupNonUniformOp *) override;
3736
bool visit(SpirvLoad *) override;
3837
bool visit(SpirvStore *) override;
3938
bool visit(SpirvConstantComposite *) override;

tools/clang/lib/SPIRV/PervertexInputVisitor.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ class PervertexInputVisitor : public Visitor {
9898
REMAP_FUNC_OP(ImageOp)
9999
REMAP_FUNC_OP(ExtInst)
100100
REMAP_FUNC_OP(Atomic)
101-
REMAP_FUNC_OP(NonUniformBinaryOp)
102101
REMAP_FUNC_OP(BitFieldInsert)
103102
REMAP_FUNC_OP(BitFieldExtract)
104103
REMAP_FUNC_OP(IntrinsicInstruction)
@@ -115,7 +114,7 @@ class PervertexInputVisitor : public Visitor {
115114
REMAP_FUNC_OP(Select)
116115
REMAP_FUNC_OP(Switch)
117116
REMAP_FUNC_OP(CopyObject)
118-
REMAP_FUNC_OP(NonUniformUnaryOp)
117+
REMAP_FUNC_OP(GroupNonUniformOp)
119118

120119
private:
121120
///< Whether in entry function wrapper, which will influence replace steps.

tools/clang/lib/SPIRV/PreciseVisitor.cpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -233,14 +233,9 @@ bool PreciseVisitor::visit(SpirvUnaryOp *inst) {
233233
return true;
234234
}
235235

236-
bool PreciseVisitor::visit(SpirvNonUniformBinaryOp *inst) {
237-
inst->getArg1()->setPrecise(inst->isPrecise());
238-
inst->getArg2()->setPrecise(inst->isPrecise());
239-
return true;
240-
}
241-
242-
bool PreciseVisitor::visit(SpirvNonUniformUnaryOp *inst) {
243-
inst->getArg()->setPrecise(inst->isPrecise());
236+
bool PreciseVisitor::visit(SpirvGroupNonUniformOp *inst) {
237+
for (auto *operand : inst->getOperands())
238+
operand->setPrecise(inst->isPrecise());
244239
return true;
245240
}
246241

tools/clang/lib/SPIRV/PreciseVisitor.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ class PreciseVisitor : public Visitor {
3636
bool visit(SpirvStore *) override;
3737
bool visit(SpirvBinaryOp *) override;
3838
bool visit(SpirvUnaryOp *) override;
39-
bool visit(SpirvNonUniformBinaryOp *) override;
40-
bool visit(SpirvNonUniformUnaryOp *) override;
39+
bool visit(SpirvGroupNonUniformOp *) override;
4140
bool visit(SpirvExtInst *) override;
4241
bool visit(SpirvFunctionCall *) override;
4342

0 commit comments

Comments
 (0)