Skip to content

Commit daad16a

Browse files
authored
Merge pull request #150 from sx-aurora-dev/merge/ve-legalavl
Merge/ve legalavl
2 parents 1537915 + 8172717 commit daad16a

File tree

8 files changed

+181
-141
lines changed

8 files changed

+181
-141
lines changed

llvm/lib/Target/VE/VECustomDAG.cpp

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -701,10 +701,25 @@ bool isMaskType(EVT VT) {
701701
return false;
702702

703703
// an actual bit mask type
704-
if (VT.getVectorElementType() == MVT::i1)
705-
return true;
704+
return VT.getVectorElementType() == MVT::i1;
705+
}
706706

707-
// not a mask
707+
bool maySafelyIgnoreMask(unsigned VVPOpcode) {
708+
// Most arithmetic is safe without mask.
709+
if (isVVPTernaryOp(VVPOpcode))
710+
return VVPOpcode != VEISD::VVP_SELECT;
711+
if (isVVPBinaryOp(VVPOpcode)) {
712+
switch (VVPOpcode) {
713+
default:
714+
return true;
715+
case VEISD::VVP_UREM:
716+
case VEISD::VVP_SREM:
717+
case VEISD::VVP_UDIV:
718+
case VEISD::VVP_SDIV:
719+
case VEISD::VVP_FDIV:
720+
return false;
721+
}
722+
}
708723
return false;
709724
}
710725

@@ -1052,6 +1067,17 @@ SDValue VECustomDAG::createUniformConstMask(Packing Packing, unsigned NumElement
10521067
return DAG.getNOT(DL, Res, Res.getValueType());
10531068
}
10541069

1070+
bool isLegalAVL(SDValue AVL) { return AVL->getOpcode() == VEISD::LEGALAVL; }
1071+
1072+
std::pair<SDValue, bool> getAnnotatedNodeAVL(SDValue Op) {
1073+
SDValue AVL = getNodeAVL(Op);
1074+
if (!AVL)
1075+
return {SDValue(), true};
1076+
if (isLegalAVL(AVL))
1077+
return {AVL->getOperand(0), true};
1078+
return {AVL, false};
1079+
}
1080+
10551081
SDValue VECustomDAG::getConstant(uint64_t Val, EVT VT, bool IsTarget,
10561082
bool IsOpaque) const {
10571083
return DAG.getConstant(Val, DL, VT, IsTarget, IsOpaque);
@@ -1408,6 +1434,12 @@ raw_ostream &VECustomDAG::print(raw_ostream &Out, SDValue V) const {
14081434
return Out;
14091435
}
14101436

1437+
SDValue VECustomDAG::annotateLegalAVL(SDValue AVL) const {
1438+
if (isLegalAVL(AVL))
1439+
return AVL;
1440+
return getNode(VEISD::LEGALAVL, AVL.getValueType(), AVL);
1441+
}
1442+
14111443
/// } class VECustomDAG
14121444

14131445
} // namespace llvm

llvm/lib/Target/VE/VECustomDAG.h

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,52 @@ SDValue getUnpackAVL(SDValue N);
195195

196196
/// } Packing
197197

198+
bool isVVPOrVEC(unsigned);
199+
200+
bool maySafelyIgnoreMask(unsigned Opc);
201+
202+
/// The VE backend uses a two-staged process to lower and legalize vector
203+
/// instructions:
204+
//
205+
/// 1. VP and standard vector SDNodes are lowered to SDNodes of the VVP_* layer.
206+
//
207+
// All VVP nodes have a mask and an Active Vector Length (AVL) parameter.
208+
// The AVL parameters refers to the element position in the vector the VVP
209+
// node operates on.
210+
//
211+
//
212+
// 2. The VVP SDNodes are legalized. The AVL in a legal VVP node refers to
213+
// chunks of 64bit. We track this by wrapping the AVL in a LEGALAVL node.
214+
//
215+
// The AVL mechanism in the VE architecture always refers to chunks of
216+
// 64bit, regardless of the actual element type vector instructions are
217+
// operating on. For vector types v256.32 or v256.64 nothing needs to be
218+
// legalized since each element occupies a 64bit chunk - there is no
219+
// difference between counting 64bit chunks or element positions. However,
220+
// all vector types with > 256 elements store more than one logical element
221+
// per 64bit chunk and need to be transformed.
222+
// However legalization is performed, the resulting legal VVP SDNodes will
223+
// have a LEGALAVL node as their AVL operand. The LEGALAVL nodes wraps
224+
// around an AVL that refers to 64 bit chunks just as the architecture
225+
// demands - that is, the wrapped AVL is the correct setting for the VL
226+
// register for this VVP operation to get the desired behavior.
227+
//
228+
/// AVL Functions {
229+
// The AVL operand position of this node.
230+
Optional<int> getAVLPos(unsigned);
231+
232+
// Whether this is a LEGALAVL node.
233+
bool isLegalAVL(SDValue AVL);
234+
235+
// The AVL operand of this node.
236+
SDValue getNodeAVL(SDValue);
237+
238+
// Return the AVL operand of this node. If it is a LEGALAVL node, unwrap it.
239+
// Return with the boolean whether unwrapping happened.
240+
std::pair<SDValue, bool> getAnnotatedNodeAVL(SDValue);
241+
242+
/// } AVL Functions
243+
198244
/// Helper class for short hand custom node creation ///
199245
struct VECustomDAG {
200246
const VELoweringInfo &VLI;
@@ -479,6 +525,9 @@ struct VECustomDAG {
479525

480526
raw_ostream &print(raw_ostream &, SDValue) const;
481527
void dump(SDValue) const;
528+
529+
// Wrap AVL in a LEGALAVL node (unless it is one already).
530+
SDValue annotateLegalAVL(SDValue AVL) const;
482531
};
483532

484533
} // namespace llvm

