From 1d9dabbc528285746067cdb214056e023a98ac3e Mon Sep 17 00:00:00 2001 From: Ivan Kosarev Date: Tue, 8 Jul 2025 14:09:55 +0100 Subject: [PATCH 1/2] [AMDGPU][SDAG] Use the f16 lowering for bf16 safe divisions. --- llvm/lib/Target/AMDGPU/SIISelLowering.cpp | 15 ++- llvm/test/CodeGen/AMDGPU/bf16.ll | 120 ++++++++++------------ 2 files changed, 64 insertions(+), 71 deletions(-) diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp index b083a9014737b..bd0bb38570a8f 100644 --- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp @@ -624,7 +624,7 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM, Expand); setOperationAction({ISD::FLDEXP, ISD::STRICT_FLDEXP}, MVT::f16, Custom); setOperationAction(ISD::FFREXP, MVT::f16, Custom); - setOperationAction(ISD::FDIV, MVT::f16, Custom); + setOperationAction(ISD::FDIV, {MVT::f16, MVT::bf16}, Custom); // F16 - VOP3 Actions. setOperationAction(ISD::FMA, MVT::f16, Legal); @@ -11229,6 +11229,7 @@ SDValue SITargetLowering::LowerFDIV16(SDValue Op, SelectionDAG &DAG) const { SDLoc SL(Op); SDValue LHS = Op.getOperand(0); SDValue RHS = Op.getOperand(1); + EVT VT = Op.getValueType(); // a32.u = opx(V_CVT_F32_F16, a.u); // CVT to F32 // b32.u = opx(V_CVT_F32_F16, b.u); // CVT to F32 @@ -11265,10 +11266,14 @@ SDValue SITargetLowering::LowerFDIV16(SDValue Op, SelectionDAG &DAG) const { DAG.getConstant(0xff800000, SL, MVT::i32)); Tmp = DAG.getNode(ISD::BITCAST, SL, MVT::f32, TmpCast); Quot = DAG.getNode(ISD::FADD, SL, MVT::f32, Tmp, Quot, Op->getFlags()); - SDValue RDst = DAG.getNode(ISD::FP_ROUND, SL, MVT::f16, Quot, + + EVT FixupVT = VT == MVT::bf16 ? MVT::f32 : VT; + SDValue RDst = DAG.getNode(ISD::FP_ROUND, SL, FixupVT, Quot, DAG.getTargetConstant(0, SL, MVT::i32)); - return DAG.getNode(AMDGPUISD::DIV_FIXUP, SL, MVT::f16, RDst, RHS, LHS, - Op->getFlags()); + SDValue Fixup = DAG.getNode(AMDGPUISD::DIV_FIXUP, SL, FixupVT, RDst, RHS, LHS, + Op->getFlags()); + return DAG.getNode(ISD::FP_ROUND, SL, VT, Fixup, + DAG.getTargetConstant(0, SL, MVT::i32)); } // Faster 2.5 ULP division that does not support denormals. @@ -11531,7 +11536,7 @@ SDValue SITargetLowering::LowerFDIV(SDValue Op, SelectionDAG &DAG) const { if (VT == MVT::f64) return LowerFDIV64(Op, DAG); - if (VT == MVT::f16) + if (VT == MVT::f16 || VT == MVT::bf16) return LowerFDIV16(Op, DAG); llvm_unreachable("Unexpected type for fdiv"); diff --git a/llvm/test/CodeGen/AMDGPU/bf16.ll b/llvm/test/CodeGen/AMDGPU/bf16.ll index 2bdf994496421..2724c16acbcc9 100644 --- a/llvm/test/CodeGen/AMDGPU/bf16.ll +++ b/llvm/test/CodeGen/AMDGPU/bf16.ll @@ -18494,18 +18494,16 @@ define bfloat @v_fdiv_bf16(bfloat %a, bfloat %b) { ; GFX8-LABEL: v_fdiv_bf16: ; GFX8: ; %bb.0: ; GFX8-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -; GFX8-NEXT: v_lshlrev_b32_e32 v0, 16, v0 -; GFX8-NEXT: v_lshlrev_b32_e32 v1, 16, v1 -; GFX8-NEXT: v_div_scale_f32 v2, s[4:5], v1, v1, v0 -; GFX8-NEXT: v_div_scale_f32 v3, vcc, v0, v1, v0 -; GFX8-NEXT: v_rcp_f32_e32 v4, v2 -; GFX8-NEXT: v_fma_f32 v5, -v2, v4, 1.0 -; GFX8-NEXT: v_fma_f32 v4, v5, v4, v4 -; GFX8-NEXT: v_mul_f32_e32 v5, v3, v4 -; GFX8-NEXT: v_fma_f32 v6, -v2, v5, v3 -; GFX8-NEXT: v_fma_f32 v5, v6, v4, v5 -; GFX8-NEXT: v_fma_f32 v2, -v2, v5, v3 -; GFX8-NEXT: v_div_fmas_f32 v2, v2, v4, v5 +; GFX8-NEXT: v_lshlrev_b32_e32 v2, 16, v1 +; GFX8-NEXT: v_rcp_f32_e32 v3, v2 +; GFX8-NEXT: v_lshlrev_b32_e32 v4, 16, v0 +; GFX8-NEXT: v_mul_f32_e32 v5, v4, v3 +; GFX8-NEXT: v_mad_f32 v6, -v2, v5, v4 +; GFX8-NEXT: v_mac_f32_e32 v5, v6, v3 +; GFX8-NEXT: v_mad_f32 v2, -v2, v5, v4 +; GFX8-NEXT: v_mul_f32_e32 v2, v2, v3 +; GFX8-NEXT: v_and_b32_e32 v2, 0xff800000, v2 +; GFX8-NEXT: v_add_f32_e32 v2, v2, v5 ; GFX8-NEXT: v_div_fixup_f32 v0, v2, v1, v0 ; GFX8-NEXT: v_bfe_u32 v1, v0, 16, 1 ; GFX8-NEXT: v_add_u32_e32 v1, vcc, v1, v0 @@ -18519,23 +18517,21 @@ define bfloat @v_fdiv_bf16(bfloat %a, bfloat %b) { ; GFX9-LABEL: v_fdiv_bf16: ; GFX9: ; %bb.0: ; GFX9-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -; GFX9-NEXT: v_lshlrev_b32_e32 v0, 16, v0 -; GFX9-NEXT: v_lshlrev_b32_e32 v1, 16, v1 -; GFX9-NEXT: v_div_scale_f32 v2, s[4:5], v1, v1, v0 -; GFX9-NEXT: v_div_scale_f32 v3, vcc, v0, v1, v0 +; GFX9-NEXT: v_lshlrev_b32_e32 v2, 16, v1 +; GFX9-NEXT: v_rcp_f32_e32 v3, v2 +; GFX9-NEXT: v_lshlrev_b32_e32 v4, 16, v0 ; GFX9-NEXT: s_movk_i32 s4, 0x7fff -; GFX9-NEXT: v_rcp_f32_e32 v4, v2 -; GFX9-NEXT: v_fma_f32 v5, -v2, v4, 1.0 -; GFX9-NEXT: v_fma_f32 v4, v5, v4, v4 -; GFX9-NEXT: v_mul_f32_e32 v5, v3, v4 -; GFX9-NEXT: v_fma_f32 v6, -v2, v5, v3 -; GFX9-NEXT: v_fma_f32 v5, v6, v4, v5 -; GFX9-NEXT: v_fma_f32 v2, -v2, v5, v3 -; GFX9-NEXT: v_div_fmas_f32 v2, v2, v4, v5 +; GFX9-NEXT: v_mul_f32_e32 v5, v4, v3 +; GFX9-NEXT: v_mad_f32 v6, -v2, v5, v4 +; GFX9-NEXT: v_mac_f32_e32 v5, v6, v3 +; GFX9-NEXT: v_mad_f32 v2, -v2, v5, v4 +; GFX9-NEXT: v_mul_f32_e32 v2, v2, v3 +; GFX9-NEXT: v_and_b32_e32 v2, 0xff800000, v2 +; GFX9-NEXT: v_add_f32_e32 v2, v2, v5 ; GFX9-NEXT: v_div_fixup_f32 v0, v2, v1, v0 ; GFX9-NEXT: v_bfe_u32 v1, v0, 16, 1 -; GFX9-NEXT: v_or_b32_e32 v2, 0x400000, v0 ; GFX9-NEXT: v_add3_u32 v1, v1, v0, s4 +; GFX9-NEXT: v_or_b32_e32 v2, 0x400000, v0 ; GFX9-NEXT: v_cmp_u_f32_e32 vcc, v0, v0 ; GFX9-NEXT: v_cndmask_b32_e32 v0, v1, v2, vcc ; GFX9-NEXT: v_lshrrev_b32_e32 v0, 16, v0 @@ -18544,18 +18540,16 @@ define bfloat @v_fdiv_bf16(bfloat %a, bfloat %b) { ; GFX10-LABEL: v_fdiv_bf16: ; GFX10: ; %bb.0: ; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -; GFX10-NEXT: v_lshlrev_b32_e32 v0, 16, v0 -; GFX10-NEXT: v_lshlrev_b32_e32 v1, 16, v1 -; GFX10-NEXT: v_div_scale_f32 v2, s4, v1, v1, v0 -; GFX10-NEXT: v_div_scale_f32 v5, vcc_lo, v0, v1, v0 +; GFX10-NEXT: v_lshlrev_b32_e32 v2, 16, v1 +; GFX10-NEXT: v_lshlrev_b32_e32 v4, 16, v0 ; GFX10-NEXT: v_rcp_f32_e32 v3, v2 -; GFX10-NEXT: v_fma_f32 v4, -v2, v3, 1.0 -; GFX10-NEXT: v_fmac_f32_e32 v3, v4, v3 -; GFX10-NEXT: v_mul_f32_e32 v4, v5, v3 -; GFX10-NEXT: v_fma_f32 v6, -v2, v4, v5 -; GFX10-NEXT: v_fmac_f32_e32 v4, v6, v3 -; GFX10-NEXT: v_fma_f32 v2, -v2, v4, v5 -; GFX10-NEXT: v_div_fmas_f32 v2, v2, v3, v4 +; GFX10-NEXT: v_mul_f32_e32 v5, v4, v3 +; GFX10-NEXT: v_mad_f32 v6, -v2, v5, v4 +; GFX10-NEXT: v_mac_f32_e32 v5, v6, v3 +; GFX10-NEXT: v_mad_f32 v2, -v2, v5, v4 +; GFX10-NEXT: v_mul_f32_e32 v2, v2, v3 +; GFX10-NEXT: v_and_b32_e32 v2, 0xff800000, v2 +; GFX10-NEXT: v_add_f32_e32 v2, v2, v5 ; GFX10-NEXT: v_div_fixup_f32 v0, v2, v1, v0 ; GFX10-NEXT: v_bfe_u32 v1, v0, 16, 1 ; GFX10-NEXT: v_or_b32_e32 v2, 0x400000, v0 @@ -18568,64 +18562,58 @@ define bfloat @v_fdiv_bf16(bfloat %a, bfloat %b) { ; GFX11TRUE16-LABEL: v_fdiv_bf16: ; GFX11TRUE16: ; %bb.0: ; GFX11TRUE16-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -; GFX11TRUE16-NEXT: v_lshlrev_b32_e32 v0, 16, v0 -; GFX11TRUE16-NEXT: v_lshlrev_b32_e32 v1, 16, v1 -; GFX11TRUE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1) -; GFX11TRUE16-NEXT: v_div_scale_f32 v2, null, v1, v1, v0 +; GFX11TRUE16-NEXT: v_lshlrev_b32_e32 v4, 16, v0 +; GFX11TRUE16-NEXT: v_lshlrev_b32_e32 v2, 16, v1 +; GFX11TRUE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_2) | instid1(VALU_DEP_1) ; GFX11TRUE16-NEXT: v_rcp_f32_e32 v3, v2 ; GFX11TRUE16-NEXT: s_waitcnt_depctr 0xfff -; GFX11TRUE16-NEXT: v_fma_f32 v4, -v2, v3, 1.0 -; GFX11TRUE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_1) | instid1(VALU_DEP_1) -; GFX11TRUE16-NEXT: v_fmac_f32_e32 v3, v4, v3 -; GFX11TRUE16-NEXT: v_div_scale_f32 v5, vcc_lo, v0, v1, v0 -; GFX11TRUE16-NEXT: v_mul_f32_e32 v4, v5, v3 +; GFX11TRUE16-NEXT: v_mul_f32_e32 v5, v4, v3 +; GFX11TRUE16-NEXT: v_fma_f32 v6, -v2, v5, v4 ; GFX11TRUE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1) -; GFX11TRUE16-NEXT: v_fma_f32 v6, -v2, v4, v5 -; GFX11TRUE16-NEXT: v_fmac_f32_e32 v4, v6, v3 +; GFX11TRUE16-NEXT: v_fmac_f32_e32 v5, v6, v3 +; GFX11TRUE16-NEXT: v_fma_f32 v2, -v2, v5, v4 ; GFX11TRUE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1) -; GFX11TRUE16-NEXT: v_fma_f32 v2, -v2, v4, v5 -; GFX11TRUE16-NEXT: v_div_fmas_f32 v2, v2, v3, v4 +; GFX11TRUE16-NEXT: v_mul_f32_e32 v2, v2, v3 +; GFX11TRUE16-NEXT: v_and_b32_e32 v2, 0xff800000, v2 ; GFX11TRUE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1) +; GFX11TRUE16-NEXT: v_add_f32_e32 v2, v2, v5 ; GFX11TRUE16-NEXT: v_div_fixup_f32 v0, v2, v1, v0 +; GFX11TRUE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_2) | instid1(VALU_DEP_3) ; GFX11TRUE16-NEXT: v_bfe_u32 v1, v0, 16, 1 ; GFX11TRUE16-NEXT: v_or_b32_e32 v2, 0x400000, v0 ; GFX11TRUE16-NEXT: v_cmp_u_f32_e32 vcc_lo, v0, v0 -; GFX11TRUE16-NEXT: s_delay_alu instid0(VALU_DEP_3) | instskip(NEXT) | instid1(VALU_DEP_1) ; GFX11TRUE16-NEXT: v_add3_u32 v1, v1, v0, 0x7fff +; GFX11TRUE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1) ; GFX11TRUE16-NEXT: v_cndmask_b32_e32 v0, v1, v2, vcc_lo -; GFX11TRUE16-NEXT: s_delay_alu instid0(VALU_DEP_1) ; GFX11TRUE16-NEXT: v_mov_b16_e32 v0.l, v0.h ; GFX11TRUE16-NEXT: s_setpc_b64 s[30:31] ; ; GFX11FAKE16-LABEL: v_fdiv_bf16: ; GFX11FAKE16: ; %bb.0: ; GFX11FAKE16-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) -; GFX11FAKE16-NEXT: v_lshlrev_b32_e32 v0, 16, v0 -; GFX11FAKE16-NEXT: v_lshlrev_b32_e32 v1, 16, v1 -; GFX11FAKE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1) -; GFX11FAKE16-NEXT: v_div_scale_f32 v2, null, v1, v1, v0 +; GFX11FAKE16-NEXT: v_lshlrev_b32_e32 v4, 16, v0 +; GFX11FAKE16-NEXT: v_lshlrev_b32_e32 v2, 16, v1 +; GFX11FAKE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_2) | instid1(VALU_DEP_1) ; GFX11FAKE16-NEXT: v_rcp_f32_e32 v3, v2 ; GFX11FAKE16-NEXT: s_waitcnt_depctr 0xfff -; GFX11FAKE16-NEXT: v_fma_f32 v4, -v2, v3, 1.0 -; GFX11FAKE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_1) | instid1(VALU_DEP_1) -; GFX11FAKE16-NEXT: v_fmac_f32_e32 v3, v4, v3 -; GFX11FAKE16-NEXT: v_div_scale_f32 v5, vcc_lo, v0, v1, v0 -; GFX11FAKE16-NEXT: v_mul_f32_e32 v4, v5, v3 +; GFX11FAKE16-NEXT: v_mul_f32_e32 v5, v4, v3 +; GFX11FAKE16-NEXT: v_fma_f32 v6, -v2, v5, v4 ; GFX11FAKE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1) -; GFX11FAKE16-NEXT: v_fma_f32 v6, -v2, v4, v5 -; GFX11FAKE16-NEXT: v_fmac_f32_e32 v4, v6, v3 +; GFX11FAKE16-NEXT: v_fmac_f32_e32 v5, v6, v3 +; GFX11FAKE16-NEXT: v_fma_f32 v2, -v2, v5, v4 ; GFX11FAKE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1) -; GFX11FAKE16-NEXT: v_fma_f32 v2, -v2, v4, v5 -; GFX11FAKE16-NEXT: v_div_fmas_f32 v2, v2, v3, v4 +; GFX11FAKE16-NEXT: v_mul_f32_e32 v2, v2, v3 +; GFX11FAKE16-NEXT: v_and_b32_e32 v2, 0xff800000, v2 ; GFX11FAKE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1) +; GFX11FAKE16-NEXT: v_add_f32_e32 v2, v2, v5 ; GFX11FAKE16-NEXT: v_div_fixup_f32 v0, v2, v1, v0 +; GFX11FAKE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_2) | instid1(VALU_DEP_3) ; GFX11FAKE16-NEXT: v_bfe_u32 v1, v0, 16, 1 ; GFX11FAKE16-NEXT: v_or_b32_e32 v2, 0x400000, v0 ; GFX11FAKE16-NEXT: v_cmp_u_f32_e32 vcc_lo, v0, v0 -; GFX11FAKE16-NEXT: s_delay_alu instid0(VALU_DEP_3) | instskip(NEXT) | instid1(VALU_DEP_1) ; GFX11FAKE16-NEXT: v_add3_u32 v1, v1, v0, 0x7fff +; GFX11FAKE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1) ; GFX11FAKE16-NEXT: v_cndmask_b32_e32 v0, v1, v2, vcc_lo -; GFX11FAKE16-NEXT: s_delay_alu instid0(VALU_DEP_1) ; GFX11FAKE16-NEXT: v_lshrrev_b32_e32 v0, 16, v0 ; GFX11FAKE16-NEXT: s_setpc_b64 s[30:31] %op = fdiv bfloat %a, %b From 21d37f168e5caa6686d9951d294066e010e0ab87 Mon Sep 17 00:00:00 2001 From: Ivan Kosarev Date: Wed, 9 Jul 2025 13:03:36 +0100 Subject: [PATCH 2/2] Address feedback. --- llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp | 1 + llvm/lib/Target/AMDGPU/SIISelLowering.cpp | 7 +++---- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp index f82e6df9bcbfc..8cd34aed52dbe 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp @@ -4513,6 +4513,7 @@ bool AMDGPULegalizerInfo::legalizeFDIV(MachineInstr &MI, LLT S32 = LLT::scalar(32); LLT S64 = LLT::scalar(64); + // TODO: Adapt the f16 logic to work for bf16 too as we do in SDAG. if (DstTy == S16) return legalizeFDIV16(MI, MRI, B); if (DstTy == S32) diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp index bd0bb38570a8f..eb27f10547e1d 100644 --- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp @@ -11268,12 +11268,11 @@ SDValue SITargetLowering::LowerFDIV16(SDValue Op, SelectionDAG &DAG) const { Quot = DAG.getNode(ISD::FADD, SL, MVT::f32, Tmp, Quot, Op->getFlags()); EVT FixupVT = VT == MVT::bf16 ? MVT::f32 : VT; - SDValue RDst = DAG.getNode(ISD::FP_ROUND, SL, FixupVT, Quot, - DAG.getTargetConstant(0, SL, MVT::i32)); + SDValue RoundFlags = DAG.getTargetConstant(0, SL, MVT::i32); + SDValue RDst = DAG.getNode(ISD::FP_ROUND, SL, FixupVT, Quot, RoundFlags); SDValue Fixup = DAG.getNode(AMDGPUISD::DIV_FIXUP, SL, FixupVT, RDst, RHS, LHS, Op->getFlags()); - return DAG.getNode(ISD::FP_ROUND, SL, VT, Fixup, - DAG.getTargetConstant(0, SL, MVT::i32)); + return DAG.getNode(ISD::FP_ROUND, SL, VT, Fixup, RoundFlags); } // Faster 2.5 ULP division that does not support denormals.