From ac96fac10f63a3ab77c17b945273ad488f6ac8b1 Mon Sep 17 00:00:00 2001 From: Paul Walker Date: Tue, 1 Jul 2025 17:00:03 +0100 Subject: [PATCH] [LLVM][CodeGen][SVE] Make bf16 fabs/fneg isel consistent with fp16. Whilst at first glance there appears to be no native bfloat instructions to modify the sign bit, this is only the case when FEAT_AFP is implemented. Without this feature vector FABS/FNEG does not care about the floating point format beyond needing to know the position of the sign bit. From what I can see LLVM has no support for FEAT_AFP in terms of feature detection or ACLE builtins and so I believe the compiler can work under the assumption the feature is not enabled. In fact, if FEAT_AFP is enabled then I believe the current isel is likely broken for half, float and double anyway. NOTE: The main motivation behind this change is to allow existing PatFrags to work with bfloat vectors without having to add special handling for unpredicated fabs/fneg operations. NOTE: For some cases using bitwise instructions can be better but I figure it's better to unify the operations and investigate later whether to use bitwise instructions for all vector types. --- .../Target/AArch64/AArch64ISelLowering.cpp | 4 ++-- .../lib/Target/AArch64/AArch64SVEInstrInfo.td | 9 ------- llvm/lib/Target/AArch64/SVEInstrFormats.td | 6 +++++ llvm/test/CodeGen/AArch64/sve-bf16-arith.ll | 18 +++++++++----- .../AArch64/sve-intrinsics-fp-arith.ll | 24 +++++++++++++++++++ .../test/CodeGen/AArch64/sve-merging-unary.ll | 18 +++++--------- 6 files changed, 50 insertions(+), 29 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 0d388fc3c787d..d4a6732b23170 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1742,9 +1742,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, for (auto VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16}) { setOperationAction(ISD::BITCAST, VT, Custom); setOperationAction(ISD::CONCAT_VECTORS, VT, Custom); - setOperationAction(ISD::FABS, VT, Legal); + setOperationAction(ISD::FABS, VT, Custom); setOperationAction(ISD::FCOPYSIGN, VT, Custom); - setOperationAction(ISD::FNEG, VT, Legal); + setOperationAction(ISD::FNEG, VT, Custom); setOperationAction(ISD::FP_EXTEND, VT, Custom); setOperationAction(ISD::FP_ROUND, VT, Custom); setOperationAction(ISD::MLOAD, VT, Custom); diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td index 261df563bb2a9..0dcae26e1480a 100644 --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -737,15 +737,6 @@ let Predicates = [HasSVE_or_SME] in { defm FABS_ZPmZ : sve_int_un_pred_arit_bitwise_fp<0b100, "fabs", AArch64fabs_mt>; defm FNEG_ZPmZ : sve_int_un_pred_arit_bitwise_fp<0b101, "fneg", AArch64fneg_mt>; - foreach VT = [nxv2bf16, nxv4bf16, nxv8bf16] in { - // No dedicated instruction, so just clear the sign bit. - def : Pat<(VT (fabs VT:$op)), - (AND_ZI $op, (i64 (logical_imm64_XFORM(i64 0x7fff7fff7fff7fff))))>; - // No dedicated instruction, so just invert the sign bit. - def : Pat<(VT (fneg VT:$op)), - (EOR_ZI $op, (i64 (logical_imm64_XFORM(i64 0x8000800080008000))))>; - } - // zext(cmpeq(x, splat(0))) -> cnot(x) def : Pat<(nxv16i8 (zext (nxv16i1 (AArch64setcc_z (nxv16i1 (SVEAllActive):$Pg), nxv16i8:$Op2, (SVEDup0), SETEQ)))), (CNOT_ZPmZ_B $Op2, $Pg, $Op2)>; diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td index d5c12a9658113..3b7e5a6c2b1cf 100644 --- a/llvm/lib/Target/AArch64/SVEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -5049,6 +5049,9 @@ multiclass sve_int_un_pred_arit_bitwise_fp opc, string asm, def : SVE_1_Op_Passthru_Pat(NAME # _S)>; def : SVE_1_Op_Passthru_Pat(NAME # _S)>; def : SVE_1_Op_Passthru_Pat(NAME # _D)>; + def : SVE_1_Op_Passthru_Pat(NAME # _H)>; + def : SVE_1_Op_Passthru_Pat(NAME # _H)>; + def : SVE_1_Op_Passthru_Pat(NAME # _H)>; def _H_UNDEF : PredOneOpPassthruPseudo; def _S_UNDEF : PredOneOpPassthruPseudo; @@ -5060,6 +5063,9 @@ multiclass sve_int_un_pred_arit_bitwise_fp opc, string asm, defm : SVE_1_Op_PassthruUndef_Pat(NAME # _S_UNDEF)>; defm : SVE_1_Op_PassthruUndef_Pat(NAME # _S_UNDEF)>; defm : SVE_1_Op_PassthruUndef_Pat(NAME # _D_UNDEF)>; + defm : SVE_1_Op_PassthruUndef_Pat(NAME # _H_UNDEF)>; + defm : SVE_1_Op_PassthruUndef_Pat(NAME # _H_UNDEF)>; + defm : SVE_1_Op_PassthruUndef_Pat(NAME # _H_UNDEF)>; } multiclass sve_int_un_pred_arit_bitwise_fp_z opc, string asm, SDPatternOperator op> { diff --git a/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll b/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll index e8468ddfeed18..83f4f8fc57aae 100644 --- a/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll +++ b/llvm/test/CodeGen/AArch64/sve-bf16-arith.ll @@ -13,7 +13,8 @@ target triple = "aarch64-unknown-linux-gnu" define @fabs_nxv2bf16( %a) { ; CHECK-LABEL: fabs_nxv2bf16: ; CHECK: // %bb.0: -; CHECK-NEXT: and z0.h, z0.h, #0x7fff +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: fabs z0.h, p0/m, z0.h ; CHECK-NEXT: ret %res = call @llvm.fabs.nxv2bf16( %a) ret %res @@ -22,7 +23,8 @@ define @fabs_nxv2bf16( %a) { define @fabs_nxv4bf16( %a) { ; CHECK-LABEL: fabs_nxv4bf16: ; CHECK: // %bb.0: -; CHECK-NEXT: and z0.h, z0.h, #0x7fff +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: fabs z0.h, p0/m, z0.h ; CHECK-NEXT: ret %res = call @llvm.fabs.nxv4bf16( %a) ret %res @@ -31,7 +33,8 @@ define @fabs_nxv4bf16( %a) { define @fabs_nxv8bf16( %a) { ; CHECK-LABEL: fabs_nxv8bf16: ; CHECK: // %bb.0: -; CHECK-NEXT: and z0.h, z0.h, #0x7fff +; CHECK-NEXT: ptrue p0.h +; CHECK-NEXT: fabs z0.h, p0/m, z0.h ; CHECK-NEXT: ret %res = call @llvm.fabs.nxv8bf16( %a) ret %res @@ -586,7 +589,8 @@ define @fmul_nxv8bf16( %a, @fneg_nxv2bf16( %a) { ; CHECK-LABEL: fneg_nxv2bf16: ; CHECK: // %bb.0: -; CHECK-NEXT: eor z0.h, z0.h, #0x8000 +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: fneg z0.h, p0/m, z0.h ; CHECK-NEXT: ret %res = fneg %a ret %res @@ -595,7 +599,8 @@ define @fneg_nxv2bf16( %a) { define @fneg_nxv4bf16( %a) { ; CHECK-LABEL: fneg_nxv4bf16: ; CHECK: // %bb.0: -; CHECK-NEXT: eor z0.h, z0.h, #0x8000 +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: fneg z0.h, p0/m, z0.h ; CHECK-NEXT: ret %res = fneg %a ret %res @@ -604,7 +609,8 @@ define @fneg_nxv4bf16( %a) { define @fneg_nxv8bf16( %a) { ; CHECK-LABEL: fneg_nxv8bf16: ; CHECK: // %bb.0: -; CHECK-NEXT: eor z0.h, z0.h, #0x8000 +; CHECK-NEXT: ptrue p0.h +; CHECK-NEXT: fneg z0.h, p0/m, z0.h ; CHECK-NEXT: ret %res = fneg %a ret %res diff --git a/llvm/test/CodeGen/AArch64/sve-intrinsics-fp-arith.ll b/llvm/test/CodeGen/AArch64/sve-intrinsics-fp-arith.ll index 0aeab72096caa..5491dc274cd1f 100644 --- a/llvm/test/CodeGen/AArch64/sve-intrinsics-fp-arith.ll +++ b/llvm/test/CodeGen/AArch64/sve-intrinsics-fp-arith.ll @@ -75,6 +75,17 @@ define @fabs_d( %a, ret %out } +define @fabs_bf( %a, %pg, %b) { +; CHECK-LABEL: fabs_bf: +; CHECK: // %bb.0: +; CHECK-NEXT: fabs z0.h, p0/m, z1.h +; CHECK-NEXT: ret + %out = call @llvm.aarch64.sve.fabs.nxv8bf16( %a, + %pg, + %b) + ret %out +} + ; ; FADD ; @@ -835,6 +846,17 @@ define @fneg_d( %a, ret %out } +define @fneg_bf( %a, %pg, %b) { +; CHECK-LABEL: fneg_bf: +; CHECK: // %bb.0: +; CHECK-NEXT: fneg z0.h, p0/m, z1.h +; CHECK-NEXT: ret + %out = call @llvm.aarch64.sve.fneg.nxv8bf16( %a, + %pg, + %b) + ret %out +} + ; ; FNMAD ; @@ -1613,6 +1635,7 @@ declare @llvm.aarch64.sve.fabd.nxv2f64(, declare @llvm.aarch64.sve.fabs.nxv8f16(, , ) declare @llvm.aarch64.sve.fabs.nxv4f32(, , ) declare @llvm.aarch64.sve.fabs.nxv2f64(, , ) +declare @llvm.aarch64.sve.fabs.nxv8bf16(, , ) declare @llvm.aarch64.sve.fadd.nxv8f16(, , ) declare @llvm.aarch64.sve.fadd.nxv4f32(, , ) @@ -1692,6 +1715,7 @@ declare @llvm.aarch64.sve.fmulx.nxv2f64(, declare @llvm.aarch64.sve.fneg.nxv8f16(, , ) declare @llvm.aarch64.sve.fneg.nxv4f32(, , ) declare @llvm.aarch64.sve.fneg.nxv2f64(, , ) +declare @llvm.aarch64.sve.fneg.nxv8bf16(, , ) declare @llvm.aarch64.sve.fnmad.nxv8f16(, , , ) declare @llvm.aarch64.sve.fnmad.nxv4f32(, , , ) diff --git a/llvm/test/CodeGen/AArch64/sve-merging-unary.ll b/llvm/test/CodeGen/AArch64/sve-merging-unary.ll index 07f75c866d3d0..02c815844f1d6 100644 --- a/llvm/test/CodeGen/AArch64/sve-merging-unary.ll +++ b/llvm/test/CodeGen/AArch64/sve-merging-unary.ll @@ -187,8 +187,7 @@ define @fabs_nxv2f64( %pg, @fabs_nxv2bf16( %pg, %a, %b) { ; CHECK-LABEL: fabs_nxv2bf16: ; CHECK: // %bb.0: -; CHECK-NEXT: and z1.h, z1.h, #0x7fff -; CHECK-NEXT: mov z0.d, p0/m, z1.d +; CHECK-NEXT: fabs z0.h, p0/m, z1.h ; CHECK-NEXT: ret %b.op = call @llvm.fabs.nxv2bf16( %b) %res = select %pg, %b.op, %a @@ -198,8 +197,7 @@ define @fabs_nxv2bf16( %pg, @fabs_nxv4bf16( %pg, %a, %b) { ; CHECK-LABEL: fabs_nxv4bf16: ; CHECK: // %bb.0: -; CHECK-NEXT: and z1.h, z1.h, #0x7fff -; CHECK-NEXT: mov z0.s, p0/m, z1.s +; CHECK-NEXT: fabs z0.h, p0/m, z1.h ; CHECK-NEXT: ret %b.op = call @llvm.fabs.nxv4bf16( %b) %res = select %pg, %b.op, %a @@ -209,8 +207,7 @@ define @fabs_nxv4bf16( %pg, @fabs_nxv8bf16( %pg, %a, %b) { ; CHECK-LABEL: fabs_nxv8bf16: ; CHECK: // %bb.0: -; CHECK-NEXT: and z1.h, z1.h, #0x7fff -; CHECK-NEXT: mov z0.h, p0/m, z1.h +; CHECK-NEXT: fabs z0.h, p0/m, z1.h ; CHECK-NEXT: ret %b.op = call @llvm.fabs.nxv8bf16( %b) %res = select %pg, %b.op, %a @@ -545,8 +542,7 @@ define @fneg_nxv2f64( %pg, @fneg_nxv2bf16( %pg, %a, %b) { ; CHECK-LABEL: fneg_nxv2bf16: ; CHECK: // %bb.0: -; CHECK-NEXT: eor z1.h, z1.h, #0x8000 -; CHECK-NEXT: mov z0.d, p0/m, z1.d +; CHECK-NEXT: fneg z0.h, p0/m, z1.h ; CHECK-NEXT: ret %b.op = fneg %b %res = select %pg, %b.op, %a @@ -556,8 +552,7 @@ define @fneg_nxv2bf16( %pg, @fneg_nxv4bf16( %pg, %a, %b) { ; CHECK-LABEL: fneg_nxv4bf16: ; CHECK: // %bb.0: -; CHECK-NEXT: eor z1.h, z1.h, #0x8000 -; CHECK-NEXT: mov z0.s, p0/m, z1.s +; CHECK-NEXT: fneg z0.h, p0/m, z1.h ; CHECK-NEXT: ret %b.op = fneg %b %res = select %pg, %b.op, %a @@ -567,8 +562,7 @@ define @fneg_nxv4bf16( %pg, @fneg_nxv8bf16( %pg, %a, %b) { ; CHECK-LABEL: fneg_nxv8bf16: ; CHECK: // %bb.0: -; CHECK-NEXT: eor z1.h, z1.h, #0x8000 -; CHECK-NEXT: mov z0.h, p0/m, z1.h +; CHECK-NEXT: fneg z0.h, p0/m, z1.h ; CHECK-NEXT: ret %b.op = fneg %b %res = select %pg, %b.op, %a