Skip to content

Commit 7d926b7

Browse files
author
Simon Moll
committed
[VE] LEGALAVL and staged VVP legalization
The new LEGALAVL node annotates that the AVL refers to packs of 64bit. We use a two-stage lowering approach with LEGALAVL: First, standard SDNodes are translated into illegal VVP layer nodes. Regardless of source (VP or standard), all VVP nodes have a mask and AVL parameter. The AVL parameter refers to the element position (just as in VP intrinsics). Second, we legalize the AVL usage in VVP layer nodes. If the element size is < 64bit, the EVL parameter has to be adjusted to refer to packs of 64bits. We wrap the legalized AVL in a LEGALAVL node to track this. Reviewed By: kaz7 Differential Revision: https://reviews.llvm.org/D118321
1 parent 44ee986 commit 7d926b7

File tree

9 files changed

+240
-5
lines changed

9 files changed

+240
-5
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,6 +1071,11 @@ class TargetLoweringBase {
10711071
return false;
10721072
}
10731073

1074+
/// How to legalize this custom operation?
1075+
virtual LegalizeAction getCustomOperationAction(SDNode &Op) const {
1076+
return Legal;
1077+
}
1078+
10741079
/// Return how this operation should be treated: either it is legal, needs to
10751080
/// be promoted to a larger size, needs to be expanded to some other code
10761081
/// sequence, or the target has a custom expander for it.

llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1212,7 +1212,7 @@ void SelectionDAGLegalize::LegalizeOp(SDNode *Node) {
12121212
break;
12131213
default:
12141214
if (Node->getOpcode() >= ISD::BUILTIN_OP_END) {
1215-
Action = TargetLowering::Legal;
1215+
Action = TLI.getCustomOperationAction(*Node);
12161216
} else {
12171217
Action = TLI.getOperationAction(Node->getOpcode(), Node->getValueType(0));
12181218
}

llvm/lib/Target/VE/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ add_llvm_target(VECodeGen
2626
VERegisterInfo.cpp
2727
VESubtarget.cpp
2828
VETargetMachine.cpp
29+
VVPISelLowering.cpp
2930

3031
LINK_COMPONENTS
3132
Analysis

llvm/lib/Target/VE/VECustomDAG.cpp

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,32 @@ Optional<unsigned> getVVPOpcode(unsigned Opcode) {
4242
return None;
4343
}
4444

45+
bool maySafelyIgnoreMask(SDValue Op) {
46+
auto VVPOpc = getVVPOpcode(Op->getOpcode());
47+
auto Opc = VVPOpc.getValueOr(Op->getOpcode());
48+
49+
switch (Opc) {
50+
case VEISD::VVP_SDIV:
51+
case VEISD::VVP_UDIV:
52+
case VEISD::VVP_FDIV:
53+
case VEISD::VVP_SELECT:
54+
return false;
55+
56+
default:
57+
return true;
58+
}
59+
}
60+
61+
bool isVVPOrVEC(unsigned Opcode) {
62+
switch (Opcode) {
63+
case VEISD::VEC_BROADCAST:
64+
#define ADD_VVP_OP(VVPNAME, ...) case VEISD::VVPNAME:
65+
#include "VVPNodes.def"
66+
return true;
67+
}
68+
return false;
69+
}
70+
4571
bool isVVPBinaryOp(unsigned VVPOpcode) {
4672
switch (VVPOpcode) {
4773
#define ADD_BINARY_VVP_OP(VVPNAME, ...) \
@@ -52,6 +78,44 @@ bool isVVPBinaryOp(unsigned VVPOpcode) {
5278
return false;
5379
}
5480

81+
// Return the AVL operand position for this VVP or VEC Op.
82+
Optional<int> getAVLPos(unsigned Opc) {
83+
// This is only available for VP SDNodes
84+
auto PosOpt = ISD::getVPExplicitVectorLengthIdx(Opc);
85+
if (PosOpt)
86+
return *PosOpt;
87+
88+
// VVP Opcodes.
89+
if (isVVPBinaryOp(Opc))
90+
return 3;
91+
92+
// VM Opcodes.
93+
switch (Opc) {
94+
case VEISD::VEC_BROADCAST:
95+
return 1;
96+
case VEISD::VVP_SELECT:
97+
return 3;
98+
}
99+
100+
return None;
101+
}
102+
103+
bool isLegalAVL(SDValue AVL) { return AVL->getOpcode() == VEISD::LEGALAVL; }
104+
105+
SDValue getNodeAVL(SDValue Op) {
106+
auto PosOpt = getAVLPos(Op->getOpcode());
107+
return PosOpt ? Op->getOperand(*PosOpt) : SDValue();
108+
}
109+
110+
std::pair<SDValue, bool> getAnnotatedNodeAVL(SDValue Op) {
111+
SDValue AVL = getNodeAVL(Op);
112+
if (!AVL)
113+
return {SDValue(), true};
114+
if (isLegalAVL(AVL))
115+
return {AVL->getOperand(0), true};
116+
return {AVL, false};
117+
}
118+
55119
SDValue VECustomDAG::getConstant(uint64_t Val, EVT VT, bool IsTarget,
56120
bool IsOpaque) const {
57121
return DAG.getConstant(Val, DL, VT, IsTarget, IsOpaque);
@@ -78,4 +142,10 @@ SDValue VECustomDAG::getBroadcast(EVT ResultVT, SDValue Scalar,
78142
return getNode(VEISD::VEC_BROADCAST, ResultVT, {Scalar, AVL});
79143
}
80144

145+
SDValue VECustomDAG::annotateLegalAVL(SDValue AVL) const {
146+
if (isLegalAVL(AVL))
147+
return AVL;
148+
return getNode(VEISD::LEGALAVL, AVL.getValueType(), AVL);
149+
}
150+
81151
} // namespace llvm

llvm/lib/Target/VE/VECustomDAG.h

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,52 @@ bool isVVPBinaryOp(unsigned Opcode);
2727

2828
bool isPackedVectorType(EVT SomeVT);
2929

30+
bool isVVPOrVEC(unsigned);
31+
32+
bool maySafelyIgnoreMask(SDValue Op);
33+
34+
/// The VE backend uses a two-staged process to lower and legalize vector
35+
/// instructions:
36+
//
37+
/// 1. VP and standard vector SDNodes are lowered to SDNodes of the VVP_* layer.
38+
//
39+
// All VVP nodes have a mask and an Active Vector Length (AVL) parameter.
40+
// The AVL parameters refers to the element position in the vector the VVP
41+
// node operates on.
42+
//
43+
//
44+
// 2. The VVP SDNodes are legalized. The AVL in a legal VVP node refers to
45+
// chunks of 64bit. We track this by wrapping the AVL in a LEGALAVL node.
46+
//
47+
// The AVL mechanism in the VE architecture always refers to chunks of
48+
// 64bit, regardless of the actual element type vector instructions are
49+
// operating on. For vector types v256.32 or v256.64 nothing needs to be
50+
// legalized since each element occupies a 64bit chunk - there is no
51+
// difference between counting 64bit chunks or element positions. However,
52+
// all vector types with > 256 elements store more than one logical element
53+
// per 64bit chunk and need to be transformed.
54+
// However legalization is performed, the resulting legal VVP SDNodes will
55+
// have a LEGALAVL node as their AVL operand. The LEGALAVL nodes wraps
56+
// around an AVL that refers to 64 bit chunks just as the architecture
57+
// demands - that is, the wrapped AVL is the correct setting for the VL
58+
// register for this VVP operation to get the desired behavior.
59+
//
60+
/// AVL Functions {
61+
// The AVL operand position of this node.
62+
Optional<int> getAVLPos(unsigned);
63+
64+
// Whether this is a LEGALAVL node.
65+
bool isLegalAVL(SDValue AVL);
66+
67+
// The AVL operand of this node.
68+
SDValue getNodeAVL(SDValue);
69+
70+
// Return the AVL operand of this node. If it is a LEGALAVL node, unwrap it.
71+
// Return with the boolean whether unwrapping happened.
72+
std::pair<SDValue, bool> getAnnotatedNodeAVL(SDValue);
73+
74+
/// } AVL Functions
75+
3076
class VECustomDAG {
3177
SelectionDAG &DAG;
3278
SDLoc DL;
@@ -72,6 +118,9 @@ class VECustomDAG {
72118
bool IsOpaque = false) const;
73119

74120
SDValue getBroadcast(EVT ResultVT, SDValue Scalar, SDValue AVL) const;
121+
122+
// Wrap AVL in a LEGALAVL node (unless it is one already).
123+
SDValue annotateLegalAVL(SDValue AVL) const;
75124
};
76125

77126
} // namespace llvm

llvm/lib/Target/VE/VEISelDAGToDAG.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,12 @@ void VEDAGToDAGISel::Select(SDNode *N) {
335335
}
336336

337337
switch (N->getOpcode()) {
338+
339+
// Late eliminate the LEGALAVL wrapper
340+
case VEISD::LEGALAVL:
341+
ReplaceNode(N, N->getOperand(0).getNode());
342+
return;
343+
338344
case VEISD::GLOBAL_BASE_REG:
339345
ReplaceNode(N, getGlobalBaseReg());
340346
return;

llvm/lib/Target/VE/VEISelLowering.cpp

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -902,6 +902,8 @@ const char *VETargetLowering::getTargetNodeName(unsigned Opcode) const {
902902
TARGET_NODE_CASE(REPL_I32)
903903
TARGET_NODE_CASE(REPL_F32)
904904

905+
TARGET_NODE_CASE(LEGALAVL)
906+
905907
// Register the VVP_* SDNodes.
906908
#define ADD_VVP_OP(VVP_NAME, ...) TARGET_NODE_CASE(VVP_NAME)
907909
#include "VVPNodes.def"
@@ -1658,18 +1660,24 @@ SDValue VETargetLowering::lowerBUILD_VECTOR(SDValue Op,
16581660
// Else emit a broadcast.
16591661
if (SDValue ScalarV = getSplatValue(Op.getNode())) {
16601662
unsigned NumEls = ResultVT.getVectorNumElements();
1661-
// TODO: Legalize packed-mode AVL.
1662-
// For now, cap the AVL at 256.
1663-
auto CappedLength = std::min<unsigned>(256, NumEls);
1664-
auto AVL = CDAG.getConstant(CappedLength, MVT::i32);
1663+
auto AVL = CDAG.getConstant(NumEls, MVT::i32);
16651664
return CDAG.getBroadcast(ResultVT, Op.getOperand(0), AVL);
16661665
}
16671666

16681667
// Expand
16691668
return SDValue();
16701669
}
16711670

1671+
TargetLowering::LegalizeAction
1672+
VETargetLowering::getCustomOperationAction(SDNode &Op) const {
1673+
// Custom lower to legalize AVL for packed mode.
1674+
if (isVVPOrVEC(Op.getOpcode()))
1675+
return Custom;
1676+
return Legal;
1677+
}
1678+
16721679
SDValue VETargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
1680+
LLVM_DEBUG(dbgs() << "::LowerOperation"; Op->print(dbgs()););
16731681
unsigned Opcode = Op.getOpcode();
16741682
if (ISD::isVPOpcode(Opcode))
16751683
return lowerToVVP(Op, DAG);
@@ -1721,6 +1729,16 @@ SDValue VETargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
17211729
case ISD::EXTRACT_VECTOR_ELT:
17221730
return lowerEXTRACT_VECTOR_ELT(Op, DAG);
17231731

1732+
// Legalize the AVL of this internal node.
1733+
case VEISD::VEC_BROADCAST:
1734+
#define ADD_VVP_OP(VVP_NAME, ...) case VEISD::VVP_NAME:
1735+
#include "VVPNodes.def"
1736+
// AVL already legalized.
1737+
if (getAnnotatedNodeAVL(Op).second)
1738+
return Op;
1739+
return legalizeInternalVectorOp(Op, DAG);
1740+
1741+
// Translate into a VEC_*/VVP_* layer operation.
17241742
#define ADD_VVP_OP(VVP_NAME, ISD_NAME) case ISD::ISD_NAME:
17251743
#include "VVPNodes.def"
17261744
return lowerToVVP(Op, DAG);

llvm/lib/Target/VE/VEISelLowering.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,19 @@ enum NodeType : unsigned {
4343
REPL_I32,
4444
REPL_F32, // Replicate subregister to other half.
4545

46+
// Annotation as a wrapper. LEGALAVL(VL) means that VL refers to 64bit of
47+
// data, whereas the raw EVL coming in from VP nodes always refers to number
48+
// of elements, regardless of their size.
49+
LEGALAVL,
50+
4651
// VVP_* nodes.
4752
#define ADD_VVP_OP(VVP_NAME, ...) VVP_NAME,
4853
#include "VVPNodes.def"
4954
};
5055
}
5156

57+
class VECustomDAG;
58+
5259
class VETargetLowering : public TargetLowering {
5360
const VESubtarget *Subtarget;
5461

@@ -105,6 +112,9 @@ class VETargetLowering : public TargetLowering {
105112
}
106113

107114
/// Custom Lower {
115+
TargetLoweringBase::LegalizeAction
116+
getCustomOperationAction(SDNode &) const override;
117+
108118
SDValue LowerOperation(SDValue Op, SelectionDAG &DAG) const override;
109119
unsigned getJumpTableEncoding() const override;
110120
const MCExpr *LowerCustomJumpTableEntry(const MachineJumpTableInfo *MJTI,
@@ -170,6 +180,8 @@ class VETargetLowering : public TargetLowering {
170180

171181
/// VVP Lowering {
172182
SDValue lowerToVVP(SDValue Op, SelectionDAG &DAG) const;
183+
SDValue legalizeInternalVectorOp(SDValue Op, SelectionDAG &DAG) const;
184+
SDValue legalizePackedAVL(SDValue Op, VECustomDAG &CDAG) const;
173185
/// } VVPLowering
174186

175187
/// Custom DAGCombine {
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
//===-- VVPISelLowering.cpp - VE DAG Lowering Implementation --------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements the lowering and legalization of vector instructions to
10+
// VVP_*layer SDNodes.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "VECustomDAG.h"
15+
#include "VEISelLowering.h"
16+
17+
using namespace llvm;
18+
19+
#define DEBUG_TYPE "ve-lower"
20+
21+
SDValue VETargetLowering::legalizeInternalVectorOp(SDValue Op,
22+
SelectionDAG &DAG) const {
23+
VECustomDAG CDAG(DAG, Op);
24+
// TODO: Implement odd/even splitting.
25+
return legalizePackedAVL(Op, CDAG);
26+
}
27+
28+
SDValue VETargetLowering::legalizePackedAVL(SDValue Op,
29+
VECustomDAG &CDAG) const {
30+
LLVM_DEBUG(dbgs() << "::legalizePackedAVL\n";);
31+
// Only required for VEC and VVP ops.
32+
if (!isVVPOrVEC(Op->getOpcode()))
33+
return Op;
34+
35+
// Operation already has a legal AVL.
36+
auto AVL = getNodeAVL(Op);
37+
if (isLegalAVL(AVL))
38+
return Op;
39+
40+
// Half and round up EVL for 32bit element types.
41+
SDValue LegalAVL = AVL;
42+
if (isPackedVectorType(Op.getValueType())) {
43+
assert(maySafelyIgnoreMask(Op) &&
44+
"TODO Shift predication from EVL into Mask");
45+
46+
if (auto *ConstAVL = dyn_cast<ConstantSDNode>(AVL)) {
47+
LegalAVL = CDAG.getConstant((ConstAVL->getZExtValue() + 1) / 2, MVT::i32);
48+
} else {
49+
auto ConstOne = CDAG.getConstant(1, MVT::i32);
50+
auto PlusOne = CDAG.getNode(ISD::ADD, MVT::i32, {AVL, ConstOne});
51+
LegalAVL = CDAG.getNode(ISD::SRL, MVT::i32, {PlusOne, ConstOne});
52+
}
53+
}
54+
55+
SDValue AnnotatedLegalAVL = CDAG.annotateLegalAVL(LegalAVL);
56+
57+
// Copy the operand list.
58+
int NumOp = Op->getNumOperands();
59+
auto AVLPos = getAVLPos(Op->getOpcode());
60+
std::vector<SDValue> FixedOperands;
61+
for (int i = 0; i < NumOp; ++i) {
62+
if (AVLPos && (i == *AVLPos)) {
63+
FixedOperands.push_back(AnnotatedLegalAVL);
64+
continue;
65+
}
66+
FixedOperands.push_back(Op->getOperand(i));
67+
}
68+
69+
// Clone the operation with fixed operands.
70+
auto Flags = Op->getFlags();
71+
SDValue NewN =
72+
CDAG.getNode(Op->getOpcode(), Op->getVTList(), FixedOperands, Flags);
73+
return NewN;
74+
}

0 commit comments

Comments
 (0)