Skip to content

Commit 323112a

Browse files
[LLVM][SVE] Add isel for bfloat based constant splats. (#129550)
There are no dedicated bfloat MOV instructions but we can use the half variants when the encoding allows (e.g. f16(1.875) == bf16(1.0)).
1 parent 06fc7d6 commit 323112a

File tree

2 files changed

+36
-2
lines changed

2 files changed

+36
-2
lines changed

llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -931,6 +931,13 @@ let Predicates = [HasSVE_or_SME] in {
931931
(FDUP_ZI_S fpimm32:$imm8)>;
932932
def : Pat<(nxv2f64 (splat_vector fpimm64:$imm8)),
933933
(FDUP_ZI_D fpimm64:$imm8)>;
934+
// Some half precision immediates alias with bfloat (e.g. f16(1.875) == bf16(1.0)).
935+
def : Pat<(nxv8bf16 (splat_vector fpimmbf16:$imm8)),
936+
(FDUP_ZI_H (fpimm16XForm bf16:$imm8))>;
937+
def : Pat<(nxv4bf16 (splat_vector fpimmbf16:$imm8)),
938+
(FDUP_ZI_H (fpimm16XForm bf16:$imm8))>;
939+
def : Pat<(nxv2bf16 (splat_vector fpimmbf16:$imm8)),
940+
(FDUP_ZI_H (fpimm16XForm bf16:$imm8))>;
934941
}
935942

936943
// Select elements from either vector (predicated)

llvm/test/CodeGen/AArch64/sve-vector-splat.ll

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,33 @@ define <vscale x 2 x double> @splat_nxv2f64_imm() {
482482
ret <vscale x 2 x double> splat(double 1.0)
483483
}
484484

485+
; NOTE: f16(1.875) == bf16(1.0)
486+
define <vscale x 8 x bfloat> @splat_nxv8bf16_imm() {
487+
; CHECK-LABEL: splat_nxv8bf16_imm:
488+
; CHECK: // %bb.0:
489+
; CHECK-NEXT: fmov z0.h, #1.87500000
490+
; CHECK-NEXT: ret
491+
ret <vscale x 8 x bfloat> splat(bfloat 1.0)
492+
}
493+
494+
; NOTE: f16(-1.875) == bf16(-1.0)
495+
define <vscale x 4 x bfloat> @splat_nxv4bf16_imm() {
496+
; CHECK-LABEL: splat_nxv4bf16_imm:
497+
; CHECK: // %bb.0:
498+
; CHECK-NEXT: fmov z0.h, #-1.87500000
499+
; CHECK-NEXT: ret
500+
ret <vscale x 4 x bfloat> splat(bfloat -1.0)
501+
}
502+
503+
; NOTE: f16(1.875) == bf16(1.0)
504+
define <vscale x 2 x bfloat> @splat_nxv2bf16_imm() {
505+
; CHECK-LABEL: splat_nxv2bf16_imm:
506+
; CHECK: // %bb.0:
507+
; CHECK-NEXT: fmov z0.h, #1.87500000
508+
; CHECK-NEXT: ret
509+
ret <vscale x 2 x bfloat> splat(bfloat 1.0)
510+
}
511+
485512
define <vscale x 4 x i32> @splat_nxv4i32_fold(<vscale x 4 x i32> %x) {
486513
; CHECK-LABEL: splat_nxv4i32_fold:
487514
; CHECK: // %bb.0:
@@ -554,8 +581,8 @@ define <vscale x 2 x double> @splat_nxv2f64_imm_out_of_range() {
554581
; CHECK-LABEL: splat_nxv2f64_imm_out_of_range:
555582
; CHECK: // %bb.0:
556583
; CHECK-NEXT: ptrue p0.d
557-
; CHECK-NEXT: adrp x8, .LCPI57_0
558-
; CHECK-NEXT: add x8, x8, :lo12:.LCPI57_0
584+
; CHECK-NEXT: adrp x8, .LCPI60_0
585+
; CHECK-NEXT: add x8, x8, :lo12:.LCPI60_0
559586
; CHECK-NEXT: ld1rd { z0.d }, p0/z, [x8]
560587
; CHECK-NEXT: ret
561588
ret <vscale x 2 x double> splat(double 3.33)

0 commit comments

Comments
 (0)