llvm/lib/Target/VE/VEISelDAGToDAG.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,10 @@ void VEDAGToDAGISel::Select(SDNode *N) {
431431
return;
432432
}
433433

434+
// Late eliminate the LEGALAVL wrapper
435+
case VEISD::LEGALAVL:
436+
ReplaceNode(N, N->getOperand(0).getNode());
437+
return;
434438

435439
case VEISD::GLOBAL_BASE_REG:
436440
ReplaceNode(N, getGlobalBaseReg());

llvm/lib/Target/VE/VEISelLowering.cpp

Lines changed: 56 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,6 +1092,8 @@ const char *VETargetLowering::getTargetNodeName(unsigned Opcode) const {
10921092

10931093
TARGET_NODE_CASE(REPL_F32)
10941094
TARGET_NODE_CASE(REPL_I32)
1095+
TARGET_NODE_CASE(LEGALAVL)
1096+
10951097
// Register the VVP_* SDNodes.
10961098
#define REGISTER_VVP_OP(VVP_NAME) TARGET_NODE_CASE(VVP_NAME)
10971099
#include "VVPNodes.def"
@@ -1972,6 +1974,60 @@ SDValue VETargetLowering::lowerINTRINSIC_W_CHAIN(SDValue Op,
19721974
}
19731975
}
19741976

