Skip to content

[DAGCombine] Fold vselect with splat zero #147305

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9902,11 +9902,14 @@ SDValue DAGCombiner::visitXOR(SDNode *N) {
if (SDValue Combined = visitADDLike(N))
return Combined;

// fold !(x cc y) -> (x !cc y)
// fold not (setcc x, y, cc) -> setcc x y !cc
// Avoid breaking: and (not(setcc x, y, cc), z) -> andn for vec
unsigned N0Opcode = N0.getOpcode();
SDValue LHS, RHS, CC;
if (TLI.isConstTrueVal(N1) &&
isSetCCEquivalent(N0, LHS, RHS, CC, /*MatchStrict*/ true)) {
isSetCCEquivalent(N0, LHS, RHS, CC, /*MatchStrict*/ true) &&
!(VT.isVector() && TLI.hasAndNot(SDValue(N, 0)) && N->hasOneUse() &&
N->use_begin()->getUser()->getOpcode() == ISD::AND)) {
ISD::CondCode NotCC = ISD::getSetCCInverse(cast<CondCodeSDNode>(CC)->get(),
LHS.getValueType());
if (!LegalOperations ||
Expand Down Expand Up @@ -13088,10 +13091,10 @@ static SDValue combineVSelectWithAllOnesOrZeros(SDValue Cond, SDValue TVal,
EVT CondVT = Cond.getValueType();
assert(CondVT.isVector() && "Vector select expects a vector selector!");

bool IsTAllZero = ISD::isBuildVectorAllZeros(TVal.getNode());
bool IsTAllOne = ISD::isBuildVectorAllOnes(TVal.getNode());
bool IsFAllZero = ISD::isBuildVectorAllZeros(FVal.getNode());
bool IsFAllOne = ISD::isBuildVectorAllOnes(FVal.getNode());
bool IsTAllZero = ISD::isConstantSplatVectorAllZeros(TVal.getNode());
bool IsTAllOne = ISD::isConstantSplatVectorAllOnes(TVal.getNode());
bool IsFAllZero = ISD::isConstantSplatVectorAllZeros(FVal.getNode());
bool IsFAllOne = ISD::isConstantSplatVectorAllOnes(FVal.getNode());

// no vselect(cond, 0/-1, X) or vselect(cond, X, 0/-1), return
if (!IsTAllZero && !IsTAllOne && !IsFAllZero && !IsFAllOne)
Expand Down Expand Up @@ -13165,6 +13168,15 @@ static SDValue combineVSelectWithAllOnesOrZeros(SDValue Cond, SDValue TVal,
return DAG.getBitcast(VT, And);
}

// select Cond, 0, x -> and not(Cond), x
if (IsTAllZero &&
(isBitwiseNot(peekThroughBitcasts(Cond)) || TLI.hasAndNot(Cond))) {
SDValue X = DAG.getBitcast(CondVT, FVal);
SDValue And =
DAG.getNode(ISD::AND, DL, CondVT, DAG.getNOT(DL, Cond, CondVT), X);
return DAG.getBitcast(VT, And);
}

return SDValue();
}

Expand Down
54 changes: 0 additions & 54 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47260,57 +47260,6 @@ static SDValue combineToExtendBoolVectorInReg(
DAG.getConstant(EltSizeInBits - 1, DL, VT));
}

/// If a vector select has an left operand that is 0, try to simplify the
/// select to a bitwise logic operation.
/// TODO: Move to DAGCombiner.combineVSelectWithAllOnesOrZeros, possibly using
/// TargetLowering::hasAndNot()?
static SDValue combineVSelectWithLastZeros(SDNode *N, SelectionDAG &DAG,
const SDLoc &DL,
TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &Subtarget) {
SDValue Cond = N->getOperand(0);
SDValue LHS = N->getOperand(1);
SDValue RHS = N->getOperand(2);
EVT VT = LHS.getValueType();
EVT CondVT = Cond.getValueType();
const TargetLowering &TLI = DAG.getTargetLoweringInfo();

if (N->getOpcode() != ISD::VSELECT)
return SDValue();

assert(CondVT.isVector() && "Vector select expects a vector selector!");

// To use the condition operand as a bitwise mask, it must have elements that
// are the same size as the select elements. Ie, the condition operand must
// have already been promoted from the IR select condition type <N x i1>.
// Don't check if the types themselves are equal because that excludes
// vector floating-point selects.
if (CondVT.getScalarSizeInBits() != VT.getScalarSizeInBits())
return SDValue();

// Cond value must be 'sign splat' to be converted to a logical op.
if (DAG.ComputeNumSignBits(Cond) != CondVT.getScalarSizeInBits())
return SDValue();

if (!TLI.isTypeLegal(CondVT))
return SDValue();

// vselect Cond, 000..., X -> andn Cond, X
if (ISD::isBuildVectorAllZeros(LHS.getNode())) {
SDValue CastRHS = DAG.getBitcast(CondVT, RHS);
SDValue AndN;
// The canonical form differs for i1 vectors - x86andnp is not used
if (CondVT.getScalarType() == MVT::i1)
AndN = DAG.getNode(ISD::AND, DL, CondVT, DAG.getNOT(DL, Cond, CondVT),
CastRHS);
else
AndN = DAG.getNode(X86ISD::ANDNP, DL, CondVT, Cond, CastRHS);
return DAG.getBitcast(VT, AndN);
}

return SDValue();
}

/// If both arms of a vector select are concatenated vectors, split the select,
/// and concatenate the result to eliminate a wide (256-bit) vector instruction:
/// vselect Cond, (concat T0, T1), (concat F0, F1) -->
Expand Down Expand Up @@ -48052,9 +48001,6 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
if (!TLI.isTypeLegal(VT) || isSoftF16(VT, Subtarget))
return SDValue();

if (SDValue V = combineVSelectWithLastZeros(N, DAG, DL, DCI, Subtarget))
return V;

if (SDValue V = combineVSelectToBLENDV(N, DAG, DL, DCI, Subtarget))
return V;

Expand Down
29 changes: 10 additions & 19 deletions llvm/test/CodeGen/AArch64/sve-fixed-length-shuffles.ll
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ define void @crash_when_lowering_extract_shuffle(ptr %dst, i1 %cond) vscale_rang
; CHECK-NEXT: // %bb.1: // %vector.body
; CHECK-NEXT: movi v0.2d, #0000000000000000
; CHECK-NEXT: movi v1.2d, #0000000000000000
; CHECK-NEXT: ptrue p0.s
; CHECK-NEXT: ldr z4, [x0]
; CHECK-NEXT: ldr z5, [x0, #2, mul vl]
; CHECK-NEXT: ldr z6, [x0, #3, mul vl]
; CHECK-NEXT: umov w8, v0.b[8]
; CHECK-NEXT: mov v1.b[1], v0.b[1]
; CHECK-NEXT: fmov s2, w8
Expand Down Expand Up @@ -60,31 +62,20 @@ define void @crash_when_lowering_extract_shuffle(ptr %dst, i1 %cond) vscale_rang
; CHECK-NEXT: asr z1.s, z1.s, #31
; CHECK-NEXT: uunpklo z3.s, z3.h
; CHECK-NEXT: lsl z0.s, z0.s, #31
; CHECK-NEXT: and z1.s, z1.s, #0x1
; CHECK-NEXT: bic z1.d, z4.d, z1.d
; CHECK-NEXT: lsl z2.s, z2.s, #31
; CHECK-NEXT: ldr z4, [x0, #1, mul vl]
; CHECK-NEXT: asr z0.s, z0.s, #31
; CHECK-NEXT: cmpne p1.s, p0/z, z1.s, #0
; CHECK-NEXT: ldr z1, [x0]
; CHECK-NEXT: str z1, [x0]
; CHECK-NEXT: lsl z3.s, z3.s, #31
; CHECK-NEXT: asr z2.s, z2.s, #31
; CHECK-NEXT: and z0.s, z0.s, #0x1
; CHECK-NEXT: bic z0.d, z5.d, z0.d
; CHECK-NEXT: asr z3.s, z3.s, #31
; CHECK-NEXT: and z2.s, z2.s, #0x1
; CHECK-NEXT: mov z1.s, p1/m, #0 // =0x0
; CHECK-NEXT: cmpne p2.s, p0/z, z0.s, #0
; CHECK-NEXT: ldr z0, [x0, #2, mul vl]
; CHECK-NEXT: and z3.s, z3.s, #0x1
; CHECK-NEXT: str z1, [x0]
; CHECK-NEXT: cmpne p3.s, p0/z, z3.s, #0
; CHECK-NEXT: cmpne p0.s, p0/z, z2.s, #0
; CHECK-NEXT: ldr z3, [x0, #3, mul vl]
; CHECK-NEXT: ldr z2, [x0, #1, mul vl]
; CHECK-NEXT: mov z0.s, p2/m, #0 // =0x0
; CHECK-NEXT: mov z3.s, p3/m, #0 // =0x0
; CHECK-NEXT: mov z2.s, p0/m, #0 // =0x0
; CHECK-NEXT: bic z1.d, z4.d, z2.d
; CHECK-NEXT: str z0, [x0, #2, mul vl]
; CHECK-NEXT: bic z3.d, z6.d, z3.d
; CHECK-NEXT: str z1, [x0, #1, mul vl]
; CHECK-NEXT: str z3, [x0, #3, mul vl]
; CHECK-NEXT: str z2, [x0, #1, mul vl]
; CHECK-NEXT: .LBB1_2: // %exit
; CHECK-NEXT: ret
%broadcast.splat = shufflevector <32 x i1> zeroinitializer, <32 x i1> zeroinitializer, <32 x i32> zeroinitializer
Expand Down
6 changes: 2 additions & 4 deletions llvm/test/CodeGen/AArch64/vselect-constants.ll
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,9 @@ define <4 x i32> @cmp_sel_1_or_0_vec(<4 x i32> %x, <4 x i32> %y) {
define <4 x i32> @sel_0_or_1_vec(<4 x i1> %cond) {
; CHECK-LABEL: sel_0_or_1_vec:
; CHECK: // %bb.0:
; CHECK-NEXT: ushll v0.4s, v0.4h, #0
; CHECK-NEXT: movi v1.4s, #1
; CHECK-NEXT: shl v0.4s, v0.4s, #31
; CHECK-NEXT: cmge v0.4s, v0.4s, #0
; CHECK-NEXT: and v0.16b, v0.16b, v1.16b
; CHECK-NEXT: ushll v0.4s, v0.4h, #0
; CHECK-NEXT: bic v0.16b, v1.16b, v0.16b
; CHECK-NEXT: ret
%add = select <4 x i1> %cond, <4 x i32> <i32 0, i32 0, i32 0, i32 0>, <4 x i32> <i32 1, i32 1, i32 1, i32 1>
ret <4 x i32> %add
Expand Down
12 changes: 4 additions & 8 deletions llvm/test/CodeGen/WebAssembly/fpclamptosat_vec.ll
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,9 @@ define <2 x i32> @ustest_f64i32(<2 x double> %x) {
; CHECK-NEXT: v128.bitselect
; CHECK-NEXT: local.tee 0
; CHECK-NEXT: v128.const 0, 0
; CHECK-NEXT: local.tee 1
; CHECK-NEXT: local.get 0
; CHECK-NEXT: local.get 1
; CHECK-NEXT: i64x2.gt_s
; CHECK-NEXT: v128.bitselect
; CHECK-NEXT: local.get 0
; CHECK-NEXT: v128.and
; CHECK-NEXT: local.get 0
; CHECK-NEXT: i8x16.shuffle 0, 1, 2, 3, 8, 9, 10, 11, 0, 1, 2, 3, 0, 1, 2, 3
; CHECK-NEXT: # fallthrough-return
Expand Down Expand Up @@ -1558,11 +1556,9 @@ define <2 x i32> @ustest_f64i32_mm(<2 x double> %x) {
; CHECK-NEXT: v128.bitselect
; CHECK-NEXT: local.tee 0
; CHECK-NEXT: v128.const 0, 0
; CHECK-NEXT: local.tee 1
; CHECK-NEXT: local.get 0
; CHECK-NEXT: local.get 1
; CHECK-NEXT: i64x2.gt_s
; CHECK-NEXT: v128.bitselect
; CHECK-NEXT: local.get 0
; CHECK-NEXT: v128.and
; CHECK-NEXT: local.get 0
; CHECK-NEXT: i8x16.shuffle 0, 1, 2, 3, 8, 9, 10, 11, 0, 1, 2, 3, 0, 1, 2, 3
; CHECK-NEXT: # fallthrough-return
Expand Down
71 changes: 71 additions & 0 deletions llvm/test/CodeGen/WebAssembly/simd-bitselect.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc < %s -O3 -verify-machineinstrs -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+simd128 | FileCheck %s
target triple = "wasm32-unknown-unknown"

define <4 x i32> @bitselect_splat_first_zero_and_icmp(<4 x i32> %input) {
; CHECK-LABEL: bitselect_splat_first_zero_and_icmp:
; CHECK: .functype bitselect_splat_first_zero_and_icmp (v128) -> (v128)
; CHECK-NEXT: # %bb.0: # %start
; CHECK-NEXT: v128.const $push0=, 2139095040, 2139095040, 2139095040, 2139095040
; CHECK-NEXT: v128.and $push1=, $0, $pop0
; CHECK-NEXT: v128.const $push2=, 0, 0, 0, 0
; CHECK-NEXT: i32x4.ne $push3=, $pop1, $pop2
; CHECK-NEXT: v128.and $push4=, $pop3, $0
; CHECK-NEXT: return $pop4
start:
%0 = and <4 x i32> %input, splat (i32 2139095040)
%1 = icmp eq <4 x i32> %0, zeroinitializer
%2 = select <4 x i1> %1, <4 x i32> zeroinitializer, <4 x i32> %input
ret <4 x i32> %2
}


define <4 x i32> @bitselect_splat_second_zero_and_icmp(<4 x i32> %input) {
; CHECK-LABEL: bitselect_splat_second_zero_and_icmp:
; CHECK: .functype bitselect_splat_second_zero_and_icmp (v128) -> (v128)
; CHECK-NEXT: # %bb.0: # %start
; CHECK-NEXT: v128.const $push0=, 2139095040, 2139095040, 2139095040, 2139095040
; CHECK-NEXT: v128.and $push1=, $0, $pop0
; CHECK-NEXT: v128.const $push2=, 0, 0, 0, 0
; CHECK-NEXT: i32x4.eq $push3=, $pop1, $pop2
; CHECK-NEXT: v128.and $push4=, $pop3, $0
; CHECK-NEXT: return $pop4
start:
%0 = and <4 x i32> %input, splat (i32 2139095040)
%1 = icmp eq <4 x i32> %0, zeroinitializer
%2 = select <4 x i1> %1, <4 x i32> %input, <4 x i32> zeroinitializer
ret <4 x i32> %2
}


define <4 x i32> @bitselect_splat_first_zero_cond_input(<4 x i1> %cond, <4 x i32> %input) {
; CHECK-LABEL: bitselect_splat_first_zero_cond_input:
; CHECK: .functype bitselect_splat_first_zero_cond_input (v128, v128) -> (v128)
; CHECK-NEXT: # %bb.0: # %start
; CHECK-NEXT: v128.const $push3=, 0, 0, 0, 0
; CHECK-NEXT: i32.const $push0=, 31
; CHECK-NEXT: i32x4.shl $push1=, $0, $pop0
; CHECK-NEXT: i32.const $push5=, 31
; CHECK-NEXT: i32x4.shr_s $push2=, $pop1, $pop5
; CHECK-NEXT: v128.bitselect $push4=, $pop3, $1, $pop2
; CHECK-NEXT: return $pop4
start:
%2 = select <4 x i1> %cond, <4 x i32> zeroinitializer, <4 x i32> %input
ret <4 x i32> %2
}

define <4 x i32> @bitselect_splat_second_zero_cond_input(<4 x i1> %cond, <4 x i32> %input) {
; CHECK-LABEL: bitselect_splat_second_zero_cond_input:
; CHECK: .functype bitselect_splat_second_zero_cond_input (v128, v128) -> (v128)
; CHECK-NEXT: # %bb.0: # %start
; CHECK-NEXT: i32.const $push0=, 31
; CHECK-NEXT: i32x4.shl $push1=, $0, $pop0
; CHECK-NEXT: i32.const $push4=, 31
; CHECK-NEXT: i32x4.shr_s $push2=, $pop1, $pop4
; CHECK-NEXT: v128.and $push3=, $pop2, $1
; CHECK-NEXT: return $pop3
start:
%2 = select <4 x i1> %cond, <4 x i32> %input, <4 x i32> zeroinitializer
ret <4 x i32> %2
}