Skip to content

Commit af773a1

Browse files
committed
[RISCV][VP] Lower VP_MERGE to RVV instructions
This patch adds lowering of the llvm.vp.merge.* intrinsic (ISD::VP_MERGE) to RVV vmerge/vfmerge instructions. It introduces a special pseudo form of vmerge which allows a tied merge operand, allowing us to specify the tail elements as being equal to the "on false" operand, using a tied-def constraint and a "tail undisturbed" policy. While this strategy allows us to often lower the intrinsic to just one instruction, it may be less efficient in fixed-vector types as the number of tail elements may extend far beyond the length of the fixed vector. Another strategy could be to use a vmerge/vfmerge instruction with an AVL equal to the length of the vector type, and manipulate the condition operand such that mask elements greater than the operation's EVL are false. I've also observed inefficient codegen in which our 'VF' patterns don't match raw floating-point SPLAT_VECTORs, which occur in scalable-vector code. Reviewed By: craig.topper Differential Revision: https://reviews.llvm.org/D117561
1 parent e7926e8 commit af773a1

File tree

6 files changed

+2372
-12
lines changed

6 files changed

+2372
-12
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -521,12 +521,13 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
521521
ISD::VP_SHL, ISD::VP_REDUCE_ADD, ISD::VP_REDUCE_AND,
522522
ISD::VP_REDUCE_OR, ISD::VP_REDUCE_XOR, ISD::VP_REDUCE_SMAX,
523523
ISD::VP_REDUCE_SMIN, ISD::VP_REDUCE_UMAX, ISD::VP_REDUCE_UMIN,
524-
ISD::VP_SELECT};
524+
ISD::VP_MERGE, ISD::VP_SELECT};
525525

526526
static const unsigned FloatingPointVPOps[] = {
527527
ISD::VP_FADD, ISD::VP_FSUB, ISD::VP_FMUL,
528528
ISD::VP_FDIV, ISD::VP_REDUCE_FADD, ISD::VP_REDUCE_SEQ_FADD,
529-
ISD::VP_REDUCE_FMIN, ISD::VP_REDUCE_FMAX, ISD::VP_SELECT};
529+
ISD::VP_REDUCE_FMIN, ISD::VP_REDUCE_FMAX, ISD::VP_MERGE,
530+
ISD::VP_SELECT};
530531

531532
if (!Subtarget.is64Bit()) {
532533
// We must custom-lower certain vXi64 operations on RV32 due to the vector
@@ -3441,6 +3442,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
34413442
return lowerSET_ROUNDING(Op, DAG);
34423443
case ISD::VP_SELECT:
34433444
return lowerVPOp(Op, DAG, RISCVISD::VSELECT_VL);
3445+
case ISD::VP_MERGE:
3446+
return lowerVPOp(Op, DAG, RISCVISD::VP_MERGE_VL);
34443447
case ISD::VP_ADD:
34453448
return lowerVPOp(Op, DAG, RISCVISD::ADD_VL);
34463449
case ISD::VP_SUB:
@@ -10087,6 +10090,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
1008710090
NODE_NAME_CASE(VWADDU_VL)
1008810091
NODE_NAME_CASE(SETCC_VL)
1008910092
NODE_NAME_CASE(VSELECT_VL)
10093+
NODE_NAME_CASE(VP_MERGE_VL)
1009010094
NODE_NAME_CASE(VMAND_VL)
1009110095
NODE_NAME_CASE(VMOR_VL)
1009210096
NODE_NAME_CASE(VMXOR_VL)

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,10 @@ enum NodeType : unsigned {
253253

254254
// Vector select with an additional VL operand. This operation is unmasked.
255255
VSELECT_VL,
256+
// Vector select with operand #2 (the value when the condition is false) tied
257+
// to the destination and an additional VL operand. This operation is
258+
// unmasked.
259+
VP_MERGE_VL,
256260

257261
// Mask binary operators.
258262
VMAND_VL,

llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -579,10 +579,11 @@ class PseudoToVInst<string PseudoInst> {
579579
!subst("_B64", "",
580580
!subst("_MASK", "",
581581
!subst("_TIED", "",
582+
!subst("_TU", "",
582583
!subst("F16", "F",
583584
!subst("F32", "F",
584585
!subst("F64", "F",
585-
!subst("Pseudo", "", PseudoInst))))))))))))))))))));
586+
!subst("Pseudo", "", PseudoInst)))))))))))))))))))));
586587
}
587588

