Skip to content

Commit 0736f33

Browse files
authored
[DAG] Handle truncated splat in isBoolConstant (#145473)
This allows truncated splat / buildvector in isBoolConstant, to allow certain not instructions to be recognized post-legalization, and allow vselect to optimize. An override for x86 avx512 predicated vectors is required to avoid an infinite recursion from the code that detects zero vectors. From: ``` // Check if the first operand is all zeros and Cond type is vXi1. // If this an avx512 target we can improve the use of zero masking by // swapping the operands and inverting the condition. ```
1 parent dd60663 commit 0736f33

File tree

14 files changed

+2400
-3151
lines changed

14 files changed

+2400
-3151
lines changed

llvm/include/llvm/CodeGen/SelectionDAG.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2488,8 +2488,7 @@ class SelectionDAG {
24882488

24892489
/// Check if a value \op N is a constant using the target's BooleanContent for
24902490
/// its type.
2491-
LLVM_ABI std::optional<bool>
2492-
isBoolConstant(SDValue N, bool AllowTruncation = false) const;
2491+
LLVM_ABI std::optional<bool> isBoolConstant(SDValue N) const;
24932492

24942493
/// Set CallSiteInfo to be associated with Node.
24952494
void addCallSiteInfo(const SDNode *Node, CallSiteInfo &&CallInfo) {

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4371,6 +4371,11 @@ class LLVM_ABI TargetLowering : public TargetLoweringBase {
43714371
Op.getOpcode() == ISD::SPLAT_VECTOR_PARTS;
43724372
}
43734373

4374+
/// Return true if the given select/vselect should be considered canonical and
4375+
/// not be transformed. Currently only used for "vselect (not Cond), N1, N2 ->
4376+
/// vselect Cond, N2, N1".
4377+
virtual bool isTargetCanonicalSelect(SDNode *N) const { return false; }
4378+
43744379
struct DAGCombinerInfo {
43754380
void *DC; // The DAG Combiner object.
43764381
CombineLevel Level;

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13194,8 +13194,9 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) {
1319413194
return V;
1319513195

1319613196
// vselect (not Cond), N1, N2 -> vselect Cond, N2, N1
13197-
if (SDValue F = extractBooleanFlip(N0, DAG, TLI, false))
13198-
return DAG.getSelect(DL, VT, F, N2, N1);
13197+
if (!TLI.isTargetCanonicalSelect(N))
13198+
if (SDValue F = extractBooleanFlip(N0, DAG, TLI, false))
13199+
return DAG.getSelect(DL, VT, F, N2, N1);
1319913200

1320013201
// select (sext m), (add X, C), X --> (add X, (and C, (sext m))))
1320113202
if (N1.getOpcode() == ISD::ADD && N1.getOperand(0) == N2 && N1->hasOneUse() &&

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10459,7 +10459,7 @@ SDValue SelectionDAG::simplifySelect(SDValue Cond, SDValue T, SDValue F) {
1045910459

1046010460
// select true, T, F --> T
1046110461
// select false, T, F --> F
10462-
if (auto C = isBoolConstant(Cond, /*AllowTruncation=*/true))
10462+
if (auto C = isBoolConstant(Cond))
1046310463
return *C ? T : F;
1046410464

1046510465
// select ?, T, T --> T
@@ -13688,13 +13688,14 @@ bool SelectionDAG::isConstantFPBuildVectorOrConstantFP(SDValue N) const {
1368813688
return false;
1368913689
}
1369013690

13691-
std::optional<bool> SelectionDAG::isBoolConstant(SDValue N,
13692-
bool AllowTruncation) const {
13693-
ConstantSDNode *Const = isConstOrConstSplat(N, false, AllowTruncation);
13691+
std::optional<bool> SelectionDAG::isBoolConstant(SDValue N) const {
13692+
ConstantSDNode *Const =
13693+
isConstOrConstSplat(N, false, /*AllowTruncation=*/true);
1369413694
if (!Const)
1369513695
return std::nullopt;
1369613696

13697-
const APInt &CVal = Const->getAPIntValue();
13697+
EVT VT = N->getValueType(0);
13698+
const APInt CVal = Const->getAPIntValue().trunc(VT.getScalarSizeInBits());
1369813699
switch (TLI->getBooleanContents(N.getValueType())) {
1369913700
case TargetLowering::ZeroOrOneBooleanContent:
1370013701
if (CVal.isOne())

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4975,6 +4975,16 @@ X86TargetLowering::getTargetConstantFromLoad(LoadSDNode *LD) const {
49754975
return getTargetConstantFromNode(LD);
49764976
}
49774977

4978+
bool X86TargetLowering::isTargetCanonicalSelect(SDNode *N) const {
4979+
// Do not fold (vselect not(C), X, 0s) to (vselect C, Os, X)
4980+
SDValue Cond = N->getOperand(0);
4981+
SDValue RHS = N->getOperand(2);
4982+
EVT CondVT = Cond.getValueType();
4983+
return N->getOpcode() == ISD::VSELECT && Subtarget.hasAVX512() &&
4984+
CondVT.getVectorElementType() == MVT::i1 &&
4985+
ISD::isBuildVectorAllZeros(RHS.getNode());
4986+
}
4987+
49784988
// Extract raw constant bits from constant pools.
49794989
static bool getTargetConstantBitsFromNode(SDValue Op, unsigned EltSizeInBits,
49804990
APInt &UndefElts,

llvm/lib/Target/X86/X86ISelLowering.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1356,6 +1356,8 @@ namespace llvm {
13561356
TargetLowering::isTargetCanonicalConstantNode(Op);
13571357
}
13581358

1359+
bool isTargetCanonicalSelect(SDNode *N) const override;
1360+
13591361
const Constant *getTargetConstantFromLoad(LoadSDNode *LD) const override;
13601362

13611363
SDValue unwrapAddress(SDValue N) const override;

0 commit comments

Comments
 (0)