Skip to content

Commit 607485f

Browse files
[LLVM][SVE] Lower bfloat extends the same as other types. (#129544)
1 parent 323112a commit 607485f

File tree

2 files changed

+6
-15
lines changed

2 files changed

+6
-15
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4507,18 +4507,9 @@ SDValue AArch64TargetLowering::LowerFP_EXTEND(SDValue Op,
45074507
if (VT.isScalableVector()) {
45084508
SDValue SrcVal = Op.getOperand(0);
45094509

4510-
if (SrcVal.getValueType().getScalarType() == MVT::bf16) {
4511-
// bf16 and f32 share the same exponent range so the conversion requires
4512-
// them to be aligned with the new mantissa bits zero'd. This is just a
4513-
// left shift that is best to isel directly.
4514-
if (VT == MVT::nxv2f32 || VT == MVT::nxv4f32)
4515-
return Op;
4516-
4517-
if (VT != MVT::nxv2f64)
4518-
return SDValue();
4519-
4520-
// Break other conversions in two with the first part converting to f32
4521-
// and the second using native f32->VT instructions.
4510+
if (VT == MVT::nxv2f64 && SrcVal.getValueType() == MVT::nxv2bf16) {
4511+
// Break conversion in two with the first part converting to f32 and the
4512+
// second using native f32->VT instructions.
45224513
SDLoc DL(Op);
45234514
return DAG.getNode(ISD::FP_EXTEND, DL, VT,
45244515
DAG.getNode(ISD::FP_EXTEND, DL, MVT::nxv2f32, SrcVal));

llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def AArch64fclamp : PatFrags<(ops node:$Zd, node:$Zn, node:$Zm),
345345

346346
def SDT_AArch64FCVT : SDTypeProfile<1, 3, [
347347
SDTCisVec<0>, SDTCisVec<1>, SDTCisVec<2>, SDTCisVec<3>,
348-
SDTCVecEltisVT<1,i1>
348+
SDTCVecEltisVT<1,i1>, SDTCisSameNumEltsAs<0,1>, SDTCisSameAs<0,3>
349349
]>;
350350

351351
def SDT_AArch64FCVTR : SDTypeProfile<1, 4, [
@@ -2377,9 +2377,9 @@ let Predicates = [HasSVE_or_SME] in {
23772377
def : Pat<(nxv2f16 (AArch64fcvtr_mt (nxv2i1 (SVEAllActive:$Pg)), nxv2f32:$Zs, (i64 timm0_1), nxv2f16:$Zd)),
23782378
(FCVT_ZPmZ_StoH_UNDEF ZPR:$Zd, PPR:$Pg, ZPR:$Zs)>;
23792379

2380-
def : Pat<(nxv4f32 (fpextend nxv4bf16:$op)),
2380+
def : Pat<(nxv4f32 (AArch64fcvte_mt (SVEAnyPredicate), nxv4bf16:$op, undef)),
23812381
(LSL_ZZI_S $op, (i32 16))>;
2382-
def : Pat<(nxv2f32 (fpextend nxv2bf16:$op)),
2382+
def : Pat<(nxv2f32 (AArch64fcvte_mt (SVEAnyPredicate), nxv2bf16:$op, undef)),
23832383
(LSL_ZZI_S $op, (i32 16))>;
23842384

23852385
// Signed integer -> Floating-point

0 commit comments

Comments
 (0)