588589
// The destination vector register group for a masked vector instruction cannot
@@ -928,6 +929,9 @@ class VPseudoBinaryNoMask<VReg RetClass,
928929
let BaseInstr = !cast<Instruction>(PseudoToVInst<NAME>.VInst);
929930
}
930931

932+
// Special version of VPseudoBinaryNoMask where we pretend the first source is
933+
// tied to the destination.
934+
// This allows maskedoff and rs2 to be the same register.
931935
class VPseudoTiedBinaryNoMask<VReg RetClass,
932936
DAGOperand Op2Class,
933937
string Constraint> :
@@ -1079,6 +1083,30 @@ class VPseudoBinaryCarryIn<VReg RetClass,
10791083
let VLMul = MInfo.value;
10801084
}
10811085

1086+
class VPseudoTiedBinaryCarryIn<VReg RetClass,
1087+
VReg Op1Class,
1088+
DAGOperand Op2Class,
1089+
LMULInfo MInfo,
1090+
bit CarryIn,
1091+
string Constraint> :
1092+
Pseudo<(outs RetClass:$rd),
1093+
!if(CarryIn,
1094+
(ins RetClass:$merge, Op1Class:$rs2, Op2Class:$rs1, VMV0:$carry, AVL:$vl,
1095+
ixlenimm:$sew),
1096+
(ins RetClass:$merge, Op1Class:$rs2, Op2Class:$rs1, AVL:$vl, ixlenimm:$sew)), []>,
1097+
RISCVVPseudo {
1098+
let mayLoad = 0;
1099+
let mayStore = 0;
1100+
let hasSideEffects = 0;
1101+
let Constraints = Join<[Constraint, "$rd = $merge"], ",">.ret;
1102+
let HasVLOp = 1;
1103+
let HasSEWOp = 1;
1104+
let HasMergeOp = 1;
1105+
let HasVecPolicyOp = 0;
1106+
let BaseInstr = !cast<Instruction>(PseudoToVInst<NAME>.VInst);
1107+
let VLMul = MInfo.value;
1108+
}
1109+
10821110
class VPseudoTernaryNoMask<VReg RetClass,
10831111
RegisterClass Op1Class,
10841112
DAGOperand Op2Class,
@@ -1741,6 +1769,16 @@ multiclass VPseudoBinaryV_VM<bit CarryOut = 0, bit CarryIn = 1,
17411769
m.vrclass, m.vrclass, m, CarryIn, Constraint>;
17421770
}
17431771

1772+
multiclass VPseudoTiedBinaryV_VM<bit CarryOut = 0, bit CarryIn = 1,
1773+
string Constraint = ""> {
1774+
foreach m = MxList in
1775+
def "_VV" # !if(CarryIn, "M", "") # "_" # m.MX # "_TU" :
1776+
VPseudoTiedBinaryCarryIn<!if(CarryOut, VR,
1777+
!if(!and(CarryIn, !not(CarryOut)),
1778+
GetVRegNoV0<m.vrclass>.R, m.vrclass)),
1779+
m.vrclass, m.vrclass, m, CarryIn, Constraint>;
1780+
}
1781+
17441782
multiclass VPseudoBinaryV_XM<bit CarryOut = 0, bit CarryIn = 1,
17451783
string Constraint = ""> {
17461784
foreach m = MxList in
@@ -1751,13 +1789,29 @@ multiclass VPseudoBinaryV_XM<bit CarryOut = 0, bit CarryIn = 1,
17511789
m.vrclass, GPR, m, CarryIn, Constraint>;
17521790
}
17531791

1792+
multiclass VPseudoTiedBinaryV_XM<bit CarryOut = 0, bit CarryIn = 1,
1793+
string Constraint = ""> {
1794+
foreach m = MxList in
1795+
def "_VX" # !if(CarryIn, "M", "") # "_" # m.MX # "_TU":
1796+
VPseudoTiedBinaryCarryIn<!if(CarryOut, VR,
1797+
!if(!and(CarryIn, !not(CarryOut)),
1798+
GetVRegNoV0<m.vrclass>.R, m.vrclass)),
1799+
m.vrclass, GPR, m, CarryIn, Constraint>;
1800+
}
1801+
17541802
multiclass VPseudoVMRG_FM {
17551803
foreach f = FPList in
1756-
foreach m = f.MxList in
1804+
foreach m = f.MxList in {
17571805
def "_V" # f.FX # "M_" # m.MX :
17581806
VPseudoBinaryCarryIn<GetVRegNoV0<m.vrclass>.R,
17591807
m.vrclass, f.fprclass, m, /*CarryIn=*/1, "">,
17601808
Sched<[WriteVFMergeV, ReadVFMergeV, ReadVFMergeF, ReadVMask]>;
1809+
// Tied version to allow codegen control over the tail elements
1810+
def "_V" # f.FX # "M_" # m.MX # "_TU":
1811+
VPseudoTiedBinaryCarryIn<GetVRegNoV0<m.vrclass>.R,
1812+
m.vrclass, f.fprclass, m, /*CarryIn=*/1, "">,
1813+
Sched<[WriteVFMergeV, ReadVFMergeV, ReadVFMergeF, ReadVMask]>;
1814+
}
17611815
}
17621816

