Skip to content

Commit 5238395

Browse files
authored
[DAG] Replace DAGCombiner::ConstantFoldBITCASTofBUILD_VECTOR with SelectionDAG::FoldConstantBuildVector (#147037)
DAGCombiner can already constant fold build vectors of constants/undefs to a new vector type, but it has to be incredibly careful after legalization to not affect a target's canonicalized constants. This patch proposes we move the implementation inside SelectionDAG to make it easier for targets to manually use the constant folding whenever it deems it safe to do so. I've also altered the method to take the BuildVectorSDNode input directly and consistently use the same SDLoc.
1 parent 465f2b0 commit 5238395

File tree

4 files changed

+84
-85
lines changed

4 files changed

+84
-85
lines changed

llvm/include/llvm/CodeGen/SelectionDAG.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2030,6 +2030,11 @@ class SelectionDAG {
20302030
LLVM_ABI SDValue foldConstantFPMath(unsigned Opcode, const SDLoc &DL, EVT VT,
20312031
ArrayRef<SDValue> Ops);
20322032

2033+
/// Fold BUILD_VECTOR of constants/undefs to the destination type
2034+
/// BUILD_VECTOR of constants/undefs elements.
2035+
LLVM_ABI SDValue FoldConstantBuildVector(BuildVectorSDNode *BV,
2036+
const SDLoc &DL, EVT DstEltVT);
2037+
20332038
/// Constant fold a setcc to true or false.
20342039
LLVM_ABI SDValue FoldSetCC(EVT VT, SDValue N1, SDValue N2, ISD::CondCode Cond,
20352040
const SDLoc &dl);

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 2 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -638,7 +638,6 @@ namespace {
638638
SDValue mergeInsertEltWithShuffle(SDNode *N, unsigned InsIndex);
639639
SDValue combineInsertEltToShuffle(SDNode *N, unsigned InsIndex);
640640
SDValue combineInsertEltToLoad(SDNode *N, unsigned InsIndex);
641-
SDValue ConstantFoldBITCASTofBUILD_VECTOR(SDNode *, EVT);
642641
SDValue BuildSDIV(SDNode *N);
643642
SDValue BuildSDIVPow2(SDNode *N);
644643
SDValue BuildUDIV(SDNode *N);
@@ -16431,8 +16430,8 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) {
1643116430
TLI.isTypeLegal(VT.getVectorElementType()))) &&
1643216431
N0.getOpcode() == ISD::BUILD_VECTOR && N0->hasOneUse() &&
1643316432
cast<BuildVectorSDNode>(N0)->isConstant())
16434-
return ConstantFoldBITCASTofBUILD_VECTOR(N0.getNode(),
16435-
VT.getVectorElementType());
16433+
return DAG.FoldConstantBuildVector(cast<BuildVectorSDNode>(N0), SDLoc(N),
16434+
VT.getVectorElementType());
1643616435

