Skip to content

[GlobaISel] Allow expanding of sdiv -> mul by constant #146504

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@ class CombinerHelper {
/// Query is legal on the target.
bool isLegalOrBeforeLegalizer(const LegalityQuery &Query) const;

/// \return true if \p Query is legal on the target, or if \p Query will
/// perform WidenScalar action on the target.
bool isLegalOrHasWidenScalar(const LegalityQuery &Query) const;

/// \return true if the combine is running prior to legalization, or if \p Ty
/// is a legal integer constant type on the target.
bool isConstantLegalOrBeforeLegalizer(const LLT Ty) const;
Expand Down
12 changes: 6 additions & 6 deletions llvm/include/llvm/Target/GlobalISel/Combine.td
Original file line number Diff line number Diff line change
Expand Up @@ -1131,13 +1131,13 @@ def form_bitfield_extract : GICombineGroup<[bitfield_extract_from_sext_inreg,

def udiv_by_const : GICombineRule<
(defs root:$root),
(match (wip_match_opcode G_UDIV):$root,
(match (G_UDIV $dst, $x, $y):$root,
[{ return Helper.matchUDivorURemByConst(*${root}); }]),
(apply [{ Helper.applyUDivorURemByConst(*${root}); }])>;

def sdiv_by_const : GICombineRule<
(defs root:$root),
(match (wip_match_opcode G_SDIV):$root,
(match (G_SDIV $dst, $x, $y):$root,
[{ return Helper.matchSDivByConst(*${root}); }]),
(apply [{ Helper.applySDivByConst(*${root}); }])>;

Expand All @@ -1153,8 +1153,8 @@ def udiv_by_pow2 : GICombineRule<
[{ return Helper.matchDivByPow2(*${root}, /*IsSigned=*/false); }]),
(apply [{ Helper.applyUDivByPow2(*${root}); }])>;

def intdiv_combines : GICombineGroup<[udiv_by_const, sdiv_by_const,
sdiv_by_pow2, udiv_by_pow2]>;
def intdiv_combines : GICombineGroup<[udiv_by_pow2, sdiv_by_pow2,
udiv_by_const, sdiv_by_const,]>;

def urem_by_const : GICombineRule<
(defs root:$root),
Expand Down Expand Up @@ -2054,9 +2054,9 @@ def all_combines : GICombineGroup<[integer_reassoc_combines, trivial_combines,
div_rem_to_divrem, funnel_shift_combines, bitreverse_shift, commute_shift,
form_bitfield_extract, constant_fold_binops, constant_fold_fma,
constant_fold_cast_op, fabs_fneg_fold,
intdiv_combines, mulh_combines, redundant_neg_operands,
mulh_combines, redundant_neg_operands,
and_or_disjoint_mask, fma_combines, fold_binop_into_select,
intrem_combines, sub_add_reg, select_to_minmax,
intrem_combines, intdiv_combines, sub_add_reg, select_to_minmax,
fsub_to_fneg, commute_constant_to_rhs, match_ands, match_ors,
simplify_neg_minmax, combine_concat_vector,
sext_trunc, zext_trunc, prefer_sign_combines, shuffle_combines,
Expand Down
134 changes: 114 additions & 20 deletions llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,11 @@ bool CombinerHelper::isLegalOrBeforeLegalizer(
return isPreLegalize() || isLegal(Query);
}

bool CombinerHelper::isLegalOrHasWidenScalar(const LegalityQuery &Query) const {
return isLegal(Query) ||
LI->getAction(Query).Action == LegalizeActions::WidenScalar;
}

bool CombinerHelper::isConstantLegalOrBeforeLegalizer(const LLT Ty) const {
if (!Ty.isVector())
return isLegalOrBeforeLegalizer({TargetOpcode::G_CONSTANT, {Ty}});
Expand Down Expand Up @@ -5522,6 +5527,8 @@ bool CombinerHelper::matchSDivByConst(MachineInstr &MI) const {
Register Dst = MI.getOperand(0).getReg();
Register RHS = MI.getOperand(2).getReg();
LLT DstTy = MRI.getType(Dst);
auto SizeInBits = DstTy.getScalarSizeInBits();
LLT WideTy = DstTy.changeElementSize(SizeInBits * 2);

auto &MF = *MI.getMF();
AttributeList Attr = MF.getFunction().getAttributes();
Expand All @@ -5541,8 +5548,21 @@ bool CombinerHelper::matchSDivByConst(MachineInstr &MI) const {
MRI, RHS, [](const Constant *C) { return C && !C->isNullValue(); });
}

// Don't support the general case for now.
return false;
auto *RHSDef = MRI.getVRegDef(RHS);
if (!isConstantOrConstantVector(*RHSDef, MRI))
return false;

// Don't do this if the types are not going to be legal.
if (LI) {
if (!isLegalOrBeforeLegalizer({TargetOpcode::G_MUL, {DstTy, DstTy}}))
return false;
if (!isLegal({TargetOpcode::G_SMULH, {DstTy}}) &&
!isLegalOrHasWidenScalar({TargetOpcode::G_MUL, {WideTy, WideTy}}))
return false;
}

return matchUnaryPredicate(
MRI, RHS, [](const Constant *C) { return C && !C->isNullValue(); });
}

void CombinerHelper::applySDivByConst(MachineInstr &MI) const {
Expand All @@ -5558,21 +5578,22 @@ MachineInstr *CombinerHelper::buildSDivUsingMul(MachineInstr &MI) const {
Register RHS = SDiv.getReg(2);
LLT Ty = MRI.getType(Dst);
LLT ScalarTy = Ty.getScalarType();
const unsigned EltBits = ScalarTy.getScalarSizeInBits();
LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(Ty);
LLT ScalarShiftAmtTy = ShiftAmtTy.getScalarType();
auto &MIB = Builder;

bool UseSRA = false;
SmallVector<Register, 16> Shifts, Factors;
SmallVector<Register, 16> ExactShifts, ExactFactors;

auto *RHSDef = cast<GenericMachineInstr>(getDefIgnoringCopies(RHS, MRI));
bool IsSplat = getIConstantSplatVal(*RHSDef, MRI).has_value();
auto *RHSDefInstr = cast<GenericMachineInstr>(getDefIgnoringCopies(RHS, MRI));
bool IsSplat = getIConstantSplatVal(*RHSDefInstr, MRI).has_value();

auto BuildSDIVPattern = [&](const Constant *C) {
auto BuildExactSDIVPattern = [&](const Constant *C) {
// Don't recompute inverses for each splat element.
if (IsSplat && !Factors.empty()) {
Shifts.push_back(Shifts[0]);
Factors.push_back(Factors[0]);
if (IsSplat && !ExactFactors.empty()) {
ExactShifts.push_back(ExactShifts[0]);
ExactFactors.push_back(ExactFactors[0]);
return true;
}

Expand All @@ -5587,31 +5608,104 @@ MachineInstr *CombinerHelper::buildSDivUsingMul(MachineInstr &MI) const {
// Calculate the multiplicative inverse modulo BW.
// 2^W requires W + 1 bits, so we have to extend and then truncate.
APInt Factor = Divisor.multiplicativeInverse();
Shifts.push_back(MIB.buildConstant(ScalarShiftAmtTy, Shift).getReg(0));
Factors.push_back(MIB.buildConstant(ScalarTy, Factor).getReg(0));
ExactShifts.push_back(MIB.buildConstant(ScalarShiftAmtTy, Shift).getReg(0));
ExactFactors.push_back(MIB.buildConstant(ScalarTy, Factor).getReg(0));
return true;
};

// Collect all magic values from the build vector.
if (MI.getFlag(MachineInstr::MIFlag::IsExact)) {
// Collect all magic values from the build vector.
bool Matched = matchUnaryPredicate(MRI, RHS, BuildExactSDIVPattern);
(void)Matched;
assert(Matched && "Expected unary predicate match to succeed");

Register Shift, Factor;
if (Ty.isVector()) {
Shift = MIB.buildBuildVector(ShiftAmtTy, ExactShifts).getReg(0);
Factor = MIB.buildBuildVector(Ty, ExactFactors).getReg(0);
} else {
Shift = ExactShifts[0];
Factor = ExactFactors[0];
}

Register Res = LHS;

if (UseSRA)
Res = MIB.buildAShr(Ty, Res, Shift, MachineInstr::IsExact).getReg(0);

return MIB.buildMul(Ty, Res, Factor);
}

SmallVector<Register, 16> MagicFactors, Factors, Shifts, ShiftMasks;

auto BuildSDIVPattern = [&](const Constant *C) {
auto *CI = cast<ConstantInt>(C);
const APInt &Divisor = CI->getValue();

SignedDivisionByConstantInfo magics =
SignedDivisionByConstantInfo::get(Divisor);
int NumeratorFactor = 0;
int ShiftMask = -1;

if (Divisor.isOne() || Divisor.isAllOnes()) {
// If d is +1/-1, we just multiply the numerator by +1/-1.
NumeratorFactor = Divisor.getSExtValue();
magics.Magic = 0;
magics.ShiftAmount = 0;
ShiftMask = 0;
} else if (Divisor.isStrictlyPositive() && magics.Magic.isNegative()) {
// If d > 0 and m < 0, add the numerator.
NumeratorFactor = 1;
} else if (Divisor.isNegative() && magics.Magic.isStrictlyPositive()) {
// If d < 0 and m > 0, subtract the numerator.
NumeratorFactor = -1;
}

MagicFactors.push_back(MIB.buildConstant(ScalarTy, magics.Magic).getReg(0));
Factors.push_back(MIB.buildConstant(ScalarTy, NumeratorFactor).getReg(0));
Shifts.push_back(
MIB.buildConstant(ScalarShiftAmtTy, magics.ShiftAmount).getReg(0));
ShiftMasks.push_back(MIB.buildConstant(ScalarTy, ShiftMask).getReg(0));

return true;
};

// Collect the shifts/magic values from each element.
bool Matched = matchUnaryPredicate(MRI, RHS, BuildSDIVPattern);
(void)Matched;
assert(Matched && "Expected unary predicate match to succeed");

Register Shift, Factor;
if (Ty.isVector()) {
Shift = MIB.buildBuildVector(ShiftAmtTy, Shifts).getReg(0);
Register MagicFactor, Factor, Shift, ShiftMask;
auto *RHSDef = getOpcodeDef<GBuildVector>(RHS, MRI);
if (RHSDef) {
MagicFactor = MIB.buildBuildVector(Ty, MagicFactors).getReg(0);
Factor = MIB.buildBuildVector(Ty, Factors).getReg(0);
Shift = MIB.buildBuildVector(ShiftAmtTy, Shifts).getReg(0);
ShiftMask = MIB.buildBuildVector(Ty, ShiftMasks).getReg(0);
} else {
Shift = Shifts[0];
assert(MRI.getType(RHS).isScalar() &&
"Non-build_vector operation should have been a scalar");
MagicFactor = MagicFactors[0];
Factor = Factors[0];
Shift = Shifts[0];
ShiftMask = ShiftMasks[0];
}

Register Res = LHS;
Register Q = LHS;
Q = MIB.buildSMulH(Ty, LHS, MagicFactor).getReg(0);

// (Optionally) Add/subtract the numerator using Factor.
Factor = MIB.buildMul(Ty, LHS, Factor).getReg(0);
Q = MIB.buildAdd(Ty, Q, Factor).getReg(0);

if (UseSRA)
Res = MIB.buildAShr(Ty, Res, Shift, MachineInstr::IsExact).getReg(0);
// Shift right algebraic by shift value.
Q = MIB.buildAShr(Ty, Q, Shift).getReg(0);

return MIB.buildMul(Ty, Res, Factor);
// Extract the sign bit, mask it and add it to the quotient.
auto SignShift = MIB.buildConstant(ShiftAmtTy, EltBits - 1);
auto T = MIB.buildLShr(Ty, Q, SignShift);
T = MIB.buildAnd(Ty, T, ShiftMask);
return MIB.buildAdd(Ty, Q, T);
}

bool CombinerHelper::matchDivByPow2(MachineInstr &MI, bool IsSigned) const {
Expand Down
11 changes: 8 additions & 3 deletions llvm/test/CodeGen/AArch64/GlobalISel/combine-sdiv.mir
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,14 @@ body: |
; CHECK: liveins: $w0
; CHECK-NEXT: {{ $}}
; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s32) = COPY $w0
; CHECK-NEXT: [[C:%[0-9]+]]:_(s32) = G_CONSTANT i32 104
; CHECK-NEXT: [[SDIV:%[0-9]+]]:_(s32) = G_SDIV [[COPY]], [[C]]
; CHECK-NEXT: $w0 = COPY [[SDIV]](s32)
; CHECK-NEXT: [[C:%[0-9]+]]:_(s32) = G_CONSTANT i32 1321528399
; CHECK-NEXT: [[C1:%[0-9]+]]:_(s32) = G_CONSTANT i32 5
; CHECK-NEXT: [[SMULH:%[0-9]+]]:_(s32) = G_SMULH [[COPY]], [[C]]
; CHECK-NEXT: [[ASHR:%[0-9]+]]:_(s32) = G_ASHR [[SMULH]], [[C1]](s32)
; CHECK-NEXT: [[C2:%[0-9]+]]:_(s32) = G_CONSTANT i32 31
; CHECK-NEXT: [[LSHR:%[0-9]+]]:_(s32) = G_LSHR [[ASHR]], [[C2]](s32)
; CHECK-NEXT: [[ADD:%[0-9]+]]:_(s32) = G_ADD [[ASHR]], [[LSHR]]
; CHECK-NEXT: $w0 = COPY [[ADD]](s32)
; CHECK-NEXT: RET_ReallyLR implicit $w0
%0:_(s32) = COPY $w0
%1:_(s32) = G_CONSTANT i32 104
Expand Down
114 changes: 24 additions & 90 deletions llvm/test/CodeGen/AArch64/arm64-neon-mul-div-cte.ll
Original file line number Diff line number Diff line change
Expand Up @@ -15,56 +15,13 @@ define <16 x i8> @div16xi8(<16 x i8> %x) {
;
; CHECK-GI-LABEL: div16xi8:
; CHECK-GI: // %bb.0:
; CHECK-GI-NEXT: smov w9, v0.b[0]
; CHECK-GI-NEXT: mov w8, #25 // =0x19
; CHECK-GI-NEXT: smov w10, v0.b[1]
; CHECK-GI-NEXT: smov w11, v0.b[2]
; CHECK-GI-NEXT: smov w12, v0.b[3]
; CHECK-GI-NEXT: smov w13, v0.b[4]
; CHECK-GI-NEXT: smov w14, v0.b[5]
; CHECK-GI-NEXT: smov w15, v0.b[6]
; CHECK-GI-NEXT: smov w16, v0.b[7]
; CHECK-GI-NEXT: smov w17, v0.b[8]
; CHECK-GI-NEXT: smov w18, v0.b[9]
; CHECK-GI-NEXT: sdiv w9, w9, w8
; CHECK-GI-NEXT: sdiv w10, w10, w8
; CHECK-GI-NEXT: fmov s1, w9
; CHECK-GI-NEXT: sdiv w11, w11, w8
; CHECK-GI-NEXT: mov v1.b[1], w10
; CHECK-GI-NEXT: smov w10, v0.b[10]
; CHECK-GI-NEXT: sdiv w12, w12, w8
; CHECK-GI-NEXT: mov v1.b[2], w11
; CHECK-GI-NEXT: smov w11, v0.b[11]
; CHECK-GI-NEXT: sdiv w13, w13, w8
; CHECK-GI-NEXT: mov v1.b[3], w12
; CHECK-GI-NEXT: smov w12, v0.b[12]
; CHECK-GI-NEXT: sdiv w14, w14, w8
; CHECK-GI-NEXT: mov v1.b[4], w13
; CHECK-GI-NEXT: smov w13, v0.b[13]
; CHECK-GI-NEXT: sdiv w15, w15, w8
; CHECK-GI-NEXT: mov v1.b[5], w14
; CHECK-GI-NEXT: sdiv w16, w16, w8
; CHECK-GI-NEXT: mov v1.b[6], w15
; CHECK-GI-NEXT: sdiv w17, w17, w8
; CHECK-GI-NEXT: mov v1.b[7], w16
; CHECK-GI-NEXT: sdiv w9, w18, w8
; CHECK-GI-NEXT: mov v1.b[8], w17
; CHECK-GI-NEXT: sdiv w10, w10, w8
; CHECK-GI-NEXT: mov v1.b[9], w9
; CHECK-GI-NEXT: smov w9, v0.b[14]
; CHECK-GI-NEXT: sdiv w11, w11, w8
; CHECK-GI-NEXT: mov v1.b[10], w10
; CHECK-GI-NEXT: smov w10, v0.b[15]
; CHECK-GI-NEXT: sdiv w12, w12, w8
; CHECK-GI-NEXT: mov v1.b[11], w11
; CHECK-GI-NEXT: sdiv w13, w13, w8
; CHECK-GI-NEXT: mov v1.b[12], w12
; CHECK-GI-NEXT: sdiv w9, w9, w8
; CHECK-GI-NEXT: mov v1.b[13], w13
; CHECK-GI-NEXT: sdiv w8, w10, w8
; CHECK-GI-NEXT: mov v1.b[14], w9
; CHECK-GI-NEXT: mov v1.b[15], w8
; CHECK-GI-NEXT: mov v0.16b, v1.16b
; CHECK-GI-NEXT: movi v1.16b, #41
; CHECK-GI-NEXT: smull2 v2.8h, v0.16b, v1.16b
; CHECK-GI-NEXT: smull v0.8h, v0.8b, v1.8b
; CHECK-GI-NEXT: uzp2 v1.16b, v0.16b, v2.16b
; CHECK-GI-NEXT: sshr v0.16b, v1.16b, #2
; CHECK-GI-NEXT: ushr v0.16b, v0.16b, #7
; CHECK-GI-NEXT: ssra v0.16b, v1.16b, #2
; CHECK-GI-NEXT: ret
%div = sdiv <16 x i8> %x, <i8 25, i8 25, i8 25, i8 25, i8 25, i8 25, i8 25, i8 25, i8 25, i8 25, i8 25, i8 25, i8 25, i8 25, i8 25, i8 25>
ret <16 x i8> %div
Expand All @@ -85,32 +42,15 @@ define <8 x i16> @div8xi16(<8 x i16> %x) {
;
; CHECK-GI-LABEL: div8xi16:
; CHECK-GI: // %bb.0:
; CHECK-GI-NEXT: smov w9, v0.h[0]
; CHECK-GI-NEXT: mov w8, #6577 // =0x19b1
; CHECK-GI-NEXT: smov w10, v0.h[1]
; CHECK-GI-NEXT: smov w11, v0.h[2]
; CHECK-GI-NEXT: smov w12, v0.h[3]
; CHECK-GI-NEXT: smov w13, v0.h[4]
; CHECK-GI-NEXT: smov w14, v0.h[5]
; CHECK-GI-NEXT: sdiv w9, w9, w8
; CHECK-GI-NEXT: sdiv w10, w10, w8
; CHECK-GI-NEXT: fmov s1, w9
; CHECK-GI-NEXT: sdiv w11, w11, w8
; CHECK-GI-NEXT: mov v1.h[1], w10
; CHECK-GI-NEXT: smov w10, v0.h[6]
; CHECK-GI-NEXT: sdiv w12, w12, w8
; CHECK-GI-NEXT: mov v1.h[2], w11
; CHECK-GI-NEXT: smov w11, v0.h[7]
; CHECK-GI-NEXT: sdiv w13, w13, w8
; CHECK-GI-NEXT: mov v1.h[3], w12
; CHECK-GI-NEXT: sdiv w9, w14, w8
; CHECK-GI-NEXT: mov v1.h[4], w13
; CHECK-GI-NEXT: sdiv w10, w10, w8
; CHECK-GI-NEXT: mov v1.h[5], w9
; CHECK-GI-NEXT: sdiv w8, w11, w8
; CHECK-GI-NEXT: mov v1.h[6], w10
; CHECK-GI-NEXT: mov v1.h[7], w8
; CHECK-GI-NEXT: mov v0.16b, v1.16b
; CHECK-GI-NEXT: adrp x8, .LCPI1_0
; CHECK-GI-NEXT: ldr q1, [x8, :lo12:.LCPI1_0]
; CHECK-GI-NEXT: smull2 v2.4s, v0.8h, v1.8h
; CHECK-GI-NEXT: smull v1.4s, v0.4h, v1.4h
; CHECK-GI-NEXT: uzp2 v1.8h, v1.8h, v2.8h
; CHECK-GI-NEXT: add v1.8h, v1.8h, v0.8h
; CHECK-GI-NEXT: sshr v0.8h, v1.8h, #12
; CHECK-GI-NEXT: ushr v0.8h, v0.8h, #15
; CHECK-GI-NEXT: ssra v0.8h, v1.8h, #12
; CHECK-GI-NEXT: ret
%div = sdiv <8 x i16> %x, <i16 6577, i16 6577, i16 6577, i16 6577, i16 6577, i16 6577, i16 6577, i16 6577>
ret <8 x i16> %div
Expand All @@ -131,20 +71,14 @@ define <4 x i32> @div32xi4(<4 x i32> %x) {
;
; CHECK-GI-LABEL: div32xi4:
; CHECK-GI: // %bb.0:
; CHECK-GI-NEXT: fmov w9, s0
; CHECK-GI-NEXT: mov w8, #39957 // =0x9c15
; CHECK-GI-NEXT: mov w10, v0.s[1]
; CHECK-GI-NEXT: movk w8, #145, lsl #16
; CHECK-GI-NEXT: mov w11, v0.s[2]
; CHECK-GI-NEXT: mov w12, v0.s[3]
; CHECK-GI-NEXT: sdiv w9, w9, w8
; CHECK-GI-NEXT: sdiv w10, w10, w8
; CHECK-GI-NEXT: fmov s0, w9
; CHECK-GI-NEXT: sdiv w11, w11, w8
; CHECK-GI-NEXT: mov v0.s[1], w10
; CHECK-GI-NEXT: sdiv w8, w12, w8
; CHECK-GI-NEXT: mov v0.s[2], w11
; CHECK-GI-NEXT: mov v0.s[3], w8
; CHECK-GI-NEXT: adrp x8, .LCPI2_0
; CHECK-GI-NEXT: ldr q1, [x8, :lo12:.LCPI2_0]
; CHECK-GI-NEXT: smull2 v2.2d, v0.4s, v1.4s
; CHECK-GI-NEXT: smull v0.2d, v0.2s, v1.2s
; CHECK-GI-NEXT: uzp2 v1.4s, v0.4s, v2.4s
; CHECK-GI-NEXT: sshr v0.4s, v1.4s, #22
; CHECK-GI-NEXT: ushr v0.4s, v0.4s, #31
; CHECK-GI-NEXT: ssra v0.4s, v1.4s, #22
; CHECK-GI-NEXT: ret
%div = sdiv <4 x i32> %x, <i32 9542677, i32 9542677, i32 9542677, i32 9542677>
ret <4 x i32> %div
Expand Down
Loading
Loading