17631817
multiclass VPseudoBinaryV_IM<bit CarryOut = 0, bit CarryIn = 1,
@@ -1770,6 +1824,16 @@ multiclass VPseudoBinaryV_IM<bit CarryOut = 0, bit CarryIn = 1,
17701824
m.vrclass, simm5, m, CarryIn, Constraint>;
17711825
}
17721826

1827+
multiclass VPseudoTiedBinaryV_IM<bit CarryOut = 0, bit CarryIn = 1,
1828+
string Constraint = ""> {
1829+
foreach m = MxList in
1830+
def "_VI" # !if(CarryIn, "M", "") # "_" # m.MX # "_TU":
1831+
VPseudoTiedBinaryCarryIn<!if(CarryOut, VR,
1832+
!if(!and(CarryIn, !not(CarryOut)),
1833+
GetVRegNoV0<m.vrclass>.R, m.vrclass)),
1834+
m.vrclass, simm5, m, CarryIn, Constraint>;
1835+
}
1836+
17731837
multiclass VPseudoUnaryVMV_V_X_I {
17741838
foreach m = MxList in {
17751839
let VLMul = m.value in {
@@ -2104,6 +2168,13 @@ multiclass VPseudoVMRG_VM_XM_IM {
21042168
Sched<[WriteVIMergeX, ReadVIMergeV, ReadVIMergeX, ReadVMask]>;
21052169
defm "" : VPseudoBinaryV_IM,
21062170
Sched<[WriteVIMergeI, ReadVIMergeV, ReadVMask]>;
2171+
// Tied versions to allow codegen control over the tail elements
2172+
defm "" : VPseudoTiedBinaryV_VM,
2173+
Sched<[WriteVIMergeV, ReadVIMergeV, ReadVIMergeV, ReadVMask]>;
2174+
defm "" : VPseudoTiedBinaryV_XM,
2175+
Sched<[WriteVIMergeX, ReadVIMergeV, ReadVIMergeX, ReadVMask]>;
2176+
defm "" : VPseudoTiedBinaryV_IM,
2177+
Sched<[WriteVIMergeI, ReadVIMergeV, ReadVMask]>;
21072178
}
21082179

21092180
multiclass VPseudoVCALU_VM_XM_IM {

llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -177,14 +177,13 @@ def riscv_vrgatherei16_vv_vl : SDNode<"RISCVISD::VRGATHEREI16_VV_VL",
177177
SDTCisSameNumEltsAs<0, 3>,
178178
SDTCisVT<4, XLenVT>]>>;
179179

180-
def riscv_vselect_vl : SDNode<"RISCVISD::VSELECT_VL",
181-
SDTypeProfile<1, 4, [SDTCisVec<0>,
182-
SDTCisVec<1>,
183-
SDTCisSameNumEltsAs<0, 1>,
184-
SDTCVecEltisVT<1, i1>,
185-
SDTCisSameAs<0, 2>,
186-
SDTCisSameAs<2, 3>,
187-
SDTCisVT<4, XLenVT>]>>;
180+
def SDT_RISCVSelect_VL : SDTypeProfile<1, 4, [
181+
SDTCisVec<0>, SDTCisVec<1>, SDTCisSameNumEltsAs<0, 1>, SDTCVecEltisVT<1, i1>,
182+
SDTCisSameAs<0, 2>, SDTCisSameAs<2, 3>, SDTCisVT<4, XLenVT>
183+
]>;
184+
185+
def riscv_vselect_vl : SDNode<"RISCVISD::VSELECT_VL", SDT_RISCVSelect_VL>;
186+
def riscv_vp_merge_vl : SDNode<"RISCVISD::VP_MERGE_VL", SDT_RISCVSelect_VL>;
188187

189188
def SDT_RISCVMaskBinOp_VL : SDTypeProfile<1, 3, [SDTCisSameAs<0, 1>,
190189
SDTCisSameAs<0, 2>,
@@ -976,6 +975,30 @@ foreach vti = AllIntegerVectors in {
976975
VLOpFrag)),
977976
(!cast<Instruction>("PseudoVMERGE_VIM_"#vti.LMul.MX)
978977
vti.RegClass:$rs2, simm5:$rs1, (vti.Mask V0), GPR:$vl, vti.Log2SEW)>;
978+
979+
def : Pat<(vti.Vector (riscv_vp_merge_vl (vti.Mask V0),
980+
vti.RegClass:$rs1,
981+
vti.RegClass:$rs2,
982+
VLOpFrag)),
983+
(!cast<Instruction>("PseudoVMERGE_VVM_"#vti.LMul.MX#"_TU")
984+
vti.RegClass:$rs2, vti.RegClass:$rs2, vti.RegClass:$rs1,
985+
(vti.Mask V0), GPR:$vl, vti.Log2SEW)>;
986+
987+
def : Pat<(vti.Vector (riscv_vp_merge_vl (vti.Mask V0),
988+
(SplatPat XLenVT:$rs1),
989+
vti.RegClass:$rs2,
990+
VLOpFrag)),
991+
(!cast<Instruction>("PseudoVMERGE_VXM_"#vti.LMul.MX#"_TU")
992+
vti.RegClass:$rs2, vti.RegClass:$rs2, GPR:$rs1,
993+
(vti.Mask V0), GPR:$vl, vti.Log2SEW)>;
994+
995+
def : Pat<(vti.Vector (riscv_vp_merge_vl (vti.Mask V0),
996+
(SplatPat_simm5 simm5:$rs1),
997+
vti.RegClass:$rs2,
998+
VLOpFrag)),
999+
(!cast<Instruction>("PseudoVMERGE_VIM_"#vti.LMul.MX#"_TU")
1000+
vti.RegClass:$rs2, vti.RegClass:$rs2, simm5:$rs1,
1001+
(vti.Mask V0), GPR:$vl, vti.Log2SEW)>;
9791002
}
9801003

9811004
// 12.16. Vector Integer Move Instructions
@@ -1223,6 +1246,31 @@ foreach fvti = AllFloatVectors in {
12231246
(!cast<Instruction>("PseudoVMERGE_VIM_"#fvti.LMul.MX)
12241247
fvti.RegClass:$rs2, 0, (fvti.Mask V0), GPR:$vl, fvti.Log2SEW)>;
12251248

1249+
def : Pat<(fvti.Vector (riscv_vp_merge_vl (fvti.Mask V0),
1250+
fvti.RegClass:$rs1,
1251+
fvti.RegClass:$rs2,
1252+
VLOpFrag)),
1253+
(!cast<Instruction>("PseudoVMERGE_VVM_"#fvti.LMul.MX#"_TU")
1254+
fvti.RegClass:$rs2, fvti.RegClass:$rs2, fvti.RegClass:$rs1, (fvti.Mask V0),
1255+
GPR:$vl, fvti.Log2SEW)>;
1256+
1257+
def : Pat<(fvti.Vector (riscv_vp_merge_vl (fvti.Mask V0),
1258+
(SplatFPOp fvti.ScalarRegClass:$rs1),
1259+
fvti.RegClass:$rs2,
1260+
VLOpFrag)),
1261+
(!cast<Instruction>("PseudoVFMERGE_V"#fvti.ScalarSuffix#"M_"#fvti.LMul.MX#"_TU")
1262+
fvti.RegClass:$rs2, fvti.RegClass:$rs2,
1263+
(fvti.Scalar fvti.ScalarRegClass:$rs1),
1264+
(fvti.Mask V0), GPR:$vl, fvti.Log2SEW)>;
1265+
1266+
def : Pat<(fvti.Vector (riscv_vp_merge_vl (fvti.Mask V0),
1267+
(SplatFPOp (fvti.Scalar fpimm0)),
1268+
fvti.RegClass:$rs2,
1269+
VLOpFrag)),
1270+
(!cast<Instruction>("PseudoVMERGE_VIM_"#fvti.LMul.MX#"_TU")
1271+
fvti.RegClass:$rs2, fvti.RegClass:$rs2, 0, (fvti.Mask V0),
1272+
GPR:$vl, fvti.Log2SEW)>;
1273+
12261274
// 14.16. Vector Floating-Point Move Instruction
12271275
// If we're splatting fpimm0, use vmv.v.x vd, x0.
12281276
def : Pat<(fvti.Vector (riscv_vfmv_v_f_vl

0 commit comments

Comments
 (0)