1977+
SDValue VETargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
1978+
LLVM_DEBUG(dbgs() << "::LowerOperation"; Op->print(dbgs()););
1979+
unsigned Opcode = Op.getOpcode();
1980+
switch (Opcode) {
1981+
default:
1982+
if (Subtarget->enableVPU())
1983+
return LowerOperation_VVP(Op, DAG);
1984+
else if (Subtarget->simd())
1985+
return LowerOperation_SIMD(Op, DAG);
1986+
llvm_unreachable("Unexpected Opcode in LowerOperation");
1987+
1988+
case ISD::ATOMIC_FENCE:
1989+
return lowerATOMIC_FENCE(Op, DAG);
1990+
case ISD::ATOMIC_SWAP:
1991+
return lowerATOMIC_SWAP(Op, DAG);
1992+
case ISD::BlockAddress:
1993+
return lowerBlockAddress(Op, DAG);
1994+
case ISD::ConstantPool:
1995+
return lowerConstantPool(Op, DAG);
1996+
case ISD::DYNAMIC_STACKALLOC:
1997+
return lowerDYNAMIC_STACKALLOC(Op, DAG);
1998+
case ISD::EH_SJLJ_LONGJMP:
1999+
return lowerEH_SJLJ_LONGJMP(Op, DAG);
2000+
case ISD::EH_SJLJ_SETJMP:
2001+
return lowerEH_SJLJ_SETJMP(Op, DAG);
2002+
case ISD::EH_SJLJ_SETUP_DISPATCH:
2003+
return lowerEH_SJLJ_SETUP_DISPATCH(Op, DAG);
2004+
case ISD::FRAMEADDR:
2005+
return lowerFRAMEADDR(Op, DAG, *this, Subtarget);
2006+
case ISD::GlobalAddress:
2007+
return lowerGlobalAddress(Op, DAG);
2008+
case ISD::GlobalTLSAddress:
2009+
return lowerGlobalTLSAddress(Op, DAG);
2010+
case ISD::INTRINSIC_VOID:
2011+
return lowerINTRINSIC_VOID(Op, DAG);
2012+
case ISD::INTRINSIC_W_CHAIN:
2013+
return lowerINTRINSIC_W_CHAIN(Op, DAG);
2014+
case ISD::INTRINSIC_WO_CHAIN:
2015+
return lowerINTRINSIC_WO_CHAIN(Op, DAG);
2016+
case ISD::JumpTable:
2017+
return lowerJumpTable(Op, DAG);
2018+
case ISD::LOAD:
2019+
return lowerLOAD(Op, DAG);
2020+
case ISD::RETURNADDR:
2021+
return lowerRETURNADDR(Op, DAG, *this, Subtarget);
2022+
case ISD::STORE:
2023+
return lowerSTORE(Op, DAG);
2024+
case ISD::VASTART:
2025+
return lowerVASTART(Op, DAG);
2026+
case ISD::VAARG:
2027+
return lowerVAARG(Op, DAG);
2028+
}
2029+
}
2030+
19752031
SDValue VETargetLowering::lowerINTRINSIC_VOID(SDValue Op,
19762032
SelectionDAG &DAG) const {
19772033
SDLoc dl(Op);
@@ -3427,8 +3483,6 @@ SDValue VETargetLowering::PerformDAGCombine(SDNode *N,
34273483
SDLoc dl(N);
34283484
unsigned Opcode = N->getOpcode();
34293485
switch (Opcode) {
3430-
case ISD::EntryToken:
3431-
return combineEntryToken_VVP(N, DCI);
34323486
default:
34333487
if (!Subtarget->enableVPU())
34343488
return SDValue();
@@ -3683,61 +3737,6 @@ bool VETargetLowering::hasAndNot(SDValue Y) const {
36833737
return true;
36843738
}
36853739

3686-
SDValue VETargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
3687-
LLVM_DEBUG(dbgs() << "LowerOp: "; Op.dump(&DAG); dbgs() << "\n";);
3688-
3689-
switch (Op.getOpcode()) {
3690-
default:
3691-
if (Subtarget->enableVPU())
3692-
return LowerOperation_VVP(Op, DAG);
3693-
else if (Subtarget->simd())
3694-
return LowerOperation_SIMD(Op, DAG);
3695-
llvm_unreachable("Unexpected Opcode in LowerOperation");
3696-
3697-
// Mostly all-scalar lowerings below.
3698-
case ISD::ATOMIC_FENCE:
3699-
return lowerATOMIC_FENCE(Op, DAG);
3700-
case ISD::ATOMIC_SWAP:
3701-
return lowerATOMIC_SWAP(Op, DAG);
3702-
case ISD::BlockAddress:
3703-
return lowerBlockAddress(Op, DAG);
3704-
case ISD::ConstantPool:
3705-
return lowerConstantPool(Op, DAG);
3706-
case ISD::DYNAMIC_STACKALLOC:
3707-
return lowerDYNAMIC_STACKALLOC(Op, DAG);
3708-
case ISD::EH_SJLJ_LONGJMP:
3709-
return lowerEH_SJLJ_LONGJMP(Op, DAG);
3710-
case ISD::EH_SJLJ_SETJMP:
3711-
return lowerEH_SJLJ_SETJMP(Op, DAG);
3712-
case ISD::EH_SJLJ_SETUP_DISPATCH:
3713-
return lowerEH_SJLJ_SETUP_DISPATCH(Op, DAG);
3714-
case ISD::FRAMEADDR:
3715-
return lowerFRAMEADDR(Op, DAG, *this, Subtarget);
3716-
case ISD::GlobalAddress:
3717-
return lowerGlobalAddress(Op, DAG);
3718-
case ISD::GlobalTLSAddress:
3719-
return lowerGlobalTLSAddress(Op, DAG);
3720-
case ISD::INTRINSIC_VOID:
3721-
return lowerINTRINSIC_VOID(Op, DAG);
3722-
case ISD::INTRINSIC_W_CHAIN:
3723-
return lowerINTRINSIC_W_CHAIN(Op, DAG);
3724-
case ISD::INTRINSIC_WO_CHAIN:
3725-
return lowerINTRINSIC_WO_CHAIN(Op, DAG);
3726-
case ISD::JumpTable:
3727-
return lowerJumpTable(Op, DAG);
3728-
case ISD::LOAD:
3729-
return lowerLOAD(Op, DAG);
3730-
case ISD::RETURNADDR:
3731-
return lowerRETURNADDR(Op, DAG, *this, Subtarget);
3732-
case ISD::STORE:
3733-
return lowerSTORE(Op, DAG);
3734-
case ISD::VASTART:
3735-
return lowerVASTART(Op, DAG);
3736-
case ISD::VAARG:
3737-
return lowerVAARG(Op, DAG);
3738-
}
3739-
}
3740-
37413740
static bool isPackableElemVT(EVT VT) {
37423741
if (VT.isVector())
37433742
return false;

llvm/lib/Target/VE/VEISelLowering.h

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,11 @@ enum NodeType : unsigned {
9999
/// MCSymbol and TargetBlockAddress.
100100
Wrapper,
101101

102+
// Annotation as a wrapper. LEGALAVL(VL) means that VL refers to 64bit of
103+
// data, whereas the raw EVL coming in from VP nodes always refers to number
104+
// of elements, regardless of their size.
105+
LEGALAVL,
106+
102107
// VVP_* nodes.
103108
#define REGISTER_VVP_OP(VVP_NAME) VVP_NAME,
104109
#include "VVPNodes.def"
@@ -127,15 +132,6 @@ struct VVPWideningInfo {
127132
};
128133

129134
class VETargetLowering final : public TargetLowering, public VELoweringInfo {
130-
// FIXME: Find a more robust solution for this.
131-
mutable std::set<const SDNode *> LegalizedVectorNodes;
132-
bool isPackLegalizedInternalNode(const SDNode *N) const {
133-
return LegalizedVectorNodes.count(N);
134-
}
135-
void addPackLegalizedNode(const SDNode *N) const {
136-
LegalizedVectorNodes.insert(N);
137-
}
138-
139135
const VESubtarget *Subtarget;
140136

141137
void initRegisterClasses();
@@ -227,9 +223,9 @@ class VETargetLowering final : public TargetLowering, public VELoweringInfo {
227223
/// } Custom CC Mapping
228224

229225
/// Custom Lower {
230-
231226
Optional<LegalizeKind> getCustomTypeConversion(LLVMContext &Context,
232227
EVT VT) const override;
228+
233229
const MCExpr *LowerCustomJumpTableEntry(const MachineJumpTableInfo *MJTI,
234230
const MachineBasicBlock *MBB,
235231
unsigned uid,
@@ -305,7 +301,6 @@ class VETargetLowering final : public TargetLowering, public VELoweringInfo {
305301
/// VVP Lowering {
306302
// internal node tracker reset checkpoint.
307303

308-
SDValue combineEntryToken_VVP(SDNode *N, DAGCombinerInfo &DCI) const;
309304
// Expand SETCC operands directly used in vector arithmetic ops.
310305
SDValue lowerSETCCInVectorArithmetic(SDValue Op, SelectionDAG &DAG) const;
311306
SDValue expandSELECT(SDValue MaskV, SDValue OnTrueV, SDValue OnFalseV,

llvm/lib/Target/VE/VVPCombine.cpp

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,6 @@ SDValue VETargetLowering::combineVVP(SDNode *N, DAGCombinerInfo &DCI) const {
171171
dbgs() << "\n";);
172172

173173
VECustomDAG CDAG(*this, DCI.DAG, N);
174-
bool RootIsPackLegalized = isPackLegalizedInternalNode(N);
175174
SDNodeFlags Flags = N->getFlags();
176175
switch (N->getOpcode()) {
177176

@@ -182,8 +181,6 @@ SDValue VETargetLowering::combineVVP(SDNode *N, DAGCombinerInfo &DCI) const {
182181
MVT ResVT = N->getSimpleValueType(0);
183182
auto N =
184183
CDAG.getNode(VEISD::VVP_FFMA, ResVT, {VY, VZ, VW, Mask, AVL}, Flags);
185-
if (RootIsPackLegalized)
186-
addPackLegalizedNode(N.getNode());
187184
return N;
188185
}
189186
} break;
@@ -194,8 +191,6 @@ SDValue VETargetLowering::combineVVP(SDNode *N, DAGCombinerInfo &DCI) const {
194191
MVT ResVT = N->getSimpleValueType(0);
195192
unsigned Opcode = Negated ? VEISD::VVP_FFMSN : VEISD::VVP_FFMS;
196193
auto N = CDAG.getNode(Opcode, ResVT, {VY, VZ, VW, Mask, AVL}, Flags);
197-
if (RootIsPackLegalized)
198-
addPackLegalizedNode(N.getNode());
199194
return N;
200195
}
201196
} break;
@@ -215,19 +210,13 @@ SDValue VETargetLowering::combineVVP(SDNode *N, DAGCombinerInfo &DCI) const {
215210
// 1 / vy
216211
if (match_FPOne(VX)) {
217212
auto N = CDAG.getNode(VEISD::VVP_FRCP, ResVT, {VY, Mask, AVL}, Flags);
218-
if (RootIsPackLegalized)
219-
addPackLegalizedNode(N.getNode());
220213
return N;
221214
}
222215
// vx * VRCP(vy)
223216
auto RecipV =
224217
CDAG.getNode(VEISD::VVP_FRCP, ResVT, {VY, Mask, AVL}, Flags);
225218
auto MulV =
226219
CDAG.getNode(VEISD::VVP_FMUL, ResVT, {VX, RecipV, Mask, AVL}, Flags);
227-
if (RootIsPackLegalized) {
228-
addPackLegalizedNode(RecipV.getNode());
229-
addPackLegalizedNode(MulV.getNode());
230-
}
231220
return MulV;
232221
} break;
233222
default:

0 commit comments

Comments
 (0)