Skip to content

Commit 540d255

Browse files
[SPIRV] Add vector reduction instructions (llvm#82786)
This PR is to add vector reduction instructions according to https://llvm.org/docs/GlobalISel/GenericOpcode.html#vector-reduction-operations and widen in such a way a range of successful supported conversions, covering new cases of vector reduction instructions which IRTranslator is unable to resolve. By legalizing vector reduction instructions we introduce a new instruction patterns that should be addressed, including patterns that are delegated to pre-legalize step. To address this problem, a new pass is added that is to bring newly generated instructions after legalization to an aspect required by instruction selection. Expected overheads for existing cases is minimal, because a new pass is working only with newly introduced instructions, otherwise it's just a additional code traverse without any actions.
1 parent cad6ad2 commit 540d255

File tree

22 files changed

+3472
-53
lines changed

22 files changed

+3472
-53
lines changed

llvm/lib/Target/SPIRV/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ add_llvm_target(SPIRVCodeGen
2929
SPIRVMetadata.cpp
3030
SPIRVModuleAnalysis.cpp
3131
SPIRVPreLegalizer.cpp
32+
SPIRVPostLegalizer.cpp
3233
SPIRVPrepareFunctions.cpp
3334
SPIRVRegisterBankInfo.cpp
3435
SPIRVRegisterInfo.cpp

llvm/lib/Target/SPIRV/SPIRV.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ ModulePass *createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM);
2323
FunctionPass *createSPIRVStripConvergenceIntrinsicsPass();
2424
FunctionPass *createSPIRVRegularizerPass();
2525
FunctionPass *createSPIRVPreLegalizerPass();
26+
FunctionPass *createSPIRVPostLegalizerPass();
2627
FunctionPass *createSPIRVEmitIntrinsicsPass(SPIRVTargetMachine *TM);
2728
InstructionSelector *
2829
createSPIRVInstructionSelector(const SPIRVTargetMachine &TM,
@@ -32,6 +33,7 @@ createSPIRVInstructionSelector(const SPIRVTargetMachine &TM,
3233
void initializeSPIRVModuleAnalysisPass(PassRegistry &);
3334
void initializeSPIRVConvergenceRegionAnalysisWrapperPassPass(PassRegistry &);
3435
void initializeSPIRVPreLegalizerPass(PassRegistry &);
36+
void initializeSPIRVPostLegalizerPass(PassRegistry &);
3537
void initializeSPIRVEmitIntrinsicsPass(PassRegistry &);
3638
} // namespace llvm
3739

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,8 @@ class SPIRVInstructionSelector : public InstructionSelector {
191191
bool selectLog10(Register ResVReg, const SPIRVType *ResType,
192192
MachineInstr &I) const;
193193

194+
bool selectUnmergeValues(MachineInstr &I) const;
195+
194196
Register buildI32Constant(uint32_t Val, MachineInstr &I,
195197
const SPIRVType *ResType = nullptr) const;
196198

@@ -243,7 +245,7 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) {
243245
if (Opcode == SPIRV::ASSIGN_TYPE) { // These pseudos aren't needed any more.
244246
auto *Def = MRI->getVRegDef(I.getOperand(1).getReg());
245247
if (isTypeFoldingSupported(Def->getOpcode())) {
246-
auto Res = selectImpl(I, *CoverageInfo);
248+
bool Res = selectImpl(I, *CoverageInfo);
247249
assert(Res || Def->getOpcode() == TargetOpcode::G_CONSTANT);
248250
if (Res)
249251
return Res;
@@ -271,7 +273,8 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) {
271273
assert(!HasDefs || ResType || I.getOpcode() == TargetOpcode::G_GLOBAL_VALUE);
272274
if (spvSelect(ResVReg, ResType, I)) {
273275
if (HasDefs) // Make all vregs 32 bits (for SPIR-V IDs).
274-
MRI->setType(ResVReg, LLT::scalar(32));
276+
for (unsigned i = 0; i < I.getNumDefs(); ++i)
277+
MRI->setType(I.getOperand(i).getReg(), LLT::scalar(32));
275278
I.removeFromParent();
276279
return true;
277280
}
@@ -281,9 +284,9 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) {
281284
bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
282285
const SPIRVType *ResType,
283286
MachineInstr &I) const {
284-
assert(!isTypeFoldingSupported(I.getOpcode()) ||
285-
I.getOpcode() == TargetOpcode::G_CONSTANT);
286287
const unsigned Opcode = I.getOpcode();
288+
if (isTypeFoldingSupported(Opcode) && Opcode != TargetOpcode::G_CONSTANT)
289+
return selectImpl(I, *CoverageInfo);
287290
switch (Opcode) {
288291
case TargetOpcode::G_CONSTANT:
289292
return selectConst(ResVReg, ResType, I.getOperand(1).getCImm()->getValue(),
@@ -519,6 +522,9 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
519522
case TargetOpcode::G_STACKRESTORE:
520523
return selectStackRestore(I);
521524

525+
case TargetOpcode::G_UNMERGE_VALUES:
526+
return selectUnmergeValues(I);
527+
522528
default:
523529
return false;
524530
}
@@ -777,6 +783,41 @@ bool SPIRVInstructionSelector::selectAtomicRMW(Register ResVReg,
777783
return Result;
778784
}
779785

786+
bool SPIRVInstructionSelector::selectUnmergeValues(MachineInstr &I) const {
787+
unsigned ArgI = I.getNumOperands() - 1;
788+
Register SrcReg =
789+
I.getOperand(ArgI).isReg() ? I.getOperand(ArgI).getReg() : Register(0);
790+
SPIRVType *DefType =
791+
SrcReg.isValid() ? GR.getSPIRVTypeForVReg(SrcReg) : nullptr;
792+
if (!DefType || DefType->getOpcode() != SPIRV::OpTypeVector)
793+
report_fatal_error(
794+
"cannot select G_UNMERGE_VALUES with a non-vector argument");
795+
796+
SPIRVType *ScalarType =
797+
GR.getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
798+
MachineBasicBlock &BB = *I.getParent();
799+
bool Res = false;
800+
for (unsigned i = 0; i < I.getNumDefs(); ++i) {
801+
Register ResVReg = I.getOperand(i).getReg();
802+
SPIRVType *ResType = GR.getSPIRVTypeForVReg(ResVReg);
803+
if (!ResType) {
804+
// There was no "assign type" actions, let's fix this now
805+
ResType = ScalarType;
806+
MRI->setRegClass(ResVReg, &SPIRV::IDRegClass);
807+
MRI->setType(ResVReg, LLT::scalar(GR.getScalarOrVectorBitWidth(ResType)));
808+
GR.assignSPIRVTypeToVReg(ResType, ResVReg, *GR.CurMF);
809+
}
810+
auto MIB =
811+
BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
812+
.addDef(ResVReg)
813+
.addUse(GR.getSPIRVTypeID(ResType))
814+
.addUse(SrcReg)
815+
.addImm(static_cast<int64_t>(i));
816+
Res |= MIB.constrainAllUses(TII, TRI, RBI);
817+
}
818+
return Res;
819+
}
820+
780821
bool SPIRVInstructionSelector::selectFence(MachineInstr &I) const {
781822
AtomicOrdering AO = AtomicOrdering(I.getOperand(0).getImm());
782823
uint32_t MemSem = static_cast<uint32_t>(getMemSemantics(AO));

llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,11 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
113113
v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, v8s8, v8s16,
114114
v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
115115

116+
auto allVectors = {v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8,
117+
v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32,
118+
v4s64, v8s1, v8s8, v8s16, v8s32, v8s64, v16s1,
119+
v16s8, v16s16, v16s32, v16s64};
120+
116121
auto allScalarsAndVectors = {
117122
s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64,
118123
v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64,
@@ -146,6 +151,24 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
146151
// TODO: add proper rules for vectors legalization.
147152
getActionDefinitionsBuilder({G_BUILD_VECTOR, G_SHUFFLE_VECTOR}).alwaysLegal();
148153

154+
// Vector Reduction Operations
155+
getActionDefinitionsBuilder(
156+
{G_VECREDUCE_SMIN, G_VECREDUCE_SMAX, G_VECREDUCE_UMIN, G_VECREDUCE_UMAX,
157+
G_VECREDUCE_ADD, G_VECREDUCE_MUL, G_VECREDUCE_FMUL, G_VECREDUCE_FMIN,
158+
G_VECREDUCE_FMAX, G_VECREDUCE_FMINIMUM, G_VECREDUCE_FMAXIMUM,
159+
G_VECREDUCE_OR, G_VECREDUCE_AND, G_VECREDUCE_XOR})
160+
.legalFor(allVectors)
161+
.scalarize(1)
162+
.lower();
163+
164+
getActionDefinitionsBuilder({G_VECREDUCE_SEQ_FADD, G_VECREDUCE_SEQ_FMUL})
165+
.scalarize(2)
166+
.lower();
167+
168+
// Merge/Unmerge
169+
// TODO: add proper legalization rules.
170+
getActionDefinitionsBuilder(G_UNMERGE_VALUES).alwaysLegal();
171+
149172
getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE})
150173
.legalIf(all(typeInSet(0, allWritablePtrs), typeInSet(1, allPtrs)));
151174

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
//===-- SPIRVPostLegalizer.cpp - ammend info after legalization -*- C++ -*-===//
2+
//
3+
// which may appear after the legalizer pass
4+
//
5+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
6+
// See https://llvm.org/LICENSE.txt for license information.
7+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8+
//
9+
//===----------------------------------------------------------------------===//
10+
//
11+
// The pass partially apply pre-legalization logic to new instructions inserted
12+
// as a result of legalization:
13+
// - assigns SPIR-V types to registers for new instructions.
14+
//
15+
//===----------------------------------------------------------------------===//
16+
17+
#include "SPIRV.h"
18+
#include "SPIRVSubtarget.h"
19+
#include "SPIRVUtils.h"
20+
#include "llvm/ADT/PostOrderIterator.h"
21+
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
22+
#include "llvm/IR/Attributes.h"
23+
#include "llvm/IR/Constants.h"
24+
#include "llvm/IR/DebugInfoMetadata.h"
25+
#include "llvm/IR/IntrinsicsSPIRV.h"
26+
#include "llvm/Target/TargetIntrinsicInfo.h"
27+
28+
#define DEBUG_TYPE "spirv-postlegalizer"
29+
30+
using namespace llvm;
31+
32+
namespace {
33+
class SPIRVPostLegalizer : public MachineFunctionPass {
34+
public:
35+
static char ID;
36+
SPIRVPostLegalizer() : MachineFunctionPass(ID) {
37+
initializeSPIRVPostLegalizerPass(*PassRegistry::getPassRegistry());
38+
}
39+
bool runOnMachineFunction(MachineFunction &MF) override;
40+
};
41+
} // namespace
42+
43+
// Defined in SPIRVLegalizerInfo.cpp.
44+
extern bool isTypeFoldingSupported(unsigned Opcode);
45+
46+
namespace llvm {
47+
// Defined in SPIRVPreLegalizer.cpp.
48+
extern Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,
49+
SPIRVGlobalRegistry *GR,
50+
MachineIRBuilder &MIB,
51+
MachineRegisterInfo &MRI);
52+
extern void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
53+
MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR);
54+
} // namespace llvm
55+
56+
static bool isMetaInstrGET(unsigned Opcode) {
57+
return Opcode == SPIRV::GET_ID || Opcode == SPIRV::GET_fID ||
58+
Opcode == SPIRV::GET_pID || Opcode == SPIRV::GET_vID ||
59+
Opcode == SPIRV::GET_vfID;
60+
}
61+
62+
static bool mayBeInserted(unsigned Opcode) {
63+
switch (Opcode) {
64+
case TargetOpcode::G_SMAX:
65+
case TargetOpcode::G_UMAX:
66+
case TargetOpcode::G_SMIN:
67+
case TargetOpcode::G_UMIN:
68+
case TargetOpcode::G_FMINNUM:
69+
case TargetOpcode::G_FMINIMUM:
70+
case TargetOpcode::G_FMAXNUM:
71+
case TargetOpcode::G_FMAXIMUM:
72+
return true;
73+
default:
74+
return isTypeFoldingSupported(Opcode);
75+
}
76+
}
77+
78+
static void processNewInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
79+
MachineIRBuilder MIB) {
80+
MachineRegisterInfo &MRI = MF.getRegInfo();
81+
82+
for (MachineBasicBlock &MBB : MF) {
83+
for (MachineInstr &I : MBB) {
84+
const unsigned Opcode = I.getOpcode();
85+
if (Opcode == TargetOpcode::G_UNMERGE_VALUES) {
86+
unsigned ArgI = I.getNumOperands() - 1;
87+
Register SrcReg = I.getOperand(ArgI).isReg()
88+
? I.getOperand(ArgI).getReg()
89+
: Register(0);
90+
SPIRVType *DefType =
91+
SrcReg.isValid() ? GR->getSPIRVTypeForVReg(SrcReg) : nullptr;
92+
if (!DefType || DefType->getOpcode() != SPIRV::OpTypeVector)
93+
report_fatal_error(
94+
"cannot select G_UNMERGE_VALUES with a non-vector argument");
95+
SPIRVType *ScalarType =
96+
GR->getSPIRVTypeForVReg(DefType->getOperand(1).getReg());
97+
for (unsigned i = 0; i < I.getNumDefs(); ++i) {
98+
Register ResVReg = I.getOperand(i).getReg();
99+
SPIRVType *ResType = GR->getSPIRVTypeForVReg(ResVReg);
100+
if (!ResType) {
101+
// There was no "assign type" actions, let's fix this now
102+
ResType = ScalarType;
103+
MRI.setRegClass(ResVReg, &SPIRV::IDRegClass);
104+
MRI.setType(ResVReg,
105+
LLT::scalar(GR->getScalarOrVectorBitWidth(ResType)));
106+
GR->assignSPIRVTypeToVReg(ResType, ResVReg, *GR->CurMF);
107+
}
108+
}
109+
} else if (mayBeInserted(Opcode) && I.getNumDefs() == 1 &&
110+
I.getNumOperands() > 1 && I.getOperand(1).isReg()) {
111+
// Legalizer may have added a new instructions and introduced new
112+
// registers, we must decorate them as if they were introduced in a
113+
// non-automatic way
114+
Register ResVReg = I.getOperand(0).getReg();
115+
SPIRVType *ResVType = GR->getSPIRVTypeForVReg(ResVReg);
116+
// Check if the register defined by the instruction is newly generated
117+
// or already processed
118+
if (!ResVType) {
119+
// Set type of the defined register
120+
ResVType = GR->getSPIRVTypeForVReg(I.getOperand(1).getReg());
121+
// Check if we have type defined for operands of the new instruction
122+
if (!ResVType)
123+
continue;
124+
// Set type & class
125+
MRI.setRegClass(ResVReg, &SPIRV::IDRegClass);
126+
MRI.setType(ResVReg,
127+
LLT::scalar(GR->getScalarOrVectorBitWidth(ResVType)));
128+
GR->assignSPIRVTypeToVReg(ResVType, ResVReg, *GR->CurMF);
129+
}
130+
// If this is a simple operation that is to be reduced by TableGen
131+
// definition we must apply some of pre-legalizer rules here
132+
if (isTypeFoldingSupported(Opcode)) {
133+
// Check if the instruction newly generated or already processed
134+
MachineInstr *NextMI = I.getNextNode();
135+
if (NextMI && isMetaInstrGET(NextMI->getOpcode()))
136+
continue;
137+
// Restore usual instructions pattern for the newly inserted
138+
// instruction
139+
MRI.setRegClass(ResVReg, MRI.getType(ResVReg).isVector()
140+
? &SPIRV::IDRegClass
141+
: &SPIRV::ANYIDRegClass);
142+
MRI.setType(ResVReg, LLT::scalar(32));
143+
insertAssignInstr(ResVReg, nullptr, ResVType, GR, MIB, MRI);
144+
processInstr(I, MIB, MRI, GR);
145+
}
146+
}
147+
}
148+
}
149+
}
150+
151+
bool SPIRVPostLegalizer::runOnMachineFunction(MachineFunction &MF) {
152+
// Initialize the type registry.
153+
const SPIRVSubtarget &ST = MF.getSubtarget<SPIRVSubtarget>();
154+
SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
155+
GR->setCurrentFunc(MF);
156+
MachineIRBuilder MIB(MF);
157+
158+
processNewInstrs(MF, GR, MIB);
159+
160+
return true;
161+
}
162+
163+
INITIALIZE_PASS(SPIRVPostLegalizer, DEBUG_TYPE, "SPIRV post legalizer", false,
164+
false)
165+
166+
char SPIRVPostLegalizer::ID = 0;
167+
168+
FunctionPass *llvm::createSPIRVPostLegalizerPass() {
169+
return new SPIRVPostLegalizer();
170+
}

0 commit comments

Comments
 (0)