Skip to content

Commit 473bc0d

Browse files
changpengrampitec
andauthored
[AMDGPU] Support V_FMA_MIX*_BF16 instructions on gfx1250 (llvm#150381)
Co-authored-by: Stanislav Mekhanoshin <Stanislav.Mekhanoshin@amd.com>
1 parent a4d4859 commit 473bc0d

File tree

12 files changed

+1874
-91
lines changed

12 files changed

+1874
-91
lines changed

llvm/lib/Target/AMDGPU/AMDGPU.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,12 @@ def FeatureFmaMixInsts : SubtargetFeature<"fma-mix-insts",
149149
"Has v_fma_mix_f32, v_fma_mixlo_f16, v_fma_mixhi_f16 instructions"
150150
>;
151151

152+
def FeatureFmaMixBF16Insts : SubtargetFeature<"fma-mix-bf16-insts",
153+
"HasFmaMixBF16Insts",
154+
"true",
155+
"Has v_fma_mix_f32_bf16, v_fma_mixlo_bf16, v_fma_mixhi_bf16 instructions"
156+
>;
157+
152158
def FeatureIEEEMinimumMaximumInsts : SubtargetFeature<"ieee-minimum-maximum-insts",
153159
"HasIEEEMinimumMaximumInsts",
154160
"true",
@@ -2007,6 +2013,7 @@ def FeatureISAVersion12_50 : FeatureSet<
20072013
FeatureBF16ConversionInsts,
20082014
FeatureBF16PackedInsts,
20092015
FeatureCvtPkF16F32Inst,
2016+
FeatureFmaMixBF16Insts,
20102017
FeatureMin3Max3PKF16,
20112018
FeatureMinimum3Maximum3PKF16,
20122019
FeaturePrngInst,
@@ -2599,6 +2606,9 @@ def HasMovrel : Predicate<"Subtarget->hasMovrel()">,
25992606
def HasFmaMixInsts : Predicate<"Subtarget->hasFmaMixInsts()">,
26002607
AssemblerPredicate<(all_of FeatureFmaMixInsts)>;
26012608

2609+
def HasFmaMixBF16Insts : Predicate<"Subtarget->hasFmaMixBF16Insts()">,
2610+
AssemblerPredicate<(all_of FeatureFmaMixBF16Insts)>;
2611+
26022612
def HasDLInsts : Predicate<"Subtarget->hasDLInsts()">,
26032613
AssemblerPredicate<(all_of FeatureDLInsts)>;
26042614

llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp

Lines changed: 104 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3861,58 +3861,114 @@ bool AMDGPUDAGToDAGISel::SelectVOP3OpSelMods(SDValue In, SDValue &Src,
38613861
return SelectVOP3Mods(In, Src, SrcMods);
38623862
}
38633863

3864+
// Match lowered fpext from bf16 to f32. This is a bit operation extending
3865+
// a 16-bit value with 16-bit of zeroes at LSB:
3866+
//
3867+
// 1. (f32 (bitcast (build_vector (i16 0), (i16 (bitcast bf16:val)))))
3868+
// 2. (f32 (bitcast (and i32:val, 0xffff0000))) -> IsExtractHigh = true
3869+
// 3. (f32 (bitcast (shl i32:va, 16) -> IsExtractHigh = false
3870+
static SDValue matchBF16FPExtendLike(SDValue Op, bool &IsExtractHigh) {
3871+
if (Op.getValueType() != MVT::f32 || Op.getOpcode() != ISD::BITCAST)
3872+
return SDValue();
3873+
Op = Op.getOperand(0);
3874+
3875+
IsExtractHigh = false;
3876+
if (Op.getValueType() == MVT::v2i16 && Op.getOpcode() == ISD::BUILD_VECTOR) {
3877+
auto Low16 = dyn_cast<ConstantSDNode>(Op.getOperand(0));
3878+
if (!Low16 || !Low16->isZero())
3879+
return SDValue();
3880+
Op = stripBitcast(Op.getOperand(1));
3881+
if (Op.getValueType() != MVT::bf16)
3882+
return SDValue();
3883+
return Op;
3884+
}
3885+
3886+
if (Op.getValueType() != MVT::i32)
3887+
return SDValue();
3888+
3889+
if (Op.getOpcode() == ISD::AND) {
3890+
if (auto Mask = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
3891+
if (Mask->getZExtValue() == 0xffff0000) {
3892+
IsExtractHigh = true;
3893+
return Op.getOperand(0);
3894+
}
3895+
}
3896+
return SDValue();
3897+
}
3898+
3899+
if (Op.getOpcode() == ISD::SHL) {
3900+
if (auto Amt = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
3901+
if (Amt->getZExtValue() == 16)
3902+
return Op.getOperand(0);
3903+
}
3904+
}
3905+
3906+
return SDValue();
3907+
}
3908+
38643909
// The return value is not whether the match is possible (which it always is),
38653910
// but whether or not it a conversion is really used.
38663911
bool AMDGPUDAGToDAGISel::SelectVOP3PMadMixModsImpl(SDValue In, SDValue &Src,
3867-
unsigned &Mods) const {
3912+
unsigned &Mods,
3913+
MVT VT) const {
38683914
Mods = 0;
38693915
SelectVOP3ModsImpl(In, Src, Mods);
38703916

3917+
bool IsExtractHigh = false;
38713918
if (Src.getOpcode() == ISD::FP_EXTEND) {
38723919
Src = Src.getOperand(0);
3873-
assert(Src.getValueType() == MVT::f16);
3874-
Src = stripBitcast(Src);
3920+
} else if (VT == MVT::bf16) {
3921+
SDValue B16 = matchBF16FPExtendLike(Src, IsExtractHigh);
3922+
if (!B16)
3923+
return false;
3924+
Src = B16;
3925+
} else
3926+
return false;
38753927

3876-
// Be careful about folding modifiers if we already have an abs. fneg is
3877-
// applied last, so we don't want to apply an earlier fneg.
3878-
if ((Mods & SISrcMods::ABS) == 0) {
3879-
unsigned ModsTmp;
3880-
SelectVOP3ModsImpl(Src, Src, ModsTmp);
3928+
if (Src.getValueType() != VT &&
3929+
(VT != MVT::bf16 || Src.getValueType() != MVT::i32))
3930+
return false;
38813931

3882-
if ((ModsTmp & SISrcMods::NEG) != 0)
3883-
Mods ^= SISrcMods::NEG;
3932+
Src = stripBitcast(Src);
38843933

3885-
if ((ModsTmp & SISrcMods::ABS) != 0)
3886-
Mods |= SISrcMods::ABS;
3887-
}
3934+
// Be careful about folding modifiers if we already have an abs. fneg is
3935+
// applied last, so we don't want to apply an earlier fneg.
3936+
if ((Mods & SISrcMods::ABS) == 0) {
3937+
unsigned ModsTmp;
3938+
SelectVOP3ModsImpl(Src, Src, ModsTmp);
3939+
3940+
if ((ModsTmp & SISrcMods::NEG) != 0)
3941+
Mods ^= SISrcMods::NEG;
38883942

3889-
// op_sel/op_sel_hi decide the source type and source.
3890-
// If the source's op_sel_hi is set, it indicates to do a conversion from fp16.
3891-
// If the sources's op_sel is set, it picks the high half of the source
3892-
// register.
3943+
if ((ModsTmp & SISrcMods::ABS) != 0)
3944+
Mods |= SISrcMods::ABS;
3945+
}
38933946

3894-
Mods |= SISrcMods::OP_SEL_1;
3895-
if (isExtractHiElt(Src, Src)) {
3896-
Mods |= SISrcMods::OP_SEL_0;
3947+
// op_sel/op_sel_hi decide the source type and source.
3948+
// If the source's op_sel_hi is set, it indicates to do a conversion from
3949+
// fp16. If the sources's op_sel is set, it picks the high half of the source
3950+
// register.
38973951

3898-
// TODO: Should we try to look for neg/abs here?
3899-
}
3952+
Mods |= SISrcMods::OP_SEL_1;
3953+
if (IsExtractHigh ||
3954+
(Src.getValueSizeInBits() == 16 && isExtractHiElt(Src, Src))) {
3955+
Mods |= SISrcMods::OP_SEL_0;
39003956

3901-
// Prevent unnecessary subreg COPY to VGPR_16
3902-
if (Src.getOpcode() == ISD::TRUNCATE &&
3903-
Src.getOperand(0).getValueType() == MVT::i32) {
3904-
Src = Src.getOperand(0);
3905-
}
3906-
return true;
3957+
// TODO: Should we try to look for neg/abs here?
39073958
}
39083959

3909-
return false;
3960+
// Prevent unnecessary subreg COPY to VGPR_16
3961+
if (Src.getOpcode() == ISD::TRUNCATE &&
3962+
Src.getOperand(0).getValueType() == MVT::i32) {
3963+
Src = Src.getOperand(0);
3964+
}
3965+
return true;
39103966
}
39113967

39123968
bool AMDGPUDAGToDAGISel::SelectVOP3PMadMixModsExt(SDValue In, SDValue &Src,
39133969
SDValue &SrcMods) const {
39143970
unsigned Mods = 0;
3915-
if (!SelectVOP3PMadMixModsImpl(In, Src, Mods))
3971+
if (!SelectVOP3PMadMixModsImpl(In, Src, Mods, MVT::f16))
39163972
return false;
39173973
SrcMods = CurDAG->getTargetConstant(Mods, SDLoc(In), MVT::i32);
39183974
return true;
@@ -3921,7 +3977,24 @@ bool AMDGPUDAGToDAGISel::SelectVOP3PMadMixModsExt(SDValue In, SDValue &Src,
39213977
bool AMDGPUDAGToDAGISel::SelectVOP3PMadMixMods(SDValue In, SDValue &Src,
39223978
SDValue &SrcMods) const {
39233979
unsigned Mods = 0;
3924-
SelectVOP3PMadMixModsImpl(In, Src, Mods);
3980+
SelectVOP3PMadMixModsImpl(In, Src, Mods, MVT::f16);
3981+
SrcMods = CurDAG->getTargetConstant(Mods, SDLoc(In), MVT::i32);
3982+
return true;
3983+
}
3984+
3985+
bool AMDGPUDAGToDAGISel::SelectVOP3PMadMixBF16ModsExt(SDValue In, SDValue &Src,
3986+
SDValue &SrcMods) const {
3987+
unsigned Mods = 0;
3988+
if (!SelectVOP3PMadMixModsImpl(In, Src, Mods, MVT::bf16))
3989+
return false;
3990+
SrcMods = CurDAG->getTargetConstant(Mods, SDLoc(In), MVT::i32);
3991+
return true;
3992+
}
3993+
3994+
bool AMDGPUDAGToDAGISel::SelectVOP3PMadMixBF16Mods(SDValue In, SDValue &Src,
3995+
SDValue &SrcMods) const {
3996+
unsigned Mods = 0;
3997+
SelectVOP3PMadMixModsImpl(In, Src, Mods, MVT::bf16);
39253998
SrcMods = CurDAG->getTargetConstant(Mods, SDLoc(In), MVT::i32);
39263999
return true;
39274000
}

llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,11 +254,15 @@ class AMDGPUDAGToDAGISel : public SelectionDAGISel {
254254
bool SelectVOP3OpSel(SDValue In, SDValue &Src, SDValue &SrcMods) const;
255255

256256
bool SelectVOP3OpSelMods(SDValue In, SDValue &Src, SDValue &SrcMods) const;
257-
bool SelectVOP3PMadMixModsImpl(SDValue In, SDValue &Src,
258-
unsigned &Mods) const;
257+
bool SelectVOP3PMadMixModsImpl(SDValue In, SDValue &Src, unsigned &Mods,
258+
MVT VT) const;
259259
bool SelectVOP3PMadMixModsExt(SDValue In, SDValue &Src,
260260
SDValue &SrcMods) const;
261261
bool SelectVOP3PMadMixMods(SDValue In, SDValue &Src, SDValue &SrcMods) const;
262+
bool SelectVOP3PMadMixBF16ModsExt(SDValue In, SDValue &Src,
263+
SDValue &SrcMods) const;
264+
bool SelectVOP3PMadMixBF16Mods(SDValue In, SDValue &Src,
265+
SDValue &SrcMods) const;
262266

263267
bool SelectBITOP3(SDValue In, SDValue &Src0, SDValue &Src1, SDValue &Src2,
264268
SDValue &Tbl) const;

llvm/lib/Target/AMDGPU/GCNSubtarget.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ class GCNSubtarget final : public AMDGPUGenSubtargetInfo,
123123
bool HasSMemRealTime = false;
124124
bool HasIntClamp = false;
125125
bool HasFmaMixInsts = false;
126+
bool HasFmaMixBF16Insts = false;
126127
bool HasMovrel = false;
127128
bool HasVGPRIndexMode = false;
128129
bool HasScalarDwordx3Loads = false;
@@ -462,6 +463,8 @@ class GCNSubtarget final : public AMDGPUGenSubtargetInfo,
462463
return HasFmaMixInsts;
463464
}
464465

466+
bool hasFmaMixBF16Insts() const { return HasFmaMixBF16Insts; }
467+
465468
bool hasCARRY() const {
466469
return true;
467470
}

llvm/lib/Target/AMDGPU/SIISelLowering.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1061,10 +1061,12 @@ ArrayRef<MCPhysReg> SITargetLowering::getRoundingControlRegisters() const {
10611061
// where this is OK to use.
10621062
bool SITargetLowering::isFPExtFoldable(const SelectionDAG &DAG, unsigned Opcode,
10631063
EVT DestVT, EVT SrcVT) const {
1064-
return ((Opcode == ISD::FMAD && Subtarget->hasMadMixInsts()) ||
1065-
(Opcode == ISD::FMA && Subtarget->hasFmaMixInsts())) &&
1066-
DestVT.getScalarType() == MVT::f32 &&
1067-
SrcVT.getScalarType() == MVT::f16 &&
1064+
return DestVT.getScalarType() == MVT::f32 &&
1065+
((((Opcode == ISD::FMAD && Subtarget->hasMadMixInsts()) ||
1066+
(Opcode == ISD::FMA && Subtarget->hasFmaMixInsts())) &&
1067+
SrcVT.getScalarType() == MVT::f16) ||
1068+
(Opcode == ISD::FMA && Subtarget->hasFmaMixBF16Insts() &&
1069+
SrcVT.getScalarType() == MVT::bf16)) &&
10681070
// TODO: This probably only requires no input flushing?
10691071
denormalModeIsFlushAllF32(DAG.getMachineFunction());
10701072
}

llvm/lib/Target/AMDGPU/SIInstrInfo.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1662,6 +1662,8 @@ def VOP3OpSelMods : ComplexPattern<untyped, 2, "SelectVOP3OpSelMods">;
16621662

16631663
def VOP3PMadMixModsExt : ComplexPattern<untyped, 2, "SelectVOP3PMadMixModsExt">;
16641664
def VOP3PMadMixMods : ComplexPattern<untyped, 2, "SelectVOP3PMadMixMods">;
1665+
def VOP3PMadMixBF16ModsExt : ComplexPattern<untyped, 2, "SelectVOP3PMadMixBF16ModsExt">;
1666+
def VOP3PMadMixBF16Mods : ComplexPattern<untyped, 2, "SelectVOP3PMadMixBF16Mods">;
16651667

16661668
def VINTERPMods : ComplexPattern<untyped, 2, "SelectVINTERPMods">;
16671669
def VINTERPModsHi : ComplexPattern<untyped, 2, "SelectVINTERPModsHi">;
@@ -2866,6 +2868,7 @@ def VOP_I16_I16_I16_ARITH : VOPProfile <[i16, i16, i16, untyped], /*EnableClamp=
28662868

28672869
def VOP_I16_I16_I16_I16 : VOPProfile <[i16, i16, i16, i16, untyped]>;
28682870
def VOP_F16_F16_F16_F16 : VOPProfile <[f16, f16, f16, f16, untyped]>;
2871+
def VOP_BF16_BF16_BF16_BF16 : VOPProfile <[bf16, bf16, bf16, bf16, untyped]>;
28692872

28702873
def VOP_I32_I16_I16_I32 : VOPProfile <[i32, i16, i16, i32, untyped]>;
28712874
def VOP_I32_I16 : VOPProfile <[i32, i16, untyped, untyped]>;
@@ -2917,6 +2920,7 @@ def VOP_I32_I32_I32_ARITH : VOPProfile <[i32, i32, i32, untyped], /*EnableClamp=
29172920
def VOP_I64_I64_I64_ARITH : VOPProfile <[i64, i64, i64, untyped], /*EnableClamp=*/1>;
29182921
def VOP_V2F16_F32_F32 : VOPProfile <[v2f16, f32, f32, untyped]>;
29192922
def VOP_F32_F16_F16_F16 : VOPProfile <[f32, f16, f16, f16]>;
2923+
def VOP_F32_BF16_BF16_BF16 : VOPProfile <[f32, bf16, bf16, bf16]>;
29202924
def VOP_V2BF16_F32_F32 : VOPProfile <[v2bf16, f32, f32, untyped]>;
29212925
def VOP_V32F32_V6I32_F32 : VOPProfile <[v32f32, v6i32, f32, untyped]>;
29222926
def VOP_V32F16_V6I32_F32 : VOPProfile <[v32f16, v6i32, f32, untyped]>;

0 commit comments

Comments
 (0)