Skip to content

[LLVM][CodeGen][SVE] Make bf16 fabs/fneg isel consistent with fp16. #147543

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1742,9 +1742,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
for (auto VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16}) {
setOperationAction(ISD::BITCAST, VT, Custom);
setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
setOperationAction(ISD::FABS, VT, Legal);
setOperationAction(ISD::FABS, VT, Custom);
setOperationAction(ISD::FCOPYSIGN, VT, Custom);
setOperationAction(ISD::FNEG, VT, Legal);
setOperationAction(ISD::FNEG, VT, Custom);
setOperationAction(ISD::FP_EXTEND, VT, Custom);
setOperationAction(ISD::FP_ROUND, VT, Custom);
setOperationAction(ISD::MLOAD, VT, Custom);
Expand Down
9 changes: 0 additions & 9 deletions llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -737,15 +737,6 @@ let Predicates = [HasSVE_or_SME] in {
defm FABS_ZPmZ : sve_int_un_pred_arit_bitwise_fp<0b100, "fabs", AArch64fabs_mt>;
defm FNEG_ZPmZ : sve_int_un_pred_arit_bitwise_fp<0b101, "fneg", AArch64fneg_mt>;

foreach VT = [nxv2bf16, nxv4bf16, nxv8bf16] in {
// No dedicated instruction, so just clear the sign bit.
def : Pat<(VT (fabs VT:$op)),
(AND_ZI $op, (i64 (logical_imm64_XFORM(i64 0x7fff7fff7fff7fff))))>;
// No dedicated instruction, so just invert the sign bit.
def : Pat<(VT (fneg VT:$op)),
(EOR_ZI $op, (i64 (logical_imm64_XFORM(i64 0x8000800080008000))))>;
}

// zext(cmpeq(x, splat(0))) -> cnot(x)
def : Pat<(nxv16i8 (zext (nxv16i1 (AArch64setcc_z (nxv16i1 (SVEAllActive):$Pg), nxv16i8:$Op2, (SVEDup0), SETEQ)))),
(CNOT_ZPmZ_B $Op2, $Pg, $Op2)>;
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Target/AArch64/SVEInstrFormats.td
Original file line number Diff line number Diff line change
Expand Up @@ -5049,6 +5049,9 @@ multiclass sve_int_un_pred_arit_bitwise_fp<bits<3> opc, string asm,
def : SVE_1_Op_Passthru_Pat<nxv4f32, op, nxv4i1, nxv4f32, !cast<Instruction>(NAME # _S)>;
def : SVE_1_Op_Passthru_Pat<nxv2f32, op, nxv2i1, nxv2f32, !cast<Instruction>(NAME # _S)>;
def : SVE_1_Op_Passthru_Pat<nxv2f64, op, nxv2i1, nxv2f64, !cast<Instruction>(NAME # _D)>;
def : SVE_1_Op_Passthru_Pat<nxv8bf16, op, nxv8i1, nxv8bf16, !cast<Instruction>(NAME # _H)>;
def : SVE_1_Op_Passthru_Pat<nxv4bf16, op, nxv4i1, nxv4bf16, !cast<Instruction>(NAME # _H)>;
def : SVE_1_Op_Passthru_Pat<nxv2bf16, op, nxv2i1, nxv2bf16, !cast<Instruction>(NAME # _H)>;

def _H_UNDEF : PredOneOpPassthruPseudo<NAME # _H, ZPR16>;
def _S_UNDEF : PredOneOpPassthruPseudo<NAME # _S, ZPR32>;
Expand All @@ -5060,6 +5063,9 @@ multiclass sve_int_un_pred_arit_bitwise_fp<bits<3> opc, string asm,
defm : SVE_1_Op_PassthruUndef_Pat<nxv4f32, op, nxv4i1, nxv4f32, !cast<Pseudo>(NAME # _S_UNDEF)>;
defm : SVE_1_Op_PassthruUndef_Pat<nxv2f32, op, nxv2i1, nxv2f32, !cast<Pseudo>(NAME # _S_UNDEF)>;
defm : SVE_1_Op_PassthruUndef_Pat<nxv2f64, op, nxv2i1, nxv2f64, !cast<Pseudo>(NAME # _D_UNDEF)>;
defm : SVE_1_Op_PassthruUndef_Pat<nxv8bf16, op, nxv8i1, nxv8bf16, !cast<Pseudo>(NAME # _H_UNDEF)>;
defm : SVE_1_Op_PassthruUndef_Pat<nxv4bf16, op, nxv4i1, nxv4bf16, !cast<Pseudo>(NAME # _H_UNDEF)>;
defm : SVE_1_Op_PassthruUndef_Pat<nxv2bf16, op, nxv2i1, nxv2bf16, !cast<Pseudo>(NAME # _H_UNDEF)>;
}

multiclass sve_int_un_pred_arit_bitwise_fp_z<bits<3> opc, string asm, SDPatternOperator op> {
Expand Down
18 changes: 12 additions & 6 deletions llvm/test/CodeGen/AArch64/sve-bf16-arith.ll
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ target triple = "aarch64-unknown-linux-gnu"
define <vscale x 2 x bfloat> @fabs_nxv2bf16(<vscale x 2 x bfloat> %a) {
; CHECK-LABEL: fabs_nxv2bf16:
; CHECK: // %bb.0:
; CHECK-NEXT: and z0.h, z0.h, #0x7fff
; CHECK-NEXT: ptrue p0.d
; CHECK-NEXT: fabs z0.h, p0/m, z0.h
; CHECK-NEXT: ret
%res = call <vscale x 2 x bfloat> @llvm.fabs.nxv2bf16(<vscale x 2 x bfloat> %a)
ret <vscale x 2 x bfloat> %res
Expand All @@ -22,7 +23,8 @@ define <vscale x 2 x bfloat> @fabs_nxv2bf16(<vscale x 2 x bfloat> %a) {
define <vscale x 4 x bfloat> @fabs_nxv4bf16(<vscale x 4 x bfloat> %a) {
; CHECK-LABEL: fabs_nxv4bf16:
; CHECK: // %bb.0:
; CHECK-NEXT: and z0.h, z0.h, #0x7fff
; CHECK-NEXT: ptrue p0.s
; CHECK-NEXT: fabs z0.h, p0/m, z0.h
; CHECK-NEXT: ret
%res = call <vscale x 4 x bfloat> @llvm.fabs.nxv4bf16(<vscale x 4 x bfloat> %a)
ret <vscale x 4 x bfloat> %res
Expand All @@ -31,7 +33,8 @@ define <vscale x 4 x bfloat> @fabs_nxv4bf16(<vscale x 4 x bfloat> %a) {
define <vscale x 8 x bfloat> @fabs_nxv8bf16(<vscale x 8 x bfloat> %a) {
; CHECK-LABEL: fabs_nxv8bf16:
; CHECK: // %bb.0:
; CHECK-NEXT: and z0.h, z0.h, #0x7fff
; CHECK-NEXT: ptrue p0.h
; CHECK-NEXT: fabs z0.h, p0/m, z0.h
; CHECK-NEXT: ret
%res = call <vscale x 8 x bfloat> @llvm.fabs.nxv8bf16(<vscale x 8 x bfloat> %a)
ret <vscale x 8 x bfloat> %res
Expand Down Expand Up @@ -586,7 +589,8 @@ define <vscale x 8 x bfloat> @fmul_nxv8bf16(<vscale x 8 x bfloat> %a, <vscale x
define <vscale x 2 x bfloat> @fneg_nxv2bf16(<vscale x 2 x bfloat> %a) {
; CHECK-LABEL: fneg_nxv2bf16:
; CHECK: // %bb.0:
; CHECK-NEXT: eor z0.h, z0.h, #0x8000
; CHECK-NEXT: ptrue p0.d
; CHECK-NEXT: fneg z0.h, p0/m, z0.h
; CHECK-NEXT: ret
%res = fneg <vscale x 2 x bfloat> %a
ret <vscale x 2 x bfloat> %res
Expand All @@ -595,7 +599,8 @@ define <vscale x 2 x bfloat> @fneg_nxv2bf16(<vscale x 2 x bfloat> %a) {
define <vscale x 4 x bfloat> @fneg_nxv4bf16(<vscale x 4 x bfloat> %a) {
; CHECK-LABEL: fneg_nxv4bf16:
; CHECK: // %bb.0:
; CHECK-NEXT: eor z0.h, z0.h, #0x8000
; CHECK-NEXT: ptrue p0.s
; CHECK-NEXT: fneg z0.h, p0/m, z0.h
; CHECK-NEXT: ret
%res = fneg <vscale x 4 x bfloat> %a
ret <vscale x 4 x bfloat> %res
Expand All @@ -604,7 +609,8 @@ define <vscale x 4 x bfloat> @fneg_nxv4bf16(<vscale x 4 x bfloat> %a) {
define <vscale x 8 x bfloat> @fneg_nxv8bf16(<vscale x 8 x bfloat> %a) {
; CHECK-LABEL: fneg_nxv8bf16:
; CHECK: // %bb.0:
; CHECK-NEXT: eor z0.h, z0.h, #0x8000
; CHECK-NEXT: ptrue p0.h
; CHECK-NEXT: fneg z0.h, p0/m, z0.h
; CHECK-NEXT: ret
%res = fneg <vscale x 8 x bfloat> %a
ret <vscale x 8 x bfloat> %res
Expand Down
24 changes: 24 additions & 0 deletions llvm/test/CodeGen/AArch64/sve-intrinsics-fp-arith.ll
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,17 @@ define <vscale x 2 x double> @fabs_d(<vscale x 2 x double> %a, <vscale x 2 x i1>
ret <vscale x 2 x double> %out
}

define <vscale x 8 x bfloat> @fabs_bf(<vscale x 8 x bfloat> %a, <vscale x 8 x i1> %pg, <vscale x 8 x bfloat> %b) {
; CHECK-LABEL: fabs_bf:
; CHECK: // %bb.0:
; CHECK-NEXT: fabs z0.h, p0/m, z1.h
; CHECK-NEXT: ret
%out = call <vscale x 8 x bfloat> @llvm.aarch64.sve.fabs.nxv8bf16(<vscale x 8 x bfloat> %a,
<vscale x 8 x i1> %pg,
<vscale x 8 x bfloat> %b)
ret <vscale x 8 x bfloat> %out
}

;
; FADD
;
Expand Down Expand Up @@ -835,6 +846,17 @@ define <vscale x 2 x double> @fneg_d(<vscale x 2 x double> %a, <vscale x 2 x i1>
ret <vscale x 2 x double> %out
}

define <vscale x 8 x bfloat> @fneg_bf(<vscale x 8 x bfloat> %a, <vscale x 8 x i1> %pg, <vscale x 8 x bfloat> %b) {
; CHECK-LABEL: fneg_bf:
; CHECK: // %bb.0:
; CHECK-NEXT: fneg z0.h, p0/m, z1.h
; CHECK-NEXT: ret
%out = call <vscale x 8 x bfloat> @llvm.aarch64.sve.fneg.nxv8bf16(<vscale x 8 x bfloat> %a,
<vscale x 8 x i1> %pg,
<vscale x 8 x bfloat> %b)
ret <vscale x 8 x bfloat> %out
}

;
; FNMAD
;
Expand Down Expand Up @@ -1613,6 +1635,7 @@ declare <vscale x 2 x double> @llvm.aarch64.sve.fabd.nxv2f64(<vscale x 2 x i1>,
declare <vscale x 8 x half> @llvm.aarch64.sve.fabs.nxv8f16(<vscale x 8 x half>, <vscale x 8 x i1>, <vscale x 8 x half>)
declare <vscale x 4 x float> @llvm.aarch64.sve.fabs.nxv4f32(<vscale x 4 x float>, <vscale x 4 x i1>, <vscale x 4 x float>)
declare <vscale x 2 x double> @llvm.aarch64.sve.fabs.nxv2f64(<vscale x 2 x double>, <vscale x 2 x i1>, <vscale x 2 x double>)
declare <vscale x 8 x bfloat> @llvm.aarch64.sve.fabs.nxv8bf16(<vscale x 8 x bfloat>, <vscale x 8 x i1>, <vscale x 8 x bfloat>)

declare <vscale x 8 x half> @llvm.aarch64.sve.fadd.nxv8f16(<vscale x 8 x i1>, <vscale x 8 x half>, <vscale x 8 x half>)
declare <vscale x 4 x float> @llvm.aarch64.sve.fadd.nxv4f32(<vscale x 4 x i1>, <vscale x 4 x float>, <vscale x 4 x float>)
Expand Down Expand Up @@ -1692,6 +1715,7 @@ declare <vscale x 2 x double> @llvm.aarch64.sve.fmulx.nxv2f64(<vscale x 2 x i1>,
declare <vscale x 8 x half> @llvm.aarch64.sve.fneg.nxv8f16(<vscale x 8 x half>, <vscale x 8 x i1>, <vscale x 8 x half>)
declare <vscale x 4 x float> @llvm.aarch64.sve.fneg.nxv4f32(<vscale x 4 x float>, <vscale x 4 x i1>, <vscale x 4 x float>)
declare <vscale x 2 x double> @llvm.aarch64.sve.fneg.nxv2f64(<vscale x 2 x double>, <vscale x 2 x i1>, <vscale x 2 x double>)
declare <vscale x 8 x bfloat> @llvm.aarch64.sve.fneg.nxv8bf16(<vscale x 8 x bfloat>, <vscale x 8 x i1>, <vscale x 8 x bfloat>)

declare <vscale x 8 x half> @llvm.aarch64.sve.fnmad.nxv8f16(<vscale x 8 x i1>, <vscale x 8 x half>, <vscale x 8 x half>, <vscale x 8 x half>)
declare <vscale x 4 x float> @llvm.aarch64.sve.fnmad.nxv4f32(<vscale x 4 x i1>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>)
Expand Down
18 changes: 6 additions & 12 deletions llvm/test/CodeGen/AArch64/sve-merging-unary.ll
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,7 @@ define <vscale x 2 x double> @fabs_nxv2f64(<vscale x 2 x i1> %pg, <vscale x 2 x
define <vscale x 2 x bfloat> @fabs_nxv2bf16(<vscale x 2 x i1> %pg, <vscale x 2 x bfloat> %a, <vscale x 2 x bfloat> %b) {
; CHECK-LABEL: fabs_nxv2bf16:
; CHECK: // %bb.0:
; CHECK-NEXT: and z1.h, z1.h, #0x7fff
; CHECK-NEXT: mov z0.d, p0/m, z1.d
; CHECK-NEXT: fabs z0.h, p0/m, z1.h
; CHECK-NEXT: ret
%b.op = call <vscale x 2 x bfloat> @llvm.fabs.nxv2bf16(<vscale x 2 x bfloat> %b)
%res = select <vscale x 2 x i1> %pg, <vscale x 2 x bfloat> %b.op, <vscale x 2 x bfloat> %a
Expand All @@ -198,8 +197,7 @@ define <vscale x 2 x bfloat> @fabs_nxv2bf16(<vscale x 2 x i1> %pg, <vscale x 2 x
define <vscale x 4 x bfloat> @fabs_nxv4bf16(<vscale x 4 x i1> %pg, <vscale x 4 x bfloat> %a, <vscale x 4 x bfloat> %b) {
; CHECK-LABEL: fabs_nxv4bf16:
; CHECK: // %bb.0:
; CHECK-NEXT: and z1.h, z1.h, #0x7fff
; CHECK-NEXT: mov z0.s, p0/m, z1.s
; CHECK-NEXT: fabs z0.h, p0/m, z1.h
; CHECK-NEXT: ret
%b.op = call <vscale x 4 x bfloat> @llvm.fabs.nxv4bf16(<vscale x 4 x bfloat> %b)
%res = select <vscale x 4 x i1> %pg, <vscale x 4 x bfloat> %b.op, <vscale x 4 x bfloat> %a
Expand All @@ -209,8 +207,7 @@ define <vscale x 4 x bfloat> @fabs_nxv4bf16(<vscale x 4 x i1> %pg, <vscale x 4 x
define <vscale x 8 x bfloat> @fabs_nxv8bf16(<vscale x 8 x i1> %pg, <vscale x 8 x bfloat> %a, <vscale x 8 x bfloat> %b) {
; CHECK-LABEL: fabs_nxv8bf16:
; CHECK: // %bb.0:
; CHECK-NEXT: and z1.h, z1.h, #0x7fff
; CHECK-NEXT: mov z0.h, p0/m, z1.h
; CHECK-NEXT: fabs z0.h, p0/m, z1.h
; CHECK-NEXT: ret
%b.op = call <vscale x 8 x bfloat> @llvm.fabs.nxv8bf16(<vscale x 8 x bfloat> %b)
%res = select <vscale x 8 x i1> %pg, <vscale x 8 x bfloat> %b.op, <vscale x 8 x bfloat> %a
Expand Down Expand Up @@ -545,8 +542,7 @@ define <vscale x 2 x double> @fneg_nxv2f64(<vscale x 2 x i1> %pg, <vscale x 2 x
define <vscale x 2 x bfloat> @fneg_nxv2bf16(<vscale x 2 x i1> %pg, <vscale x 2 x bfloat> %a, <vscale x 2 x bfloat> %b) {
; CHECK-LABEL: fneg_nxv2bf16:
; CHECK: // %bb.0:
; CHECK-NEXT: eor z1.h, z1.h, #0x8000
; CHECK-NEXT: mov z0.d, p0/m, z1.d
; CHECK-NEXT: fneg z0.h, p0/m, z1.h
; CHECK-NEXT: ret
%b.op = fneg <vscale x 2 x bfloat> %b
%res = select <vscale x 2 x i1> %pg, <vscale x 2 x bfloat> %b.op, <vscale x 2 x bfloat> %a
Expand All @@ -556,8 +552,7 @@ define <vscale x 2 x bfloat> @fneg_nxv2bf16(<vscale x 2 x i1> %pg, <vscale x 2 x
define <vscale x 4 x bfloat> @fneg_nxv4bf16(<vscale x 4 x i1> %pg, <vscale x 4 x bfloat> %a, <vscale x 4 x bfloat> %b) {
; CHECK-LABEL: fneg_nxv4bf16:
; CHECK: // %bb.0:
; CHECK-NEXT: eor z1.h, z1.h, #0x8000
; CHECK-NEXT: mov z0.s, p0/m, z1.s
; CHECK-NEXT: fneg z0.h, p0/m, z1.h
; CHECK-NEXT: ret
%b.op = fneg <vscale x 4 x bfloat> %b
%res = select <vscale x 4 x i1> %pg, <vscale x 4 x bfloat> %b.op, <vscale x 4 x bfloat> %a
Expand All @@ -567,8 +562,7 @@ define <vscale x 4 x bfloat> @fneg_nxv4bf16(<vscale x 4 x i1> %pg, <vscale x 4 x
define <vscale x 8 x bfloat> @fneg_nxv8bf16(<vscale x 8 x i1> %pg, <vscale x 8 x bfloat> %a, <vscale x 8 x bfloat> %b) {
; CHECK-LABEL: fneg_nxv8bf16:
; CHECK: // %bb.0:
; CHECK-NEXT: eor z1.h, z1.h, #0x8000
; CHECK-NEXT: mov z0.h, p0/m, z1.h
; CHECK-NEXT: fneg z0.h, p0/m, z1.h
; CHECK-NEXT: ret
%b.op = fneg <vscale x 8 x bfloat> %b
%res = select <vscale x 8 x i1> %pg, <vscale x 8 x bfloat> %b.op, <vscale x 8 x bfloat> %a
Expand Down
Loading