1643716436
// If the input is a constant, let getNode fold it.
1643816437
if (isIntOrFPConstant(N0)) {
@@ -16825,83 +16824,6 @@ SDValue DAGCombiner::visitFREEZE(SDNode *N) {
1682516824
return DAG.getNode(N0.getOpcode(), DL, N0->getVTList(), Ops, SafeFlags);
1682616825
}
1682716826

16828-
/// We know that BV is a build_vector node with Constant, ConstantFP or Undef
16829-
/// operands. DstEltVT indicates the destination element value type.
16830-
SDValue DAGCombiner::
16831-
ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) {
16832-
EVT SrcEltVT = BV->getValueType(0).getVectorElementType();
16833-
16834-
// If this is already the right type, we're done.
16835-
if (SrcEltVT == DstEltVT) return SDValue(BV, 0);
16836-
16837-
unsigned SrcBitSize = SrcEltVT.getSizeInBits();
16838-
unsigned DstBitSize = DstEltVT.getSizeInBits();
16839-
16840-
// If this is a conversion of N elements of one type to N elements of another
16841-
// type, convert each element. This handles FP<->INT cases.
16842-
if (SrcBitSize == DstBitSize) {
16843-
SmallVector<SDValue, 8> Ops;
16844-
for (SDValue Op : BV->op_values()) {
16845-
// If the vector element type is not legal, the BUILD_VECTOR operands
16846-
// are promoted and implicitly truncated. Make that explicit here.
16847-
if (Op.getValueType() != SrcEltVT)
16848-
Op = DAG.getNode(ISD::TRUNCATE, SDLoc(BV), SrcEltVT, Op);
16849-
Ops.push_back(DAG.getBitcast(DstEltVT, Op));
16850-
AddToWorklist(Ops.back().getNode());
16851-
}
16852-
EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT,
16853-
BV->getValueType(0).getVectorNumElements());
16854-
return DAG.getBuildVector(VT, SDLoc(BV), Ops);
16855-
}
16856-
16857-
// Otherwise, we're growing or shrinking the elements. To avoid having to
16858-
// handle annoying details of growing/shrinking FP values, we convert them to
16859-
// int first.
16860-
if (SrcEltVT.isFloatingPoint()) {
16861-
// Convert the input float vector to a int vector where the elements are the
16862-
// same sizes.
16863-
EVT IntVT = EVT::getIntegerVT(*DAG.getContext(), SrcEltVT.getSizeInBits());
16864-
BV = ConstantFoldBITCASTofBUILD_VECTOR(BV, IntVT).getNode();
16865-
SrcEltVT = IntVT;
16866-
}
16867-
16868-
// Now we know the input is an integer vector. If the output is a FP type,
16869-
// convert to integer first, then to FP of the right size.
16870-
if (DstEltVT.isFloatingPoint()) {
16871-
EVT TmpVT = EVT::getIntegerVT(*DAG.getContext(), DstEltVT.getSizeInBits());
16872-
SDNode *Tmp = ConstantFoldBITCASTofBUILD_VECTOR(BV, TmpVT).getNode();
16873-
16874-
// Next, convert to FP elements of the same size.
16875-
return ConstantFoldBITCASTofBUILD_VECTOR(Tmp, DstEltVT);
16876-
}
16877-
16878-
// Okay, we know the src/dst types are both integers of differing types.
16879-
assert(SrcEltVT.isInteger() && DstEltVT.isInteger());
16880-
16881-
// TODO: Should ConstantFoldBITCASTofBUILD_VECTOR always take a
16882-
// BuildVectorSDNode?
16883-
auto *BVN = cast<BuildVectorSDNode>(BV);
16884-
16885-
// Extract the constant raw bit data.
16886-
BitVector UndefElements;
16887-
SmallVector<APInt> RawBits;
16888-
bool IsLE = DAG.getDataLayout().isLittleEndian();
16889-
if (!BVN->getConstantRawBits(IsLE, DstBitSize, RawBits, UndefElements))
16890-
return SDValue();
16891-
16892-
SDLoc DL(BV);
16893-
SmallVector<SDValue, 8> Ops;
16894-
for (unsigned I = 0, E = RawBits.size(); I != E; ++I) {
16895-
if (UndefElements[I])
16896-
Ops.push_back(DAG.getUNDEF(DstEltVT));
16897-
else
16898-
Ops.push_back(DAG.getConstant(RawBits[I], DL, DstEltVT));
16899-
}
16900-
16901-
EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT, Ops.size());
16902-
return DAG.getBuildVector(VT, DL, Ops);
16903-
}
16904-
1690516827
// Returns true if floating point contraction is allowed on the FMUL-SDValue
1690616828
// `N`
1690716829
static bool isContractableFMUL(const TargetOptions &Options, SDValue N) {

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7280,6 +7280,78 @@ SDValue SelectionDAG::foldConstantFPMath(unsigned Opcode, const SDLoc &DL,
72807280
return SDValue();
72817281
}
72827282

7283+
SDValue SelectionDAG::FoldConstantBuildVector(BuildVectorSDNode *BV,
7284+
const SDLoc &DL, EVT DstEltVT) {
7285+
EVT SrcEltVT = BV->getValueType(0).getVectorElementType();
7286+
7287+
// If this is already the right type, we're done.
7288+
if (SrcEltVT == DstEltVT)
7289+
return SDValue(BV, 0);
7290+
7291+
unsigned SrcBitSize = SrcEltVT.getSizeInBits();
7292+
unsigned DstBitSize = DstEltVT.getSizeInBits();
7293+
7294+
// If this is a conversion of N elements of one type to N elements of another
7295+
// type, convert each element. This handles FP<->INT cases.
7296+
if (SrcBitSize == DstBitSize) {
7297+
SmallVector<SDValue, 8> Ops;
7298+
for (SDValue Op : BV->op_values()) {
7299+
// If the vector element type is not legal, the BUILD_VECTOR operands
7300+
// are promoted and implicitly truncated. Make that explicit here.
7301+
if (Op.getValueType() != SrcEltVT)
7302+
Op = getNode(ISD::TRUNCATE, DL, SrcEltVT, Op);
7303+
Ops.push_back(getBitcast(DstEltVT, Op));
7304+
}
7305+
EVT VT = EVT::getVectorVT(*getContext(), DstEltVT,
7306+
BV->getValueType(0).getVectorNumElements());
7307+
return getBuildVector(VT, DL, Ops);
7308+
}
7309+
7310+
// Otherwise, we're growing or shrinking the elements. To avoid having to
7311+
// handle annoying details of growing/shrinking FP values, we convert them to
7312+
// int first.
7313+
if (SrcEltVT.isFloatingPoint()) {
7314+
// Convert the input float vector to a int vector where the elements are the
7315+
// same sizes.
7316+
EVT IntEltVT = EVT::getIntegerVT(*getContext(), SrcEltVT.getSizeInBits());
7317+
if (SDValue Tmp = FoldConstantBuildVector(BV, DL, IntEltVT))
7318+
return FoldConstantBuildVector(cast<BuildVectorSDNode>(Tmp), DL,
7319+
DstEltVT);
7320+
return SDValue();
7321+
}
7322+
7323+
// Now we know the input is an integer vector. If the output is a FP type,
7324+
// convert to integer first, then to FP of the right size.
7325+
if (DstEltVT.isFloatingPoint()) {
7326+
EVT IntEltVT = EVT::getIntegerVT(*getContext(), DstEltVT.getSizeInBits());
7327+
if (SDValue Tmp = FoldConstantBuildVector(BV, DL, IntEltVT))
7328+
return FoldConstantBuildVector(cast<BuildVectorSDNode>(Tmp), DL,
7329+
DstEltVT);
7330+
return SDValue();
7331+
}
7332+
7333+
// Okay, we know the src/dst types are both integers of differing types.
7334+
assert(SrcEltVT.isInteger() && DstEltVT.isInteger());
7335+
7336+
// Extract the constant raw bit data.
7337+
BitVector UndefElements;
7338+
SmallVector<APInt> RawBits;
7339+
bool IsLE = getDataLayout().isLittleEndian();
7340+
if (!BV->getConstantRawBits(IsLE, DstBitSize, RawBits, UndefElements))
7341+
return SDValue();
7342+
7343+
SmallVector<SDValue, 8> Ops;
7344+
for (unsigned I = 0, E = RawBits.size(); I != E; ++I) {
7345+
if (UndefElements[I])
7346+
Ops.push_back(getUNDEF(DstEltVT));
7347+
else
7348+
Ops.push_back(getConstant(RawBits[I], DL, DstEltVT));
7349+
}
7350+
7351+
EVT VT = EVT::getVectorVT(*getContext(), DstEltVT, Ops.size());
7352+
return getBuildVector(VT, DL, Ops);
7353+
}
7354+
72837355
SDValue SelectionDAG::getAssertAlign(const SDLoc &DL, SDValue Val, Align A) {
72847356
assert(Val.getValueType().isInteger() && "Invalid AssertAlign!");
72857357

llvm/test/CodeGen/AMDGPU/llvm.amdgcn.iglp.AFLCustomIRMutator.opt.ll

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,19 @@ define amdgpu_kernel void @test_iglp_opt_rev_mfma_gemm(<1 x i64> %L1) {
55
; GCN-LABEL: test_iglp_opt_rev_mfma_gemm:
66
; GCN: ; %bb.0: ; %entry
77
; GCN-NEXT: v_mov_b32_e32 v32, 0
8+
; GCN-NEXT: ds_read_b128 v[0:3], v32
89
; GCN-NEXT: s_load_dwordx2 s[0:1], s[8:9], 0x0
910
; GCN-NEXT: ds_read_b128 v[28:31], v32 offset:112
1011
; GCN-NEXT: ds_read_b128 v[24:27], v32 offset:96
1112
; GCN-NEXT: ds_read_b128 v[20:23], v32 offset:80
1213
; GCN-NEXT: ds_read_b128 v[16:19], v32 offset:64
13-
; GCN-NEXT: ds_read_b128 v[0:3], v32
1414
; GCN-NEXT: ds_read_b128 v[4:7], v32 offset:16
1515
; GCN-NEXT: ds_read_b128 v[8:11], v32 offset:32
1616
; GCN-NEXT: ds_read_b128 v[12:15], v32 offset:48
17-
; GCN-NEXT: v_mov_b32_e32 v34, 0
18-
; GCN-NEXT: v_mov_b32_e32 v35, v34
1917
; GCN-NEXT: s_waitcnt lgkmcnt(0)
18+
; GCN-NEXT: ds_write_b128 v32, v[0:3]
19+
; GCN-NEXT: v_mov_b32_e32 v0, 0
20+
; GCN-NEXT: v_mov_b32_e32 v1, v0
2021
; GCN-NEXT: s_cmp_lg_u64 s[0:1], 0
2122
; GCN-NEXT: ; iglp_opt mask(0x00000001)
2223
; GCN-NEXT: ds_write_b128 v32, v[28:31] offset:112
@@ -26,8 +27,7 @@ define amdgpu_kernel void @test_iglp_opt_rev_mfma_gemm(<1 x i64> %L1) {
2627
; GCN-NEXT: ds_write_b128 v32, v[12:15] offset:48
2728
; GCN-NEXT: ds_write_b128 v32, v[8:11] offset:32
2829
; GCN-NEXT: ds_write_b128 v32, v[4:7] offset:16
29-
; GCN-NEXT: ds_write_b128 v32, v[0:3]
30-
; GCN-NEXT: ds_write_b64 v32, v[34:35]
30+
; GCN-NEXT: ds_write_b64 v32, v[0:1]
3131
; GCN-NEXT: s_endpgm
3232
entry:
3333
call void @llvm.amdgcn.iglp.opt(i32 1)

0 commit comments

Comments
 (0)